package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.UnmodifiableIterator;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSemiJoin;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule.class */
public class HiveSemiJoinRule {
    public static final HiveProjectJoinToSemiJoinRule INSTANCE_PROJECT = new HiveProjectJoinToSemiJoinRule();
    public static final HiveAggregateJoinToSemiJoinRule INSTANCE_AGGREGATE = new HiveAggregateJoinToSemiJoinRule();
    public static final HiveProjectJoinToSemiJoinRuleSwapInputs INSTANCE_PROJECT_SWAPPED = new HiveProjectJoinToSemiJoinRuleSwapInputs();
    public static final HiveAggregateJoinToSemiJoinRuleSwapInputs INSTANCE_AGGREGATE_SWAPPED = new HiveAggregateJoinToSemiJoinRuleSwapInputs();

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule$HiveAggregateJoinToSemiJoinRule.class */
    public static class HiveAggregateJoinToSemiJoinRule extends HiveSemiJoinRuleBase<Aggregate> {
        protected HiveAggregateJoinToSemiJoinRule() {
            super(Aggregate.class, HiveRelFactories.HIVE_BUILDER);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public ImmutableBitSet extractUsedFields(Aggregate aggregate) {
            return HiveCalciteUtil.extractRefs(aggregate);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public Aggregate recreateTopOperator(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, Aggregate aggregate, RelNode relNode) {
            return HiveSemiJoinRule.recreateAggregateOperator(relBuilder, iArr, aggregate, relNode);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public RelNode recreateTopOperatorUnforced(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, Aggregate aggregate, RelNode relNode) {
            return HiveSemiJoinRule.recreateAggregateOperator(relBuilder, iArr, aggregate, relNode);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule$HiveAggregateJoinToSemiJoinRuleSwapInputs.class */
    public static class HiveAggregateJoinToSemiJoinRuleSwapInputs extends HiveToSemiJoinRuleSwapInputs<Aggregate> {
        protected HiveAggregateJoinToSemiJoinRuleSwapInputs() {
            super(Aggregate.class, HiveRelFactories.HIVE_BUILDER);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public ImmutableBitSet extractUsedFields(Aggregate aggregate) {
            return HiveCalciteUtil.extractRefs(aggregate);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public Aggregate recreateTopOperator(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, Aggregate aggregate, RelNode relNode) {
            return HiveSemiJoinRule.recreateAggregateOperator(relBuilder, iArr, aggregate, relNode);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public RelNode recreateTopOperatorUnforced(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, Aggregate aggregate, RelNode relNode) {
            return HiveSemiJoinRule.recreateAggregateOperator(relBuilder, iArr, aggregate, relNode);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule$HiveProjectJoinToSemiJoinRule.class */
    public static class HiveProjectJoinToSemiJoinRule extends HiveSemiJoinRuleBase<Project> {
        protected HiveProjectJoinToSemiJoinRule() {
            super(Project.class, HiveRelFactories.HIVE_BUILDER);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public ImmutableBitSet extractUsedFields(Project project) {
            return RelOptUtil.InputFinder.bits(project.getChildExps(), (RexNode) null);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public Project recreateTopOperator(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, Project project, RelNode relNode) {
            return HiveSemiJoinRule.recreateProjectOperator(relBuilder, rexBuilder, iArr, project, relNode, true);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public RelNode recreateTopOperatorUnforced(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, Project project, RelNode relNode) {
            return HiveSemiJoinRule.recreateProjectOperator(relBuilder, rexBuilder, iArr, project, relNode, false);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule$HiveProjectJoinToSemiJoinRuleSwapInputs.class */
    public static class HiveProjectJoinToSemiJoinRuleSwapInputs extends HiveToSemiJoinRuleSwapInputs<Project> {
        protected HiveProjectJoinToSemiJoinRuleSwapInputs() {
            super(Project.class, HiveRelFactories.HIVE_BUILDER);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public ImmutableBitSet extractUsedFields(Project project) {
            return RelOptUtil.InputFinder.bits(project.getChildExps(), (RexNode) null);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public Project recreateTopOperator(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, Project project, RelNode relNode) {
            return HiveSemiJoinRule.recreateProjectOperator(relBuilder, rexBuilder, iArr, project, relNode, true);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public RelNode recreateTopOperatorUnforced(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, Project project, RelNode relNode) {
            return HiveSemiJoinRule.recreateProjectOperator(relBuilder, rexBuilder, iArr, project, relNode, false);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule$HiveSemiJoinRuleBase.class */
    public static abstract class HiveSemiJoinRuleBase<T extends RelNode> extends RelOptRule {
        protected static final Logger LOG;
        static final /* synthetic */ boolean $assertionsDisabled;

        protected HiveSemiJoinRuleBase(Class<T> cls, RelBuilderFactory relBuilderFactory) {
            super(operand(cls, operand(Join.class, operand(RelNode.class, any()), new RelOptRuleOperand[]{operand(Aggregate.class, operand(RelNode.class, any()), new RelOptRuleOperand[0])}), new RelOptRuleOperand[0]), relBuilderFactory, (String) null);
        }

        protected HiveSemiJoinRuleBase(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory) {
            super(relOptRuleOperand, relBuilderFactory, (String) null);
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void onMatch(RelOptRuleCall relOptRuleCall) {
            RelNode rel = relOptRuleCall.rel(0);
            Join rel2 = relOptRuleCall.rel(1);
            if (rel2 instanceof HiveSemiJoin) {
                return;
            }
            perform(relOptRuleCall, extractUsedFields(rel), rel, rel2, relOptRuleCall.rel(2), relOptRuleCall.rel(3), relOptRuleCall.rel(4));
        }

        private boolean needProject(RelNode relNode, RelNode relNode2) {
            return (relNode instanceof Join) || relNode.getRowType().getFieldCount() != relNode2.getRowType().getFieldCount();
        }

        protected void perform(RelOptRuleCall relOptRuleCall, ImmutableBitSet immutableBitSet, T t, Join join, RelNode relNode, Aggregate aggregate, RelNode relNode2) {
            ImmutableIntList immutableIntList;
            LOG.debug("Matched HiveSemiJoinRule");
            RexBuilder rexBuilder = join.getCluster().getRexBuilder();
            if (immutableBitSet.intersects(ImmutableBitSet.range(relNode.getRowType().getFieldCount(), join.getRowType().getFieldCount()))) {
                return;
            }
            JoinInfo analyzeCondition = join.analyzeCondition();
            if (analyzeCondition.rightSet().equals(ImmutableBitSet.range(aggregate.getGroupCount())) && analyzeCondition.isEqui()) {
                if (join.getJoinType() == JoinRelType.LEFT) {
                    relOptRuleCall.transformTo(t.copy(t.getTraitSet(), ImmutableList.of(relNode)));
                    return;
                }
                if (join.getJoinType() != JoinRelType.INNER) {
                    return;
                }
                LOG.debug("All conditions matched for HiveSemiJoinRule. Going to apply transformation.");
                ImmutableBitSet union = immutableBitSet.union(ImmutableBitSet.of(analyzeCondition.leftKeys));
                boolean z = union.cardinality() != relNode.getRowType().getFieldCount();
                RelNode buildProjectLeftInput = z ? buildProjectLeftInput(relNode, union, rexBuilder, relOptRuleCall.builder()) : relNode;
                RelNode buildProjectRightInput = needProject(relNode2, aggregate) ? buildProjectRightInput(aggregate, rexBuilder, relOptRuleCall.builder()) : relNode2;
                if (z) {
                    ArrayList arrayList = new ArrayList();
                    Iterator it = analyzeCondition.leftKeys.iterator();
                    while (it.hasNext()) {
                        arrayList.add(Integer.valueOf(union.indexOf(((Integer) it.next()).intValue())));
                    }
                    immutableIntList = ImmutableIntList.copyOf(arrayList);
                } else {
                    immutableIntList = analyzeCondition.leftKeys;
                }
                RelNode build = relOptRuleCall.builder().push(buildProjectLeftInput).push(buildProjectRightInput).semiJoin(new RexNode[]{RelOptUtil.createEquiJoinCondition(buildProjectLeftInput, immutableIntList, buildProjectRightInput, analyzeCondition.rightKeys, rexBuilder)}).build();
                int[] iArr = new int[relNode.getRowType().getFieldCount()];
                for (int i = 0; i < iArr.length; i++) {
                    iArr[i] = union.indexOf(i) - i;
                }
                relOptRuleCall.transformTo(recreateTopOperatorUnforced(relOptRuleCall.builder(), rexBuilder, iArr, t, build));
            }
        }

        private RelNode buildProjectLeftInput(RelNode relNode, ImmutableBitSet immutableBitSet, RexBuilder rexBuilder, RelBuilder relBuilder) {
            ArrayList arrayList = new ArrayList();
            Iterator it = immutableBitSet.iterator();
            while (it.hasNext()) {
                arrayList.add(rexBuilder.makeInputRef(relNode, ((Integer) it.next()).intValue()));
            }
            return relBuilder.push(relNode).project(arrayList).build();
        }

        private RelNode buildProjectRightInput(Aggregate aggregate, RexBuilder rexBuilder, RelBuilder relBuilder) {
            if (!$assertionsDisabled && (aggregate.getGroupType() != Aggregate.Group.SIMPLE || !aggregate.getAggCallList().isEmpty())) {
                throw new AssertionError();
            }
            RelNode input = aggregate.getInput();
            List asList = aggregate.getGroupSet().asList();
            ArrayList arrayList = new ArrayList();
            Iterator it = asList.iterator();
            while (it.hasNext()) {
                arrayList.add(rexBuilder.makeInputRef(input, ((Integer) it.next()).intValue()));
            }
            return relBuilder.push(aggregate.getInput()).project(arrayList).build();
        }

        protected abstract ImmutableBitSet extractUsedFields(T t);

        protected abstract T recreateTopOperator(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, T t, RelNode relNode);

        protected abstract RelNode recreateTopOperatorUnforced(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, T t, RelNode relNode);

        static {
            $assertionsDisabled = !HiveSemiJoinRule.class.desiredAssertionStatus();
            LOG = LoggerFactory.getLogger(HiveSemiJoinRuleBase.class);
        }
    }

    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSemiJoinRule$HiveToSemiJoinRuleSwapInputs.class */
    private static abstract class HiveToSemiJoinRuleSwapInputs<T extends RelNode> extends HiveSemiJoinRuleBase<T> {
        protected HiveToSemiJoinRuleSwapInputs(Class<T> cls, RelBuilderFactory relBuilderFactory) {
            super(operand(cls, operand(Join.class, operand(Aggregate.class, operand(RelNode.class, any()), new RelOptRuleOperand[0]), new RelOptRuleOperand[]{operand(RelNode.class, any())}), new RelOptRuleOperand[0]), relBuilderFactory);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveSemiJoinRule.HiveSemiJoinRuleBase
        public void onMatch(RelOptRuleCall relOptRuleCall) {
            RelNode rel = relOptRuleCall.rel(0);
            Join rel2 = relOptRuleCall.rel(1);
            if (rel2 instanceof HiveSemiJoin) {
                return;
            }
            Aggregate rel3 = relOptRuleCall.rel(2);
            RelNode rel4 = relOptRuleCall.rel(3);
            RelNode rel5 = relOptRuleCall.rel(4);
            JoinInfo analyzeCondition = rel2.analyzeCondition();
            if (analyzeCondition.isEqui() && analyzeCondition.leftSet().equals(ImmutableBitSet.range(rel3.getGroupCount())) && !extractUsedFields(rel).intersects(ImmutableBitSet.range(0, rel2.getLeft().getRowType().getFieldCount()))) {
                RelNode swapInputs = swapInputs(rel2, rel, relOptRuleCall.builder());
                perform(relOptRuleCall, extractUsedFields(swapInputs), swapInputs, swapInputs.getInput(0), rel5, rel3, rel4);
            }
        }

        protected T swapInputs(Join join, T t, RelBuilder relBuilder) {
            RexBuilder rexBuilder = join.getCluster().getRexBuilder();
            int fieldCount = join.getRight().getRowType().getFieldCount();
            int fieldCount2 = join.getLeft().getRowType().getFieldCount();
            List fieldList = join.getRowType().getFieldList();
            int[] iArr = new int[fieldList.size()];
            for (int i = 0; i < fieldList.size(); i++) {
                if (i < fieldCount2) {
                    iArr[i] = fieldCount;
                } else {
                    iArr[i] = -fieldCount2;
                }
            }
            return recreateTopOperator(relBuilder, rexBuilder, iArr, t, relBuilder.push(join.getRight()).push(join.getLeft()).join(join.getJoinType(), (RexNode) join.getCondition().accept(new RelOptUtil.RexInputConverter(rexBuilder, fieldList, fieldList, iArr))).build());
        }
    }

    private HiveSemiJoinRule() {
    }

    protected static RelNode recreateProjectOperator(RelBuilder relBuilder, RexBuilder rexBuilder, int[] iArr, Project project, RelNode relNode, boolean z) {
        ArrayList arrayList = new ArrayList();
        List fieldList = relNode.getRowType().getFieldList();
        Iterator it = project.getProjects().iterator();
        while (it.hasNext()) {
            arrayList.add((RexNode) ((RexNode) it.next()).accept(new RelOptUtil.RexInputConverter(rexBuilder, fieldList, fieldList, iArr)));
        }
        return relBuilder.push(relNode).project(arrayList, ImmutableList.of(), z).build();
    }

    protected static RelNode recreateAggregateOperator(RelBuilder relBuilder, int[] iArr, Aggregate aggregate, RelNode relNode) {
        RelBuilder.GroupKey groupKey;
        relBuilder.push(relNode);
        ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
        Iterator it = aggregate.getGroupSet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            builder.set(intValue + iArr[intValue]);
        }
        if (aggregate.getGroupType() == Aggregate.Group.SIMPLE) {
            groupKey = relBuilder.groupKey(builder.build());
        } else {
            ArrayList arrayList = new ArrayList();
            UnmodifiableIterator it2 = aggregate.getGroupSets().iterator();
            while (it2.hasNext()) {
                ImmutableBitSet immutableBitSet = (ImmutableBitSet) it2.next();
                ImmutableBitSet.Builder builder2 = ImmutableBitSet.builder();
                Iterator it3 = immutableBitSet.iterator();
                while (it3.hasNext()) {
                    int intValue2 = ((Integer) it3.next()).intValue();
                    builder2.set(intValue2 + iArr[intValue2]);
                }
                arrayList.add(builder2.build());
            }
            groupKey = relBuilder.groupKey(builder.build(), arrayList);
        }
        ArrayList arrayList2 = new ArrayList();
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            arrayList2.add(aggregateCall.copy((List) aggregateCall.getArgList().stream().map(num -> {
                return Integer.valueOf(num.intValue() + iArr[num.intValue()]);
            }).collect(Collectors.toList()), aggregateCall.filterArg != -1 ? aggregateCall.filterArg + iArr[aggregateCall.filterArg] : -1, aggregateCall.getCollation() != null ? RelCollations.of((List) aggregateCall.getCollation().getFieldCollations().stream().map(relFieldCollation -> {
                return relFieldCollation.copy(relFieldCollation.getFieldIndex() + iArr[relFieldCollation.getFieldIndex()]);
            }).collect(Collectors.toList())) : null));
        }
        return relBuilder.push(relNode).aggregate(groupKey, arrayList2).build();
    }
}
