package org.apache.flink.runtime.iterative.task;

import java.io.IOException;
import java.io.Serializable;
import java.util.Map;
import java.util.concurrent.Future;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.aggregators.LongSumAggregator;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.api.common.operators.util.JoinHashMap;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.core.fs.Path;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.io.network.api.reader.MutableReader;
import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannel;
import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannelBroker;
import org.apache.flink.runtime.iterative.concurrent.IterationAggregatorBroker;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetBroker;
import org.apache.flink.runtime.iterative.convergence.WorksetEmptyConvergenceCriterion;
import org.apache.flink.runtime.iterative.io.SolutionSetObjectsUpdateOutputCollector;
import org.apache.flink.runtime.iterative.io.SolutionSetUpdateOutputCollector;
import org.apache.flink.runtime.iterative.io.WorksetUpdateOutputCollector;
import org.apache.flink.runtime.operators.BatchTask;
import org.apache.flink.runtime.operators.Driver;
import org.apache.flink.runtime.operators.ResettableDriver;
import org.apache.flink.runtime.operators.hash.CompactingHashTable;
import org.apache.flink.runtime.operators.util.DistributedRuntimeUDFContext;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.MutableObjectIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/iterative/task/AbstractIterativeTask.class */
public abstract class AbstractIterativeTask<S extends Function, OT> extends BatchTask<S, OT> implements Terminable {
    private static final Logger log = LoggerFactory.getLogger(AbstractIterativeTask.class);
    protected LongSumAggregator worksetAggregator;
    protected BlockingBackChannel worksetBackChannel;
    protected boolean isWorksetIteration;
    protected boolean isWorksetUpdate;
    protected boolean isSolutionSetUpdate;
    private RuntimeAggregatorRegistry iterationAggregators;
    private String brokerKey;
    private int superstepNum = 1;
    private volatile boolean terminationRequested;

    /* loaded from: input_file:org/apache/flink/runtime/iterative/task/AbstractIterativeTask$IterativeRuntimeUdfContext.class */
    private class IterativeRuntimeUdfContext extends DistributedRuntimeUDFContext implements IterationRuntimeContext {
        public IterativeRuntimeUdfContext(TaskInfo taskInfo, ClassLoader classLoader, ExecutionConfig executionConfig, Map<String, Future<Path>> map, Map<String, Accumulator<?, ?>> map2, MetricGroup metricGroup) {
            super(taskInfo, classLoader, executionConfig, map, map2, metricGroup);
        }

        @Override // org.apache.flink.api.common.functions.IterationRuntimeContext
        public int getSuperstepNumber() {
            return AbstractIterativeTask.this.superstepNum;
        }

        @Override // org.apache.flink.api.common.functions.IterationRuntimeContext
        public <T extends Aggregator<?>> T getIterationAggregator(String str) {
            return (T) AbstractIterativeTask.this.getIterationAggregators().getAggregator(str);
        }

        @Override // org.apache.flink.api.common.functions.IterationRuntimeContext
        public <T extends Value> T getPreviousIterationAggregate(String str) {
            return (T) AbstractIterativeTask.this.getIterationAggregators().getPreviousGlobalAggregate(str);
        }

