package org.apache.hadoop.hive.ql.parse;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import jodd.util.StringPool;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.Warehouse;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.metastore.api.MetaException;
import org.apache.hadoop.hive.ql.Context;
import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.QueryState;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.HiveUtils;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.parse.RewriteSemanticAnalyzer;
import org.apache.hadoop.hive.ql.session.SessionState;

/* loaded from: input_file:org/apache/hadoop/hive/ql/parse/MergeSemanticAnalyzer.class */
public class MergeSemanticAnalyzer extends RewriteSemanticAnalyzer {
    private int numWhenMatchedUpdateClauses;
    private int numWhenMatchedDeleteClauses;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/parse/MergeSemanticAnalyzer$OnClauseAnalyzer.class */
    public static final class OnClauseAnalyzer {
        private final ASTNode onClause;
        private final Map<String, List<String>> table2column = new HashMap();
        private final List<String> unresolvedColumns = new ArrayList();
        private final List<FieldSchema> allTargetTableColumns = new ArrayList();
        private final Set<String> tableNamesFound = new HashSet();
        private final String targetTableNameInSourceQuery;
        private final HiveConf conf;
        private final String onClauseAsString;

        OnClauseAnalyzer(ASTNode aSTNode, Table table, String str, HiveConf hiveConf, String str2) {
            this.onClause = aSTNode;
            this.allTargetTableColumns.addAll(table.getCols());
            this.allTargetTableColumns.addAll(table.getPartCols());
            this.targetTableNameInSourceQuery = BaseSemanticAnalyzer.unescapeIdentifier(str);
            this.conf = hiveConf;
            this.onClauseAsString = str2;
        }

