package org.apache.spark.ml.feature;

import java.io.IOException;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.Attribute$;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.AttributeGroup$;
import org.apache.spark.ml.attribute.BinaryAttribute;
import org.apache.spark.ml.attribute.NominalAttribute;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.ml.util.SchemaUtils$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.StringContext;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.Mirror;
import scala.reflect.api.Symbols;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.runtime.BoxesRunTime;

/* compiled from: OneHotEncoder.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\rc\u0001B\u0001\u0003\u00015\u0011Qb\u00148f\u0011>$XI\\2pI\u0016\u0014(BA\u0002\u0005\u0003\u001d1W-\u0019;ve\u0016T!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u0015\u0019\b/\u0019:l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M)\u0001A\u0004\n\u001b;A\u0011q\u0002E\u0007\u0002\t%\u0011\u0011\u0003\u0002\u0002\f)J\fgn\u001d4pe6,'\u000f\u0005\u0002\u001415\tAC\u0003\u0002\u0016-\u000511\u000f[1sK\u0012T!a\u0006\u0003\u0002\u000bA\f'/Y7\n\u0005e!\"a\u0003%bg&s\u0007/\u001e;D_2\u0004\"aE\u000e\n\u0005q!\"\u0001\u0004%bg>+H\u000f];u\u0007>d\u0007C\u0001\u0010\"\u001b\u0005y\"B\u0001\u0011\u0005\u0003\u0011)H/\u001b7\n\u0005\tz\"!\u0006#fM\u0006,H\u000e\u001e)be\u0006l7o\u0016:ji\u0006\u0014G.\u001a\u0005\tI\u0001\u0011)\u0019!C!K\u0005\u0019Q/\u001b3\u0016\u0003\u0019\u0002\"aJ\u0017\u000f\u0005!ZS\"A\u0015\u000b\u0003)\nQa]2bY\u0006L!\u0001L\u0015\u0002\rA\u0013X\rZ3g\u0013\tqsF\u0001\u0004TiJLgn\u001a\u0006\u0003Y%B\u0001\"\r\u0001\u0003\u0002\u0003\u0006IAJ\u0001\u0005k&$\u0007\u0005C\u00034\u0001\u0011\u0005A'\u0001\u0004=S:LGO\u0010\u000b\u0003k]\u0002\"A\u000e\u0001\u000e\u0003\tAQ\u0001\n\u001aA\u0002\u0019BQa\r\u0001\u0005\u0002e\"\u0012!\u000e\u0005\bw\u0001\u0011\r\u0011\"\u0002=\u0003!!'o\u001c9MCN$X#A\u001f\u0011\u0005yzT\"\u0001\f\n\u0005\u00013\"\u0001\u0004\"p_2,\u0017M\u001c)be\u0006l\u0007B\u0002\"\u0001A\u00035Q(A\u0005ee>\u0004H*Y:uA!)A\t\u0001C\u0001\u000b\u0006Y1/\u001a;Ee>\u0004H*Y:u)\t1u)D\u0001\u0001\u0011\u0015A5\t1\u0001J\u0003\u00151\u0018\r\\;f!\tA#*\u0003\u0002LS\t9!i\\8mK\u0006t\u0007\"B'\u0001\t\u0003q\u0015aC:fi&s\u0007/\u001e;D_2$\"AR(\t\u000b!c\u0005\u0019\u0001\u0014\t\u000bE\u0003A\u0011\u0001*\u0002\u0019M,GoT;uaV$8i\u001c7\u0015\u0005\u0019\u001b\u0006\"\u0002%Q\u0001\u00041\u0003\"B+\u0001\t\u00032\u0016a\u0004;sC:\u001chm\u001c:n'\u000eDW-\\1\u0015\u0005]{\u0006C\u0001-^\u001b\u0005I&B\u0001.\\\u0003\u0015!\u0018\u0010]3t\u0015\taf!A\u0002tc2L!AX-\u0003\u0015M#(/^2u)f\u0004X\rC\u0003a)\u0002\u0007q+\u0001\u0004tG\",W.\u0019\u0005\u0006E\u0002!\teY\u0001\niJ\fgn\u001d4pe6$\"\u0001\u001a5\u0011\u0005\u00154W\"A.\n\u0005\u001d\\&!\u0003#bi\u00064%/Y7f\u0011\u0015I\u0017\r1\u0001e\u0003\u001d!\u0017\r^1tKRDQa\u001b\u0001\u0005B1\fAaY8qsR\u0011Q'\u001c\u0005\u0006]*\u0004\ra\\\u0001\u0006Kb$(/\u0019\t\u0003}AL!!\u001d\f\u0003\u0011A\u000b'/Y7NCBD#\u0001A:\u0011\u0005Q<X\"A;\u000b\u0005Y4\u0011AC1o]>$\u0018\r^5p]&\u0011\u00010\u001e\u0002\r\u000bb\u0004XM]5nK:$\u0018\r\\\u0004\u0006u\nA\ta_\u0001\u000e\u001f:,\u0007j\u001c;F]\u000e|G-\u001a:\u0011\u0005Ybh!B\u0001\u0003\u0011\u0003i8C\u0002?\u007f\u0003\u0007\tI\u0001\u0005\u0002)\u007f&\u0019\u0011\u0011A\u0015\u0003\r\u0005s\u0017PU3g!\u0011q\u0012QA\u001b\n\u0007\u0005\u001dqDA\u000bEK\u001a\fW\u000f\u001c;QCJ\fWn\u001d*fC\u0012\f'\r\\3\u0011\u0007!\nY!C\u0002\u0002\u000e%\u0012AbU3sS\u0006d\u0017N_1cY\u0016Daa\r?\u0005\u0002\u0005EA#A>\t\u000f\u0005UA\u0010\"\u0011\u0002\u0018\u0005!An\\1e)\r)\u0014\u0011\u0004\u0005\b\u00037\t\u0019\u00021\u0001'\u0003\u0011\u0001\u0018\r\u001e5)\r\u0005M\u0011qDA\u0013!\r!\u0018\u0011E\u0005\u0004\u0003G)(!B*j]\u000e,\u0017EAA\u0014\u0003\u0015\tdF\u000e\u00181\u0011%\tY\u0003`A\u0001\n\u0013\ti#A\u0006sK\u0006$'+Z:pYZ,GCAA\u0018!\u0011\t\t$a\u000f\u000e\u0005\u0005M\"\u0002BA\u001b\u0003o\tA\u0001\\1oO*\u0011\u0011\u0011H\u0001\u0005U\u00064\u0018-\u0003\u0003\u0002>\u0005M\"AB(cU\u0016\u001cG\u000fK\u0003}\u0003?\t)\u0003K\u0003z\u0003?\t)\u0003")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/feature/OneHotEncoder.class */
public class OneHotEncoder extends Transformer implements HasInputCol, HasOutputCol, DefaultParamsWritable {
    private final String uid;
    private final BooleanParam dropLast;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    public static MLReader<OneHotEncoder> read() {
        return OneHotEncoder$.MODULE$.read();
    }

