package org.apache.spark.ml.classification;

import org.apache.spark.ml.classification.ProbabilisticClassificationModel;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.DoubleArrayParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.apache.spark.ml.util.SchemaUtils$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.collection.immutable.StringOps;
import scala.reflect.ScalaSignature;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;

/* compiled from: ProbabilisticClassifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Ed!\u0002\t\u0012\u0003\u0003a\u0002\"\u0002\u001d\u0001\t\u0003I\u0004\"\u0002\u001e\u0001\t\u0003Y\u0004\"B%\u0001\t\u0003Q\u0005\"\u0002*\u0001\t\u0003\u001a\u0006\"\u00020\u0001\t\u0003z\u0006\"B=\u0001\r#Q\bbBA\u0004\u0001\u0011E\u0011\u0011\u0002\u0005\b\u0003\u001b\u0001A\u0011KA\b\u0011\u001d\t\u0019\u0002\u0001C\u0001\u0003+Aq!!\f\u0001\t#\tyc\u0002\u0005\u00026EA\taEA\u001c\r\u001d\u0001\u0012\u0003#\u0001\u0014\u0003sAa\u0001\u000f\u0007\u0005\u0002\u0005\u001d\u0003bBA%\u0019\u0011\u0005\u00111\n\u0005\n\u0003;b\u0011\u0011!C\u0005\u0003?\u0012\u0001\u0005\u0015:pE\u0006\u0014\u0017\u000e\\5ti&\u001c7\t\\1tg&4\u0017nY1uS>tWj\u001c3fY*\u0011!cE\u0001\u000fG2\f7o]5gS\u000e\fG/[8o\u0015\t!R#\u0001\u0002nY*\u0011acF\u0001\u0006gB\f'o\u001b\u0006\u00031e\ta!\u00199bG\",'\"\u0001\u000e\u0002\u0007=\u0014xm\u0001\u0001\u0016\u0007u!\u0013gE\u0002\u0001=U\u0002Ba\b\u0011#a5\t\u0011#\u0003\u0002\"#\t\u00192\t\\1tg&4\u0017nY1uS>tWj\u001c3fYB\u00111\u0005\n\u0007\u0001\t\u0015)\u0003A1\u0001'\u000511U-\u0019;ve\u0016\u001cH+\u001f9f#\t9S\u0006\u0005\u0002)W5\t\u0011FC\u0001+\u0003\u0015\u00198-\u00197b\u0013\ta\u0013FA\u0004O_RD\u0017N\\4\u0011\u0005!r\u0013BA\u0018*\u0005\r\te.\u001f\t\u0003GE\"QA\r\u0001C\u0002M\u0012\u0011!T\t\u0003OQ\u0002Ba\b\u0001#aA\u0011qDN\u0005\u0003oE\u0011Q\u0004\u0015:pE\u0006\u0014\u0017\u000e\\5ti&\u001c7\t\\1tg&4\u0017.\u001a:QCJ\fWn]\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0003Q\n\u0011c]3u!J|'-\u00192jY&$\u0018pQ8m)\t\u0001D\bC\u0003>\u0005\u0001\u0007a(A\u0003wC2,X\r\u0005\u0002@\r:\u0011\u0001\t\u0012\t\u0003\u0003&j\u0011A\u0011\u0006\u0003\u0007n\ta\u0001\u0010:p_Rt\u0014BA#*\u0003\u0019\u0001&/\u001a3fM&\u0011q\t\u0013\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005\u0015K\u0013!D:fiRC'/Z:i_2$7\u000f\u0006\u00021\u0017\")Qh\u0001a\u0001\u0019B\u0019\u0001&T(\n\u00059K#!B!se\u0006L\bC\u0001\u0015Q\u0013\t\t\u0016F\u0001\u0004E_V\u0014G.Z\u0001\u0010iJ\fgn\u001d4pe6\u001c6\r[3nCR\u0011A\u000b\u0018\t\u0003+jk\u0011A\u0016\u0006\u0003/b\u000bQ\u0001^=qKNT!!W\u000b\u0002\u0007M\fH.\u0003\u0002\\-\nQ1\u000b\u001e:vGR$\u0016\u0010]3\t\u000bu#\u0001\u0019\u0001+\u0002\rM\u001c\u0007.Z7b\u0003%!(/\u00198tM>\u0014X\u000e\u0006\u0002a_B\u0011\u0011\r\u001c\b\u0003E*t!aY5\u000f\u0005\u0011DgBA3h\u001d\t\te-C\u0001\u001b\u0013\tA\u0012$\u0003\u0002\u0017/%\u0011\u0011,F\u0005\u0003Wb\u000bq\u0001]1dW\u0006<W-\u0003\u0002n]\nIA)\u0019;b\rJ\fW.\u001a\u0006\u0003WbCQ\u0001]\u0003A\u0002E\fq\u0001Z1uCN,G\u000f\r\u0002soB\u00191\u000f\u001e<\u000e\u0003aK!!\u001e-\u0003\u000f\u0011\u000bG/Y:fiB\u00111e\u001e\u0003\nq>\f\t\u0011!A\u0003\u0002\u0019\u00121a\u0018\u00132\u0003Y\u0011\u0018m\u001e\u001aqe>\u0014\u0017MY5mSRL\u0018J\u001c)mC\u000e,GcA>\u0002\u0004A\u0011Ap`\u0007\u0002{*\u0011apE\u0001\u0007Y&t\u0017\r\\4\n\u0007\u0005\u0005QP\u0001\u0004WK\u000e$xN\u001d\u0005\u0007\u0003\u000b1\u0001\u0019A>\u0002\u001bI\fw\u000f\u0015:fI&\u001cG/[8o\u0003=\u0011\u0018m\u001e\u001aqe>\u0014\u0017MY5mSRLHcA>\u0002\f!1\u0011QA\u0004A\u0002m\faB]1xeA\u0014X\rZ5di&|g\u000eF\u0002P\u0003#Aa!!\u0002\t\u0001\u0004Y\u0018A\u00059sK\u0012L7\r\u001e)s_\n\f'-\u001b7jif$2a_A\f\u0011\u0019\tI\"\u0003a\u0001E\u0005Aa-Z1ukJ,7\u000fK\u0003\n\u0003;\tI\u0003\u0005\u0003\u0002 \u0005\u0015RBAA\u0011\u0015\r\t\u0019#F\u0001\u000bC:tw\u000e^1uS>t\u0017\u0002BA\u0014\u0003C\u0011QaU5oG\u0016\f#!a\u000b\u0002\u000bMr\u0003G\f\u0019\u0002-A\u0014xNY1cS2LG/\u001f\u001aqe\u0016$\u0017n\u0019;j_:$2aTA\u0019\u0011\u0019\t\u0019D\u0003a\u0001w\u0006Y\u0001O]8cC\nLG.\u001b;z\u0003\u0001\u0002&o\u001c2bE&d\u0017n\u001d;jG\u000ec\u0017m]:jM&\u001c\u0017\r^5p]6{G-\u001a7\u0011\u0005}a1#\u0002\u0007\u0002<\u0005\u0005\u0003c\u0001\u0015\u0002>%\u0019\u0011qH\u0015\u0003\r\u0005s\u0017PU3g!\rA\u00131I\u0005\u0004\u0003\u000bJ#\u0001D*fe&\fG.\u001b>bE2,GCAA\u001c\u0003}qwN]7bY&TX\rV8Qe>\u0014\u0017MY5mSRLWm]%o!2\f7-\u001a\u000b\u0005\u0003\u001b\n\u0019\u0006E\u0002)\u0003\u001fJ1!!\u0015*\u0005\u0011)f.\u001b;\t\u000f\u0005Uc\u00021\u0001\u0002X\u0005\ta\u000fE\u0002}\u00033J1!a\u0017~\u0005-!UM\\:f-\u0016\u001cGo\u001c:\u0002\u0017I,\u0017\r\u001a*fg>dg/\u001a\u000b\u0003\u0003C\u0002B!a\u0019\u0002n5\u0011\u0011Q\r\u0006\u0005\u0003O\nI'\u0001\u0003mC:<'BAA6\u0003\u0011Q\u0017M^1\n\t\u0005=\u0014Q\r\u0002\u0007\u001f\nTWm\u0019;")
/* loaded from: input_file:org/apache/spark/ml/classification/ProbabilisticClassificationModel.class */
public abstract class ProbabilisticClassificationModel<FeaturesType, M extends ProbabilisticClassificationModel<FeaturesType, M>> extends ClassificationModel<FeaturesType, M> implements ProbabilisticClassifierParams {
    private final DoubleArrayParam thresholds;
    private final Param<String> probabilityCol;

