package org.apache.spark.sql.execution.adaptive;

import org.apache.commons.io.FileUtils;
import org.apache.spark.MapOutputStatistics;
import org.apache.spark.sql.catalyst.plans.Cross$;
import org.apache.spark.sql.catalyst.plans.Inner$;
import org.apache.spark.sql.catalyst.plans.JoinType;
import org.apache.spark.sql.catalyst.plans.LeftAnti$;
import org.apache.spark.sql.catalyst.plans.LeftOuter$;
import org.apache.spark.sql.catalyst.plans.LeftSemi$;
import org.apache.spark.sql.catalyst.plans.RightOuter$;
import org.apache.spark.sql.catalyst.rules.Rule;
import org.apache.spark.sql.execution.CoalescedPartitionSpec;
import org.apache.spark.sql.execution.CoalescedPartitionSpec$;
import org.apache.spark.sql.execution.PartialReducerPartitionSpec;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS$;
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike;
import org.apache.spark.sql.execution.exchange.ShuffleOrigin;
import org.apache.spark.sql.internal.SQLConf$;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Product;
import scala.Some;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric$LongIsIntegral$;
import scala.math.Ordering$Long$;
import scala.math.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.runtime.RichLong$;

/* compiled from: OptimizeSkewedJoin.scala */
/* loaded from: input_file:org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin$.class */
public final class OptimizeSkewedJoin$ extends Rule<SparkPlan> implements AQEShuffleReadRule {
    public static OptimizeSkewedJoin$ MODULE$;
    private final Seq<ShuffleOrigin> supportedShuffleOrigins;

    static {
        new OptimizeSkewedJoin$();
    }

    @Override // org.apache.spark.sql.execution.adaptive.AQEShuffleReadRule
    public boolean isSupported(ShuffleExchangeLike shuffleExchangeLike) {
        boolean isSupported;
        isSupported = isSupported(shuffleExchangeLike);
        return isSupported;
    }

    @Override // org.apache.spark.sql.execution.adaptive.AQEShuffleReadRule
    public Seq<ShuffleOrigin> supportedShuffleOrigins() {
        return this.supportedShuffleOrigins;
    }

    public long getSkewThreshold(long j) {
        return RichLong$.MODULE$.max$extension(Predef$.MODULE$.longWrapper(BoxesRunTime.unboxToLong(conf().getConf(SQLConf$.MODULE$.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD()))), j * BoxesRunTime.unboxToInt(conf().getConf(SQLConf$.MODULE$.SKEW_JOIN_SKEWED_PARTITION_FACTOR())));
    }

