/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.exec.tez;

import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.protobuf.ByteString;
import java.io.DataOutput;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.tez.CustomEdgeConfiguration;
import org.apache.hadoop.hive.ql.exec.tez.CustomPartitionEdge;
import org.apache.hadoop.hive.ql.exec.tez.CustomVertexConfiguration;
import org.apache.hadoop.hive.ql.exec.tez.DataInputByteBuffer;
import org.apache.hadoop.hive.ql.exec.tez.SplitGrouper;
import org.apache.hadoop.hive.ql.plan.TezWork;
import org.apache.hadoop.hive.shims.ShimLoader;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.serializer.SerializationFactory;
import org.apache.hadoop.mapred.FileSplit;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.split.TezGroupedSplit;
import org.apache.tez.common.TezUtils;
import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexLocationHint;
import org.apache.tez.dag.api.VertexManagerPlugin;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.mapreduce.hadoop.MRInputHelpers;
import org.apache.tez.mapreduce.protos.MRRuntimeProtos;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.InputSpecUpdate;
import org.apache.tez.runtime.api.events.InputConfigureVertexTasksEvent;
import org.apache.tez.runtime.api.events.InputDataInformationEvent;
import org.apache.tez.runtime.api.events.InputUpdatePayloadEvent;
import org.apache.tez.runtime.api.events.VertexManagerEvent;

