package org.apache.hadoop.hive.ql.exec.tez;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.TimeUnit;
import jodd.util.StringPool;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.tez.client.registry.AMRecord;
import org.apache.tez.client.registry.AMRegistryClientListener;
import org.apache.tez.client.registry.zookeeper.ZkAMRegistryClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hadoop/hive/ql/exec/tez/TezExternalSessionsRegistryClient.class */
public class TezExternalSessionsRegistryClient implements ExternalSessionsRegistry {
    private static final Logger LOG = LoggerFactory.getLogger(TezExternalSessionsRegistryClient.class);
    private static int DEFAULT_QUEUE_CAPACITY = 16;
    private PriorityBlockingQueue<AMRecord> available;
    private final int maxAttempts;
    private final SelectionStrategy selectionStrategy;
    private ZkAMRegistryClient tezRegistryClient;
    private Map<String, AMRecord> appIdAmRecordMap = new HashMap();
    private HashSet<AMRecord> taken = new HashSet<>();

    @VisibleForTesting
    /* loaded from: input_file:org/apache/hadoop/hive/ql/exec/tez/TezExternalSessionsRegistryClient$AMRecordBinPackingComputeComparator.class */
    static class AMRecordBinPackingComputeComparator implements Comparator<AMRecord> {
        AMRecordBinPackingComputeComparator() {
        }

        @Override // java.util.Comparator
        public int compare(AMRecord aMRecord, AMRecord aMRecord2) {
            Integer computeOrdinal = TezExternalSessionsRegistryClient.getComputeOrdinal(aMRecord.getComputeName());
            Integer computeOrdinal2 = TezExternalSessionsRegistryClient.getComputeOrdinal(aMRecord2.getComputeName());
            if (computeOrdinal.intValue() < 0 || computeOrdinal2.intValue() < 0) {
                return 0;
            }
            if (computeOrdinal.compareTo(computeOrdinal2) != 0) {
                return computeOrdinal.compareTo(computeOrdinal2);
            }
            Integer hostOrdinal = TezExternalSessionsRegistryClient.getHostOrdinal(aMRecord.getHost());
            Integer hostOrdinal2 = TezExternalSessionsRegistryClient.getHostOrdinal(aMRecord2.getHost());
            if (hostOrdinal.intValue() < 0 || hostOrdinal2.intValue() < 0) {
                return 0;
            }
            return hostOrdinal.compareTo(hostOrdinal2);
        }
    }

    @VisibleForTesting
    /* loaded from: input_file:org/apache/hadoop/hive/ql/exec/tez/TezExternalSessionsRegistryClient$AMRecordRandomComputeComparator.class */
    static class AMRecordRandomComputeComparator implements Comparator<AMRecord> {
        AMRecordRandomComputeComparator() {
        }

        @Override // java.util.Comparator
        public int compare(AMRecord aMRecord, AMRecord aMRecord2) {
            return aMRecord.hashCode() - aMRecord2.hashCode();
        }
    }

    @VisibleForTesting
    /* loaded from: input_file:org/apache/hadoop/hive/ql/exec/tez/TezExternalSessionsRegistryClient$AMRecordRoundRobinComputeComparator.class */
    static class AMRecordRoundRobinComputeComparator implements Comparator<AMRecord> {
        AMRecordRoundRobinComputeComparator() {
        }

        @Override // java.util.Comparator
        public int compare(AMRecord aMRecord, AMRecord aMRecord2) {
            Integer hostOrdinal = TezExternalSessionsRegistryClient.getHostOrdinal(aMRecord.getHost());
            Integer hostOrdinal2 = TezExternalSessionsRegistryClient.getHostOrdinal(aMRecord2.getHost());
            if (hostOrdinal.intValue() < 0 || hostOrdinal2.intValue() < 0) {
                return 0;
            }
            if (hostOrdinal.compareTo(hostOrdinal2) != 0) {
                return hostOrdinal.compareTo(hostOrdinal2);
            }
            Integer computeOrdinal = TezExternalSessionsRegistryClient.getComputeOrdinal(aMRecord.getComputeName());
            Integer computeOrdinal2 = TezExternalSessionsRegistryClient.getComputeOrdinal(aMRecord2.getComputeName());
            if (computeOrdinal.intValue() < 0 || computeOrdinal2.intValue() < 0) {
                return 0;
            }
            return computeOrdinal.compareTo(computeOrdinal2);
        }
    }

    /* loaded from: input_file:org/apache/hadoop/hive/ql/exec/tez/TezExternalSessionsRegistryClient$AMStateListener.class */
    private static class AMStateListener implements AMRegistryClientListener {
        private final PriorityBlockingQueue<AMRecord> available;
        private final HashSet<AMRecord> taken;
        private final Map<String, AMRecord> appIdAmRecordMap;

        AMStateListener(Map<String, AMRecord> map, PriorityBlockingQueue<AMRecord> priorityBlockingQueue, HashSet<AMRecord> hashSet) {
            this.appIdAmRecordMap = map;
            this.available = priorityBlockingQueue;
            this.taken = hashSet;
        }

