/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.translator;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollationImpl;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSort;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.SqlFunctionConverter;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.spark_project.guava.collect.ImmutableList;
import org.spark_project.guava.collect.ImmutableMap;

public class PlanModifierForASTConv {
    private static final Log LOG = LogFactory.getLog(PlanModifierForASTConv.class);

    public static RelNode convertOpTree(RelNode rel, List<FieldSchema> resultSchema) throws CalciteSemanticException {
        RelNode newTopNode = rel;
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Original plan for PlanModifier\n " + RelOptUtil.toString((RelNode)newTopNode)));
        }
        if (!(newTopNode instanceof Project) && !(newTopNode instanceof Sort)) {
            newTopNode = PlanModifierForASTConv.introduceDerivedTable(newTopNode);
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Plan after top-level introduceDerivedTable\n " + RelOptUtil.toString((RelNode)newTopNode)));
            }
        }
        PlanModifierForASTConv.convertOpTree(newTopNode, (RelNode)null);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Plan after nested convertOpTree\n " + RelOptUtil.toString((RelNode)newTopNode)));
        }
        Pair<RelNode, RelNode> topSelparentPair = HiveCalciteUtil.getTopLevelSelect(newTopNode);
        PlanModifierForASTConv.fixTopOBSchema(newTopNode, topSelparentPair, resultSchema);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Plan after fixTopOBSchema\n " + RelOptUtil.toString((RelNode)newTopNode)));
        }
        topSelparentPair = HiveCalciteUtil.getTopLevelSelect(newTopNode);
        newTopNode = PlanModifierForASTConv.renameTopLevelSelectInResultSchema(newTopNode, topSelparentPair, resultSchema);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Final plan after modifier\n " + RelOptUtil.toString((RelNode)newTopNode)));
        }
        return newTopNode;
    }

    private static void convertOpTree(RelNode rel, RelNode parent) {
        List childNodes;
        if (rel instanceof HepRelVertex) {
            throw new RuntimeException("Found HepRelVertex");
        }
        if (rel instanceof Join) {
            if (!PlanModifierForASTConv.validJoinParent(rel, parent)) {
                PlanModifierForASTConv.introduceDerivedTable(rel, parent);
            }
        } else {
            if (rel instanceof MultiJoin) {
                throw new RuntimeException("Found MultiJoin");
            }
            if (rel instanceof RelSubset) {
                throw new RuntimeException("Found RelSubset");
            }
            if (rel instanceof SetOp) {
                if (!PlanModifierForASTConv.validSetopParent(rel, parent)) {
                    PlanModifierForASTConv.introduceDerivedTable(rel, parent);
                }
                SetOp setop = (SetOp)rel;
                for (RelNode inputRel : setop.getInputs()) {
                    if (PlanModifierForASTConv.validSetopChild(inputRel)) continue;
                    PlanModifierForASTConv.introduceDerivedTable(inputRel, (RelNode)setop);
                }
            } else if (rel instanceof SingleRel) {
                if (rel instanceof Filter) {
                    if (!PlanModifierForASTConv.validFilterParent(rel, parent)) {
                        PlanModifierForASTConv.introduceDerivedTable(rel, parent);
                    }
                } else if (rel instanceof HiveSort) {
                    if (!PlanModifierForASTConv.validSortParent(rel, parent)) {
                        PlanModifierForASTConv.introduceDerivedTable(rel, parent);
                    }
                    if (!PlanModifierForASTConv.validSortChild((HiveSort)rel)) {
                        PlanModifierForASTConv.introduceDerivedTable(((HiveSort)rel).getInput(), rel);
                    }
                } else if (rel instanceof HiveAggregate) {
                    RelNode newParent = parent;
                    if (!PlanModifierForASTConv.validGBParent(rel, parent)) {
                        newParent = PlanModifierForASTConv.introduceDerivedTable(rel, parent);
                    }
                    if (PlanModifierForASTConv.isEmptyGrpAggr(rel)) {
                        PlanModifierForASTConv.replaceEmptyGroupAggr(rel, newParent);
                    }
                }
            }
        }
        if ((childNodes = rel.getInputs()) != null) {
            for (RelNode r : childNodes) {
                PlanModifierForASTConv.convertOpTree(r, rel);
            }
        }
    }

    private static void fixTopOBSchema(RelNode rootRel, Pair<RelNode, RelNode> topSelparentPair, List<FieldSchema> resultSchema) throws CalciteSemanticException {
        if (!(topSelparentPair.getKey() instanceof Sort) || !HiveCalciteUtil.orderRelNode((RelNode)topSelparentPair.getKey())) {
            return;
        }
        HiveSort obRel = (HiveSort)topSelparentPair.getKey();
        Project obChild = (Project)topSelparentPair.getValue();
        if (obChild.getRowType().getFieldCount() <= resultSchema.size()) {
            return;
        }
        RelDataType rt = obChild.getRowType();
        HashSet collationInputRefs = new HashSet(RelCollationImpl.ordinals((RelCollation)obRel.getCollation()));
        ImmutableMap.Builder inputRefToCallMapBldr = ImmutableMap.builder();
        for (int i = resultSchema.size(); i < rt.getFieldCount(); ++i) {
            if (!collationInputRefs.contains(i)) continue;
            inputRefToCallMapBldr.put((Object)i, obChild.getChildExps().get(i));
        }
        ImmutableMap inputRefToCallMap = inputRefToCallMapBldr.build();
        if (obChild.getRowType().getFieldCount() - inputRefToCallMap.size() != resultSchema.size()) {
            LOG.error((Object)PlanModifierForASTConv.generateInvalidSchemaMessage(obChild, resultSchema, inputRefToCallMap.size()));
            throw new CalciteSemanticException("Result Schema didn't match Optimized Op Tree Schema");
        }
        HiveProject replacementProjectRel = HiveProject.create(obChild.getInput(), obChild.getChildExps().subList(0, resultSchema.size()), obChild.getRowType().getFieldNames().subList(0, resultSchema.size()));
        obRel.replaceInput(0, replacementProjectRel);
        obRel.setInputRefToCallMap((ImmutableMap<Integer, RexNode>)inputRefToCallMap);
    }

    private static String generateInvalidSchemaMessage(Project topLevelProj, List<FieldSchema> resultSchema, int fieldsForOB) {
        String errorDesc = "Result Schema didn't match Calcite Optimized Op Tree; schema: ";
        for (FieldSchema fs : resultSchema) {
            errorDesc = errorDesc + "[" + fs.getName() + ":" + fs.getType() + "], ";
        }
        errorDesc = errorDesc + " projection fields: ";
        for (RexNode exp : topLevelProj.getChildExps()) {
            errorDesc = errorDesc + "[" + exp.toString() + ":" + exp.getType() + "], ";
        }
        if (fieldsForOB != 0) {
            errorDesc = errorDesc + fieldsForOB + " fields removed due to ORDER BY  ";
        }
        return errorDesc.substring(0, errorDesc.length() - 2);
    }

    private static RelNode renameTopLevelSelectInResultSchema(RelNode rootRel, Pair<RelNode, RelNode> topSelparentPair, List<FieldSchema> resultSchema) throws CalciteSemanticException {
        RelNode parentOforiginalProjRel = (RelNode)topSelparentPair.getKey();
        HiveProject originalProjRel = (HiveProject)topSelparentPair.getValue();
        List rootChildExps = originalProjRel.getChildExps();
        if (resultSchema.size() != rootChildExps.size()) {
            LOG.error((Object)PlanModifierForASTConv.generateInvalidSchemaMessage(originalProjRel, resultSchema, 0));
            throw new CalciteSemanticException("Result Schema didn't match Optimized Op Tree Schema");
        }
        ArrayList<String> newSelAliases = new ArrayList<String>();
        for (int i = 0; i < rootChildExps.size(); ++i) {
            String colAlias = resultSchema.get(i).getName();
            if (colAlias.startsWith("_")) {
                colAlias = colAlias.substring(1);
            }
            newSelAliases.add(colAlias);
        }
        HiveProject replacementProjectRel = HiveProject.create(originalProjRel.getInput(), originalProjRel.getChildExps(), newSelAliases);
        if (rootRel == originalProjRel) {
            return replacementProjectRel;
        }
        parentOforiginalProjRel.replaceInput(0, (RelNode)replacementProjectRel);
        return rootRel;
    }

    private static RelNode introduceDerivedTable(RelNode rel) {
        List<RexNode> projectList = HiveCalciteUtil.getProjsFromBelowAsInputRef(rel);
        HiveProject select = HiveProject.create(rel.getCluster(), rel, projectList, rel.getRowType(), rel.getCollationList());
        return select;
    }

    private static RelNode introduceDerivedTable(RelNode rel, RelNode parent) {
        int i = 0;
        int pos = -1;
        List childList = parent.getInputs();
        for (RelNode child : childList) {
            if (child == rel) {
                pos = i;
                break;
            }
            ++i;
        }
        if (pos == -1) {
            throw new RuntimeException("Couldn't find child node in parent's inputs");
        }
        RelNode select = PlanModifierForASTConv.introduceDerivedTable(rel);
        parent.replaceInput(pos, select);
        return select;
    }

    private static boolean validJoinParent(RelNode joinNode, RelNode parent) {
        boolean validParent = true;
        if (parent instanceof Join) {
            if (((Join)parent).getRight() == joinNode) {
                validParent = false;
            }
        } else if (parent instanceof SetOp) {
            validParent = false;
        }
        return validParent;
    }

    private static boolean validFilterParent(RelNode filterNode, RelNode parent) {
        boolean validParent = true;
        if (parent instanceof Filter || parent instanceof Join || parent instanceof SetOp) {
            validParent = false;
        }
        return validParent;
    }

    private static boolean validGBParent(RelNode gbNode, RelNode parent) {
        boolean validParent = true;
        if (parent instanceof Join || parent instanceof SetOp || parent instanceof Aggregate || parent instanceof Filter && ((Aggregate)gbNode).getGroupSet().isEmpty()) {
            validParent = false;
        }
        return validParent;
    }

    private static boolean validSortParent(RelNode sortNode, RelNode parent) {
        boolean validParent = true;
        if (!(parent == null || parent instanceof Project || parent instanceof Sort || HiveCalciteUtil.orderRelNode(parent))) {
            validParent = false;
        }
        return validParent;
    }

    private static boolean validSortChild(HiveSort sortNode) {
        boolean validChild = true;
        RelNode child = sortNode.getInput();
        if (!(HiveCalciteUtil.limitRelNode(sortNode) && HiveCalciteUtil.orderRelNode(child) || child instanceof Project)) {
            validChild = false;
        }
        return validChild;
    }

    private static boolean validSetopParent(RelNode setop, RelNode parent) {
        boolean validChild = true;
        if (parent != null && !(parent instanceof Project)) {
            validChild = false;
        }
        return validChild;
    }

    private static boolean validSetopChild(RelNode setopChild) {
        boolean validChild = true;
        if (!(setopChild instanceof Project)) {
            validChild = false;
        }
        return validChild;
    }

    private static boolean isEmptyGrpAggr(RelNode gbNode) {
        Aggregate aggrnode = (Aggregate)gbNode;
        return aggrnode.getGroupSet().isEmpty() && aggrnode.getAggCallList().isEmpty();
    }

    private static void replaceEmptyGroupAggr(RelNode rel, RelNode parent) {
        List exps = parent.getChildExps();
        for (RexNode rexNode : exps) {
            if (rexNode.getKind() == SqlKind.LITERAL) continue;
            throw new RuntimeException("We expect " + parent.toString() + " to contain only constants. However, " + rexNode.toString() + " is " + rexNode.getKind());
        }
        HiveAggregate oldAggRel = (HiveAggregate)rel;
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        RelDataType longType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, typeFactory);
        RelDataType intType = TypeConverter.convert(TypeInfoFactory.intTypeInfo, typeFactory);
        SqlAggFunction countFn = SqlFunctionConverter.getCalciteAggFn("count", (ImmutableList<RelDataType>)ImmutableList.of((Object)intType), longType);
        ImmutableList argList = ImmutableList.of((Object)0);
        AggregateCall dummyCall = new AggregateCall(countFn, false, (List)argList, longType, null);
        Aggregate newAggRel = oldAggRel.copy(oldAggRel.getTraitSet(), oldAggRel.getInput(), oldAggRel.indicator, oldAggRel.getGroupSet(), (List<ImmutableBitSet>)oldAggRel.getGroupSets(), (List<AggregateCall>)ImmutableList.of((Object)dummyCall));
        RelNode select = PlanModifierForASTConv.introduceDerivedTable((RelNode)newAggRel);
        parent.replaceInput(0, select);
    }
}

