package org.apache.hadoop.hive.ql.optimizer.calcite.rules.views;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.views.HiveAggregateIncrementalRewritingRuleBase.IncrementalComputePlan;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/views/HiveAggregateIncrementalRewritingRuleBase.class */
public abstract class HiveAggregateIncrementalRewritingRuleBase<T extends IncrementalComputePlan> extends RelOptRule {
    private final int aggregateIndex;

    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/views/HiveAggregateIncrementalRewritingRuleBase$IncrementalComputePlan.class */
    protected static class IncrementalComputePlan {
        protected final RelNode rightInput;

        public IncrementalComputePlan(RelNode relNode) {
            this.rightInput = relNode;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public HiveAggregateIncrementalRewritingRuleBase(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, String str, int i) {
        super(relOptRuleOperand, relBuilderFactory, str);
        this.aggregateIndex = i;
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(this.aggregateIndex);
        Union union = (Union) relOptRuleCall.rel(1);
        RelBuilder builder = relOptRuleCall.builder();
        RexBuilder rexBuilder = builder.getRexBuilder();
        RelNode input = union.getInput(1);
        T createJoinRightInput = createJoinRightInput(relOptRuleCall);
        if (createJoinRightInput == null) {
            return;
        }
        ArrayList arrayList = new ArrayList(input.getRowType().getFieldCount());
        for (int i = 0; i < input.getRowType().getFieldCount(); i++) {
            arrayList.add(rexBuilder.makeInputRef(input.getRowType().getFieldList().get(i).getType(), i));
        }
        arrayList.add(rexBuilder.makeLiteral(true));
        RelNode build = builder.push(input).project(arrayList).build();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        int groupCount = aggregate.getGroupCount();
        int groupCount2 = aggregate.getGroupCount() + aggregate.getAggCallList().size();
        int i2 = 0;
        int i3 = groupCount2 + 1;
        while (i2 < groupCount) {
            RexInputRef makeInputRef = rexBuilder.makeInputRef(build.getRowType().getFieldList().get(i2).getType(), i2);
            RexInputRef makeInputRef2 = rexBuilder.makeInputRef(createJoinRightInput.rightInput.getRowType().getFieldList().get(i2).getType(), i3);
            arrayList2.add(makeInputRef2);
            arrayList3.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, ImmutableList.of(makeInputRef, makeInputRef2)));
            i2++;
            i3++;
        }
        RelNode build2 = builder.push(build).push(createJoinRightInput.rightInput).join(JoinRelType.RIGHT, RexUtil.composeConjunction(rexBuilder, arrayList3)).build();
        int i4 = 0;
        int i5 = groupCount;
        int i6 = groupCount2 + 1 + groupCount;
        while (i5 < groupCount2) {
            RexInputRef makeInputRef3 = rexBuilder.makeInputRef(build.getRowType().getFieldList().get(i5).getType(), i5);
            RexInputRef makeInputRef4 = rexBuilder.makeInputRef(createJoinRightInput.rightInput.getRowType().getFieldList().get(i5).getType(), i6);
            arrayList2.add(rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, makeInputRef3), makeInputRef4, rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, makeInputRef4), makeInputRef3, createAggregateNode(aggregate.getAggCallList().get(i4).getAggregation(), makeInputRef3, makeInputRef4, rexBuilder)));
            i4++;
            i5++;
            i6++;
        }
        int fieldCount = build.getRowType().getFieldCount() - 1;
        relOptRuleCall.transformTo(builder.push(build2).filter(createFilterCondition(createJoinRightInput, rexBuilder.makeInputRef(build2.getRowType().getFieldList().get(fieldCount).getType(), fieldCount), arrayList2, builder)).project(arrayList2).build());
    }

    protected abstract T createJoinRightInput(RelOptRuleCall relOptRuleCall);

    /* JADX INFO: Access modifiers changed from: protected */
    public RexNode createAggregateNode(SqlAggFunction sqlAggFunction, RexNode rexNode, RexNode rexNode2, RexBuilder rexBuilder) {
        switch (sqlAggFunction.getKind()) {
            case SUM:
            case COUNT:
                return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, ImmutableList.of(rexNode2, rexNode));
            default:
                throw new AssertionError("Found an aggregation that could not be recognized: " + sqlAggFunction);
        }
    }

    protected abstract RexNode createFilterCondition(T t, RexNode rexNode, List<RexNode> list, RelBuilder relBuilder);
}
