/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.mllib.util.MLUtils$;
import scala.Array$;
import scala.Function0;
import scala.Function2;
import scala.MatchError;
import scala.NotImplementedError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;

@ScalaSignature(bytes="\u0006\u000194A!\u0001\u0002\u0005\u001b\t\u0011Bj\\4jgRL7-Q4he\u0016<\u0017\r^8s\u0015\t\u0019A!\u0001\bdY\u0006\u001c8/\u001b4jG\u0006$\u0018n\u001c8\u000b\u0005\u00151\u0011AA7m\u0015\t9\u0001\"A\u0003ta\u0006\u00148N\u0003\u0002\n\u0015\u00051\u0011\r]1dQ\u0016T\u0011aC\u0001\u0004_J<7\u0001A\n\u0004\u00019!\u0002CA\b\u0013\u001b\u0005\u0001\"\"A\t\u0002\u000bM\u001c\u0017\r\\1\n\u0005M\u0001\"AB!osJ+g\r\u0005\u0002\u0010+%\u0011a\u0003\u0005\u0002\r'\u0016\u0014\u0018.\u00197ju\u0006\u0014G.\u001a\u0005\t1\u0001\u0011)\u0019!C\u00053\u0005Ya.^7GK\u0006$XO]3t+\u0005Q\u0002CA\b\u001c\u0013\ta\u0002CA\u0002J]RD\u0001B\b\u0001\u0003\u0002\u0003\u0006IAG\u0001\r]Vlg)Z1ukJ,7\u000f\t\u0005\tA\u0001\u0011\t\u0011)A\u00055\u0005Qa.^7DY\u0006\u001c8/Z:\t\u0011\t\u0002!\u0011!Q\u0001\n\r\nABZ5u\u0013:$XM]2faR\u0004\"a\u0004\u0013\n\u0005\u0015\u0002\"a\u0002\"p_2,\u0017M\u001c\u0005\u0006O\u0001!\t\u0001K\u0001\u0007y%t\u0017\u000e\u001e \u0015\t%ZC&\f\t\u0003U\u0001i\u0011A\u0001\u0005\u00061\u0019\u0002\rA\u0007\u0005\u0006A\u0019\u0002\rA\u0007\u0005\u0006E\u0019\u0002\ra\t\u0005\b_\u0001\u0001\r\u0011\"\u00031\u0003%9X-[4iiN+X.F\u00012!\ty!'\u0003\u00024!\t1Ai\\;cY\u0016Dq!\u000e\u0001A\u0002\u0013%a'A\u0007xK&<\u0007\u000e^*v[~#S-\u001d\u000b\u0003oi\u0002\"a\u0004\u001d\n\u0005e\u0002\"\u0001B+oSRDqa\u000f\u001b\u0002\u0002\u0003\u0007\u0011'A\u0002yIEBa!\u0010\u0001!B\u0013\t\u0014AC<fS\u001eDGoU;nA!9q\b\u0001a\u0001\n\u0013\u0001\u0014a\u00027pgN\u001cV/\u001c\u0005\b\u0003\u0002\u0001\r\u0011\"\u0003C\u0003-awn]:Tk6|F%Z9\u0015\u0005]\u001a\u0005bB\u001eA\u0003\u0003\u0005\r!\r\u0005\u0007\u000b\u0002\u0001\u000b\u0015B\u0019\u0002\u00111|7o]*v[\u0002Bqa\u0012\u0001C\u0002\u0013%\u0001*\u0001\the\u0006$\u0017.\u001a8u'Vl\u0017I\u001d:bsV\t\u0011\nE\u0002\u0010\u0015FJ!a\u0013\t\u0003\u000b\u0005\u0013(/Y=\t\r5\u0003\u0001\u0015!\u0003J\u0003E9'/\u00193jK:$8+^7BeJ\f\u0017\u0010\t\u0005\u0006\u001f\u0002!\t\u0001U\u0001\u0004C\u0012$G\u0003B)S5\nl\u0011\u0001\u0001\u0005\u0006':\u0003\r\u0001V\u0001\tS:\u001cH/\u00198dKB\u0011Q\u000bW\u0007\u0002-*\u0011q\u000bB\u0001\bM\u0016\fG/\u001e:f\u0013\tIfK\u0001\u0005J]N$\u0018M\\2f\u0011\u0015Yf\n1\u0001]\u00031\u0019w.\u001a4gS\u000eLWM\u001c;t!\ti\u0006-D\u0001_\u0015\tyF!\u0001\u0004mS:\fGnZ\u0005\u0003Cz\u0013aAV3di>\u0014\b\"B2O\u0001\u0004I\u0015a\u00034fCR,(/Z:Ti\u0012DQ!\u001a\u0001\u0005\u0002\u0019\fQ!\\3sO\u0016$\"!U4\t\u000b!$\u0007\u0019A\u0015\u0002\u000b=$\b.\u001a:\t\u000b)\u0004A\u0011\u0001\u0019\u0002\t1|7o\u001d\u0005\u0006Y\u0002!\t!\\\u0001\tOJ\fG-[3oiV\tA\f")
public class LogisticAggregator
implements Serializable {
    private final int org$apache$spark$ml$classification$LogisticAggregator$$numFeatures;
    private final int numClasses;
    private final boolean fitIntercept;
    private double org$apache$spark$ml$classification$LogisticAggregator$$weightSum;
    private double lossSum;
    private final double[] gradientSumArray;

    public int org$apache$spark$ml$classification$LogisticAggregator$$numFeatures() {
        return this.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures;
    }

    public double org$apache$spark$ml$classification$LogisticAggregator$$weightSum() {
        return this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum;
    }

    private void org$apache$spark$ml$classification$LogisticAggregator$$weightSum_$eq(double x$1) {
        this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum = x$1;
    }

    private double lossSum() {
        return this.lossSum;
    }

    private void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray() {
        return this.gradientSumArray;
    }

    public LogisticAggregator add(Instance instance, Vector coefficients2, double[] featuresStd) {
        Instance instance2 = instance;
        if (instance2 != null) {
            double label = instance2.label();
            double weight = instance2.weight();
            Vector features = instance2.features();
            Predef$.MODULE$.require(this.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures() == features.size(), (Function0)new Serializable(this, features){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ LogisticAggregator $outer;
                private final Vector features$1;

                public final String apply() {
                    return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Dimensions mismatch when adding new instance."})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" Expecting ", " but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures()), BoxesRunTime.boxToInteger((int)this.features$1.size())}))).toString();
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                    this.features$1 = features$1;
                }
            });
            Predef$.MODULE$.require(weight >= 0.0, (Function0)new Serializable(this, weight){
                public static final long serialVersionUID = 0L;
                private final double weight$2;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"instance weight, ", " has to be >= 0.0"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.weight$2)}));
                }
                {
                    this.weight$2 = weight$2;
                }
            });
            if (weight == 0.0) {
                return this;
            }
            Vector vector = coefficients2;
            if (vector instanceof DenseVector) {
                double[] dArray;
                DenseVector denseVector = (DenseVector)vector;
                double[] coefficientsArray = dArray = denseVector.values();
                double[] localGradientSumArray = this.gradientSumArray();
                int n = this.numClasses;
                switch (n) {
                    default: {
                        NotImplementedError notImplementedError = new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports binary classification for now.");
                        break;
                    }
                    case 2: {
                        DoubleRef sum = DoubleRef.create((double)0.0);
                        features.foreachActive((Function2)new Serializable(this, featuresStd, coefficientsArray, sum){
                            public static final long serialVersionUID = 0L;
                            private final double[] featuresStd$2;
                            private final double[] coefficientsArray$1;
                            private final DoubleRef sum$1;

                            public final void apply(int index2, double value) {
                                this.apply$mcVID$sp(index2, value);
                            }

                            public void apply$mcVID$sp(int index2, double value) {
                                if (this.featuresStd$2[index2] != 0.0 && value != 0.0) {
                                    this.sum$1.elem += this.coefficientsArray$1[index2] * (value / this.featuresStd$2[index2]);
                                }
                            }
                            {
                                this.featuresStd$2 = featuresStd$2;
                                this.coefficientsArray$1 = coefficientsArray$1;
                                this.sum$1 = sum$1;
                            }
                        });
                        double margin = -(sum.elem + (this.fitIntercept ? coefficientsArray[this.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures()] : 0.0));
                        double multiplier = weight * (1.0 / (1.0 + package$.MODULE$.exp(margin)) - label);
                        features.foreachActive((Function2)new Serializable(this, featuresStd, localGradientSumArray, multiplier){
                            public static final long serialVersionUID = 0L;
                            private final double[] featuresStd$2;
                            private final double[] localGradientSumArray$1;
                            private final double multiplier$1;

                            public final void apply(int index2, double value) {
                                this.apply$mcVID$sp(index2, value);
                            }

                            public void apply$mcVID$sp(int index2, double value) {
                                if (this.featuresStd$2[index2] != 0.0 && value != 0.0) {
                                    this.localGradientSumArray$1[index2] = this.localGradientSumArray$1[index2] + this.multiplier$1 * (value / this.featuresStd$2[index2]);
                                }
                            }
                            {
                                this.featuresStd$2 = featuresStd$2;
                                this.localGradientSumArray$1 = localGradientSumArray$1;
                                this.multiplier$1 = multiplier$1;
                            }
                        });
                        if (this.fitIntercept) {
                            localGradientSumArray[this.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures()] = localGradientSumArray[this.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures()] + multiplier;
                        }
                        if (label > 0.0) {
                            this.lossSum_$eq(this.lossSum() + weight * MLUtils$.MODULE$.log1pExp(margin));
                        } else {
                            this.lossSum_$eq(this.lossSum() + weight * (MLUtils$.MODULE$.log1pExp(margin) - margin));
                        }
                        NotImplementedError notImplementedError = BoxedUnit.UNIT;
                    }
                }
                this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum_$eq(this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum() + weight);
                LogisticAggregator logisticAggregator = this;
                return logisticAggregator;
            }
            throw new IllegalArgumentException(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"coefficients only supports dense vector but got type ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{coefficients2.getClass()})));
        }
        throw new MatchError((Object)instance2);
    }

    public LogisticAggregator merge(LogisticAggregator other) {
        Predef$.MODULE$.require(this.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures() == other.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures(), (Function0)new Serializable(this, other){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ LogisticAggregator $outer;
            private final LogisticAggregator other$1;

            public final String apply() {
                return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Dimensions mismatch when merging with another "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"LeastSquaresAggregator. Expecting ", " but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures()), BoxesRunTime.boxToInteger((int)this.other$1.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures())}))).toString();
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
                this.other$1 = other$1;
            }
        });
        if (other.org$apache$spark$ml$classification$LogisticAggregator$$weightSum() != 0.0) {
            this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum_$eq(this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum() + other.org$apache$spark$ml$classification$LogisticAggregator$$weightSum());
            this.lossSum_$eq(this.lossSum() + other.lossSum());
            double[] localThisGradientSumArray = this.gradientSumArray();
            double[] localOtherGradientSumArray = other.gradientSumArray();
            int len = localThisGradientSumArray.length;
            for (int i = 0; i < len; ++i) {
                int n = i;
                localThisGradientSumArray[n] = localThisGradientSumArray[n] + localOtherGradientSumArray[i];
            }
        }
        return this;
    }

    public double loss() {
        Predef$.MODULE$.require(this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum() > 0.0, (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ LogisticAggregator $outer;

            public final String apply() {
                return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"The effective number of instances should be "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"greater than 0.0, but ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$weightSum())}))).toString();
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
            }
        });
        return this.lossSum() / this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum();
    }

    /*
     * WARNING - void declaration
     */
    public Vector gradient() {
        void var1_1;
        Predef$.MODULE$.require(this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum() > 0.0, (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ LogisticAggregator $outer;

            public final String apply() {
                return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"The effective number of instances should be "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"greater than 0.0, but ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$weightSum())}))).toString();
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
            }
        });
        Vector result = Vectors$.MODULE$.dense((double[])this.gradientSumArray().clone());
        BLAS$.MODULE$.scal(1.0 / this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum(), result);
        return var1_1;
    }

    public LogisticAggregator(int numFeatures, int numClasses, boolean fitIntercept) {
        this.org$apache$spark$ml$classification$LogisticAggregator$$numFeatures = numFeatures;
        this.numClasses = numClasses;
        this.fitIntercept = fitIntercept;
        this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum = 0.0;
        this.lossSum = 0.0;
        this.gradientSumArray = (double[])Array$.MODULE$.ofDim(fitIntercept ? numFeatures + 1 : numFeatures, ClassTag$.MODULE$.Double());
    }
}

