package org.apache.spark.ml.classification;

import java.io.IOException;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.ann.FeedForwardTopology$;
import org.apache.spark.ml.ann.FeedForwardTrainer;
import org.apache.spark.ml.classification.MultilayerPerceptronParams;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntArrayParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.LongParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.shared.HasMaxIter;
import org.apache.spark.ml.param.shared.HasSeed;
import org.apache.spark.ml.param.shared.HasSolver;
import org.apache.spark.ml.param.shared.HasStepSize;
import org.apache.spark.ml.param.shared.HasTol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.Instrumentation;
import org.apache.spark.ml.util.Instrumentation$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import scala.Function1;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: MultilayerPerceptronClassifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005=f\u0001B\u0001\u0003\u00015\u0011a$T;mi&d\u0017-_3s!\u0016\u00148-\u001a9ue>t7\t\\1tg&4\u0017.\u001a:\u000b\u0005\r!\u0011AD2mCN\u001c\u0018NZ5dCRLwN\u001c\u0006\u0003\u000b\u0019\t!!\u001c7\u000b\u0005\u001dA\u0011!B:qCJ\\'BA\u0005\u000b\u0003\u0019\t\u0007/Y2iK*\t1\"A\u0002pe\u001e\u001c\u0001a\u0005\u0003\u0001\u001dqy\u0002#B\b\u0011%aIR\"\u0001\u0002\n\u0005E\u0011!a\u0006)s_\n\f'-\u001b7jgRL7m\u00117bgNLg-[3s!\t\u0019b#D\u0001\u0015\u0015\t)B!\u0001\u0004mS:\fGnZ\u0005\u0003/Q\u0011aAV3di>\u0014\bCA\b\u0001!\ty!$\u0003\u0002\u001c\u0005\t9S*\u001e7uS2\f\u00170\u001a:QKJ\u001cW\r\u001d;s_:\u001cE.Y:tS\u001aL7-\u0019;j_:lu\u000eZ3m!\tyQ$\u0003\u0002\u001f\u0005\tQR*\u001e7uS2\f\u00170\u001a:QKJ\u001cW\r\u001d;s_:\u0004\u0016M]1ngB\u0011\u0001eI\u0007\u0002C)\u0011!\u0005B\u0001\u0005kRLG.\u0003\u0002%C\t)B)\u001a4bk2$\b+\u0019:b[N<&/\u001b;bE2,\u0007\u0002\u0003\u0014\u0001\u0005\u000b\u0007I\u0011I\u0014\u0002\u0007ULG-F\u0001)!\tIsF\u0004\u0002+[5\t1FC\u0001-\u0003\u0015\u00198-\u00197b\u0013\tq3&\u0001\u0004Qe\u0016$WMZ\u0005\u0003aE\u0012aa\u0015;sS:<'B\u0001\u0018,Q\r)3'\u000f\t\u0003i]j\u0011!\u000e\u0006\u0003m\u0019\t!\"\u00198o_R\fG/[8o\u0013\tATGA\u0003TS:\u001cW-I\u0001;\u0003\u0015\td&\u000e\u00181\u0011!a\u0004A!A!\u0002\u0013A\u0013\u0001B;jI\u0002B3aO\u001a:\u0011\u0015y\u0004\u0001\"\u0001A\u0003\u0019a\u0014N\\5u}Q\u0011\u0001$\u0011\u0005\u0006My\u0002\r\u0001\u000b\u0015\u0004\u0003NJ\u0004f\u0001 4s!)q\b\u0001C\u0001\u000bR\t\u0001\u0004K\u0002EgeBQ\u0001\u0013\u0001\u0005\u0002%\u000b\u0011b]3u\u0019\u0006LXM]:\u0015\u0005)[U\"\u0001\u0001\t\u000b1;\u0005\u0019A'\u0002\u000bY\fG.^3\u0011\u0007)r\u0005+\u0003\u0002PW\t)\u0011I\u001d:bsB\u0011!&U\u0005\u0003%.\u00121!\u00138uQ\r95'\u000f\u0005\u0006+\u0002!\tAV\u0001\rg\u0016$(\t\\8dWNK'0\u001a\u000b\u0003\u0015^CQ\u0001\u0014+A\u0002AC3\u0001V\u001a:\u0011\u0015Q\u0006\u0001\"\u0001\\\u0003%\u0019X\r^*pYZ,'\u000f\u0006\u0002K9\")A*\u0017a\u0001Q!\u001a\u0011l\r0\"\u0003}\u000bQA\r\u00181]ABQ!\u0019\u0001\u0005\u0002\t\f!b]3u\u001b\u0006D\u0018\n^3s)\tQ5\rC\u0003MA\u0002\u0007\u0001\u000bK\u0002ageBQA\u001a\u0001\u0005\u0002\u001d\faa]3u)>dGC\u0001&i\u0011\u0015aU\r1\u0001j!\tQ#.\u0003\u0002lW\t1Ai\\;cY\u0016D3!Z\u001a:\u0011\u0015q\u0007\u0001\"\u0001p\u0003\u001d\u0019X\r^*fK\u0012$\"A\u00139\t\u000b1k\u0007\u0019A9\u0011\u0005)\u0012\u0018BA:,\u0005\u0011auN\\4)\u00075\u001c\u0014\bC\u0003w\u0001\u0011\u0005q/A\ttKRLe.\u001b;jC2<V-[4iiN$\"A\u0013=\t\u000b1+\b\u0019\u0001\n)\u0007U\u001cd\fC\u0003|\u0001\u0011\u0005A0A\u0006tKR\u001cF/\u001a9TSj,GC\u0001&~\u0011\u0015a%\u00101\u0001jQ\rQ8G\u0018\u0005\b\u0003\u0003\u0001A\u0011IA\u0002\u0003\u0011\u0019w\u000e]=\u0015\u0007a\t)\u0001C\u0004\u0002\b}\u0004\r!!\u0003\u0002\u000b\u0015DHO]1\u0011\t\u0005-\u0011\u0011C\u0007\u0003\u0003\u001bQ1!a\u0004\u0005\u0003\u0015\u0001\u0018M]1n\u0013\u0011\t\u0019\"!\u0004\u0003\u0011A\u000b'/Y7NCBD3a`\u001a:\u0011\u001d\tI\u0002\u0001C)\u00037\tQ\u0001\u001e:bS:$2!GA\u000f\u0011!\ty\"a\u0006A\u0002\u0005\u0005\u0012a\u00023bi\u0006\u001cX\r\u001e\u0019\u0005\u0003G\t\u0019\u0004\u0005\u0004\u0002&\u0005-\u0012qF\u0007\u0003\u0003OQ1!!\u000b\u0007\u0003\r\u0019\u0018\u000f\\\u0005\u0005\u0003[\t9CA\u0004ECR\f7/\u001a;\u0011\t\u0005E\u00121\u0007\u0007\u0001\t1\t)$!\b\u0002\u0002\u0003\u0005)\u0011AA\u001c\u0005\ryF%M\t\u0005\u0003s\ty\u0004E\u0002+\u0003wI1!!\u0010,\u0005\u001dqu\u000e\u001e5j]\u001e\u00042AKA!\u0013\r\t\u0019e\u000b\u0002\u0004\u0003:L\bf\u0001\u00014s\u001d9\u0011\u0011\n\u0002\t\u0002\u0005-\u0013AH'vYRLG.Y=feB+'oY3qiJ|gn\u00117bgNLg-[3s!\ry\u0011Q\n\u0004\u0007\u0003\tA\t!a\u0014\u0014\u0011\u00055\u0013\u0011KA,\u0003;\u00022AKA*\u0013\r\t)f\u000b\u0002\u0007\u0003:L(+\u001a4\u0011\t\u0001\nI\u0006G\u0005\u0004\u00037\n#!\u0006#fM\u0006,H\u000e\u001e)be\u0006l7OU3bI\u0006\u0014G.\u001a\t\u0004U\u0005}\u0013bAA1W\ta1+\u001a:jC2L'0\u00192mK\"9q(!\u0014\u0005\u0002\u0005\u0015DCAA&\u0011-\tI'!\u0014C\u0002\u0013\u0005!!a\u001b\u0002\u000b1\u0013eiR*\u0016\u0005\u00055\u0004\u0003BA8\u0003sj!!!\u001d\u000b\t\u0005M\u0014QO\u0001\u0005Y\u0006twM\u0003\u0002\u0002x\u0005!!.\u0019<b\u0013\r\u0001\u0014\u0011\u000f\u0005\n\u0003{\ni\u0005)A\u0005\u0003[\na\u0001\u0014\"G\u000fN\u0003\u0003bCAA\u0003\u001b\u0012\r\u0011\"\u0001\u0003\u0003W\n!a\u0012#\t\u0013\u0005\u0015\u0015Q\nQ\u0001\n\u00055\u0014aA$EA!Y\u0011\u0011RA'\u0005\u0004%\tAAAF\u0003A\u0019X\u000f\u001d9peR,GmU8mm\u0016\u00148/\u0006\u0002\u0002\u000eB!!FTA7\u0011%\t\t*!\u0014!\u0002\u0013\ti)A\ttkB\u0004xN\u001d;fIN{GN^3sg\u0002B\u0001\"!&\u0002N\u0011\u0005\u0013qS\u0001\u0005Y>\fG\rF\u0002\u0019\u00033Cq!a'\u0002\u0014\u0002\u0007\u0001&\u0001\u0003qCRD\u0007\u0006BAJgyC!\"!)\u0002N\u0005\u0005I\u0011BAR\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\u0005\u0015\u0006\u0003BA8\u0003OKA!!+\u0002r\t1qJ\u00196fGRDC!!\u00144=\"\"\u0011qI\u001a_\u0001")
/* loaded from: input_file:org/apache/spark/ml/classification/MultilayerPerceptronClassifier.class */
public class MultilayerPerceptronClassifier extends ProbabilisticClassifier<Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel> implements MultilayerPerceptronParams, DefaultParamsWritable {
    private final String uid;
    private final IntArrayParam layers;
    private final IntParam blockSize;
    private final Param<String> solver;
    private final Param<Vector> initialWeights;
    private final DoubleParam stepSize;
    private final DoubleParam tol;
    private final IntParam maxIter;
    private final LongParam seed;

    public static MLReader<MultilayerPerceptronClassifier> read() {
        return MultilayerPerceptronClassifier$.MODULE$.read();
    }

    public static MultilayerPerceptronClassifier load(String str) {
        return MultilayerPerceptronClassifier$.MODULE$.load(str);
    }

    @Override // org.apache.spark.ml.util.DefaultParamsWritable, org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        return DefaultParamsWritable.Cclass.write(this);
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        MLWritable.Cclass.save(this, str);
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final IntArrayParam layers() {
        return this.layers;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final IntParam blockSize() {
        return this.blockSize;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams, org.apache.spark.ml.param.shared.HasSolver
    public final Param<String> solver() {
        return this.solver;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final Param<Vector> initialWeights() {
        return this.initialWeights;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$layers_$eq(IntArrayParam intArrayParam) {
        this.layers = intArrayParam;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$blockSize_$eq(IntParam intParam) {
        this.blockSize = intParam;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$solver_$eq(Param param) {
        this.solver = param;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$initialWeights_$eq(Param param) {
        this.initialWeights = param;
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final int[] getLayers() {
        return MultilayerPerceptronParams.Cclass.getLayers(this);
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final int getBlockSize() {
        return MultilayerPerceptronParams.Cclass.getBlockSize(this);
    }

    @Override // org.apache.spark.ml.classification.MultilayerPerceptronParams
    public final Vector getInitialWeights() {
        return MultilayerPerceptronParams.Cclass.getInitialWeights(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasSolver
    public void org$apache$spark$ml$param$shared$HasSolver$_setter_$solver_$eq(Param param) {
    }

    @Override // org.apache.spark.ml.param.shared.HasSolver
    public final String getSolver() {
        return HasSolver.Cclass.getSolver(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasStepSize
    public DoubleParam stepSize() {
        return this.stepSize;
    }

    @Override // org.apache.spark.ml.param.shared.HasStepSize
    public void org$apache$spark$ml$param$shared$HasStepSize$_setter_$stepSize_$eq(DoubleParam doubleParam) {
        this.stepSize = doubleParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasStepSize
    public final double getStepSize() {
        return HasStepSize.Cclass.getStepSize(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final DoubleParam tol() {
        return this.tol;
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam doubleParam) {
        this.tol = doubleParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final double getTol() {
        return HasTol.Cclass.getTol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final IntParam maxIter() {
        return this.maxIter;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam intParam) {
        this.maxIter = intParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final int getMaxIter() {
        return HasMaxIter.Cclass.getMaxIter(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasSeed
    public final LongParam seed() {
        return this.seed;
    }

    @Override // org.apache.spark.ml.param.shared.HasSeed
    public final void org$apache$spark$ml$param$shared$HasSeed$_setter_$seed_$eq(LongParam longParam) {
        this.seed = longParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasSeed
    public final long getSeed() {
        return HasSeed.Cclass.getSeed(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public MultilayerPerceptronClassifier setLayers(int[] iArr) {
        return (MultilayerPerceptronClassifier) set((Param<IntArrayParam>) layers(), (IntArrayParam) iArr);
    }

    public MultilayerPerceptronClassifier setBlockSize(int i) {
        return (MultilayerPerceptronClassifier) set((Param<IntParam>) blockSize(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public MultilayerPerceptronClassifier setSolver(String str) {
        return (MultilayerPerceptronClassifier) set((Param<Param<String>>) solver(), (Param<String>) str);
    }

    public MultilayerPerceptronClassifier setMaxIter(int i) {
        return (MultilayerPerceptronClassifier) set((Param<IntParam>) maxIter(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public MultilayerPerceptronClassifier setTol(double d) {
        return (MultilayerPerceptronClassifier) set((Param<DoubleParam>) tol(), (DoubleParam) BoxesRunTime.boxToDouble(d));
    }

    public MultilayerPerceptronClassifier setSeed(long j) {
        return (MultilayerPerceptronClassifier) set((Param<LongParam>) seed(), (LongParam) BoxesRunTime.boxToLong(j));
    }

    public MultilayerPerceptronClassifier setInitialWeights(Vector vector) {
        return (MultilayerPerceptronClassifier) set((Param<Param<Vector>>) initialWeights(), (Param<Vector>) vector);
    }

    public MultilayerPerceptronClassifier setStepSize(double d) {
        return (MultilayerPerceptronClassifier) set((Param<DoubleParam>) stepSize(), (DoubleParam) BoxesRunTime.boxToDouble(d));
    }

    @Override // org.apache.spark.ml.Predictor, org.apache.spark.ml.Estimator, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public MultilayerPerceptronClassifier copy(ParamMap paramMap) {
        return (MultilayerPerceptronClassifier) defaultCopy(paramMap);
    }

    @Override // org.apache.spark.ml.Predictor
    public MultilayerPerceptronClassificationModel train(Dataset<?> dataset) {
        Instrumentation create = Instrumentation$.MODULE$.create((Instrumentation$) this, dataset);
        create.logParams(Predef$.MODULE$.wrapRefArray(new Param[]{labelCol(), featuresCol(), predictionCol(), layers(), maxIter(), tol(), blockSize(), solver(), stepSize(), seed()}));
        int[] iArr = (int[]) $(layers());
        int unboxToInt = BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).last());
        create.logNumClasses(unboxToInt);
        create.logNumFeatures(BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).head()));
        RDD<Tuple2<Vector, Vector>> map = extractLabeledPoints(dataset).map(new MultilayerPerceptronClassifier$$anonfun$3(this, unboxToInt), ClassTag$.MODULE$.apply(Tuple2.class));
        FeedForwardTrainer feedForwardTrainer = new FeedForwardTrainer(FeedForwardTopology$.MODULE$.multiLayerPerceptron(iArr, true), iArr[0], BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).last()));
        if (isDefined(initialWeights())) {
            feedForwardTrainer.setWeights((Vector) $(initialWeights()));
        } else {
            feedForwardTrainer.setSeed(BoxesRunTime.unboxToLong($(seed())));
        }
        Object $ = $(solver());
        String LBFGS = MultilayerPerceptronClassifier$.MODULE$.LBFGS();
        if ($ != null ? !$.equals(LBFGS) : LBFGS != null) {
            Object $2 = $(solver());
            String GD = MultilayerPerceptronClassifier$.MODULE$.GD();
            if ($2 != null ? !$2.equals(GD) : GD != null) {
                throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"The solver ", " is not supported by MultilayerPerceptronClassifier."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{solver()})));
            }
            feedForwardTrainer.SGDOptimizer().setNumIterations(BoxesRunTime.unboxToInt($(maxIter()))).setConvergenceTol(BoxesRunTime.unboxToDouble($(tol()))).setStepSize(BoxesRunTime.unboxToDouble($(stepSize())));
        } else {
            feedForwardTrainer.LBFGSOptimizer().setConvergenceTol(BoxesRunTime.unboxToDouble($(tol()))).setNumIterations(BoxesRunTime.unboxToInt($(maxIter())));
        }
        feedForwardTrainer.setStackSize(BoxesRunTime.unboxToInt($(blockSize())));
        MultilayerPerceptronClassificationModel multilayerPerceptronClassificationModel = new MultilayerPerceptronClassificationModel(uid(), iArr, feedForwardTrainer.train(map).weights());
        create.logSuccess(multilayerPerceptronClassificationModel);
        return multilayerPerceptronClassificationModel;
    }

    @Override // org.apache.spark.ml.Predictor
    public /* bridge */ /* synthetic */ PredictionModel train(Dataset dataset) {
        return train((Dataset<?>) dataset);
    }

    public MultilayerPerceptronClassifier(String str) {
        this.uid = str;
        HasSeed.Cclass.$init$(this);
        org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", (Function1<Object, Object>) ParamValidators$.MODULE$.gtEq(0.0d)));
        org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms (>= 0)", (Function1<Object, Object>) ParamValidators$.MODULE$.gtEq(0.0d)));
        org$apache$spark$ml$param$shared$HasStepSize$_setter_$stepSize_$eq(new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization (> 0)", (Function1<Object, Object>) ParamValidators$.MODULE$.gt(0.0d)));
        org$apache$spark$ml$param$shared$HasSolver$_setter_$solver_$eq(new Param(this, "solver", "the solver algorithm for optimization"));
        MultilayerPerceptronParams.Cclass.$init$(this);
        MLWritable.Cclass.$init$(this);
        DefaultParamsWritable.Cclass.$init$(this);
    }

    public MultilayerPerceptronClassifier() {
        this(Identifiable$.MODULE$.randomUID("mlpc"));
    }
}
