/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexDynamicParam;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveRulesRegistry;
import org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator;
import org.apache.hadoop.hive.ql.optimizer.calcite.stats.HiveRelMdSize;

public class HiveFilterSortPredicates
extends RelOptRule {
    public static final HiveFilterSortPredicates INSTANCE = new HiveFilterSortPredicates();

    private HiveFilterSortPredicates() {
        super(HiveFilterSortPredicates.operand(Filter.class, (RelOptRuleOperand)HiveFilterSortPredicates.operand(RelNode.class, (RelOptRuleOperandChildren)HiveFilterSortPredicates.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]));
    }

    public boolean matches(RelOptRuleCall call) {
        Filter filter = (Filter)call.rel(0);
        HiveRulesRegistry registry = (HiveRulesRegistry)call.getPlanner().getContext().unwrap(HiveRulesRegistry.class);
        return registry == null || !registry.getVisited(this).contains(filter);
    }

    public void onMatch(RelOptRuleCall call) {
        Filter filter = (Filter)call.rel(0);
        RelNode input = call.rel(1);
        HiveRulesRegistry registry = (HiveRulesRegistry)call.getPlanner().getContext().unwrap(HiveRulesRegistry.class);
        if (registry != null) {
            registry.registerVisited(this, (RelNode)filter);
        }
        RexNode originalCond = filter.getCondition();
        RexSortPredicatesShuttle sortPredicatesShuttle = new RexSortPredicatesShuttle(input, filter.getCluster().getMetadataQuery());
        RexNode newCond = (RexNode)originalCond.accept((RexVisitor)sortPredicatesShuttle);
        if (!sortPredicatesShuttle.modified) {
            return;
        }
        Filter newFilter = filter.copy(filter.getTraitSet(), input, newCond);
        if (registry != null) {
            registry.registerVisited(this, (RelNode)newFilter);
        }
        call.transformTo((RelNode)newFilter);
    }

    private static class RexFunctionCost
    extends RexVisitorImpl<Double> {
        private RexFunctionCost() {
            super(true);
        }

        public Double visitCall(RexCall call) {
            if (!this.deep) {
                return null;
            }
            Double cost = 0.0;
            for (RexNode operand : call.operands) {
                Double operandCost = (Double)operand.accept((RexVisitor)this);
                if (operandCost == null) {
                    return null;
                }
                cost = cost + operandCost;
                Double size = operand.isA(SqlKind.LITERAL) ? Double.valueOf(HiveRelMdSize.INSTANCE.typeValueSize(operand.getType(), (Comparable)((RexLiteral)operand).getValueAs(Comparable.class))) : HiveRelMdSize.INSTANCE.averageTypeValueSize(operand.getType());
                if (size == null) {
                    return null;
                }
                cost = cost + size;
            }
            return cost + RexFunctionCost.functionCost(call);
        }

        private static Double functionCost(RexCall call) {
            switch (call.getKind()) {
                case EQUALS: 
                case NOT_EQUALS: 
                case LESS_THAN: 
                case GREATER_THAN: 
                case LESS_THAN_OR_EQUAL: 
                case GREATER_THAN_OR_EQUAL: 
                case IS_NOT_NULL: 
                case IS_NULL: 
                case IS_TRUE: 
                case IS_NOT_TRUE: 
                case IS_FALSE: 
                case IS_NOT_FALSE: {
                    return 1.0;
                }
                case BETWEEN: {
                    return 3.0;
                }
                case IN: {
                    return 2.0 * (double)(call.getOperands().size() - 1);
                }
                case AND: 
                case OR: {
                    return 1.0 * (double)call.getOperands().size();
                }
                case CAST: {
                    return 8.0;
                }
            }
            return 32.0;
        }

        public Double visitInputRef(RexInputRef inputRef) {
            return 0.0;
        }

        public Double visitFieldAccess(RexFieldAccess fieldAccess) {
            return 0.0;
        }

        public Double visitLiteral(RexLiteral literal) {
            return 0.0;
        }

        public Double visitDynamicParam(RexDynamicParam dynamicParam) {
            return 0.0;
        }
    }

    private static class RexSortPredicatesShuttle
    extends RexShuttle {
        private FilterSelectivityEstimator selectivityEstimator;
        private boolean modified;

        private RexSortPredicatesShuttle(RelNode inputRel, RelMetadataQuery mq) {
            this.selectivityEstimator = new FilterSelectivityEstimator(inputRel, mq);
            this.modified = false;
        }

        public RexNode visitCall(RexCall call) {
            switch (call.getKind()) {
                case AND: {
                    List newAndOperands = call.getOperands().stream().map(pred -> new Pair(pred, (Object)this.rankingAnd((RexNode)pred))).sorted(Comparator.comparing(Pair::getValue, Comparator.nullsLast(Double::compare))).map(Pair::getKey).collect(Collectors.toList());
                    if (call.getOperands().equals(newAndOperands)) break;
                    this.modified = true;
                    return call.clone(call.getType(), newAndOperands);
                }
                case OR: {
                    List newOrOperands = call.getOperands().stream().map(pred -> new Pair(pred, (Object)this.rankingOr((RexNode)pred))).sorted(Comparator.comparing(Pair::getValue, Comparator.nullsLast(Double::compare))).map(Pair::getKey).collect(Collectors.toList());
                    if (call.getOperands().equals(newOrOperands)) break;
                    this.modified = true;
                    return call.clone(call.getType(), newOrOperands);
                }
            }
            return call;
        }

        private Double rankingAnd(RexNode e) {
            Double selectivity = this.selectivityEstimator.estimateSelectivity(e);
            if (selectivity == null) {
                return null;
            }
            Double costPerTuple = this.costPerTuple(e);
            if (costPerTuple == null) {
                return null;
            }
            return (selectivity - 1.0) / costPerTuple;
        }

        private Double rankingOr(RexNode e) {
            Double selectivity = this.selectivityEstimator.estimateSelectivity(e);
            if (selectivity == null) {
                return null;
            }
            Double costPerTuple = this.costPerTuple(e);
            if (costPerTuple == null) {
                return null;
            }
            return -selectivity.doubleValue() / costPerTuple;
        }

        private Double costPerTuple(RexNode e) {
            return (Double)e.accept((RexVisitor)new RexFunctionCost());
        }
    }
}

