package org.talend.dataquality.sampling.parallel;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;

/* loaded from: input_file:org/talend/dataquality/sampling/parallel/SparkSamplingUtil.class */
public class SparkSamplingUtil<T> implements Serializable {
    private Long seed;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/talend/dataquality/sampling/parallel/SparkSamplingUtil$PairComparator.class */
    public class PairComparator<T> implements Serializable, Comparator<ImmutablePair<Double, T>> {
        private PairComparator() {
        }

        @Override // java.util.Comparator
        public int compare(ImmutablePair<Double, T> immutablePair, ImmutablePair<Double, T> immutablePair2) {
            if (((Double) immutablePair.left).doubleValue() > ((Double) immutablePair2.left).doubleValue()) {
                return 1;
            }
            return ((Double) immutablePair.left).doubleValue() < ((Double) immutablePair2.left).doubleValue() ? -1 : 0;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/talend/dataquality/sampling/parallel/SparkSamplingUtil$SamplingMapFunction.class */
    public class SamplingMapFunction<T> implements FlatMapFunction<Iterator<T>, ImmutablePair<Double, T>> {
        private final int nbSamples;

        public SamplingMapFunction(int i) {
            this.nbSamples = i;
        }

        public Iterable<ImmutablePair<Double, T>> call(Iterator<T> it) throws Exception {
            if (SparkSamplingUtil.this.seed == null) {
                SparkSamplingUtil.this.seed = Long.valueOf(new Random().nextLong());
            }
            ReservoirSamplerWithBinaryHeap reservoirSamplerWithBinaryHeap = new ReservoirSamplerWithBinaryHeap(this.nbSamples, SparkSamplingUtil.this.seed.longValue());
            reservoirSamplerWithBinaryHeap.clear();
            while (it.hasNext()) {
                reservoirSamplerWithBinaryHeap.onNext(it.next());
            }
            reservoirSamplerWithBinaryHeap.onCompleted(true);
            return reservoirSamplerWithBinaryHeap.samplePairs();
        }
    }

    public SparkSamplingUtil() {
        this(null);
    }

    public SparkSamplingUtil(Long l) {
        this.seed = null;
        this.seed = l;
    }

    public List<ImmutablePair<Double, T>> getSamplePairList(JavaRDD<T> javaRDD, int i) {
        return javaRDD.mapPartitions(new SamplingMapFunction(i)).top(i, new PairComparator());
    }

    public List<ImmutablePair<Double, Row>> getSamplePairList(DataFrame dataFrame, int i) {
        return dataFrame.javaRDD().mapPartitions(new SamplingMapFunction(i)).top(i, new PairComparator());
    }

    public List<T> getSampleList(JavaRDD<T> javaRDD, int i) {
        List<ImmutablePair<Double, T>> samplePairList = getSamplePairList(javaRDD, i);
        ArrayList arrayList = new ArrayList();
        Iterator<ImmutablePair<Double, T>> it = samplePairList.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getRight());
        }
        return arrayList;
    }
}