public class CustomPartitionVertex
extends VertexManagerPlugin {
    private static final Log LOG = LogFactory.getLog((String)CustomPartitionVertex.class.getName());
    VertexManagerPluginContext context;
    private InputConfigureVertexTasksEvent configureVertexTaskEvent;
    private int numBuckets = -1;
    private Configuration conf = null;
    private final SplitGrouper grouper = new SplitGrouper();
    private int taskCount = 0;
    private TezWork.VertexType vertexType;
    private String mainWorkName;
    private final Multimap<Integer, Integer> bucketToTaskMap = HashMultimap.create();
    private final Map<String, Multimap<Integer, InputSplit>> inputToGroupedSplitMap = new HashMap<String, Multimap<Integer, InputSplit>>();
    private int numInputsAffectingRootInputSpecUpdate = 1;
    private int numInputsSeenSoFar = 0;
    private final Map<String, EdgeManagerPluginDescriptor> emMap = Maps.newHashMap();
    private final List<InputSplit> finalSplits = Lists.newLinkedList();
    private final Map<String, InputSpecUpdate> inputNameInputSpecMap = new HashMap<String, InputSpecUpdate>();

    public CustomPartitionVertex(VertexManagerPluginContext context) {
        super(context);
    }

    public void initialize() {
        this.context = this.getContext();
        ByteBuffer payload = this.context.getUserPayload().getPayload();
        CustomVertexConfiguration vertexConf = new CustomVertexConfiguration();
        DataInputByteBuffer dibb = new DataInputByteBuffer();
        dibb.reset(payload);
        try {
            vertexConf.readFields(dibb);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        this.numBuckets = vertexConf.getNumBuckets();
        this.mainWorkName = vertexConf.getInputName();
        this.vertexType = vertexConf.getVertexType();
        this.numInputsAffectingRootInputSpecUpdate = vertexConf.getNumInputs();
    }

    public void onVertexStarted(Map<String, List<Integer>> completions) {
        int numTasks = this.context.getVertexNumTasks(this.context.getVertexName());
        ArrayList<VertexManagerPluginContext.TaskWithLocationHint> scheduledTasks = new ArrayList<VertexManagerPluginContext.TaskWithLocationHint>(numTasks);
        for (int i = 0; i < numTasks; ++i) {
            scheduledTasks.add(new VertexManagerPluginContext.TaskWithLocationHint(new Integer(i), null));
        }
        this.context.scheduleVertexTasks(scheduledTasks);
    }

    public void onSourceTaskCompleted(String srcVertexName, Integer attemptId) {
    }

    public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
    }

    public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor, List<Event> events) {
        ++this.numInputsSeenSoFar;
        LOG.info((Object)("On root vertex initialized " + inputName));
        try {
            MRRuntimeProtos.MRInputUserPayloadProto protoPayload = MRInputHelpers.parseMRInputPayload((UserPayload)inputDescriptor.getUserPayload());
            this.conf = TezUtils.createConfFromByteString((ByteString)protoPayload.getConfigurationBytes());
            MRRuntimeProtos.MRInputUserPayloadProto updatedPayload = MRRuntimeProtos.MRInputUserPayloadProto.newBuilder((MRRuntimeProtos.MRInputUserPayloadProto)protoPayload).setGroupingEnabled(true).build();
            inputDescriptor.setUserPayload(UserPayload.create((ByteBuffer)updatedPayload.toByteString().asReadOnlyByteBuffer()));
        }
        catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
        boolean dataInformationEventSeen = false;
        TreeMap<String, Set<FileSplit>> pathFileSplitsMap = new TreeMap<String, Set<FileSplit>>();
        for (Event event : events) {
            FileSplit fileSplit;
            if (event instanceof InputConfigureVertexTasksEvent) {
                InputConfigureVertexTasksEvent cEvent;
                LOG.info((Object)("Got a input configure vertex event for input: " + inputName));
                Preconditions.checkState(!dataInformationEventSeen);
                this.configureVertexTaskEvent = cEvent = (InputConfigureVertexTasksEvent)event;
                LOG.info((Object)("Configure task for input name: " + inputName + " num tasks: " + this.configureVertexTaskEvent.getNumTasks()));
            }
            if (event instanceof InputUpdatePayloadEvent) {
                Preconditions.checkState(false);
                continue;
            }
            if (!(event instanceof InputDataInformationEvent)) continue;
            dataInformationEventSeen = true;
            InputDataInformationEvent diEvent = (InputDataInformationEvent)event;
            try {
                fileSplit = this.getFileSplitFromEvent(diEvent);
            }
            catch (IOException e) {
                throw new RuntimeException("Failed to get file split for event: " + diEvent, e);
            }
            TreeSet<InputSplit> fsList = (TreeSet<InputSplit>)pathFileSplitsMap.get(Utilities.getBucketFileNameFromPathSubString(fileSplit.getPath().getName()));
            if (fsList == null) {
                fsList = new TreeSet<InputSplit>(new PathComparatorForSplit());
                pathFileSplitsMap.put(Utilities.getBucketFileNameFromPathSubString(fileSplit.getPath().getName()), fsList);
            }
            fsList.add((InputSplit)fileSplit);
        }
        LOG.info((Object)("Path file splits map for input name: " + inputName + " is " + pathFileSplitsMap));
        Multimap<Integer, InputSplit> bucketToInitialSplitMap = this.getBucketSplitMapForPath(pathFileSplitsMap);
        try {
            int totalResource = this.context.getTotalAvailableResource().getMemory();
            int taskResource = this.context.getVertexTaskResource().getMemory();
            float waves = this.conf.getFloat("tez.grouping.split-waves", 1.7f);
            int availableSlots = totalResource / taskResource;
            LOG.info((Object)("Grouping splits. " + availableSlots + " available slots, " + waves + " waves. Bucket initial splits map: " + bucketToInitialSplitMap));
            JobConf jobConf = new JobConf(this.conf);
            ShimLoader.getHadoopShims().getMergedCredentials(jobConf);
            HashMultimap<Integer, InputSplit> bucketToGroupedSplitMap = HashMultimap.create();
            boolean secondLevelGroupingDone = false;
            if (this.mainWorkName.isEmpty() || inputName.compareTo(this.mainWorkName) == 0) {
                for (Integer key : bucketToInitialSplitMap.keySet()) {
                    InputSplit[] inputSplitArray = bucketToInitialSplitMap.get(key).toArray(new InputSplit[0]);
                    Multimap<Integer, InputSplit> groupedSplit = this.grouper.generateGroupedSplits(jobConf, this.conf, inputSplitArray, waves, availableSlots, inputName, this.mainWorkName.isEmpty());
                    if (!this.mainWorkName.isEmpty()) {
                        HashMultimap<Integer, InputSplit> singleBucketToGroupedSplit = HashMultimap.create();
                        singleBucketToGroupedSplit.putAll(key, groupedSplit.values());
                        groupedSplit = this.grouper.group((Configuration)jobConf, singleBucketToGroupedSplit, availableSlots, HiveConf.getFloatVar(this.conf, HiveConf.ConfVars.TEZ_SMB_NUMBER_WAVES));
                        secondLevelGroupingDone = true;
                    }
                    bucketToGroupedSplitMap.putAll(key, groupedSplit.values());
                }
                this.processAllEvents(inputName, bucketToGroupedSplitMap, secondLevelGroupingDone);
            } else {
                for (Integer key : bucketToInitialSplitMap.keySet()) {
                    InputSplit[] inputSplitArray = bucketToInitialSplitMap.get(key).toArray(new InputSplit[0]);
                    Multimap<Integer, InputSplit> groupedSplit = this.grouper.generateGroupedSplits(jobConf, this.conf, inputSplitArray, waves, availableSlots, inputName, false);
                    bucketToGroupedSplitMap.putAll(key, groupedSplit.values());
                }
                LOG.info((Object)"This is the side work - multi-mr work.");
                this.processAllSideEventsSetParallelism(inputName, bucketToGroupedSplitMap);
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private void processAllSideEventsSetParallelism(String inputName, Multimap<Integer, InputSplit> bucketToGroupedSplitMap) throws IOException {
        LOG.info((Object)("Processing events for input " + inputName));
        if (this.bucketToTaskMap.isEmpty()) {
            LOG.info((Object)"We don't have a routing table yet. Will need to wait for the main input initialization");
            this.inputToGroupedSplitMap.put(inputName, bucketToGroupedSplitMap);
            return;
        }
        this.processAllSideEvents(inputName, bucketToGroupedSplitMap);
        this.setVertexParallelismAndRootInputSpec(this.inputNameInputSpecMap);
    }

    private void processAllSideEvents(String inputName, Multimap<Integer, InputSplit> bucketToGroupedSplitMap) throws IOException {
        ArrayList<InputDataInformationEvent> taskEvents = new ArrayList<InputDataInformationEvent>();
        LOG.info((Object)("We have a routing table and we are going to set the destination tasks for the multi mr inputs. " + this.bucketToTaskMap));
        Integer[] numSplitsForTask = new Integer[this.taskCount];
        LinkedListMultimap<Integer, ByteBuffer> bucketToSerializedSplitMap = LinkedListMultimap.create();
        for (Map.Entry<Integer, Collection<InputSplit>> entry : bucketToGroupedSplitMap.asMap().entrySet()) {
            for (InputSplit split : entry.getValue()) {
                MRRuntimeProtos.MRSplitProto serializedSplit = MRInputHelpers.createSplitProto((InputSplit)split);
                ByteBuffer bs = serializedSplit.toByteString().asReadOnlyByteBuffer();
                bucketToSerializedSplitMap.put(entry.getKey(), bs);
            }
        }
        for (Map.Entry<Integer, Collection<Object>> entry : bucketToSerializedSplitMap.asMap().entrySet()) {
            Collection<Integer> destTasks = this.bucketToTaskMap.get(entry.getKey());
            for (Integer task : destTasks) {
                int count = 0;
                for (ByteBuffer byteBuffer : entry.getValue()) {
                    InputDataInformationEvent diEvent = InputDataInformationEvent.createWithSerializedPayload((int)(++count), (ByteBuffer)byteBuffer);
                    diEvent.setTargetIndex(task.intValue());
                    taskEvents.add(diEvent);
                }
                numSplitsForTask[task.intValue()] = count;
            }
        }
        this.inputNameInputSpecMap.put(inputName, InputSpecUpdate.createPerTaskInputSpecUpdate(Arrays.asList(numSplitsForTask)));
        LOG.info((Object)("For input name: " + inputName + " task events size is " + taskEvents.size()));
        this.context.addRootInputEvents(inputName, taskEvents);
    }

    private void processAllEvents(String inputName, Multimap<Integer, InputSplit> bucketToGroupedSplitMap, boolean secondLevelGroupingDone) throws IOException {
        int totalInputsCount = 0;
        ArrayList<Integer> numSplitsForTask = new ArrayList<Integer>();
        for (Map.Entry<Integer, Collection<InputSplit>> entry : bucketToGroupedSplitMap.asMap().entrySet()) {
            int bucketNum = entry.getKey();
            Collection<InputSplit> initialSplits = entry.getValue();
            this.finalSplits.addAll(initialSplits);
            for (InputSplit inputSplit : initialSplits) {
                this.bucketToTaskMap.put(bucketNum, this.taskCount);
                if (secondLevelGroupingDone) {
                    TezGroupedSplit groupedSplit = (TezGroupedSplit)inputSplit;
                    numSplitsForTask.add(groupedSplit.getGroupedSplits().size());
                    totalInputsCount += groupedSplit.getGroupedSplits().size();
                } else {
                    numSplitsForTask.add(1);
                    ++totalInputsCount;
                }
                ++this.taskCount;
            }
        }
        this.inputNameInputSpecMap.put(inputName, InputSpecUpdate.createPerTaskInputSpecUpdate(numSplitsForTask));
        EdgeManagerPluginDescriptor hiveEdgeManagerDesc = null;
        if (this.vertexType == TezWork.VertexType.MULTI_INPUT_INITIALIZED_EDGES || this.vertexType == TezWork.VertexType.INITIALIZED_EDGES) {
            hiveEdgeManagerDesc = EdgeManagerPluginDescriptor.create((String)CustomPartitionEdge.class.getName());
            UserPayload payload = this.getBytePayload(this.bucketToTaskMap);
            hiveEdgeManagerDesc.setUserPayload(payload);
        }
        for (Map.Entry edgeEntry : this.context.getInputVertexEdgeProperties().entrySet()) {
            if (((EdgeProperty)edgeEntry.getValue()).getDataMovementType() != EdgeProperty.DataMovementType.CUSTOM || !((EdgeProperty)edgeEntry.getValue()).getEdgeManagerDescriptor().getClassName().equals(CustomPartitionEdge.class.getName())) continue;
            this.emMap.put((String)edgeEntry.getKey(), hiveEdgeManagerDesc);
        }
        LOG.info((Object)("Task count is " + this.taskCount + " for input name: " + inputName));
        ArrayList<InputDataInformationEvent> taskEvents = Lists.newArrayListWithCapacity(totalInputsCount);
        int count = 0;
        for (InputSplit inputSplit : this.finalSplits) {
            if (secondLevelGroupingDone) {
                TezGroupedSplit tezGroupedSplit = (TezGroupedSplit)inputSplit;
                for (InputSplit subSplit : tezGroupedSplit.getGroupedSplits()) {
                    if (!(subSplit instanceof TezGroupedSplit)) {
                        throw new IOException("Unexpected split type found: " + subSplit.getClass().getCanonicalName());
                    }
                    MRRuntimeProtos.MRSplitProto serializedSplit = MRInputHelpers.createSplitProto((InputSplit)subSplit);
                    InputDataInformationEvent diEvent = InputDataInformationEvent.createWithSerializedPayload((int)count, (ByteBuffer)serializedSplit.toByteString().asReadOnlyByteBuffer());
                    diEvent.setTargetIndex(count);
                    taskEvents.add(diEvent);
                }
            } else {
                MRRuntimeProtos.MRSplitProto serializedSplit = MRInputHelpers.createSplitProto((InputSplit)inputSplit);
                InputDataInformationEvent diEvent = InputDataInformationEvent.createWithSerializedPayload((int)count, (ByteBuffer)serializedSplit.toByteString().asReadOnlyByteBuffer());
                diEvent.setTargetIndex(count);
                taskEvents.add(diEvent);
            }
            ++count;
        }
        LOG.info((Object)("For input name: " + inputName + " task events size is " + taskEvents.size()));
        this.context.addRootInputEvents(inputName, taskEvents);
        if (!this.inputToGroupedSplitMap.isEmpty()) {
            for (Map.Entry entry : this.inputToGroupedSplitMap.entrySet()) {
                this.processAllSideEvents((String)entry.getKey(), (Multimap)entry.getValue());
            }
            this.setVertexParallelismAndRootInputSpec(this.inputNameInputSpecMap);
            this.inputToGroupedSplitMap.clear();
        }
        if (this.numInputsAffectingRootInputSpecUpdate == 1) {
            this.setVertexParallelismAndRootInputSpec(this.inputNameInputSpecMap);
        }
    }

    private void setVertexParallelismAndRootInputSpec(Map<String, InputSpecUpdate> rootInputSpecUpdate) throws IOException {
        if (this.numInputsAffectingRootInputSpecUpdate != this.numInputsSeenSoFar) {
            return;
        }
        LOG.info((Object)"Setting vertex parallelism since we have seen all inputs.");
        this.context.setVertexParallelism(this.taskCount, VertexLocationHint.create(this.grouper.createTaskLocationHints(this.finalSplits.toArray(new InputSplit[this.finalSplits.size()]))), this.emMap, rootInputSpecUpdate);
        this.finalSplits.clear();
    }

    UserPayload getBytePayload(Multimap<Integer, Integer> routingTable) throws IOException {
        CustomEdgeConfiguration edgeConf = new CustomEdgeConfiguration(routingTable.keySet().size(), routingTable);
        DataOutputBuffer dob = new DataOutputBuffer();
        edgeConf.write((DataOutput)dob);
        byte[] serialized = dob.getData();
        return UserPayload.create((ByteBuffer)ByteBuffer.wrap(serialized));
    }

    private FileSplit getFileSplitFromEvent(InputDataInformationEvent event) throws IOException {
        InputSplit inputSplit = null;
        if (event.getDeserializedUserPayload() != null) {
            inputSplit = (InputSplit)event.getDeserializedUserPayload();
        } else {
            MRRuntimeProtos.MRSplitProto splitProto = MRRuntimeProtos.MRSplitProto.parseFrom((ByteString)ByteString.copyFrom(event.getUserPayload()));
            SerializationFactory serializationFactory = new SerializationFactory(new Configuration());
            inputSplit = MRInputHelpers.createOldFormatSplitFromUserPayload((MRRuntimeProtos.MRSplitProto)splitProto, (SerializationFactory)serializationFactory);
        }
        if (!(inputSplit instanceof FileSplit)) {
            throw new UnsupportedOperationException("Cannot handle splits other than FileSplit for the moment. Current input split type: " + inputSplit.getClass().getSimpleName());
        }
        return (FileSplit)inputSplit;
    }

    private Multimap<Integer, InputSplit> getBucketSplitMapForPath(Map<String, Set<FileSplit>> pathFileSplitsMap) {
        int bucketNum = 0;
        ArrayListMultimap<Integer, InputSplit> bucketToInitialSplitMap = ArrayListMultimap.create();
        for (Map.Entry<String, Set<FileSplit>> entry : pathFileSplitsMap.entrySet()) {
            int bucketId = bucketNum % this.numBuckets;
            for (FileSplit fsplit : entry.getValue()) {
                bucketToInitialSplitMap.put(bucketId, (InputSplit)fsplit);
            }
            ++bucketNum;
        }
        if (bucketNum < this.numBuckets) {
            int loopedBucketId = 0;
            while (bucketNum < this.numBuckets) {
                for (InputSplit fsplit : bucketToInitialSplitMap.get(loopedBucketId)) {
                    bucketToInitialSplitMap.put(bucketNum, fsplit);
                }
                ++loopedBucketId;
                ++bucketNum;
            }
        }
        return bucketToInitialSplitMap;
    }

    public class PathComparatorForSplit
    implements Comparator<InputSplit> {
        @Override
        public int compare(InputSplit inp1, InputSplit inp2) {
            FileSplit fs1 = (FileSplit)inp1;
            FileSplit fs2 = (FileSplit)inp2;
            int retval = fs1.getPath().compareTo((Object)fs2.getPath());
            if (retval != 0) {
                return retval;
            }
            if (fs1.getStart() != fs2.getStart()) {
                return (int)(fs1.getStart() - fs2.getStart());
            }
            return 0;
        }
    }
}

