package ai.djl.modality.nlp.generate;

import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/modality/nlp/generate/SeqBatchScheduler.class */
public abstract class SeqBatchScheduler {
    private static final Logger logger = LoggerFactory.getLogger(SeqBatchScheduler.class);
    Predictor<NDList, CausalLMOutput> predictor;
    SeqBatcher seqBatcher;
    NDManager manager;
    SearchConfig config;
    Map<Long, NDArray> results = new ConcurrentHashMap();

    public SeqBatchScheduler(Predictor<NDList, CausalLMOutput> predictor, SearchConfig searchConfig) {
        this.predictor = predictor;
        this.config = searchConfig;
    }

    public abstract SeqBatcher initForward(NDArray nDArray, NDArray nDArray2) throws TranslateException;

    public boolean incrementForward(int i) throws TranslateException {
        int i2 = 0;
        while (true) {
            int i3 = i2;
            i2++;
            if (i3 >= i) {
                return false;
            }
            if (this.seqBatcher == null || this.seqBatcher.getData() == null) {
                break;
            }
            inferenceCall();
            if (this.seqBatcher.sequenceComplete()) {
                this.results.putAll(this.seqBatcher.collectAndTrim());
            }
        }
        logger.info("seqBatcher not set or is empty. Please call addBatch. Current inference times is " + i2);
        return true;
    }

    protected abstract NDArray inferenceCall() throws TranslateException;

    public void addRequest(NDArray nDArray, NDArray nDArray2) throws TranslateException {
        SeqBatcher initForward = initForward(nDArray, nDArray2);
        if (this.seqBatcher == null) {
            this.seqBatcher = initForward;
        } else {
            this.seqBatcher.addBatch(initForward);
        }
    }

    public Map<Long, NDArray> collectResults() {
        Map<Long, NDArray> map = this.results;
        this.results = new ConcurrentHashMap();
        return map;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static NDArray computeOffSets(NDArray nDArray, SearchConfig searchConfig) {
        int intExact = Math.toIntExact(nDArray.getShape().get(0));
        int intExact2 = Math.toIntExact(nDArray.getShape().get(1));
        long[] jArr = new long[intExact];
        for (int i = 0; i < intExact; i++) {
            long[] longArray = nDArray.get("{},:", Integer.valueOf(i)).toLongArray();
            int i2 = 0;
            while (i2 < intExact2 && longArray[i2] == searchConfig.getPadTokenId()) {
                i2++;
            }
            jArr[i] = i2;
        }
        return nDArray.getManager().create(jArr).reshape(-1, 1);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static NDArray computeAttentionMask(NDArray nDArray, SearchConfig searchConfig) {
        int intExact = Math.toIntExact(nDArray.getShape().get(0));
        int intExact2 = Math.toIntExact(nDArray.getShape().get(1));
        NDArray repeat = nDArray.getManager().ones(new Shape(1, nDArray.getShape().getLastDimension()), DataType.INT64).reshape(1, -1).repeat(0, intExact);
        for (int i = 0; i < intExact; i++) {
            long[] longArray = nDArray.get("{},:", Integer.valueOf(i)).toLongArray();
            int i2 = 0;
            while (i2 < intExact2 && longArray[i2] == searchConfig.getPadTokenId()) {
                i2++;
            }
            repeat.set(new NDIndex("{},{}:{}", Integer.valueOf(i), 0, Integer.valueOf(i2)), (Number) 0);
        }
        return repeat;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static NDArray computePositionIds(NDArray nDArray, NDArray nDArray2, long j, int i) {
        NDArray subi = nDArray.getManager().arange((float) j, (float) (j + nDArray.getShape().getLastDimension()), 1.0f, DataType.INT64).expandDims(0).repeat(0, nDArray.getShape().get(0)).subi(nDArray2.reshape(-1, 1).repeat(0, i));
        return subi.maximum(subi.zerosLike());
    }
}