    private long medianSize(long[] jArr) {
        int length = jArr.length;
        long[] jArr2 = (long[]) new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(jArr)).sorted(Ordering$Long$.MODULE$);
        switch (length) {
            default:
                return length % 2 == 0 ? package$.MODULE$.max((jArr2[length / 2] + jArr2[(length / 2) - 1]) / 2, 1L) : package$.MODULE$.max(jArr2[length / 2], 1L);
        }
    }

    private long targetSize(long[] jArr, long j) {
        long unboxToLong = BoxesRunTime.unboxToLong(conf().getConf(SQLConf$.MODULE$.ADVISORY_PARTITION_SIZE_IN_BYTES()));
        long[] jArr2 = (long[]) new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(jArr)).filter(j2 -> {
            return j2 <= j;
        });
        return new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(jArr2)).isEmpty() ? unboxToLong : package$.MODULE$.max(unboxToLong, BoxesRunTime.unboxToLong(new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(jArr2)).sum(Numeric$LongIsIntegral$.MODULE$)) / jArr2.length);
    }

    private boolean canSplitLeftSide(JoinType joinType) {
        Inner$ inner$ = Inner$.MODULE$;
        if (joinType != null ? !joinType.equals(inner$) : inner$ != null) {
            Cross$ cross$ = Cross$.MODULE$;
            if (joinType != null ? !joinType.equals(cross$) : cross$ != null) {
                LeftSemi$ leftSemi$ = LeftSemi$.MODULE$;
                if (joinType != null ? !joinType.equals(leftSemi$) : leftSemi$ != null) {
                    LeftAnti$ leftAnti$ = LeftAnti$.MODULE$;
                    if (joinType != null ? !joinType.equals(leftAnti$) : leftAnti$ != null) {
                        LeftOuter$ leftOuter$ = LeftOuter$.MODULE$;
                        if (joinType != null ? !joinType.equals(leftOuter$) : leftOuter$ != null) {
                            return false;
                        }
                    }
                }
            }
        }
        return true;
    }

    private boolean canSplitRightSide(JoinType joinType) {
        Inner$ inner$ = Inner$.MODULE$;
        if (joinType != null ? !joinType.equals(inner$) : inner$ != null) {
            Cross$ cross$ = Cross$.MODULE$;
            if (joinType != null ? !joinType.equals(cross$) : cross$ != null) {
                RightOuter$ rightOuter$ = RightOuter$.MODULE$;
                if (joinType != null ? !joinType.equals(rightOuter$) : rightOuter$ != null) {
                    return false;
                }
            }
        }
        return true;
    }

    private String getSizeInfo(long j, long[] jArr) {
        return new StringBuilder(49).append("median size: ").append(j).append(", max size: ").append(new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(jArr)).max(Ordering$Long$.MODULE$)).append(", min size: ").append(new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(jArr)).min(Ordering$Long$.MODULE$)).append(", avg size: ").append(BoxesRunTime.unboxToLong(new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(jArr)).sum(Numeric$LongIsIntegral$.MODULE$)) / jArr.length).toString();
    }

    public Option<Tuple2<SparkPlan, SparkPlan>> org$apache$spark$sql$execution$adaptive$OptimizeSkewedJoin$$tryOptimizeJoinChildren(ShuffleQueryStageExec shuffleQueryStageExec, ShuffleQueryStageExec shuffleQueryStageExec2, JoinType joinType) {
        boolean canSplitLeftSide = canSplitLeftSide(joinType);
        boolean canSplitRightSide = canSplitRightSide(joinType);
        if (!canSplitLeftSide && !canSplitRightSide) {
            return None$.MODULE$;
        }
        long[] bytesByPartitionId = ((MapOutputStatistics) shuffleQueryStageExec.mapStats().get()).bytesByPartitionId();
        long[] bytesByPartitionId2 = ((MapOutputStatistics) shuffleQueryStageExec2.mapStats().get()).bytesByPartitionId();
        Predef$.MODULE$.assert(bytesByPartitionId.length == bytesByPartitionId2.length);
        int length = bytesByPartitionId.length;
        long medianSize = medianSize(bytesByPartitionId);
        long medianSize2 = medianSize(bytesByPartitionId2);
        logDebug(() -> {
            return new StringOps(Predef$.MODULE$.augmentString(new StringBuilder(148).append("\n         |Optimizing skewed join.\n         |Left side partitions size info:\n         |").append(MODULE$.getSizeInfo(medianSize, bytesByPartitionId)).append("\n         |Right side partitions size info:\n         |").append(MODULE$.getSizeInfo(medianSize2, bytesByPartitionId2)).append("\n      ").toString())).stripMargin();
        });
        long skewThreshold = getSkewThreshold(medianSize);
        long skewThreshold2 = getSkewThreshold(medianSize2);
        long targetSize = targetSize(bytesByPartitionId, skewThreshold);
        long targetSize2 = targetSize(bytesByPartitionId2, skewThreshold2);
        ArrayBuffer empty = ArrayBuffer$.MODULE$.empty();
        ArrayBuffer empty2 = ArrayBuffer$.MODULE$.empty();
        IntRef create = IntRef.create(0);
        IntRef create2 = IntRef.create(0);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length).foreach$mVc$sp(i -> {
            Seq seq;
            Seq seq2;
            long j = bytesByPartitionId[i];
            boolean z = canSplitLeftSide && j > skewThreshold;
            long j2 = bytesByPartitionId2[i];
            boolean z2 = canSplitRightSide && j2 > skewThreshold2;
            Seq apply = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new CoalescedPartitionSpec[]{CoalescedPartitionSpec$.MODULE$.apply(i, i + 1, j)}));
            Seq apply2 = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new CoalescedPartitionSpec[]{CoalescedPartitionSpec$.MODULE$.apply(i, i + 1, j2)}));
            if (z) {
                Option<Seq<PartialReducerPartitionSpec>> createSkewPartitionSpecs = ShufflePartitionsUtil$.MODULE$.createSkewPartitionSpecs(((MapOutputStatistics) shuffleQueryStageExec.mapStats().get()).shuffleId(), i, targetSize);
                if (createSkewPartitionSpecs.isDefined()) {
                    MODULE$.logDebug(() -> {
                        return new StringBuilder(56).append("Left side partition ").append(i).append(" ").append("(").append(FileUtils.byteCountToDisplaySize(j)).append(") is skewed, ").append("split it into ").append(((SeqLike) createSkewPartitionSpecs.get()).length()).append(" parts.").toString();
                    });
                    create.elem++;
                }
                seq = (Seq) createSkewPartitionSpecs.getOrElse(() -> {
                    return apply;
                });
            } else {
                seq = apply;
            }
            Seq seq3 = seq;
            if (z2) {
                Option<Seq<PartialReducerPartitionSpec>> createSkewPartitionSpecs2 = ShufflePartitionsUtil$.MODULE$.createSkewPartitionSpecs(((MapOutputStatistics) shuffleQueryStageExec2.mapStats().get()).shuffleId(), i, targetSize2);
                if (createSkewPartitionSpecs2.isDefined()) {
                    MODULE$.logDebug(() -> {
                        return new StringBuilder(57).append("Right side partition ").append(i).append(" ").append("(").append(FileUtils.byteCountToDisplaySize(j2)).append(") is skewed, ").append("split it into ").append(((SeqLike) createSkewPartitionSpecs2.get()).length()).append(" parts.").toString();
                    });
                    create2.elem++;
                }
                seq2 = (Seq) createSkewPartitionSpecs2.getOrElse(() -> {
                    return apply2;
                });
            } else {
                seq2 = apply2;
            }
            Seq seq4 = seq2;
            seq3.foreach(product -> {
                $anonfun$tryOptimizeJoinChildren$7(seq4, empty, empty2, product);
                return BoxedUnit.UNIT;
            });
        });
        logDebug(() -> {
            return new StringBuilder(42).append("number of skewed partitions: left ").append(create.elem).append(", right ").append(create2.elem).toString();
        });
        return (create.elem > 0 || create2.elem > 0) ? new Some(new Tuple2(AQEShuffleReadExec$.MODULE$.apply((SparkPlan) shuffleQueryStageExec, empty.toSeq()), AQEShuffleReadExec$.MODULE$.apply((SparkPlan) shuffleQueryStageExec2, empty2.toSeq()))) : None$.MODULE$;
    }

    public SparkPlan optimizeSkewJoin(SparkPlan sparkPlan) {
        return sparkPlan.transformUp(new OptimizeSkewedJoin$$anonfun$optimizeSkewJoin$1());
    }

    public SparkPlan apply(SparkPlan sparkPlan) {
        if (BoxesRunTime.unboxToBoolean(conf().getConf(SQLConf$.MODULE$.SKEW_JOIN_ENABLED())) && collectShuffleStages$1(sparkPlan).length() == 2) {
            return optimizeSkewJoin(sparkPlan);
        }
        return sparkPlan;
    }

    public static final /* synthetic */ void $anonfun$tryOptimizeJoinChildren$7(Seq seq, ArrayBuffer arrayBuffer, ArrayBuffer arrayBuffer2, Product product) {
        seq.foreach(product2 -> {
            arrayBuffer.$plus$eq(product);
            return arrayBuffer2.$plus$eq(product2);
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final Seq collectShuffleStages$1(SparkPlan sparkPlan) {
        return sparkPlan instanceof ShuffleQueryStageExec ? (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new ShuffleQueryStageExec[]{(ShuffleQueryStageExec) sparkPlan})) : (Seq) sparkPlan.children().flatMap(sparkPlan2 -> {
            return collectShuffleStages$1(sparkPlan2);
        }, Seq$.MODULE$.canBuildFrom());
    }

    private OptimizeSkewedJoin$() {
        MODULE$ = this;
        AQEShuffleReadRule.$init$(this);
        this.supportedShuffleOrigins = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new ENSURE_REQUIREMENTS$[]{ENSURE_REQUIREMENTS$.MODULE$}));
    }
}