        public void onAdd(AMRecord aMRecord) {
            if (aMRecord == null || aMRecord.getApplicationId() == null) {
                return;
            }
            String applicationId = aMRecord.getApplicationId().toString();
            if (this.taken.contains(aMRecord) || this.available.contains(aMRecord)) {
                return;
            }
            this.available.offer(aMRecord);
            this.appIdAmRecordMap.putIfAbsent(applicationId, aMRecord);
            TezExternalSessionsRegistryClient.LOG.info("Adding external session with applicationId: {} in compute: {}", applicationId, aMRecord.getComputeName());
        }

        public void onRemove(AMRecord aMRecord) {
            if (aMRecord == null || aMRecord.getApplicationId() == null) {
                return;
            }
            String applicationId = aMRecord.getApplicationId().toString();
            this.available.remove(aMRecord);
            this.taken.remove(aMRecord);
            this.appIdAmRecordMap.remove(applicationId);
            TezExternalSessionsRegistryClient.LOG.info("Removed external session with applicationId: {} in compute: {}", applicationId, aMRecord.getComputeName());
        }
    }

    /* loaded from: input_file:org/apache/hadoop/hive/ql/exec/tez/TezExternalSessionsRegistryClient$SelectionStrategy.class */
    enum SelectionStrategy {
        BIN_PACK,
        ROUND_ROBIN,
        RANDOM
    }

    public TezExternalSessionsRegistryClient(Configuration configuration) {
        this.maxAttempts = HiveConf.getIntVar(configuration, HiveConf.ConfVars.HIVE_SERVER2_TEZ_EXTERNAL_SESSIONS_WAIT_MAX_ATTEMPTS);
        this.selectionStrategy = SelectionStrategy.valueOf(HiveConf.getVar(configuration, HiveConf.ConfVars.HIVE_SERVER2_TEZ_EXTERNAL_SESSIONS_ASSIGNMENT_STRATEGY).toUpperCase());
        switch (this.selectionStrategy) {
            case BIN_PACK:
                this.available = new PriorityBlockingQueue<>(DEFAULT_QUEUE_CAPACITY, new AMRecordBinPackingComputeComparator());
                break;
            case ROUND_ROBIN:
                this.available = new PriorityBlockingQueue<>(DEFAULT_QUEUE_CAPACITY, new AMRecordRoundRobinComputeComparator());
                break;
            case RANDOM:
                this.available = new PriorityBlockingQueue<>(DEFAULT_QUEUE_CAPACITY, new AMRecordRandomComputeComparator());
                break;
        }
        try {
            this.tezRegistryClient = ZkAMRegistryClient.getClient(configuration);
            this.tezRegistryClient.addListener(new AMStateListener(this.appIdAmRecordMap, this.available, this.taken));
            LOG.info("Using selection strategy: {}", this.selectionStrategy);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.apache.hadoop.hive.ql.exec.tez.ExternalSessionsRegistry
    public void close() {
        if (this.tezRegistryClient != null) {
            this.tezRegistryClient.close();
        }
    }

    @Override // org.apache.hadoop.hive.ql.exec.tez.ExternalSessionsRegistry
    public String getSession() throws Exception {
        AMRecord poll = this.available.poll(this.maxAttempts, TimeUnit.SECONDS);
        if (poll == null) {
            throw new IOException("Cannot get a session after " + this.maxAttempts + " seconds");
        }
        this.taken.add(poll);
        LOG.info("External session taken: {} in compute: {}. available: {}", new Object[]{poll.getComputeName(), poll.getApplicationId().toString(), Integer.valueOf(this.available.size())});
        return poll.getApplicationId().toString();
    }

    @Override // org.apache.hadoop.hive.ql.exec.tez.ExternalSessionsRegistry
    public void returnSession(String str) {
        AMRecord aMRecord = this.appIdAmRecordMap.get(str);
        if (this.taken.remove(aMRecord)) {
            this.available.offer(aMRecord);
            LOG.info("External session returned: {} back to compute: {}. available: {}", new Object[]{str, aMRecord.getComputeName(), Integer.valueOf(this.available.size())});
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Integer getComputeOrdinal(String str) {
        if (str.lastIndexOf(StringPool.DASH) > 0) {
            try {
                return Integer.valueOf(str.substring(str.lastIndexOf(StringPool.DASH) + 1));
            } catch (NumberFormatException e) {
            }
        }
        LOG.warn("Returning negative ordinal for compute: {}", str);
        return -1;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Integer getHostOrdinal(String str) {
        String str2 = str.split("\\.")[0];
        if (str2.lastIndexOf(StringPool.DASH) > 0) {
            try {
                return Integer.valueOf(str2.substring(str2.lastIndexOf(StringPool.DASH) + 1));
            } catch (NumberFormatException e) {
            }
        }
        LOG.warn("Returning negative ordinal for host: {}", str2);
        return -1;
    }
}
