/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.SparseMatrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction0;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001u3Q!\u0004\b\u0001%iA\u0001\u0002\f\u0001\u0003\u0002\u0003\u0006IA\f\u0005\tc\u0001\u0011\t\u0011)A\u0005e!AQ\u0007\u0001B\u0001B\u0003%a\u0007C\u0003C\u0001\u0011\u00051\tC\u0004I\u0001\t\u0007I\u0011K%\t\r5\u0003\u0001\u0015!\u0003K\u0011\u001dq\u0005A1A\u0005\n%Caa\u0014\u0001!\u0002\u0013Q\u0005b\u0002)\u0001\u0005\u0004%I!\u0015\u0005\u0007%\u0002\u0001\u000b\u0011\u0002\u001a\t\u0011M\u0003\u0001R1A\u0005\nQCQ!\u0017\u0001\u0005\u0002i\u0013AC\u00117pG.DUOY3s\u0003\u001e<'/Z4bi>\u0014(BA\b\u0011\u0003)\twm\u001a:fO\u0006$xN\u001d\u0006\u0003#I\tQa\u001c9uS6T!a\u0005\u000b\u0002\u00055d'BA\u000b\u0017\u0003\u0015\u0019\b/\u0019:l\u0015\t9\u0002$\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u00023\u0005\u0019qN]4\u0014\u0007\u0001Y\u0012\u0005\u0005\u0002\u001d?5\tQDC\u0001\u001f\u0003\u0015\u00198-\u00197b\u0013\t\u0001SD\u0001\u0004B]f\u0014VM\u001a\t\u0005E\r*3&D\u0001\u000f\u0013\t!cB\u0001\u000fES\u001a4WM]3oi&\f'\r\\3M_N\u001c\u0018iZ4sK\u001e\fGo\u001c:\u0011\u0005\u0019JS\"A\u0014\u000b\u0005!\u0012\u0012a\u00024fCR,(/Z\u0005\u0003U\u001d\u0012Q\"\u00138ti\u0006t7-\u001a\"m_\u000e\\\u0007C\u0001\u0012\u0001\u000311\u0017\u000e^%oi\u0016\u00148-\u001a9u\u0007\u0001\u0001\"\u0001H\u0018\n\u0005Aj\"a\u0002\"p_2,\u0017M\\\u0001\bKB\u001c\u0018\u000e\\8o!\ta2'\u0003\u00025;\t1Ai\\;cY\u0016\fABY2QCJ\fW.\u001a;feN\u00042a\u000e\u001e=\u001b\u0005A$BA\u001d\u0015\u0003%\u0011'o\\1eG\u0006\u001cH/\u0003\u0002<q\tI!I]8bI\u000e\f7\u000f\u001e\t\u0003{\u0001k\u0011A\u0010\u0006\u0003\u007fI\ta\u0001\\5oC2<\u0017BA!?\u0005\u00191Vm\u0019;pe\u00061A(\u001b8jiz\"2\u0001\u0012$H)\tYS\tC\u00036\t\u0001\u0007a\u0007C\u0003-\t\u0001\u0007a\u0006C\u00032\t\u0001\u0007!'A\u0002eS6,\u0012A\u0013\t\u00039-K!\u0001T\u000f\u0003\u0007%sG/\u0001\u0003eS6\u0004\u0013a\u00038v[\u001a+\u0017\r^;sKN\fAB\\;n\r\u0016\fG/\u001e:fg\u0002\n\u0011\"\u001b8uKJ\u001cW\r\u001d;\u0016\u0003I\n!\"\u001b8uKJ\u001cW\r\u001d;!\u0003\u0019a\u0017N\\3beV\tA\b\u000b\u0002\f-B\u0011AdV\u0005\u00031v\u0011\u0011\u0002\u001e:b]NLWM\u001c;\u0002\u0007\u0005$G\r\u0006\u0002,7\")A\f\u0004a\u0001K\u0005)!\r\\8dW\u0002")
public class BlockHuberAggregator
implements DifferentiableLossAggregator<InstanceBlock, BlockHuberAggregator> {
    private transient Vector linear;
    private final boolean fitIntercept;
    private final double epsilon;
    private final Broadcast<Vector> bcParameters;
    private final int dim;
    private final int numFeatures;
    private final double intercept;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient boolean bitmap$trans$0;
    private volatile boolean bitmap$0;

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        BlockHuberAggregator blockHuberAggregator = this;
        synchronized (blockHuberAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    private double intercept() {
        return this.intercept;
    }

    private Vector linear$lzycompute() {
        BlockHuberAggregator blockHuberAggregator = this;
        synchronized (blockHuberAggregator) {
            if (!this.bitmap$trans$0) {
                this.linear = Vectors$.MODULE$.dense((double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(((Vector)this.bcParameters.value()).toArray())).take(this.numFeatures()));
                this.bitmap$trans$0 = true;
            }
        }
        return this.linear;
    }

    private Vector linear() {
        return !this.bitmap$trans$0 ? this.linear$lzycompute() : this.linear;
    }

    @Override
    public BlockHuberAggregator add(InstanceBlock block) {
        block8: {
            Predef$.MODULE$.require(block.matrix().isTransposed());
            Predef$.MODULE$.require(this.numFeatures() == block.numFeatures(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(66).append("Dimensions mismatch when adding new ").append("instance. Expecting ").append(this.numFeatures()).append(" but got ").append(block.numFeatures()).append(".").toString());
            Predef$.MODULE$.require(block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$1 -> x$1 >= 0.0), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(34).append("instance weights ").append(block.weightIter().mkString("[", ",", "]")).append(" has to be >= 0.0").toString());
            if (block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$2 -> x$2 == 0.0)) {
                return this;
            }
            int size = block.size();
            double sigma = ((Vector)this.bcParameters.value()).apply(this.dim() - 1);
            DenseVector vec = this.fitIntercept ? Vectors$.MODULE$.dense((double[])Array$.MODULE$.fill(size, (Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)() -> this.intercept(), ClassTag$.MODULE$.Double())).toDense() : Vectors$.MODULE$.zeros(size).toDense();
            BLAS$.MODULE$.gemv(1.0, block.matrix(), this.linear(), 1.0, vec);
            double sigmaGradSum = 0.0;
            double localLossSum = 0.0;
            for (int i = 0; i < size; ++i) {
                double weight = block.getWeight().apply$mcDI$sp(i);
                if (weight > 0.0) {
                    double multiplier;
                    double margin;
                    double label = block.getLabel(i);
                    double linearLoss = label - (margin = vec.apply(i));
                    if (package$.MODULE$.abs(linearLoss) <= sigma * this.epsilon) {
                        double multiplier2;
                        localLossSum += 0.5 * weight * (sigma + package$.MODULE$.pow(linearLoss, 2.0) / sigma);
                        double linearLossDivSigma = linearLoss / sigma;
                        vec.values()[i] = multiplier2 = -1.0 * weight * linearLossDivSigma;
                        sigmaGradSum += 0.5 * weight * (1.0 - package$.MODULE$.pow(linearLossDivSigma, 2.0));
                        continue;
                    }
                    localLossSum += 0.5 * weight * (sigma + 2.0 * this.epsilon * package$.MODULE$.abs(linearLoss) - sigma * this.epsilon * this.epsilon);
                    double sign = linearLoss >= 0.0 ? -1.0 : 1.0;
                    vec.values()[i] = multiplier = weight * sign * this.epsilon;
                    sigmaGradSum += 0.5 * weight * (1.0 - this.epsilon * this.epsilon);
                    continue;
                }
                vec.values()[i] = 0.0;
            }
            this.lossSum_$eq(this.lossSum() + localLossSum);
            this.weightSum_$eq(this.weightSum() + BoxesRunTime.unboxToDouble((Object)block.weightIter().sum((Numeric)Numeric.DoubleIsFractional$.MODULE$)));
            Matrix matrix = block.matrix();
            if (matrix instanceof DenseMatrix) {
                DenseMatrix denseMatrix = (DenseMatrix)matrix;
                BLAS$.MODULE$.nativeBLAS().dgemv("N", denseMatrix.numCols(), denseMatrix.numRows(), 1.0, denseMatrix.values(), denseMatrix.numCols(), vec.values(), 1, 1.0, this.gradientSumArray(), 1);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (matrix instanceof SparseMatrix) {
                SparseMatrix sparseMatrix = (SparseMatrix)matrix;
                DenseVector linearGradSumVec = Vectors$.MODULE$.zeros(this.numFeatures()).toDense();
                BLAS$.MODULE$.gemv(1.0, (Matrix)sparseMatrix.transpose(), (Vector)vec, 0.0, linearGradSumVec);
                BLAS$.MODULE$.getBLAS(this.numFeatures()).daxpy(this.numFeatures(), 1.0, linearGradSumVec.values(), 1, this.gradientSumArray(), 1);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                throw new MatchError((Object)matrix);
            }
            int n = this.dim() - 1;
            this.gradientSumArray()[n] = this.gradientSumArray()[n] + sigmaGradSum;
            if (!this.fitIntercept) break block8;
            int n2 = this.dim() - 2;
            this.gradientSumArray()[n2] = this.gradientSumArray()[n2] + BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(vec.values())).sum((Numeric)Numeric.DoubleIsFractional$.MODULE$));
        }
        return this;
    }

    public BlockHuberAggregator(boolean fitIntercept, double epsilon, Broadcast<Vector> bcParameters) {
        this.fitIntercept = fitIntercept;
        this.epsilon = epsilon;
        this.bcParameters = bcParameters;
        DifferentiableLossAggregator.$init$(this);
        this.dim = ((Vector)bcParameters.value()).size();
        this.numFeatures = fitIntercept ? this.dim() - 2 : this.dim() - 1;
        this.intercept = fitIntercept ? ((Vector)bcParameters.value()).apply(this.dim() - 2) : 0.0;
    }
}