        @Override // org.apache.flink.api.common.functions.util.AbstractRuntimeUDFContext, org.apache.flink.api.common.functions.RuntimeContext
        public <V, A extends Serializable> void addAccumulator(String str, Accumulator<V, A> accumulator) {
            if (AbstractIterativeTask.this.inFirstIteration()) {
                super.addAccumulator(str, accumulator);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.flink.runtime.operators.BatchTask
    public void initialize() throws Exception {
        super.initialize();
        if (this.driver instanceof ResettableDriver) {
            ResettableDriver resettableDriver = (ResettableDriver) this.driver;
            for (int i = 0; i < resettableDriver.getNumberOfInputs(); i++) {
                if (resettableDriver.isInputResettable(i)) {
                    excludeFromReset(i);
                }
            }
        }
        TaskConfig lastTasksConfig = getLastTasksConfig();
        this.isWorksetIteration = lastTasksConfig.getIsWorksetIteration();
        this.isWorksetUpdate = lastTasksConfig.getIsWorksetUpdate();
        this.isSolutionSetUpdate = lastTasksConfig.getIsSolutionSetUpdate();
        if (this.isWorksetUpdate) {
            this.worksetBackChannel = BlockingBackChannelBroker.instance().getAndRemove(brokerKey());
            if (this.isWorksetIteration) {
                this.worksetAggregator = (LongSumAggregator) getIterationAggregators().getAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME);
                if (this.worksetAggregator == null) {
                    throw new RuntimeException("Missing workset elements count aggregator.");
                }
            }
        }
    }

    @Override // org.apache.flink.runtime.operators.BatchTask
    public void run() throws Exception {
        if (!inFirstIteration()) {
            reinstantiateDriver();
            resetAllInputs();
            for (int i : this.iterativeBroadcastInputs) {
                readAndSetBroadcastInput(i, getTaskConfig().getBroadcastInputName(i), this.runtimeUdfContext, this.superstepNum);
            }
        } else if (this.driver instanceof ResettableDriver) {
            ((ResettableDriver) this.driver).initialize();
        }
        super.run();
        for (int i2 : this.iterativeBroadcastInputs) {
            releaseBroadcastVariables(getTaskConfig().getBroadcastInputName(i2), this.superstepNum, this.runtimeUdfContext);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.flink.runtime.operators.BatchTask
    public void closeLocalStrategiesAndCaches() {
        try {
            super.closeLocalStrategiesAndCaches();
            if (this.driver instanceof ResettableDriver) {
                try {
                    ((ResettableDriver) this.driver).teardown();
                } catch (Throwable th) {
                    log.error("Error while shutting down an iterative operator.", th);
                }
            }
        } catch (Throwable th2) {
            if (this.driver instanceof ResettableDriver) {
                try {
                    ((ResettableDriver) this.driver).teardown();
                } catch (Throwable th3) {
                    log.error("Error while shutting down an iterative operator.", th3);
                }
            }
            throw th2;
        }
    }

    @Override // org.apache.flink.runtime.operators.BatchTask
    public DistributedRuntimeUDFContext createRuntimeContext(MetricGroup metricGroup) {
        Environment environment = getEnvironment();
        return new IterativeRuntimeUdfContext(environment.getTaskInfo(), getUserCodeClassLoader(), getExecutionConfig(), environment.getDistributedCacheEntries(), this.accumulatorMap, metricGroup);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean inFirstIteration() {
        return this.superstepNum == 1;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int currentIteration() {
        return this.superstepNum;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void incrementIterationCounter() {
        this.superstepNum++;
    }

    public String brokerKey() {
        if (this.brokerKey == null) {
            this.brokerKey = getEnvironment().getJobID().toString() + '#' + this.config.getIterationId() + '#' + getEnvironment().getTaskInfo().getIndexOfThisSubtask();
        }
        return this.brokerKey;
    }

    private void reinstantiateDriver() throws Exception {
        if (this.driver instanceof ResettableDriver) {
            ((ResettableDriver) this.driver).reset();
            return;
        }
        this.driver = (Driver) InstantiationUtil.instantiate(this.config.getDriver(), Driver.class);
        try {
            this.driver.setup(this);
        } catch (Throwable th) {
            throw new Exception("The pact driver setup for '" + getEnvironment().getTaskInfo().getTaskName() + "' , caused an error: " + th.getMessage(), th);
        }
    }

    public RuntimeAggregatorRegistry getIterationAggregators() {
        if (this.iterationAggregators == null) {
            this.iterationAggregators = IterationAggregatorBroker.instance().get(brokerKey());
        }
        return this.iterationAggregators;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void verifyEndOfSuperstepState() throws IOException {
        Object next;
        if (this.iterativeInputs.length == 0 && this.iterativeBroadcastInputs.length == 0) {
            throw new IllegalStateException("Error: Iterative task without a single iterative input.");
        }
        for (int i : this.iterativeInputs) {
            MutableReader<?> mutableReader = this.inputReaders[i];
            if (!mutableReader.isFinished()) {
                if (mutableReader.hasReachedEndOfSuperstep()) {
                    mutableReader.startNextSuperstep();
                } else {
                    MutableObjectIterator<?> mutableObjectIterator = this.inputIterators[i];
                    Object createInstance2 = this.inputSerializers[i].getSerializer().createInstance2();
                    do {
                        next = mutableObjectIterator.next(createInstance2);
                        createInstance2 = next;
                    } while (next != null);
                    if (!mutableReader.isFinished()) {
                        mutableReader.startNextSuperstep();
                    }
                }
            }
        }
        for (int i2 : this.iterativeBroadcastInputs) {
            MutableReader<?> mutableReader2 = this.broadcastInputReaders[i2];
            if (!mutableReader2.isFinished()) {
                if (!mutableReader2.hasReachedEndOfSuperstep()) {
                    throw new IllegalStateException("An iterative broadcast input has not been fully consumed.");
                }
                mutableReader2.startNextSuperstep();
            }
        }
    }

    @Override // org.apache.flink.runtime.iterative.task.Terminable
    public boolean terminationRequested() {
        return this.terminationRequested;
    }

    @Override // org.apache.flink.runtime.iterative.task.Terminable
    public void requestTermination() {
        this.terminationRequested = true;
    }

    @Override // org.apache.flink.runtime.operators.BatchTask, org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable
    public void cancel() throws Exception {
        requestTermination();
        super.cancel();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Collector<OT> createWorksetUpdateOutputCollector(Collector<OT> collector) {
        return new WorksetUpdateOutputCollector(this.worksetBackChannel.getWriteEnd(), getOutputSerializer(), collector);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Collector<OT> createWorksetUpdateOutputCollector() {
        return createWorksetUpdateOutputCollector(null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Collector<OT> createSolutionSetUpdateOutputCollector(Collector<OT> collector) {
        Object obj = SolutionSetBroker.instance().get(brokerKey());
        if (obj instanceof CompactingHashTable) {
            return new SolutionSetUpdateOutputCollector((CompactingHashTable) obj, collector);
        }
        if (obj instanceof JoinHashMap) {
            return new SolutionSetObjectsUpdateOutputCollector((JoinHashMap) obj, collector);
        }
        throw new RuntimeException("Unrecognized solution set handle: " + obj);
    }

    private TypeSerializer<OT> getOutputSerializer() {
        TypeSerializerFactory outputSerializer = getLastTasksConfig().getOutputSerializer(getUserCodeClassLoader());
        if (outputSerializer == null) {
            throw new RuntimeException("Missing output serializer for workset update.");
        }
        return outputSerializer.getSerializer();
    }
}