        private void visit(ASTNode aSTNode) {
            if (aSTNode.getType() == 1133) {
                ASTNode parent = aSTNode.getParent();
                if (parent == null || parent.getType() != 16) {
                    this.unresolvedColumns.add(aSTNode.getChild(0).getText());
                } else {
                    if (parent.getParent() != null && parent.getParent().getType() == 16) {
                        throw new IllegalArgumentException("Found unexpected db.table.col reference in " + this.onClauseAsString);
                    }
                    addColumn2Table(aSTNode.getChild(0).getText(), parent.getChild(1).getText());
                }
            }
            if (aSTNode.getChildCount() == 0) {
                return;
            }
            Iterator<Node> it = aSTNode.getChildren().iterator();
            while (it.hasNext()) {
                visit((ASTNode) it.next());
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void analyze() {
            visit(this.onClause);
            if (this.tableNamesFound.size() > 2) {
                throw new IllegalArgumentException("Found > 2 table refs in ON clause.  Found " + this.tableNamesFound + " in " + this.onClauseAsString);
            }
            handleUnresolvedColumns();
            if (this.tableNamesFound.size() > 2) {
                throw new IllegalArgumentException("Found > 2 table refs in ON clause (incl unresolved).  Found " + this.tableNamesFound + " in " + this.onClauseAsString);
            }
        }

        private void handleUnresolvedColumns() {
            if (this.unresolvedColumns.isEmpty()) {
                return;
            }
            for (String str : this.unresolvedColumns) {
                Iterator<FieldSchema> it = this.allTargetTableColumns.iterator();
                while (true) {
                    if (it.hasNext()) {
                        if (str.equalsIgnoreCase(it.next().getName())) {
                            addColumn2Table(this.targetTableNameInSourceQuery.toLowerCase(), str);
                            break;
                        }
                    } else {
                        break;
                    }
                }
            }
        }

        private void addColumn2Table(String str, String str2) {
            String lowerCase = str.toLowerCase();
            this.tableNamesFound.add(lowerCase);
            List<String> list = this.table2column.get(lowerCase);
            if (list == null) {
                list = new ArrayList();
                this.table2column.put(lowerCase, list);
            }
            list.add(str2);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public String getPredicate() {
            List<String> list = this.table2column.get(this.targetTableNameInSourceQuery.toLowerCase());
            if (list == null) {
                throw new IllegalArgumentException(ErrorMsg.INVALID_TABLE_IN_ON_CLAUSE_OF_MERGE.format(this.targetTableNameInSourceQuery, this.onClauseAsString));
            }
            StringBuilder sb = new StringBuilder();
            for (String str : list) {
                if (sb.length() > 0) {
                    sb.append(" AND ");
                }
                sb.append(HiveUtils.unparseIdentifier(this.targetTableNameInSourceQuery, this.conf)).append(StringPool.DOT).append(HiveUtils.unparseIdentifier(str, this.conf)).append(" IS NULL");
            }
            return sb.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public MergeSemanticAnalyzer(QueryState queryState) throws SemanticException {
        super(queryState);
    }

    @Override // org.apache.hadoop.hive.ql.parse.RewriteSemanticAnalyzer
    protected ASTNode getTargetTableNode(ASTNode aSTNode) {
        return aSTNode.getChild(0);
    }

    @Override // org.apache.hadoop.hive.ql.parse.RewriteSemanticAnalyzer
    public void analyze(ASTNode aSTNode, Table table, ASTNode aSTNode2) throws SemanticException {
        if (aSTNode.getToken().getType() != 962) {
            throw new RuntimeException("Asked to parse token " + aSTNode.getName() + " in MergeSemanticAnalyzer");
        }
        this.ctx.setOperation(Context.Operation.MERGE);
        analyzeMerge(aSTNode, table, aSTNode2);
    }

    private void analyzeMerge(ASTNode aSTNode, Table table, ASTNode aSTNode2) throws SemanticException {
        ASTNode aSTNode3 = (ASTNode) aSTNode.getChild(1);
        String simpleTableName = getSimpleTableName(aSTNode2);
        String simpleTableName2 = getSimpleTableName(aSTNode3);
        ASTNode aSTNode4 = (ASTNode) aSTNode.getChild(2);
        String matchedText = getMatchedText(aSTNode4);
        int i = 3;
        boolean z = false;
        ASTNode child = aSTNode.getChild(3);
        if (child.getType() == 387) {
            z = true;
            i = 3 + 1;
        }
        List<ASTNode> findWhenClauses = findWhenClauses(aSTNode, i);
        StringBuilder createRewrittenQueryStrBuilder = createRewrittenQueryStrBuilder();
        appendTarget(createRewrittenQueryStrBuilder, aSTNode2, simpleTableName);
        createRewrittenQueryStrBuilder.append(Utilities.INDENT).append(chooseJoinType(findWhenClauses)).append("\n");
        if (aSTNode3.getType() == 1100) {
            createRewrittenQueryStrBuilder.append(Utilities.INDENT).append(getMatchedText(aSTNode3));
        } else {
            createRewrittenQueryStrBuilder.append(Utilities.INDENT).append(getFullTableNameForSQL(aSTNode3));
            if (isAliased(aSTNode3)) {
                createRewrittenQueryStrBuilder.append(" ").append(simpleTableName2);
            }
        }
        createRewrittenQueryStrBuilder.append('\n');
        createRewrittenQueryStrBuilder.append(Utilities.INDENT).append("ON ").append(matchedText).append('\n');
        String str = z ? " /*+ " + child.getText() + " */ " : null;
        boolean z2 = HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.SPLIT_UPDATE) || HiveConf.getBoolVar(this.conf, HiveConf.ConfVars.MERGE_SPLIT_UPDATE);
        String str2 = null;
        int i2 = 0;
        this.numWhenMatchedUpdateClauses = 0;
        this.numWhenMatchedDeleteClauses = 0;
        boolean z3 = false;
        for (ASTNode aSTNode5 : findWhenClauses) {
            switch (getWhenClauseOperation(aSTNode5).getType()) {
                case 869:
                    this.numWhenMatchedDeleteClauses++;
                    String handleDelete = handleDelete(aSTNode5, createRewrittenQueryStrBuilder, aSTNode2, matchedText, str2, z3 ? null : str, false);
                    z3 = true;
                    if (this.numWhenMatchedUpdateClauses + this.numWhenMatchedDeleteClauses == 1) {
                        str2 = handleDelete;
                        break;
                    }
                    break;
                case 928:
                    i2++;
                    handleInsert(aSTNode5, createRewrittenQueryStrBuilder, aSTNode2, aSTNode4, table, simpleTableName, matchedText, z3 ? null : str);
                    z3 = true;
                    break;
                case 1165:
                    this.numWhenMatchedUpdateClauses++;
                    String handleUpdate = handleUpdate(aSTNode5, createRewrittenQueryStrBuilder, aSTNode2, matchedText, table, str2, z3 ? null : str, z2);
                    z3 = true;
                    if (this.numWhenMatchedUpdateClauses + this.numWhenMatchedDeleteClauses == 1) {
                        str2 = handleUpdate;
                        break;
                    }
                    break;
                default:
                    throw new IllegalStateException("Unexpected WHEN clause type: " + aSTNode5.getType() + addParseInfo(aSTNode5));
            }
            if (this.numWhenMatchedDeleteClauses > 1) {
                throw new SemanticException(ErrorMsg.MERGE_TOO_MANY_DELETE, this.ctx.getCmd());
            }
            if (this.numWhenMatchedUpdateClauses > 1) {
                throw new SemanticException(ErrorMsg.MERGE_TOO_MANY_UPDATE, this.ctx.getCmd());
            }
            if (!$assertionsDisabled && i2 >= 2) {
                throw new AssertionError("too many Insert clauses");
            }
        }
        if (this.numWhenMatchedDeleteClauses + this.numWhenMatchedUpdateClauses == 2 && str2 == null) {
            throw new SemanticException(ErrorMsg.MERGE_PREDIACTE_REQUIRED, this.ctx.getCmd());
        }
        boolean handleCardinalityViolation = handleCardinalityViolation(createRewrittenQueryStrBuilder, aSTNode2, matchedText, table, this.numWhenMatchedDeleteClauses == 0 && this.numWhenMatchedUpdateClauses == 0);
        RewriteSemanticAnalyzer.ReparseResult parseRewrittenQuery = parseRewrittenQuery(createRewrittenQueryStrBuilder, this.ctx.getCmd());
        Context context = parseRewrittenQuery.rewrittenCtx;
        ASTNode aSTNode6 = parseRewrittenQuery.rewrittenTree;
        context.setOperation(Context.Operation.MERGE);
        int i3 = 1;
        int i4 = 0;
        while (true) {
            if (i3 >= aSTNode6.getChildCount() - (handleCardinalityViolation ? 1 : 0)) {
                if (handleCardinalityViolation) {
                    context.addDestNamePrefix(aSTNode6.getChildCount() - 1, Context.DestClausePrefix.INSERT);
                }
                analyzeRewrittenTree(aSTNode6, context);
                updateOutputs(table);
                return;
            }
            switch (getWhenClauseOperation(findWhenClauses.get(i4)).getType()) {
                case 869:
                    context.addDestNamePrefix(i3, Context.DestClausePrefix.DELETE);
                    break;
                case 928:
                    context.addDestNamePrefix(i3, Context.DestClausePrefix.INSERT);
                    break;
                case 1165:
                    if (z2) {
                        context.addDestNamePrefix(i3, Context.DestClausePrefix.INSERT);
                        i3++;
                        context.addDeleteOfUpdateDestNamePrefix(i3, Context.DestClausePrefix.DELETE);
                        break;
                    } else {
                        context.addDestNamePrefix(i3, Context.DestClausePrefix.UPDATE);
                        break;
                    }
                default:
                    if (!$assertionsDisabled) {
                        throw new AssertionError();
                    }
                    break;
            }
            i3++;
            i4++;
        }
    }

    private String chooseJoinType(List<ASTNode> list) {
        Iterator<ASTNode> it = list.iterator();
        while (it.hasNext()) {
            if (getWhenClauseOperation(it.next()).getType() == 928) {
                return "RIGHT OUTER JOIN";
            }
        }
        return "INNER JOIN";
    }

    private boolean handleCardinalityViolation(StringBuilder sb, ASTNode aSTNode, String str, Table table, boolean z) throws SemanticException {
        if (!this.conf.getBoolVar(HiveConf.ConfVars.MERGE_CARDINALITY_VIOLATION_CHECK)) {
            LOG.info("Merge statement cardinality violation check is disabled: " + HiveConf.ConfVars.MERGE_CARDINALITY_VIOLATION_CHECK.varname);
            return false;
        }
        if (z) {
            return false;
        }
        sb.append("INSERT INTO ").append("merge_tmp_table").append("\n  SELECT cardinality_violation(").append(getSimpleTableName(aSTNode)).append(".ROW__ID");
        addPartitionColsToSelect(table.getPartCols(), sb, aSTNode);
        sb.append(")\n WHERE ").append(str).append(" GROUP BY ").append(getSimpleTableName(aSTNode)).append(".ROW__ID");
        addPartitionColsToSelect(table.getPartCols(), sb, aSTNode);
        sb.append(" HAVING count(*) > 1");
        try {
            if (null == this.db.getTable("merge_tmp_table", false)) {
                StorageFormat storageFormat = new StorageFormat(this.conf);
                storageFormat.processStorageFormat("TextFile");
                Table newTable = this.db.newTable("merge_tmp_table");
                newTable.setSerializationLib(storageFormat.getSerde());
                ArrayList arrayList = new ArrayList();
                arrayList.add(new FieldSchema("val", "int", null));
                newTable.setFields(arrayList);
                newTable.setDataLocation(Warehouse.getDnsPath(new Path(SessionState.get().getTempTableSpace(), "merge_tmp_table"), this.conf));
                newTable.getTTable().setTemporary(true);
                newTable.setStoredAsSubDirectories(false);
                newTable.setInputFormatClass(storageFormat.getInputFormat());
                newTable.setOutputFormatClass(storageFormat.getOutputFormat());
                this.db.createTable(newTable, true);
            }
            return true;
        } catch (MetaException | HiveException e) {
            throw new SemanticException(e.getMessage(), e);
        }
    }

    private String handleUpdate(ASTNode aSTNode, StringBuilder sb, ASTNode aSTNode2, String str, Table table, String str2, String str3, boolean z) throws SemanticException {
        if (!$assertionsDisabled && aSTNode.getType() != 961) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && getWhenClauseOperation(aSTNode).getType() != 1165) {
            throw new AssertionError();
        }
        String simpleTableName = getSimpleTableName(aSTNode2);
        ArrayList arrayList = new ArrayList(table.getCols().size() + (z ? 1 : 0));
        if (!z) {
            arrayList.add(simpleTableName + ".ROW__ID");
        }
        Map<String, ASTNode> collectSetColumnsAndExpressions = collectSetColumnsAndExpressions((ASTNode) getWhenClauseOperation(aSTNode).getChild(0), null, table);
        List<FieldSchema> cols = table.getCols();
        Map<String, String> colNameToDefaultValueMap = getColNameToDefaultValueMap(table);
        Iterator<FieldSchema> it = cols.iterator();
        while (it.hasNext()) {
            String name = it.next().getName();
            if (collectSetColumnsAndExpressions.containsKey(name)) {
                ASTNode aSTNode3 = collectSetColumnsAndExpressions.get(name);
                if (aSTNode3.getType() == 1133 && aSTNode3.getChildCount() == 1 && aSTNode3.getChild(0).getType() == 868) {
                    UnparseTranslator unparseTranslator = new UnparseTranslator(this.conf);
                    unparseTranslator.enable();
                    unparseTranslator.addDefaultValueTranslation(collectSetColumnsAndExpressions.get(name), colNameToDefaultValueMap.get(name));
                    unparseTranslator.applyTranslations(this.ctx.getTokenRewriteStream());
                }
                String matchedText = getMatchedText(collectSetColumnsAndExpressions.get(name));
                switch (matchedText.charAt(matchedText.length() - 1)) {
                    case '\n':
                    case ',':
                        matchedText = matchedText.substring(0, matchedText.length() - 1);
                    default:
                        arrayList.add(matchedText);
                        break;
                }
            } else {
                arrayList.add(simpleTableName + StringPool.DOT + HiveUtils.unparseIdentifier(name, this.conf));
            }
        }
        addPartitionColsAsValues(table.getPartCols(), simpleTableName, arrayList);
        sb.append("    -- update clause").append(z ? " (insert part)" : "").append("\n");
        appendInsertBranch(sb, str3, arrayList);
        sb.append(Utilities.INDENT).append("WHERE ").append(str);
        String whenClausePredicate = getWhenClausePredicate(aSTNode);
        if (whenClausePredicate != null) {
            sb.append(" AND ").append(whenClausePredicate);
        }
        if (str2 != null) {
            sb.append(" AND NOT(").append(str2).append(StringPool.RIGHT_BRACKET);
        }
        if (!z) {
            appendSortBy(sb, Collections.singletonList(simpleTableName + ".ROW__ID "));
        }
        sb.append("\n");
        setUpAccessControlInfoForUpdate(table, collectSetColumnsAndExpressions);
        if (z) {
            sb.append("    -- update clause (delete part)\n");
            handleDelete(aSTNode, sb, aSTNode2, str, str2, str3, true);
        }
        return whenClausePredicate;
    }

    private String handleDelete(ASTNode aSTNode, StringBuilder sb, ASTNode aSTNode2, String str, String str2, String str3, boolean z) throws SemanticException {
        if (!$assertionsDisabled && aSTNode.getType() != 961) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && ((!z || getWhenClauseOperation(aSTNode).getType() != 1165) && getWhenClauseOperation(aSTNode).getType() != 869)) {
            throw new AssertionError();
        }
        String simpleTableName = getSimpleTableName(aSTNode2);
        appendDeleteBranch(sb, str3, simpleTableName, Collections.singletonList(simpleTableName + ".ROW__ID"));
        sb.append(Utilities.INDENT).append("WHERE ").append(str);
        String whenClausePredicate = getWhenClausePredicate(aSTNode);
        if (whenClausePredicate != null) {
            sb.append(" AND ").append(whenClausePredicate);
        }
        if (str2 != null) {
            sb.append(" AND NOT(").append(str2).append(StringPool.RIGHT_BRACKET);
        }
        appendSortBy(sb, Collections.singletonList(simpleTableName + ".ROW__ID "));
        return whenClausePredicate;
    }

    private static String addParseInfo(ASTNode aSTNode) {
        return " at " + ErrorMsg.renderPosition(aSTNode);
    }

    private List<ASTNode> findWhenClauses(ASTNode aSTNode, int i) throws SemanticException {
        if (!$assertionsDisabled && aSTNode.getType() != 962) {
            throw new AssertionError();
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = i; i2 < aSTNode.getChildCount(); i2++) {
            ASTNode child = aSTNode.getChild(i2);
            if (!$assertionsDisabled && child.getType() != 961 && child.getType() != 967) {
                throw new AssertionError("Unexpected node type found: " + child.getType() + addParseInfo(child));
            }
            arrayList.add(child);
        }
        if (arrayList.size() <= 0) {
            throw new SemanticException("Must have at least 1 WHEN clause in MERGE statement");
        }
        return arrayList;
    }

    private ASTNode getWhenClauseOperation(ASTNode aSTNode) {
        if (aSTNode.getType() == 961 || aSTNode.getType() == 967) {
            return aSTNode.getChild(0);
        }
        throw raiseWrongType("Expected TOK_MATCHED|TOK_NOT_MATCHED", aSTNode);
    }

    private String getWhenClausePredicate(ASTNode aSTNode) {
        if (aSTNode.getType() != 961 && aSTNode.getType() != 967) {
            throw raiseWrongType("Expected TOK_MATCHED|TOK_NOT_MATCHED", aSTNode);
        }
        if (aSTNode.getChildCount() == 2) {
            return getMatchedText((ASTNode) aSTNode.getChild(1));
        }
        return null;
    }

    private void handleInsert(ASTNode aSTNode, StringBuilder sb, ASTNode aSTNode2, ASTNode aSTNode3, Table table, String str, String str2, String str3) throws SemanticException {
        ASTNode whenClauseOperation = getWhenClauseOperation(aSTNode);
        if (!$assertionsDisabled && aSTNode.getType() != 967) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && whenClauseOperation.getType() != 928) {
            throw new AssertionError();
        }
        List<Node> children = whenClauseOperation.getChildren();
        ASTNode aSTNode4 = (ASTNode) children.stream().filter(node -> {
            return ((ASTNode) node).getType() == 911;
        }).findFirst().get();
        ASTNode aSTNode5 = (ASTNode) children.stream().filter(node2 -> {
            return ((ASTNode) node2).getType() == 1111;
        }).findFirst().orElse(null);
        if (aSTNode5 != null && aSTNode5.getChildCount() != aSTNode4.getChildCount() - 1) {
            throw new SemanticException(String.format("Column schema must have the same length as values (%d vs %d)", Integer.valueOf(aSTNode5.getChildCount()), Integer.valueOf(aSTNode4.getChildCount() - 1)));
        }
        sb.append("INSERT INTO ").append(getFullTableNameForSQL(aSTNode2));
        if (aSTNode5 != null) {
            sb.append(' ').append(getMatchedText(aSTNode5));
        }
        sb.append("    -- insert clause\n  SELECT ");
        if (str3 != null) {
            sb.append(str3);
        }
        OnClauseAnalyzer onClauseAnalyzer = new OnClauseAnalyzer(aSTNode3, table, str, this.conf, str2);
        onClauseAnalyzer.analyze();
        UnparseTranslator unparseTranslator = new UnparseTranslator(this.conf);
        unparseTranslator.enable();
        collectDefaultValues(aSTNode4, table, processTableColumnNames(aSTNode5, table.getFullyQualifiedName()), unparseTranslator);
        unparseTranslator.applyTranslations(this.ctx.getTokenRewriteStream());
        String matchedText = getMatchedText(aSTNode4);
        sb.append(matchedText.substring(1, matchedText.length() - 1)).append("\n   WHERE ").append(onClauseAnalyzer.getPredicate());
        if (getWhenClausePredicate(aSTNode) != null) {
            sb.append(" AND ").append(getMatchedText((ASTNode) aSTNode.getChild(1)));
        }
        sb.append('\n');
    }

    private void collectDefaultValues(ASTNode aSTNode, Table table, List<String> list, UnparseTranslator unparseTranslator) throws SemanticException {
        List<String> defaultConstraints = getDefaultConstraints(table, list);
        for (int i = 0; i < defaultConstraints.size(); i++) {
            unparseTranslator.addDefaultValueTranslation((ASTNode) aSTNode.getChild(i + 1), defaultConstraints.get(i));
        }
    }

    @Override // org.apache.hadoop.hive.ql.parse.SemanticAnalyzer
    protected boolean allowOutputMultipleTimes() {
        return this.conf.getBoolVar(HiveConf.ConfVars.SPLIT_UPDATE) || this.conf.getBoolVar(HiveConf.ConfVars.MERGE_SPLIT_UPDATE);
    }

    @Override // org.apache.hadoop.hive.ql.parse.SemanticAnalyzer
    protected boolean enableColumnStatsCollecting() {
        return this.numWhenMatchedUpdateClauses == 0 && this.numWhenMatchedDeleteClauses == 0;
    }

    static {
        $assertionsDisabled = !MergeSemanticAnalyzer.class.desiredAssertionStatus();
    }
}
