/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.sgd.objectives;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.function.DoubleUnaryOperator;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.regression.sgd.RegressionObjective;

public class Huber
implements RegressionObjective {
    public static final double DEFAULT_COST = 5.0;
    @Config(description="Cost beyond which the loss function is linear.")
    private double cost = 5.0;
    private DoubleUnaryOperator lossFunc;

    public Huber() {
        this.postConfig();
    }

    public Huber(double cost) {
        this.cost = cost;
        this.postConfig();
    }

    public void postConfig() {
        if (this.cost <= 0.0) {
            throw new PropertyException("", "cost", "Cost must be a positive value, found " + this.cost);
        }
        this.lossFunc = a -> {
            if (a > this.cost) {
                return this.cost * a - 0.5 * this.cost * this.cost;
            }
            return 0.5 * a * a;
        };
    }

    @Override
    @Deprecated
    public Pair<Double, SGDVector> loss(DenseVector truth, SGDVector prediction) {
        return this.lossAndGradient(truth, prediction);
    }

    @Override
    public Pair<Double, SGDVector> lossAndGradient(DenseVector truth, SGDVector prediction) {
        DenseVector difference = truth.subtract(prediction);
        DenseVector absoluteDifference = difference.copy();
        absoluteDifference.foreachInPlace(Math::abs);
        double loss = absoluteDifference.reduce(0.0, this.lossFunc, Double::sum);
        difference.foreachInPlace(a -> {
            if (Math.abs(a) > this.cost) {
                return (double)Double.compare(a, 0.0) * this.cost;
            }
            return a;
        });
        return new Pair((Object)loss, (Object)difference);
    }

    public String toString() {
        return "Huber(cost=" + this.cost + ")";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "RegressionObjective");
    }
}

