/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.exec;

import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import org.apache.hadoop.hive.ql.exec.AbstractMapJoinOperator;
import org.apache.hadoop.hive.ql.exec.GroupByOperator;
import org.apache.hadoop.hive.ql.exec.NodeUtils;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.TerminalOperator;
import org.apache.hadoop.hive.ql.exec.UnionOperator;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.SemiJoinBranchInfo;
import org.apache.hadoop.hive.ql.parse.spark.SparkPartitionPruningSinkOperator;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.TableScanDesc;
import org.apache.hadoop.mapred.OutputCollector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OperatorUtils {
    private static final Logger LOG = LoggerFactory.getLogger(OperatorUtils.class);

    public static <T> Set<T> findOperators(Operator<?> start, Class<T> clazz) {
        return OperatorUtils.findOperators(start, clazz, new HashSet());
    }

    public static <T> T findSingleOperator(Operator<?> start, Class<T> clazz) {
        Set<T> found = OperatorUtils.findOperators(start, clazz, new HashSet());
        return found.size() == 1 ? (T)found.iterator().next() : null;
    }

    public static <T> Set<T> findOperators(Collection<Operator<?>> starts, Class<T> clazz) {
        HashSet found = new HashSet();
        for (Operator<?> start : starts) {
            if (start == null) continue;
            OperatorUtils.findOperators(start, clazz, found);
        }
        return found;
    }

    private static <T> Set<T> findOperators(Operator<?> start, Class<T> clazz, Set<T> found) {
        if (clazz.isInstance(start)) {
            found.add(start);
        }
        if (start.getChildOperators() != null) {
            for (Operator<OperatorDesc> child : start.getChildOperators()) {
                OperatorUtils.findOperators(child, clazz, found);
            }
        }
        return found;
    }

    public static <T> Set<T> findOperatorsUpstream(Operator<?> start, Class<T> clazz) {
        return OperatorUtils.findOperatorsUpstream(start, clazz, new HashSet());
    }

    public static <T> T findSingleOperatorUpstream(Operator<?> start, Class<T> clazz) {
        Set<T> found = OperatorUtils.findOperatorsUpstream(start, clazz, new HashSet());
        return found.size() == 1 ? (T)found.iterator().next() : null;
    }

    public static <T> T findSingleOperatorUpstreamJoinAccounted(Operator<?> start, Class<T> clazz) {
        Set<T> found = OperatorUtils.findOperatorsUpstreamJoinAccounted(start, clazz, new HashSet());
        return found.size() >= 1 ? (T)found.iterator().next() : null;
    }

    public static <T> Set<T> findOperatorsUpstream(Collection<Operator<?>> starts, Class<T> clazz) {
        HashSet found = new HashSet();
        for (Operator<?> start : starts) {
            OperatorUtils.findOperatorsUpstream(start, clazz, found);
        }
        return found;
    }

    private static <T> Set<T> findOperatorsUpstream(Operator<?> start, Class<T> clazz, Set<T> found) {
        if (clazz.isInstance(start)) {
            found.add(start);
        }
        if (start.getParentOperators() != null) {
            for (Operator<OperatorDesc> parent : start.getParentOperators()) {
                OperatorUtils.findOperatorsUpstream(parent, clazz, found);
            }
        }
        return found;
    }

    public static <T> Set<T> findOperatorsUpstreamJoinAccounted(Operator<?> start, Class<T> clazz, Set<T> found) {
        if (clazz.isInstance(start)) {
            found.add(start);
        }
        int onlyIncludeIndex = -1;
        if (start instanceof AbstractMapJoinOperator) {
            AbstractMapJoinOperator mapJoinOp = (AbstractMapJoinOperator)start;
            MapJoinDesc desc = (MapJoinDesc)mapJoinOp.getConf();
            onlyIncludeIndex = desc.getPosBigTable();
        }
        if (start.getParentOperators() != null) {
            int i = 0;
            for (Operator<OperatorDesc> parent : start.getParentOperators()) {
                if (onlyIncludeIndex >= 0) {
                    if (onlyIncludeIndex == i) {
                        OperatorUtils.findOperatorsUpstreamJoinAccounted(parent, clazz, found);
                    }
                } else {
                    OperatorUtils.findOperatorsUpstreamJoinAccounted(parent, clazz, found);
                }
                ++i;
            }
        }
        return found;
    }

    public static void setChildrenCollector(List<Operator<? extends OperatorDesc>> childOperators, OutputCollector out) {
        if (childOperators == null) {
            return;
        }
        for (Operator<? extends OperatorDesc> op : childOperators) {
            if (op.getName().equals(ReduceSinkOperator.getOperatorName())) {
                op.setOutputCollector(out);
                continue;
            }
            OperatorUtils.setChildrenCollector(op.getChildOperators(), out);
        }
    }

    public static void setChildrenCollector(List<Operator<? extends OperatorDesc>> childOperators, Map<String, OutputCollector> outMap) {
        if (childOperators == null) {
            return;
        }
        for (Operator<? extends OperatorDesc> op : childOperators) {
            if (op.getIsReduceSink()) {
                String outputName = op.getReduceOutputName();
                if (!outMap.containsKey(outputName)) continue;
                LOG.info("Setting output collector: " + op + " --> " + outputName);
                op.setOutputCollector(outMap.get(outputName));
                continue;
            }
            OperatorUtils.setChildrenCollector(op.getChildOperators(), outMap);
        }
    }

    public static <T> T findLastOperator(Operator<?> op, Class<T> clazz) {
        Operator<Object> currentOp = op;
        Operator<?> lastOp = null;
        while (currentOp != null) {
            if (clazz.isInstance(currentOp)) {
                lastOp = currentOp;
            }
            if (currentOp.getChildOperators().size() == 1) {
                currentOp = currentOp.getChildOperators().get(0);
                continue;
            }
            currentOp = null;
        }
        return (T)lastOp;
    }

    public static void iterateParents(Operator<?> operator, NodeUtils.Function<Operator<?>> function) {
        OperatorUtils.iterateParents(operator, function, new HashSet());
    }

    private static void iterateParents(Operator<?> operator, NodeUtils.Function<Operator<?>> function, Set<Operator<?>> visited) {
        if (!visited.add(operator)) {
            return;
        }
        function.apply(operator);
        if (operator.getNumParent() > 0) {
            for (Operator<OperatorDesc> parent : operator.getParentOperators()) {
                OperatorUtils.iterateParents(parent, function, visited);
            }
        }
    }

    public static Multimap<Class<? extends Operator<?>>, Operator<?>> classifyOperators(Operator<?> start, Set<Class<? extends Operator<?>>> classes) {
        ImmutableMultimap.Builder resultMap = new ImmutableMultimap.Builder();
        ArrayList ops = new ArrayList();
        ops.add(start);
        while (!ops.isEmpty()) {
            ArrayList<Operator<OperatorDesc>> allChildren = new ArrayList<Operator<OperatorDesc>>();
            for (Operator operator : ops) {
                for (Class<Operator<?>> clazz : classes) {
                    if (!clazz.isInstance(operator)) continue;
                    resultMap.put(clazz, operator);
                }
                allChildren.addAll(operator.getChildOperators());
            }
            ops = allChildren;
        }
        return resultMap.build();
    }

    public static Multimap<Class<? extends Operator<?>>, Operator<?>> classifyOperatorsUpstream(Operator<?> start, Set<Class<? extends Operator<?>>> classes) {
        ImmutableMultimap.Builder resultMap = new ImmutableMultimap.Builder();
        ArrayList ops = new ArrayList();
        ops.add(start);
        while (!ops.isEmpty()) {
            ArrayList<Operator<OperatorDesc>> allParent = new ArrayList<Operator<OperatorDesc>>();
            for (Operator operator : ops) {
                for (Class<Operator<?>> clazz : classes) {
                    if (!clazz.isInstance(operator)) continue;
                    resultMap.put(clazz, operator);
                }
                if (operator.getParentOperators() == null) continue;
                allParent.addAll(operator.getParentOperators());
            }
            ops = allParent;
        }
        return resultMap.build();
    }

    public static int countOperatorsUpstream(Operator<?> start, Set<Class<? extends Operator<?>>> classes) {
        Multimap<Class<Operator<?>>, Operator<?>> ops = OperatorUtils.classifyOperatorsUpstream(start, classes);
        int numberOperators = 0;
        HashSet uniqueOperators = new HashSet();
        for (Operator<?> op : ops.values()) {
            if (!uniqueOperators.add(op)) continue;
            ++numberOperators;
        }
        return numberOperators;
    }

    public static void setMemoryAvailable(List<Operator<? extends OperatorDesc>> operators, long memoryAvailableToTask) {
        if (operators == null) {
            return;
        }
        for (Operator<? extends OperatorDesc> op : operators) {
            if (op.getConf() != null) {
                op.getConf().setMaxMemoryAvailable(memoryAvailableToTask);
            }
            if (op.getChildOperators() == null || op.getChildOperators().isEmpty()) continue;
            OperatorUtils.setMemoryAvailable(op.getChildOperators(), memoryAvailableToTask);
        }
    }

    public static void findRoots(Operator<?> op, Collection<Operator<?>> roots) {
        List<Operator<OperatorDesc>> parents = op.getParentOperators();
        if (parents == null || parents.isEmpty()) {
            roots.add(op);
            return;
        }
        for (Operator<OperatorDesc> p : parents) {
            OperatorUtils.findRoots(p, roots);
        }
    }

    public static void removeBranch(SparkPartitionPruningSinkOperator op) {
        SparkPartitionPruningSinkOperator child = op;
        Operator curr = op;
        while (curr.getChildOperators().size() <= 1) {
            child = curr;
            if (curr.getParentOperators() == null || curr.getParentOperators().isEmpty()) {
                return;
            }
            curr = curr.getParentOperators().get(0);
        }
        curr.removeChild(child);
    }

    public static void removeOperator(Operator<?> op) {
        if (op.getNumParent() != 0) {
            ArrayList<Operator<OperatorDesc>> allParent = Lists.newArrayList(op.getParentOperators());
            for (Operator operator : allParent) {
                operator.removeChild(op);
            }
        }
        if (op.getNumChild() != 0) {
            ArrayList<Operator<OperatorDesc>> allChildren = Lists.newArrayList(op.getChildOperators());
            for (Operator operator : allChildren) {
                operator.removeParent(op);
            }
        }
    }

    public static String getOpNamePretty(Operator<?> op) {
        if (op instanceof TableScanOperator) {
            return op.toString() + " (" + ((TableScanDesc)((TableScanOperator)op).getConf()).getAlias() + ")";
        }
        return op.toString();
    }

    public static boolean isInBranch(SparkPartitionPruningSinkOperator op) {
        Operator curr = op;
        while (curr.getChildOperators().size() <= 1) {
            if (curr.getParentOperators() == null || curr.getParentOperators().isEmpty()) {
                return false;
            }
            curr = curr.getParentOperators().get(0);
        }
        return true;
    }

    public static Set<Operator<?>> getOp(BaseWork work, Class<?> clazz) {
        HashSet ops = new HashSet();
        if (work instanceof MapWork) {
            Collection<Operator<? extends OperatorDesc>> opSet = ((MapWork)work).getAliasToWork().values();
            Stack<Operator<? extends OperatorDesc>> opStack = new Stack<Operator<? extends OperatorDesc>>();
            opStack.addAll(opSet);
            while (!opStack.empty()) {
                Operator operator = (Operator)opStack.pop();
                ops.add(operator);
                if (operator.getChildOperators() == null) continue;
                opStack.addAll(operator.getChildOperators());
            }
        } else {
            ops.addAll(work.getAllOperators());
        }
        HashSet matchingOps = new HashSet();
        for (Operator operator : ops) {
            if (!clazz.isInstance(operator)) continue;
            matchingOps.add(operator);
        }
        return matchingOps;
    }

    public static Operator<?> findOperatorByMarker(Operator<?> start, String marker) {
        ArrayDeque queue = new ArrayDeque();
        queue.add(start);
        while (!queue.isEmpty()) {
            Operator op = (Operator)queue.remove();
            if (marker.equals(op.getMarker())) {
                return op;
            }
            if (op.getChildOperators() == null) continue;
            queue.addAll(op.getChildOperators());
        }
        return null;
    }

    public static Set<Operator<?>> findWorkOperatorsAndSemiJoinEdges(Operator<?> start, Map<ReduceSinkOperator, SemiJoinBranchInfo> rsToSemiJoinBranchInfo, Set<ReduceSinkOperator> semiJoinOps, Set<TerminalOperator<?>> terminalOps) {
        HashSet found = new HashSet();
        OperatorUtils.findWorkOperatorsAndSemiJoinEdges(start, found, rsToSemiJoinBranchInfo, semiJoinOps, terminalOps);
        return found;
    }

    private static void findWorkOperatorsAndSemiJoinEdges(Operator<?> start, Set<Operator<?>> found, Map<ReduceSinkOperator, SemiJoinBranchInfo> rsToSemiJoinBranchInfo, Set<ReduceSinkOperator> semiJoinOps, Set<TerminalOperator<?>> terminalOps) {
        found.add(start);
        if (start.getParentOperators() != null) {
            for (Operator<OperatorDesc> parent : start.getParentOperators()) {
                if (parent instanceof ReduceSinkOperator || found.contains(parent)) continue;
                OperatorUtils.findWorkOperatorsAndSemiJoinEdges(parent, found, rsToSemiJoinBranchInfo, semiJoinOps, terminalOps);
            }
        }
        if (start instanceof TerminalOperator) {
            Operator<OperatorDesc> rs2;
            Operator<OperatorDesc> gb2;
            boolean semiJoin = false;
            if (start.getChildOperators().size() == 1 && (gb2 = start.getChildOperators().get(0)) instanceof GroupByOperator && gb2.getChildOperators().size() == 1 && (rs2 = gb2.getChildOperators().get(0)) instanceof ReduceSinkOperator && rsToSemiJoinBranchInfo.get(rs2) != null) {
                found.add(start);
                found.add(gb2);
                found.add(rs2);
                semiJoinOps.add((ReduceSinkOperator)rs2);
                semiJoin = true;
            }
            if (!semiJoin) {
                terminalOps.add((TerminalOperator)start);
            }
            return;
        }
        if (start.getChildOperators() != null) {
            for (Operator<OperatorDesc> child : start.getChildOperators()) {
                if (found.contains(child)) continue;
                OperatorUtils.findWorkOperatorsAndSemiJoinEdges(child, found, rsToSemiJoinBranchInfo, semiJoinOps, terminalOps);
            }
        }
    }

    private static List<ExprNodeDesc> backtrackAll(List<ExprNodeDesc> exprs, Operator<? extends OperatorDesc> start, Operator<? extends OperatorDesc> terminal) {
        ArrayList<ExprNodeDesc> backtrackedExprs = new ArrayList<ExprNodeDesc>();
        try {
            for (ExprNodeDesc expr : exprs) {
                ExprNodeDesc backtrackedExpr = ExprNodeDescUtils.backtrack(expr, start, terminal);
                if (backtrackedExpr == null) {
                    return null;
                }
                backtrackedExprs.add(backtrackedExpr);
            }
        }
        catch (SemanticException e) {
            return null;
        }
        return backtrackedExprs;
    }

    private static boolean areBacktrackedExprsCompatible(List<ExprNodeDesc> orgexprs, List<ExprNodeDesc> backtrackedExprs) {
        if (backtrackedExprs == null || backtrackedExprs.size() != orgexprs.size()) {
            return false;
        }
        for (int i = 0; i < orgexprs.size(); ++i) {
            if (!(orgexprs.get(i) instanceof ExprNodeColumnDesc) || !(backtrackedExprs.get(i) instanceof ExprNodeColumnDesc)) continue;
            ExprNodeColumnDesc orgColExpr = (ExprNodeColumnDesc)orgexprs.get(i);
            ExprNodeColumnDesc backExpr = (ExprNodeColumnDesc)backtrackedExprs.get(i);
            String orgTabAlias = orgColExpr.getTabAlias();
            String backTabAlias = backExpr.getTabAlias();
            if (orgTabAlias == null || backTabAlias == null || orgTabAlias.equals(backTabAlias)) continue;
            return false;
        }
        return true;
    }

    public static Operator<? extends OperatorDesc> findSourceRS(Operator<?> start, List<ExprNodeDesc> exprs) {
        Operator<? extends OperatorDesc> nextRS;
        Operator<Object> currRS = null;
        if (start instanceof ReduceSinkOperator) {
            currRS = start;
        }
        if (start instanceof UnionOperator) {
            return currRS;
        }
        List<Operator<OperatorDesc>> parents = start.getParentOperators();
        if (parents == null | parents.isEmpty()) {
            return null;
        }
        Operator<OperatorDesc> nextOp = null;
        List<ExprNodeDesc> backtrackedExprs = null;
        for (int i = 0; i < parents.size(); ++i) {
            backtrackedExprs = OperatorUtils.backtrackAll(exprs, start, parents.get(i));
            if (!OperatorUtils.areBacktrackedExprsCompatible(exprs, backtrackedExprs)) continue;
            nextOp = parents.get(i);
            break;
        }
        if (nextOp != null && (nextRS = OperatorUtils.findSourceRS(nextOp, backtrackedExprs)) != null) {
            currRS = nextRS;
        }
        return currRS;
    }

    public static GroupByOperator findMapSideGb(GroupByOperator reduceSideGbOp) {
        Operator parentOp = reduceSideGbOp;
        while (parentOp.getParentOperators() != null && parentOp.getParentOperators().size() > 0) {
            if (parentOp.getParentOperators().size() > 1) {
                return null;
            }
            if (!((parentOp = parentOp.getParentOperators().get(0)) instanceof GroupByOperator)) continue;
            return parentOp;
        }
        return null;
    }

    public static boolean treesWithIndependentInputs(Operator<?> tree1, Operator<?> tree2) {
        Set<String> tables1 = OperatorUtils.signaturesOf(OperatorUtils.findOperatorsUpstream(tree1, TableScanOperator.class));
        Set<String> tables2 = OperatorUtils.signaturesOf(OperatorUtils.findOperatorsUpstream(tree2, TableScanOperator.class));
        tables1.retainAll(tables2);
        return tables1.isEmpty();
    }

    private static Set<String> signaturesOf(Set<TableScanOperator> ops) {
        HashSet<String> ret = new HashSet<String>();
        for (TableScanOperator o : ops) {
            ret.add(((TableScanDesc)o.getConf()).getQualifiedTable());
        }
        return ret;
    }
}

