/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.slowtaskdetector;

import java.time.Duration;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.SlowTaskDetectorOptions;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.scheduler.slowtaskdetector.SlowTaskDetector;
import org.apache.flink.runtime.scheduler.slowtaskdetector.SlowTaskDetectorListener;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.util.IterableUtils;
import org.apache.flink.util.Preconditions;

public class ExecutionTimeBasedSlowTaskDetector
implements SlowTaskDetector {
    private final long checkIntervalMillis;
    private final long baselineLowerBoundMillis;
    private final double baselineRatio;
    private final double baselineMultiplier;
    private ScheduledFuture<?> scheduledDetectionFuture;

    public ExecutionTimeBasedSlowTaskDetector(Configuration configuration) {
        this.checkIntervalMillis = ((Duration)configuration.get(SlowTaskDetectorOptions.CHECK_INTERVAL)).toMillis();
        Preconditions.checkArgument((this.checkIntervalMillis > 0L ? 1 : 0) != 0, (String)"The configuration {} should be positive, but is {}.", (Object[])new Object[]{SlowTaskDetectorOptions.CHECK_INTERVAL.key(), this.checkIntervalMillis});
        this.baselineLowerBoundMillis = ((Duration)configuration.get(SlowTaskDetectorOptions.EXECUTION_TIME_BASELINE_LOWER_BOUND)).toMillis();
        Preconditions.checkArgument((this.baselineLowerBoundMillis >= 0L ? 1 : 0) != 0, (String)"The configuration {} cannot be negative, but is {}.", (Object[])new Object[]{SlowTaskDetectorOptions.EXECUTION_TIME_BASELINE_LOWER_BOUND.key(), this.baselineLowerBoundMillis});
        this.baselineRatio = configuration.getDouble(SlowTaskDetectorOptions.EXECUTION_TIME_BASELINE_RATIO);
        Preconditions.checkArgument((this.baselineRatio >= 0.0 && this.baselineRatio < 1.0 ? 1 : 0) != 0, (String)"The configuration {} should be in [0, 1), but is {}.", (Object[])new Object[]{SlowTaskDetectorOptions.EXECUTION_TIME_BASELINE_RATIO.key(), this.baselineRatio});
        this.baselineMultiplier = configuration.getDouble(SlowTaskDetectorOptions.EXECUTION_TIME_BASELINE_MULTIPLIER);
        Preconditions.checkArgument((this.baselineMultiplier > 0.0 ? 1 : 0) != 0, (String)"The configuration {} should be positive, but is {}.", (Object[])new Object[]{SlowTaskDetectorOptions.EXECUTION_TIME_BASELINE_MULTIPLIER.key(), this.baselineMultiplier});
    }

    @Override
    public void start(ExecutionGraph executionGraph, SlowTaskDetectorListener listener, ComponentMainThreadExecutor mainThreadExecutor) {
        this.scheduleTask(executionGraph, listener, mainThreadExecutor);
    }

    private void scheduleTask(ExecutionGraph executionGraph, SlowTaskDetectorListener listener, ComponentMainThreadExecutor mainThreadExecutor) {
        this.scheduledDetectionFuture = mainThreadExecutor.schedule(() -> {
            listener.notifySlowTasks(this.findSlowTasks(executionGraph));
            this.scheduleTask(executionGraph, listener, mainThreadExecutor);
        }, this.checkIntervalMillis, TimeUnit.MILLISECONDS);
    }

    @VisibleForTesting
    Map<ExecutionVertexID, Collection<ExecutionAttemptID>> findSlowTasks(ExecutionGraph executionGraph) {
        long currentTimeMillis = System.currentTimeMillis();
        HashMap<ExecutionVertexID, Collection<ExecutionAttemptID>> slowTasks = new HashMap<ExecutionVertexID, Collection<ExecutionAttemptID>>();
        List<ExecutionJobVertex> jobVerticesToCheck = this.getJobVerticesToCheck(executionGraph);
        for (ExecutionJobVertex ejv : jobVerticesToCheck) {
            ExecutionTimeWithInputBytes baseline = this.getBaseline(ejv, currentTimeMillis);
            for (ExecutionVertex ev : ejv.getTaskVertices()) {
                List<ExecutionAttemptID> slowExecutions;
                if (ev.getExecutionState().isTerminal() || (slowExecutions = this.findExecutionsExceedingBaseline(ev.getCurrentExecutions(), baseline, currentTimeMillis)).isEmpty()) continue;
                slowTasks.put(ev.getID(), slowExecutions);
            }
        }
        return slowTasks;
    }

    private List<ExecutionJobVertex> getJobVerticesToCheck(ExecutionGraph executionGraph) {
        return IterableUtils.toStream(executionGraph.getVerticesTopologically()).filter(ExecutionJobVertex::isInitialized).filter(ejv -> ejv.getAggregateState() != ExecutionState.FINISHED).filter(ejv -> this.getFinishedRatio((ExecutionJobVertex)ejv) >= this.baselineRatio).collect(Collectors.toList());
    }

    private double getFinishedRatio(ExecutionJobVertex executionJobVertex) {
        Preconditions.checkState((executionJobVertex.getTaskVertices().length > 0 ? 1 : 0) != 0);
        long finishedCount = Arrays.stream(executionJobVertex.getTaskVertices()).filter(ev -> ev.getExecutionState() == ExecutionState.FINISHED).count();
        return (double)finishedCount / (double)executionJobVertex.getTaskVertices().length;
    }

    private ExecutionTimeWithInputBytes getBaseline(ExecutionJobVertex executionJobVertex, long currentTimeMillis) {
        ExecutionTimeWithInputBytes weightedExecutionTimeMedian = this.calculateFinishedTaskExecutionTimeMedian(executionJobVertex, currentTimeMillis);
        long multipliedBaseline = (long)((double)weightedExecutionTimeMedian.getExecutionTime() * this.baselineMultiplier);
        return new ExecutionTimeWithInputBytes(multipliedBaseline, weightedExecutionTimeMedian.getInputBytes());
    }

    private ExecutionTimeWithInputBytes calculateFinishedTaskExecutionTimeMedian(ExecutionJobVertex executionJobVertex, long currentTime) {
        int baselineExecutionCount = (int)Math.round((double)executionJobVertex.getParallelism() * this.baselineRatio);
        if (baselineExecutionCount == 0) {
            return new ExecutionTimeWithInputBytes(0L, -1L);
        }
        List finishedExecutions = Arrays.stream(executionJobVertex.getTaskVertices()).flatMap(ev -> ev.getCurrentExecutions().stream()).filter(e -> e.getState() == ExecutionState.FINISHED).collect(Collectors.toList());
        Preconditions.checkState((finishedExecutions.size() >= baselineExecutionCount ? 1 : 0) != 0);
        List firstFinishedExecutions = finishedExecutions.stream().map(e -> this.getExecutionTimeAndInputBytes((Execution)e, currentTime)).sorted().limit(baselineExecutionCount).collect(Collectors.toList());
        return (ExecutionTimeWithInputBytes)firstFinishedExecutions.get(baselineExecutionCount / 2);
    }

    private List<ExecutionAttemptID> findExecutionsExceedingBaseline(Collection<Execution> executions, ExecutionTimeWithInputBytes baseline, long currentTimeMillis) {
        return executions.stream().filter(e -> !e.getState().isTerminal() && e.getState() != ExecutionState.CANCELING).filter(e -> {
            ExecutionTimeWithInputBytes timeWithBytes = this.getExecutionTimeAndInputBytes((Execution)e, currentTimeMillis);
            return timeWithBytes.getExecutionTime() >= this.baselineLowerBoundMillis && timeWithBytes.compareTo(baseline) >= 0;
        }).map(Execution::getAttemptId).collect(Collectors.toList());
    }

    private long getExecutionTime(Execution execution, long currentTime) {
        long deployingTimestamp = execution.getStateTimestamp(ExecutionState.DEPLOYING);
        if (deployingTimestamp == 0L) {
            return 0L;
        }
        if (execution.getState() == ExecutionState.FINISHED) {
            return execution.getStateTimestamp(ExecutionState.FINISHED) - deployingTimestamp;
        }
        return currentTime - deployingTimestamp;
    }

    private long getExecutionInputBytes(Execution execution) {
        return execution.getVertex().getInputBytes();
    }

    private ExecutionTimeWithInputBytes getExecutionTimeAndInputBytes(Execution execution, long currentTime) {
        long executionTime = this.getExecutionTime(execution, currentTime);
        long executionInputBytes = this.getExecutionInputBytes(execution);
        return new ExecutionTimeWithInputBytes(executionTime, executionInputBytes);
    }

    @Override
    public void stop() {
        if (this.scheduledDetectionFuture != null) {
            this.scheduledDetectionFuture.cancel(false);
        }
    }

    @VisibleForTesting
    static class ExecutionTimeWithInputBytes
    implements Comparable<ExecutionTimeWithInputBytes> {
        private final long executionTime;
        private final long inputBytes;

        public ExecutionTimeWithInputBytes(long executionTime, long inputBytes) {
            this.executionTime = executionTime;
            this.inputBytes = inputBytes;
        }

        public long getExecutionTime() {
            return this.executionTime;
        }

        public long getInputBytes() {
            return this.inputBytes;
        }

        @Override
        public int compareTo(ExecutionTimeWithInputBytes other) {
            if (this.inputBytes == -1L || other.getInputBytes() == -1L) {
                if (this.inputBytes == -1L && other.getInputBytes() == -1L || this.executionTime == 0L || other.executionTime == 0L) {
                    return (int)(this.executionTime - other.getExecutionTime());
                }
                throw new IllegalArgumentException("Both compared elements should be NUM_BYTES_UNKNOWN.");
            }
            return Double.compare((double)this.executionTime / Math.max((double)this.inputBytes, Double.MIN_VALUE), (double)other.getExecutionTime() / Math.max((double)other.getInputBytes(), Double.MIN_VALUE));
        }
    }
}

