/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.ann;

import org.apache.spark.ml.ann.ANNGradient;
import org.apache.spark.ml.ann.ANNUpdater;
import org.apache.spark.ml.ann.DataStacker;
import org.apache.spark.ml.ann.Topology;
import org.apache.spark.ml.ann.TopologyModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorImplicits$;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.Gradient;
import org.apache.spark.mllib.optimization.GradientDescent;
import org.apache.spark.mllib.optimization.LBFGS;
import org.apache.spark.mllib.optimization.Optimizer;
import org.apache.spark.mllib.optimization.Updater;
import org.apache.spark.rdd.RDD;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u0005me!B\u0001\u0003\u0001\u0011a!A\u0005$fK\u00124uN]<be\u0012$&/Y5oKJT!a\u0001\u0003\u0002\u0007\u0005tgN\u0003\u0002\u0006\r\u0005\u0011Q\u000e\u001c\u0006\u0003\u000f!\tQa\u001d9be.T!!\u0003\u0006\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005Y\u0011aA8sON\u0019\u0001!D\n\u0011\u00059\tR\"A\b\u000b\u0003A\tQa]2bY\u0006L!AE\b\u0003\r\u0005s\u0017PU3g!\tqA#\u0003\u0002\u0016\u001f\ta1+\u001a:jC2L'0\u00192mK\"Aq\u0003\u0001B\u0001B\u0003%\u0011$\u0001\u0005u_B|Gn\\4z\u0007\u0001\u0001\"AG\u000e\u000e\u0003\tI!\u0001\b\u0002\u0003\u0011Q{\u0007o\u001c7pOfD\u0001B\b\u0001\u0003\u0006\u0004%\taH\u0001\nS:\u0004X\u000f^*ju\u0016,\u0012\u0001\t\t\u0003\u001d\u0005J!AI\b\u0003\u0007%sG\u000f\u0003\u0005%\u0001\t\u0005\t\u0015!\u0003!\u0003)Ig\u000e];u'&TX\r\t\u0005\tM\u0001\u0011)\u0019!C\u0001?\u0005Qq.\u001e;qkR\u001c\u0016N_3\t\u0011!\u0002!\u0011!Q\u0001\n\u0001\n1b\\;uaV$8+\u001b>fA!)!\u0006\u0001C\u0001W\u00051A(\u001b8jiz\"B\u0001L\u0017/_A\u0011!\u0004\u0001\u0005\u0006/%\u0002\r!\u0007\u0005\u0006=%\u0002\r\u0001\t\u0005\u0006M%\u0002\r\u0001\t\u0005\bc\u0001\u0001\r\u0011\"\u00033\u0003\u0015y6/Z3e+\u0005\u0019\u0004C\u0001\b5\u0013\t)tB\u0001\u0003M_:<\u0007bB\u001c\u0001\u0001\u0004%I\u0001O\u0001\n?N,W\rZ0%KF$\"!\u000f\u001f\u0011\u00059Q\u0014BA\u001e\u0010\u0005\u0011)f.\u001b;\t\u000fu2\u0014\u0011!a\u0001g\u0005\u0019\u0001\u0010J\u0019\t\r}\u0002\u0001\u0015)\u00034\u0003\u0019y6/Z3eA!9\u0011\t\u0001a\u0001\n\u0013\u0011\u0015\u0001C0xK&<\u0007\u000e^:\u0016\u0003\r\u0003\"\u0001R$\u000e\u0003\u0015S!A\u0012\u0003\u0002\r1Lg.\u00197h\u0013\tAUI\u0001\u0004WK\u000e$xN\u001d\u0005\b\u0015\u0002\u0001\r\u0011\"\u0003L\u00031yv/Z5hQR\u001cx\fJ3r)\tID\nC\u0004>\u0013\u0006\u0005\t\u0019A\"\t\r9\u0003\u0001\u0015)\u0003D\u0003%yv/Z5hQR\u001c\b\u0005C\u0004Q\u0001\u0001\u0007I\u0011B\u0010\u0002\u0015}\u001bH/Y2l'&TX\rC\u0004S\u0001\u0001\u0007I\u0011B*\u0002\u001d}\u001bH/Y2l'&TXm\u0018\u0013fcR\u0011\u0011\b\u0016\u0005\b{E\u000b\t\u00111\u0001!\u0011\u00191\u0006\u0001)Q\u0005A\u0005Yql\u001d;bG.\u001c\u0016N_3!\u0011\u001dA\u0006\u00011A\u0005\ne\u000b1\u0002Z1uCN#\u0018mY6feV\t!\f\u0005\u0002\u001b7&\u0011AL\u0001\u0002\f\t\u0006$\u0018m\u0015;bG.,'\u000fC\u0004_\u0001\u0001\u0007I\u0011B0\u0002\u001f\u0011\fG/Y*uC\u000e\\WM]0%KF$\"!\u000f1\t\u000fuj\u0016\u0011!a\u00015\"1!\r\u0001Q!\ni\u000bA\u0002Z1uCN#\u0018mY6fe\u0002Bq\u0001\u001a\u0001A\u0002\u0013%Q-A\u0005`OJ\fG-[3oiV\ta\r\u0005\u0002hY6\t\u0001N\u0003\u0002jU\u0006aq\u000e\u001d;j[&T\u0018\r^5p]*\u00111NB\u0001\u0006[2d\u0017NY\u0005\u0003[\"\u0014\u0001b\u0012:bI&,g\u000e\u001e\u0005\b_\u0002\u0001\r\u0011\"\u0003q\u00035yvM]1eS\u0016tGo\u0018\u0013fcR\u0011\u0011(\u001d\u0005\b{9\f\t\u00111\u0001g\u0011\u0019\u0019\b\u0001)Q\u0005M\u0006Qql\u001a:bI&,g\u000e\u001e\u0011\t\u000fU\u0004\u0001\u0019!C\u0005m\u0006Aq,\u001e9eCR,'/F\u0001x!\t9\u00070\u0003\u0002zQ\n9Q\u000b\u001d3bi\u0016\u0014\bbB>\u0001\u0001\u0004%I\u0001`\u0001\r?V\u0004H-\u0019;fe~#S-\u001d\u000b\u0003suDq!\u0010>\u0002\u0002\u0003\u0007q\u000f\u0003\u0004\u0000\u0001\u0001\u0006Ka^\u0001\n?V\u0004H-\u0019;fe\u0002B\u0011\"a\u0001\u0001\u0001\u0004%I!!\u0002\u0002\u0013=\u0004H/[7ju\u0016\u0014XCAA\u0004!\r9\u0017\u0011B\u0005\u0004\u0003\u0017A'!C(qi&l\u0017N_3s\u0011%\ty\u0001\u0001a\u0001\n\u0013\t\t\"A\u0007paRLW.\u001b>fe~#S-\u001d\u000b\u0004s\u0005M\u0001\"C\u001f\u0002\u000e\u0005\u0005\t\u0019AA\u0004\u0011!\t9\u0002\u0001Q!\n\u0005\u001d\u0011AC8qi&l\u0017N_3sA!1\u00111\u0004\u0001\u0005\u0002I\nqaZ3u'\u0016,G\rC\u0004\u0002 \u0001!\t!!\t\u0002\u000fM,GoU3fIR!\u00111EA\u0013\u001b\u0005\u0001\u0001bBA\u0014\u0003;\u0001\raM\u0001\u0006m\u0006dW/\u001a\u0005\u0007\u0003W\u0001A\u0011\u0001\"\u0002\u0015\u001d,GoV3jO\"$8\u000fC\u0004\u00020\u0001!\t!!\r\u0002\u0015M,GoV3jO\"$8\u000f\u0006\u0003\u0002$\u0005M\u0002bBA\u0014\u0003[\u0001\ra\u0011\u0005\b\u0003o\u0001A\u0011AA\u001d\u00031\u0019X\r^*uC\u000e\\7+\u001b>f)\u0011\t\u0019#a\u000f\t\u000f\u0005\u001d\u0012Q\u0007a\u0001A!9\u0011q\b\u0001\u0005\u0002\u0005\u0005\u0013\u0001D*H\t>\u0003H/[7ju\u0016\u0014XCAA\"!\r9\u0017QI\u0005\u0004\u0003\u000fB'aD$sC\u0012LWM\u001c;EKN\u001cWM\u001c;\t\u000f\u0005-\u0003\u0001\"\u0001\u0002N\u0005qAJ\u0011$H'>\u0003H/[7ju\u0016\u0014XCAA(!\r9\u0017\u0011K\u0005\u0004\u0003'B'!\u0002'C\r\u001e\u001b\u0006bBA,\u0001\u0011\u0005\u0011\u0011L\u0001\u000bg\u0016$X\u000b\u001d3bi\u0016\u0014H\u0003BA\u0012\u00037Bq!a\n\u0002V\u0001\u0007q\u000fC\u0004\u0002`\u0001!\t!!\u0019\u0002\u0017M,Go\u0012:bI&,g\u000e\u001e\u000b\u0005\u0003G\t\u0019\u0007C\u0004\u0002(\u0005u\u0003\u0019\u00014\t\u0011\u0005\u001d\u0004\u0001)C\u0005\u0003S\na\"\u001e9eCR,wI]1eS\u0016tG\u000fF\u0002:\u0003WBq!!\u001c\u0002f\u0001\u0007a-\u0001\u0005he\u0006$\u0017.\u001a8u\u0011!\t\t\b\u0001Q\u0005\n\u0005M\u0014!D;qI\u0006$X-\u00169eCR,'\u000fF\u0002:\u0003kBq!a\u001e\u0002p\u0001\u0007q/A\u0004va\u0012\fG/\u001a:\t\u000f\u0005m\u0004\u0001\"\u0001\u0002~\u0005)AO]1j]R!\u0011qPAC!\rQ\u0012\u0011Q\u0005\u0004\u0003\u0007\u0013!!\u0004+pa>dwnZ=N_\u0012,G\u000e\u0003\u0005\u0002\b\u0006e\u0004\u0019AAE\u0003\u0011!\u0017\r^1\u0011\r\u0005-\u0015\u0011SAK\u001b\t\tiIC\u0002\u0002\u0010\u001a\t1A\u001d3e\u0013\u0011\t\u0019*!$\u0003\u0007I#E\tE\u0003\u000f\u0003/\u001b5)C\u0002\u0002\u001a>\u0011a\u0001V;qY\u0016\u0014\u0004")
public class FeedForwardTrainer
implements Serializable {
    private final Topology topology;
    private final int inputSize;
    private final int outputSize;
    private long _seed;
    private Vector _weights;
    private int _stackSize;
    private DataStacker dataStacker;
    private Gradient _gradient;
    private Updater _updater;
    private Optimizer optimizer;

    public int inputSize() {
        return this.inputSize;
    }

    public int outputSize() {
        return this.outputSize;
    }

    private long _seed() {
        return this._seed;
    }

    private void _seed_$eq(long x$1) {
        this._seed = x$1;
    }

    private Vector _weights() {
        return this._weights;
    }

    private void _weights_$eq(Vector x$1) {
        this._weights = x$1;
    }

    private int _stackSize() {
        return this._stackSize;
    }

    private void _stackSize_$eq(int x$1) {
        this._stackSize = x$1;
    }

    private DataStacker dataStacker() {
        return this.dataStacker;
    }

    private void dataStacker_$eq(DataStacker x$1) {
        this.dataStacker = x$1;
    }

    private Gradient _gradient() {
        return this._gradient;
    }

    private void _gradient_$eq(Gradient x$1) {
        this._gradient = x$1;
    }

    private Updater _updater() {
        return this._updater;
    }

    private void _updater_$eq(Updater x$1) {
        this._updater = x$1;
    }

    private Optimizer optimizer() {
        return this.optimizer;
    }

    private void optimizer_$eq(Optimizer x$1) {
        this.optimizer = x$1;
    }

    public long getSeed() {
        return this._seed();
    }

    public FeedForwardTrainer setSeed(long value) {
        this._seed_$eq(value);
        return this;
    }

    public Vector getWeights() {
        return this._weights();
    }

    public FeedForwardTrainer setWeights(Vector value) {
        this._weights_$eq(value);
        return this;
    }

    public FeedForwardTrainer setStackSize(int value) {
        this._stackSize_$eq(value);
        this.dataStacker_$eq(new DataStacker(value, this.inputSize(), this.outputSize()));
        return this;
    }

    /*
     * WARNING - void declaration
     */
    public GradientDescent SGDOptimizer() {
        void var1_1;
        GradientDescent sgd = new GradientDescent(this._gradient(), this._updater());
        this.optimizer_$eq(sgd);
        return var1_1;
    }

    /*
     * WARNING - void declaration
     */
    public LBFGS LBFGSOptimizer() {
        void var1_1;
        LBFGS lbfgs = new LBFGS(this._gradient(), this._updater());
        this.optimizer_$eq(lbfgs);
        return var1_1;
    }

    public FeedForwardTrainer setUpdater(Updater value) {
        this._updater_$eq(value);
        this.updateUpdater(value);
        return this;
    }

    public FeedForwardTrainer setGradient(Gradient value) {
        this._gradient_$eq(value);
        this.updateGradient(value);
        return this;
    }

    private void updateGradient(Gradient gradient2) {
        Optimizer optimizer;
        block4: {
            block3: {
                block2: {
                    optimizer = this.optimizer();
                    if (!(optimizer instanceof LBFGS)) break block2;
                    LBFGS lBFGS = (LBFGS)optimizer;
                    lBFGS.setGradient(gradient2);
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                    break block3;
                }
                if (!(optimizer instanceof GradientDescent)) break block4;
                GradientDescent gradientDescent = (GradientDescent)optimizer;
                gradientDescent.setGradient(gradient2);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
            return;
        }
        throw new UnsupportedOperationException(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Only LBFGS and GradientDescent are supported but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{optimizer.getClass()})));
    }

    private void updateUpdater(Updater updater) {
        Optimizer optimizer;
        block4: {
            block3: {
                block2: {
                    optimizer = this.optimizer();
                    if (!(optimizer instanceof LBFGS)) break block2;
                    LBFGS lBFGS = (LBFGS)optimizer;
                    lBFGS.setUpdater(updater);
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                    break block3;
                }
                if (!(optimizer instanceof GradientDescent)) break block4;
                GradientDescent gradientDescent = (GradientDescent)optimizer;
                gradientDescent.setUpdater(updater);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
            return;
        }
        throw new UnsupportedOperationException(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Only LBFGS and GradientDescent are supported but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{optimizer.getClass()})));
    }

    public TopologyModel train(RDD<Tuple2<Vector, Vector>> data) {
        Vector w2 = this.getWeights() == null ? this.topology.model(this._seed()).weights() : this.getWeights();
        org.apache.spark.mllib.linalg.Vector newWeights = this.optimizer().optimize((RDD<Tuple2<Object, org.apache.spark.mllib.linalg.Vector>>)this.dataStacker().stack(data).map((Function1)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final Tuple2<Object, org.apache.spark.mllib.linalg.Vector> apply(Tuple2<Object, Vector> v) {
                return new Tuple2((Object)BoxesRunTime.boxToDouble((double)v._1$mcD$sp()), (Object)Vectors$.MODULE$.fromML((Vector)v._2()));
            }
        }, ClassTag$.MODULE$.apply(Tuple2.class)), VectorImplicits$.MODULE$.mlVectorToMLlibVector(w2));
        return this.topology.model(VectorImplicits$.MODULE$.mllibVectorToMLVector(newWeights));
    }

    public FeedForwardTrainer(Topology topology, int inputSize, int outputSize) {
        this.topology = topology;
        this.inputSize = inputSize;
        this.outputSize = outputSize;
        this._seed = this.getClass().getName().hashCode();
        this._weights = null;
        this._stackSize = 128;
        this.dataStacker = new DataStacker(this._stackSize(), inputSize, outputSize);
        this._gradient = new ANNGradient(topology, this.dataStacker());
        this._updater = new ANNUpdater();
        this.optimizer = this.LBFGSOptimizer().setConvergenceTol(1.0E-4).setNumIterations(100);
    }
}