    public static OneHotEncoder load(String str) {
        return OneHotEncoder$.MODULE$.load(str);
    }

    @Override // org.apache.spark.ml.util.DefaultParamsWritable, org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        return DefaultParamsWritable.Cclass.write(this);
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        MLWritable.Cclass.save(this, str);
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final Param<String> outputCol() {
        return this.outputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final void org$apache$spark$ml$param$shared$HasOutputCol$_setter_$outputCol_$eq(Param param) {
        this.outputCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final String getOutputCol() {
        return HasOutputCol.Cclass.getOutputCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final Param<String> inputCol() {
        return this.inputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param param) {
        this.inputCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final String getInputCol() {
        return HasInputCol.Cclass.getInputCol(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public final BooleanParam dropLast() {
        return this.dropLast;
    }

    public OneHotEncoder setDropLast(boolean z) {
        return (OneHotEncoder) set((Param<BooleanParam>) dropLast(), (BooleanParam) BoxesRunTime.boxToBoolean(z));
    }

    public OneHotEncoder setInputCol(String str) {
        return (OneHotEncoder) set((Param<Param<String>>) inputCol(), (Param<String>) str);
    }

    public OneHotEncoder setOutputCol(String str) {
        return (OneHotEncoder) set((Param<Param<String>>) outputCol(), (Param<String>) str);
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        Option option;
        String str = (String) $(inputCol());
        String str2 = (String) $(outputCol());
        SchemaUtils$.MODULE$.checkColumnType(structType, str, DoubleType$.MODULE$, SchemaUtils$.MODULE$.checkColumnType$default$4());
        StructField[] fields = structType.fields();
        Predef$.MODULE$.require(!Predef$.MODULE$.refArrayOps(fields).exists(new OneHotEncoder$$anonfun$transformSchema$2(this, str2)), new OneHotEncoder$$anonfun$transformSchema$1(this, str2));
        Attribute fromStructField = Attribute$.MODULE$.fromStructField(structType.apply(str));
        if (fromStructField instanceof NominalAttribute) {
            NominalAttribute nominalAttribute = (NominalAttribute) fromStructField;
            option = nominalAttribute.values().isDefined() ? nominalAttribute.values() : nominalAttribute.numValues().isDefined() ? nominalAttribute.numValues().map(new OneHotEncoder$$anonfun$3(this)) : None$.MODULE$;
        } else if (fromStructField instanceof BinaryAttribute) {
            BinaryAttribute binaryAttribute = (BinaryAttribute) fromStructField;
            option = binaryAttribute.values().isDefined() ? binaryAttribute.values() : new Some(Array$.MODULE$.tabulate(2, new OneHotEncoder$$anonfun$4(this), ClassTag$.MODULE$.apply(String.class)));
        } else {
            if (fromStructField instanceof NumericAttribute) {
                throw new RuntimeException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"The input column ", " cannot be numeric."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str})));
            }
            option = None$.MODULE$;
        }
        Option map = option.map(new OneHotEncoder$$anonfun$5(this, str));
        return new StructType((StructField[]) Predef$.MODULE$.refArrayOps(fields).$colon$plus((map.isDefined() ? new AttributeGroup((String) $(outputCol()), (Attribute[]) Predef$.MODULE$.refArrayOps((Object[]) map.get()).map(new OneHotEncoder$$anonfun$6(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Attribute.class)))) : new AttributeGroup((String) $(outputCol()))).toStructField(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(StructField.class))));
    }

    @Override // org.apache.spark.ml.Transformer
    public DataFrame transform(DataFrame dataFrame) {
        String str = (String) $(inputCol());
        String str2 = (String) $(outputCol());
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean($(dropLast()));
        AttributeGroup fromStructField = AttributeGroup$.MODULE$.fromStructField(transformSchema(dataFrame.schema()).apply(str2));
        if (fromStructField.size() < 0) {
            String[] strArr = (String[]) Array$.MODULE$.tabulate(((int) BoxesRunTime.unboxToDouble(dataFrame.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str).cast(DoubleType$.MODULE$)})).map(new OneHotEncoder$$anonfun$7(this), ClassTag$.MODULE$.Double()).aggregate(BoxesRunTime.boxToDouble(0.0d), new OneHotEncoder$$anonfun$1(this, str), new OneHotEncoder$$anonfun$2(this), ClassTag$.MODULE$.Double()))) + 1, new OneHotEncoder$$anonfun$8(this), ClassTag$.MODULE$.apply(String.class));
            fromStructField = new AttributeGroup(str2, (Attribute[]) Predef$.MODULE$.refArrayOps(unboxToBoolean ? (String[]) Predef$.MODULE$.refArrayOps(strArr).dropRight(1) : strArr).map(new OneHotEncoder$$anonfun$9(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Attribute.class))));
        }
        return dataFrame.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col("*"), functions$.MODULE$.udf(new OneHotEncoder$$anonfun$10(this, fromStructField.size(), new double[]{1.0d}, (double[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.Double()), (int[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.Int())), ((TypeTags) scala.reflect.runtime.package$.MODULE$.universe()).TypeTag().apply((Mirror) scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(OneHotEncoder.class.getClassLoader()), new TypeCreator(this) { // from class: org.apache.spark.ml.feature.OneHotEncoder$$typecreator1$1
            @Override // scala.reflect.api.TypeCreator
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe2();
                return ((Symbols.TypeSymbolApi) ((Symbols.TypeSymbolApi) mirror.staticClass("org.apache.spark.mllib.linalg.Vector")).asType()).toTypeConstructor();
            }
        }), ((TypeTags) scala.reflect.runtime.package$.MODULE$.universe()).TypeTag().Double()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str).cast(DoubleType$.MODULE$)})).as(str2, fromStructField.toMetadata())}));
    }

    @Override // org.apache.spark.ml.Transformer, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public OneHotEncoder copy(ParamMap paramMap) {
        return (OneHotEncoder) defaultCopy(paramMap);
    }

    public OneHotEncoder(String str) {
        this.uid = str;
        HasInputCol.Cclass.$init$(this);
        HasOutputCol.Cclass.$init$(this);
        MLWritable.Cclass.$init$(this);
        DefaultParamsWritable.Cclass.$init$(this);
        this.dropLast = new BooleanParam(this, "dropLast", "whether to drop the last category");
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{dropLast().$minus$greater(BoxesRunTime.boxToBoolean(true))}));
    }

    public OneHotEncoder() {
        this(Identifiable$.MODULE$.randomUID("oneHot"));
    }
}
