package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:ai/djl/modality/nlp/generate/SeqBatcher.class */
public class SeqBatcher {
    NDManager manager;
    long batchSize;
    long seqLength;
    NDArray batchUid;
    NDArray offSets;
    BatchTensorList data;
    private Map<Long, Long> exitIndexEndPosition;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    public SeqBatcher(BatchTensorList batchTensorList, NDArray nDArray, NDArray nDArray2, NDManager nDManager) {
        this.manager = nDManager.newSubManager();
        this.data = batchTensorList;
        this.batchUid = nDArray.getShape().dimension() == 2 ? nDArray : nDArray.reshape(-1, 1);
        this.offSets = nDArray2.getShape().hashCode() == 2 ? nDArray2 : nDArray2.reshape(-1, 1);
        this.batchSize = batchTensorList.getPastOutputIds().getShape().get(0);
        this.seqLength = batchTensorList.getPastOutputIds().getShape().get(1);
        this.exitIndexEndPosition = new ConcurrentHashMap();
    }

    public BatchTensorList getData() {
        return this.data;
    }

    public void addBatch(SeqBatcher seqBatcher) {
        merge(this, seqBatcher, this.seqLength - seqBatcher.seqLength);
    }

    private void merge(SeqBatcher seqBatcher, SeqBatcher seqBatcher2, long j) {
        NDIndex nDIndex;
        if (j < 0) {
            seqBatcher = seqBatcher2;
            seqBatcher2 = seqBatcher;
            j = -j;
        }
        NDScope nDScope = new NDScope();
        try {
            nDScope.suppressNotUsedWarning();
            NDList list = seqBatcher.data.getList();
            NDList list2 = seqBatcher2.data.getList();
            NDList nDList = new NDList(list.size());
            long[] seqDimOrder = seqBatcher.data.getSeqDimOrder();
            int i = 0;
            while (i < list.size()) {
                NDArray nDArray = list.get(i);
                NDArray nDArray2 = list2.get(i);
                if (j == 0) {
                    nDList.add(nDArray.concat(nDArray2, 0));
                } else {
                    long[] shape = nDArray.getShape().getShape();
                    long[] shape2 = nDArray2.getShape().getShape();
                    long[] shape3 = nDArray.getShape().getShape();
                    shape3[0] = shape2[0];
                    NDArray concat = nDArray.concat(i == 0 ? this.manager.full(new Shape(shape3), (float) 220, nDArray.getDataType()) : this.manager.zeros(new Shape(shape3), nDArray.getDataType()), 0);
                    if (seqDimOrder[i] > 0) {
                        NDIndex nDIndex2 = new NDIndex("{}:", Long.valueOf(seqBatcher.batchSize));
                        int i2 = 1;
                        while (i2 < seqDimOrder[i]) {
                            nDIndex2 = nDIndex2.addAllDim();
                            i2++;
                        }
                        if (!$assertionsDisabled && j + shape2[i2] != shape[i2]) {
                            throw new AssertionError("Wrong shapes. batch1 and batch2 are not mergable");
                        }
                        nDIndex = nDIndex2.addSliceDim(j, shape[i2]).addEllipseDim();
                    } else {
                        nDIndex = new NDIndex("{}:, ...", Long.valueOf(seqBatcher.batchSize));
                    }
                    concat.set(nDIndex, nDArray2);
                    nDList.add(concat);
                }
                i++;
            }
            this.data = this.data.fromList(nDList, this.data.getSeqDimOrder());
            this.batchSize = seqBatcher.batchSize + seqBatcher2.batchSize;
            this.batchUid = seqBatcher.batchUid.concat(seqBatcher2.batchUid, 0);
            this.offSets = seqBatcher.offSets.concat(seqBatcher2.offSets.addi(Long.valueOf(j)), 0);
            this.seqLength = seqBatcher.seqLength;
            NDScope.unregister(this.batchUid, this.offSets);
            NDScope.unregister(nDList);
            nDScope.close();
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public void exitCriteria(NDArray nDArray, long j, long j2) {
        long[] longArray = nDArray.toLongArray();
        long[] longArray2 = this.offSets.toLongArray();
        for (int i = 0; i < longArray.length; i++) {
            if ((this.seqLength - longArray2[i] >= j || longArray[i] == j2) && !this.exitIndexEndPosition.containsKey(Long.valueOf(i))) {
                this.exitIndexEndPosition.put(Long.valueOf(i), Long.valueOf(this.seqLength));
            }
        }
    }

    public Map<Long, NDArray> collectAndTrim() {
        NDIndex nDIndex;
        if (this.exitIndexEndPosition.isEmpty()) {
            return new ConcurrentHashMap();
        }
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        NDScope nDScope = new NDScope();
        try {
            nDScope.suppressNotUsedWarning();
            HashSet hashSet = new HashSet();
            for (Map.Entry<Long, Long> entry : this.exitIndexEndPosition.entrySet()) {
                long longValue = entry.getKey().longValue();
                long longValue2 = entry.getValue().longValue();
                long j = this.batchUid.getLong(longValue);
                NDArray nDArray = this.data.getPastOutputIds().get("{}, {}:{}", Long.valueOf(longValue), Long.valueOf(this.offSets.getLong(longValue)), Long.valueOf(longValue2));
                concurrentHashMap.put(Long.valueOf(j), nDArray);
                hashSet.add(Long.valueOf(longValue));
                NDScope.unregister(nDArray);
            }
            long[] jArr = new long[Math.toIntExact(this.batchSize) - hashSet.size()];
            int i = 0;
            for (long j2 = 0; j2 < this.batchSize; j2++) {
                if (!hashSet.contains(Long.valueOf(j2))) {
                    int i2 = i;
                    i++;
                    jArr[i2] = j2;
                }
            }
            if (jArr.length == 0) {
                this.batchUid = this.manager.create(new Shape(0, 1), this.batchUid.getDataType());
                this.offSets = this.manager.create(new Shape(0, 1), this.offSets.getDataType());
                this.data = null;
                this.batchSize = 0L;
                this.seqLength = 0L;
                this.exitIndexEndPosition = new ConcurrentHashMap();
                NDScope.unregister(this.batchUid, this.offSets);
                nDScope.close();
                return concurrentHashMap;
            }
            NDIndex nDIndex2 = new NDIndex("{}", this.manager.create(jArr));
            this.batchUid = this.batchUid.get(nDIndex2).reshape(-1, 1);
            this.offSets = this.offSets.get(nDIndex2).reshape(-1, 1);
            long j3 = this.offSets.min(new int[]{0}).toLongArray()[0];
            this.offSets = this.offSets.subi(Long.valueOf(j3));
            NDList list = this.data.getList();
            NDList nDList = new NDList(list.size());
            long[] seqDimOrder = this.data.getSeqDimOrder();
            for (int i3 = 0; i3 < list.size(); i3++) {
                NDArray nDArray2 = list.get(i3);
                if (j3 == 0) {
                    nDList.add(nDArray2.get(new NDIndex("{}, ...", this.manager.create(jArr))));
                } else {
                    if (seqDimOrder[i3] > 0) {
                        NDIndex nDIndex3 = new NDIndex("{}", this.manager.create(jArr));
                        for (int i4 = 1; i4 < seqDimOrder[i3]; i4++) {
                            nDIndex3 = nDIndex3.addAllDim();
                        }
                        nDIndex = nDIndex3.addSliceDim(j3, this.seqLength).addEllipseDim();
                    } else {
                        nDIndex = new NDIndex("{}, ...", this.manager.create(jArr));
                    }
                    nDList.add(nDArray2.get(nDIndex));
                }
            }
            this.data = this.data.fromList(nDList, this.data.getSeqDimOrder());
            this.batchSize -= this.exitIndexEndPosition.size();
            this.seqLength -= j3;
            this.exitIndexEndPosition = new ConcurrentHashMap();
            NDScope.unregister(nDList);
            NDScope.unregister(this.batchUid, this.offSets);
            nDScope.close();
            return concurrentHashMap;
        } catch (Throwable th) {
            try {
                nDScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public boolean sequenceComplete() {
        return !this.exitIndexEndPosition.isEmpty();
    }

    static {
        $assertionsDisabled = !SeqBatcher.class.desiredAssertionStatus();
    }
}