    public static void normalizeToProbabilitiesInPlace(DenseVector denseVector) {
        ProbabilisticClassificationModel$.MODULE$.normalizeToProbabilitiesInPlace(denseVector);
    }

    @Override // org.apache.spark.ml.classification.ProbabilisticClassifierParams
    public /* synthetic */ StructType org$apache$spark$ml$classification$ProbabilisticClassifierParams$$super$validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        StructType validateAndTransformSchema;
        validateAndTransformSchema = validateAndTransformSchema(structType, z, dataType);
        return validateAndTransformSchema;
    }

    @Override // org.apache.spark.ml.classification.ClassificationModel, org.apache.spark.ml.PredictionModel, org.apache.spark.ml.PredictorParams
    public StructType validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        StructType validateAndTransformSchema;
        validateAndTransformSchema = validateAndTransformSchema(structType, z, dataType);
        return validateAndTransformSchema;
    }

    public double[] getThresholds() {
        double[] thresholds;
        thresholds = getThresholds();
        return thresholds;
    }

    @Override // org.apache.spark.ml.param.shared.HasProbabilityCol
    public final String getProbabilityCol() {
        String probabilityCol;
        probabilityCol = getProbabilityCol();
        return probabilityCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasThresholds
    public DoubleArrayParam thresholds() {
        return this.thresholds;
    }

    @Override // org.apache.spark.ml.param.shared.HasThresholds
    public void org$apache$spark$ml$param$shared$HasThresholds$_setter_$thresholds_$eq(DoubleArrayParam doubleArrayParam) {
        this.thresholds = doubleArrayParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasProbabilityCol
    public final Param<String> probabilityCol() {
        return this.probabilityCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasProbabilityCol
    public final void org$apache$spark$ml$param$shared$HasProbabilityCol$_setter_$probabilityCol_$eq(Param<String> param) {
        this.probabilityCol = param;
    }

    public M setProbabilityCol(String str) {
        return (M) set((Param<Param>) probabilityCol(), (Param) str);
    }

    public M setThresholds(double[] dArr) {
        Predef$.MODULE$.require(dArr.length == numClasses(), () -> {
            return new StringBuilder(115).append(this.getClass().getSimpleName()).append(".setThresholds() called with non-matching numClasses and thresholds.length.").append(" numClasses=").append(this.numClasses()).append(", but thresholds has length ").append(dArr.length).toString();
        });
        return (M) set((Param<DoubleArrayParam>) thresholds(), (DoubleArrayParam) dArr);
    }

    @Override // org.apache.spark.ml.classification.ClassificationModel, org.apache.spark.ml.PredictionModel, org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        StructType transformSchema = super.transformSchema(structType);
        if (new StringOps(Predef$.MODULE$.augmentString((String) $(probabilityCol()))).nonEmpty()) {
            transformSchema = SchemaUtils$.MODULE$.updateAttributeGroupSize(transformSchema, (String) $(probabilityCol()), numClasses());
        }
        return transformSchema;
    }

    @Override // org.apache.spark.ml.classification.ClassificationModel, org.apache.spark.ml.PredictionModel, org.apache.spark.ml.Transformer
    public Dataset<Row> transform(Dataset<?> dataset) {
        Column apply;
        Column apply2;
        StructType transformSchema = transformSchema(dataset.schema(), true);
        if (isDefined(thresholds())) {
            Predef$.MODULE$.require(((double[]) $(thresholds())).length == numClasses(), () -> {
                return new StringBuilder(111).append(this.getClass().getSimpleName()).append(".transform() called with non-matching numClasses and thresholds.length.").append(" numClasses=").append(this.numClasses()).append(", but thresholds has length ").append(((double[]) this.$(this.thresholds())).length).toString();
            });
        }
        Dataset<?> dataset2 = dataset;
        int i = 0;
        if (new StringOps(Predef$.MODULE$.augmentString((String) $(rawPredictionCol()))).nonEmpty()) {
            final ProbabilisticClassificationModel probabilisticClassificationModel = null;
            dataset2 = dataset2.withColumn(getRawPredictionCol(), functions$.MODULE$.udf(obj -> {
                return this.predictRaw(obj);
            }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(probabilisticClassificationModel) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator1$1
                public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                    mirror.universe();
                    return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                }
            }), package$.MODULE$.universe().TypeTag().Any()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getFeaturesCol())})), transformSchema.apply((String) $(rawPredictionCol())).metadata());
            i = 0 + 1;
        }
        if (new StringOps(Predef$.MODULE$.augmentString((String) $(probabilityCol()))).nonEmpty()) {
            if (new StringOps(Predef$.MODULE$.augmentString((String) $(rawPredictionCol()))).nonEmpty()) {
                final ProbabilisticClassificationModel probabilisticClassificationModel2 = null;
                final ProbabilisticClassificationModel probabilisticClassificationModel3 = null;
                apply2 = functions$.MODULE$.udf(vector -> {
                    return this.raw2probability(vector);
                }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(probabilisticClassificationModel2) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator2$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(probabilisticClassificationModel3) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator3$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(rawPredictionCol()))}));
            } else {
                final ProbabilisticClassificationModel probabilisticClassificationModel4 = null;
                apply2 = functions$.MODULE$.udf(obj2 -> {
                    return this.predictProbability(obj2);
                }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(probabilisticClassificationModel4) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator4$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }), package$.MODULE$.universe().TypeTag().Any()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(featuresCol()))}));
            }
            dataset2 = dataset2.withColumn((String) $(probabilityCol()), apply2, transformSchema.apply((String) $(probabilityCol())).metadata());
            i++;
        }
        if (new StringOps(Predef$.MODULE$.augmentString((String) $(predictionCol()))).nonEmpty()) {
            if (new StringOps(Predef$.MODULE$.augmentString((String) $(rawPredictionCol()))).nonEmpty()) {
                final ProbabilisticClassificationModel probabilisticClassificationModel5 = null;
                apply = functions$.MODULE$.udf(vector2 -> {
                    return BoxesRunTime.boxToDouble(this.raw2prediction(vector2));
                }, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(probabilisticClassificationModel5) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator5$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(rawPredictionCol()))}));
            } else if (new StringOps(Predef$.MODULE$.augmentString((String) $(probabilityCol()))).nonEmpty()) {
                final ProbabilisticClassificationModel probabilisticClassificationModel6 = null;
                apply = functions$.MODULE$.udf(vector3 -> {
                    return BoxesRunTime.boxToDouble(this.probability2prediction(vector3));
                }, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(ProbabilisticClassificationModel.class.getClassLoader()), new TypeCreator(probabilisticClassificationModel6) { // from class: org.apache.spark.ml.classification.ProbabilisticClassificationModel$$typecreator6$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(probabilityCol()))}));
            } else {
                apply = functions$.MODULE$.udf(obj3 -> {
                    return BoxesRunTime.boxToDouble(this.predict(obj3));
                }, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().Any()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(featuresCol()))}));
            }
            dataset2 = dataset2.withColumn((String) $(predictionCol()), apply, transformSchema.apply((String) $(predictionCol())).metadata());
            i++;
        }
        if (i == 0) {
            logWarning(() -> {
                return new StringBuilder(95).append(this.uid()).append(": ProbabilisticClassificationModel.transform() does nothing").append(" because no output columns were set.").toString();
            });
        }
        return dataset2.toDF();
    }

    public abstract Vector raw2probabilityInPlace(Vector vector);

    public Vector raw2probability(Vector vector) {
        return raw2probabilityInPlace(vector.copy());
    }

    @Override // org.apache.spark.ml.classification.ClassificationModel
    public double raw2prediction(Vector vector) {
        return !isDefined(thresholds()) ? vector.argmax() : probability2prediction(raw2probability(vector));
    }

    public Vector predictProbability(FeaturesType featurestype) {
        return raw2probabilityInPlace(predictRaw(featurestype));
    }

    public double probability2prediction(Vector vector) {
        if (!isDefined(thresholds())) {
            return vector.argmax();
        }
        double[] thresholds = getThresholds();
        int i = 0;
        double d = Double.NEGATIVE_INFINITY;
        int size = vector.size();
        for (int i2 = 0; i2 < size; i2++) {
            double apply = vector.apply(i2) / thresholds[i2];
            if (apply > d) {
                d = apply;
                i = i2;
            }
        }
        return i;
    }

    public ProbabilisticClassificationModel() {
        HasProbabilityCol.$init$((HasProbabilityCol) this);
        org$apache$spark$ml$param$shared$HasThresholds$_setter_$thresholds_$eq(new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", dArr -> {
            return BoxesRunTime.boxToBoolean($anonfun$thresholds$1(dArr));
        }));
        ProbabilisticClassifierParams.$init$((ProbabilisticClassifierParams) this);
    }
}
