package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;

/* loaded from: input_file:ai/djl/training/dataset/BulkDataIterable.class */
public class BulkDataIterable extends DataIterable {
    public BulkDataIterable(ArrayDataset arrayDataset, NDManager nDManager, Sampler sampler, Batchifier batchifier, Batchifier batchifier2, Pipeline pipeline, Pipeline pipeline2, ExecutorService executorService, int i, Device device) {
        super(arrayDataset, nDManager, sampler, batchifier, batchifier2, pipeline, pipeline2, executorService, i, device);
    }

    @Override // ai.djl.training.dataset.DataIterable
    protected Batch fetch(List<Long> list, int i) throws IOException {
        Batch byIndices;
        NDManager newSubManager = this.manager.newSubManager();
        newSubManager.setName("dataIter fetch");
        int size = list.size();
        if (isRange(list)) {
            long longValue = list.get(0).longValue();
            byIndices = ((ArrayDataset) this.dataset).getByRange(newSubManager, longValue, longValue + list.size());
        } else {
            byIndices = ((ArrayDataset) this.dataset).getByIndices(newSubManager, list.stream().mapToLong((v0) -> {
                return v0.longValue();
            }).toArray());
        }
        NDList data = byIndices.getData();
        if (this.pipeline != null) {
            data = this.pipeline.transform(data);
        }
        NDList labels = byIndices.getLabels();
        if (this.targetPipeline != null) {
            labels = this.targetPipeline.transform(labels);
        }
        if (this.device != null) {
            data = data.toDevice(this.device, false);
            labels = labels.toDevice(this.device, false);
        }
        return new Batch(newSubManager, data, labels, size, this.dataBatchifier, this.labelBatchifier, i, this.dataset.size(), list);
    }

    public static boolean isRange(List<Long> list) {
        if (list.isEmpty()) {
            return false;
        }
        long longValue = list.get(0).longValue();
        Iterator<Long> it2 = list.iterator();
        while (it2.hasNext()) {
            long j = longValue;
            longValue = j + 1;
            if (it2.next().longValue() != j) {
                return false;
            }
        }
        return true;
    }
}
