/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPermuteInputsShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HiveJoinConstraintsRule
extends RelOptRule {
    protected static final Logger LOG = LoggerFactory.getLogger(HiveJoinConstraintsRule.class);
    public static final HiveJoinConstraintsRule INSTANCE = new HiveJoinConstraintsRule(HiveRelFactories.HIVE_BUILDER);

    protected HiveJoinConstraintsRule(RelBuilderFactory relBuilder) {
        super(HiveJoinConstraintsRule.operand(Project.class, (RelOptRuleOperandChildren)HiveJoinConstraintsRule.some((RelOptRuleOperand)HiveJoinConstraintsRule.operand(Join.class, (RelOptRuleOperandChildren)HiveJoinConstraintsRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0])), relBuilder, "HiveJoinConstraintsRule");
    }

    public void onMatch(RelOptRuleCall call) {
        Mode mode;
        RelNode fkInput;
        Project project = (Project)call.rel(0);
        RexBuilder rexBuilder = project.getCluster().getRexBuilder();
        List<Object> topProjExprs = project.getChildExps();
        Join join = (Join)call.rel(1);
        JoinRelType joinType = join.getJoinType();
        RelNode leftInput = join.getLeft();
        RelNode rightInput = join.getRight();
        RexNode cond = join.getCondition();
        ImmutableBitSet topRefs = RelOptUtil.InputFinder.bits((List)topProjExprs, null);
        ImmutableBitSet leftBits = ImmutableBitSet.range((int)leftInput.getRowType().getFieldCount());
        ImmutableBitSet rightBits = ImmutableBitSet.range((int)leftInput.getRowType().getFieldCount(), (int)join.getRowType().getFieldCount());
        boolean leftInputPotentialFK = topRefs.intersects(leftBits);
        boolean rightInputPotentialFK = topRefs.intersects(rightBits);
        if (leftInputPotentialFK && rightInputPotentialFK && joinType == JoinRelType.INNER) {
            int joinFieldCount = join.getRowType().getFieldCount();
            Mapping mappingLR = Mappings.create((MappingType)MappingType.PARTIAL_FUNCTION, (int)joinFieldCount, (int)joinFieldCount);
            Mapping mappingRL = Mappings.create((MappingType)MappingType.PARTIAL_FUNCTION, (int)joinFieldCount, (int)joinFieldCount);
            for (RexNode conj : RelOptUtil.conjunctions((RexNode)cond)) {
                if (!conj.isA(SqlKind.EQUALS)) continue;
                RexCall eq = (RexCall)conj;
                RexNode rexNode = (RexNode)eq.getOperands().get(0);
                RexNode op2 = (RexNode)eq.getOperands().get(1);
                if (!(rexNode instanceof RexInputRef) || !(op2 instanceof RexInputRef)) continue;
                int ref1 = ((RexInputRef)rexNode).getIndex();
                int ref2 = ((RexInputRef)op2).getIndex();
                int leftRef = -1;
                int rightRef = -1;
                if (leftBits.get(ref1) && rightBits.get(ref2)) {
                    leftRef = ref1;
                    rightRef = ref2;
                } else if (rightBits.get(ref1) && leftBits.get(ref2)) {
                    leftRef = ref2;
                    rightRef = ref1;
                }
                if (leftRef == -1 || rightRef == -1) continue;
                if (mappingLR.getTargetOpt(leftRef) == -1) {
                    mappingLR.set(leftRef, rightRef);
                }
                if (mappingRL.getTargetOpt(rightRef) != -1) continue;
                mappingRL.set(rightRef, leftRef);
            }
            if (mappingLR.size() != 0) {
                for (int i = 0; i < joinFieldCount; ++i) {
                    if (mappingLR.getTargetOpt(i) == -1) {
                        mappingLR.set(i, i);
                    }
                    if (mappingRL.getTargetOpt(i) != -1) continue;
                    mappingRL.set(i, i);
                }
                List swappedTopProjExprs = topProjExprs.stream().map(projExpr -> (RexNode)projExpr.accept((RexVisitor)new RexPermuteInputsShuttle((Mappings.TargetMapping)mappingRL, new RelNode[]{call.rel(1)}))).collect(Collectors.toList());
                rightInputPotentialFK = RelOptUtil.InputFinder.bits(swappedTopProjExprs, null).intersects(rightBits);
                if (!rightInputPotentialFK) {
                    topProjExprs = swappedTopProjExprs;
                } else {
                    swappedTopProjExprs = topProjExprs.stream().map(projExpr -> (RexNode)projExpr.accept((RexVisitor)new RexPermuteInputsShuttle((Mappings.TargetMapping)mappingLR, new RelNode[]{call.rel(1)}))).collect(Collectors.toList());
                    leftInputPotentialFK = RelOptUtil.InputFinder.bits(swappedTopProjExprs, null).intersects(leftBits);
                    if (!leftInputPotentialFK) {
                        topProjExprs = swappedTopProjExprs;
                    }
                }
            }
        } else if (!leftInputPotentialFK && !rightInputPotentialFK) {
            leftInputPotentialFK = true;
        }
        switch (joinType) {
            case INNER: {
                if (leftInputPotentialFK && rightInputPotentialFK) {
                    return;
                }
                fkInput = leftInputPotentialFK ? leftInput : rightInput;
                mode = Mode.REMOVE;
                break;
            }
            case LEFT: {
                fkInput = leftInput;
                mode = leftInputPotentialFK && !rightInputPotentialFK ? Mode.REMOVE : Mode.TRANSFORM;
                break;
            }
            case RIGHT: {
                fkInput = rightInput;
                mode = !leftInputPotentialFK && rightInputPotentialFK ? Mode.REMOVE : Mode.TRANSFORM;
                break;
            }
            default: {
                return;
            }
        }
        HiveRelOptUtil.RewritablePKFKJoinInfo r = HiveRelOptUtil.isRewritablePKFKJoin(join, leftInput == fkInput, call.getMetadataQuery());
        if (r.rewritable) {
            List<Object> nullableNodes = r.nullableNodes;
            if (mode == Mode.REMOVE) {
                if (rightInputPotentialFK) {
                    nullableNodes = nullableNodes.stream().map(node -> RexUtil.shift((RexNode)node, (int)0, (int)(-leftInput.getRowType().getFieldCount()))).collect(Collectors.toList());
                    topProjExprs = topProjExprs.stream().map(node -> RexUtil.shift((RexNode)node, (int)0, (int)(-leftInput.getRowType().getFieldCount()))).collect(Collectors.toList());
                }
                topProjExprs = HiveCalciteUtil.fixNullability(rexBuilder, topProjExprs, (List<RelDataType>)RelOptUtil.getFieldTypeList((RelDataType)fkInput.getRowType()));
                if (nullableNodes.isEmpty()) {
                    call.transformTo(call.builder().push(fkInput).project(topProjExprs).convert(project.getRowType(), false).build());
                } else {
                    RexNode newFilterCond;
                    if (nullableNodes.size() == 1) {
                        newFilterCond = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, new RexNode[]{(RexNode)nullableNodes.get(0)});
                    } else {
                        ArrayList<RexNode> isNotNullConds = new ArrayList<RexNode>();
                        for (RexNode rexNode : nullableNodes) {
                            isNotNullConds.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, new RexNode[]{rexNode}));
                        }
                        newFilterCond = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, isNotNullConds);
                    }
                    call.transformTo(call.builder().push(fkInput).filter(new RexNode[]{newFilterCond}).project(topProjExprs).convert(project.getRowType(), false).build());
                }
            } else {
                call.transformTo(call.builder().push(leftInput).push(rightInput).join(JoinRelType.INNER, join.getCondition()).convert(call.rel(1).getRowType(), false).project((Iterable)project.getChildExps()).build());
            }
        }
    }

    private static enum Mode {
        REMOVE,
        TRANSFORM;

    }
}

