package ai.djl.timeseries.transform.split;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.transform.InstanceSampler;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;

/* loaded from: input_file:ai/djl/timeseries/transform/split/Split.class */
public final class Split {
    private Split() {
    }

    public static void instanceSplit(NDManager nDManager, FieldName fieldName, FieldName fieldName2, FieldName fieldName3, FieldName fieldName4, InstanceSampler instanceSampler, int i, int i2, int i3, boolean z, FieldName[] fieldNameArr, float f, TimeSeriesData timeSeriesData) {
        NDArray nDArray;
        NDArray nDArray2;
        ArrayList<FieldName> arrayList = new ArrayList(fieldNameArr.length + 1);
        arrayList.addAll(Arrays.asList(fieldNameArr));
        arrayList.add(fieldName);
        NDArray nDArray3 = timeSeriesData.get(fieldName);
        Iterator<Integer> it2 = instanceSampler.call(nDArray3).subList(0, 1).iterator();
        while (it2.hasNext()) {
            int intValue = it2.next().intValue();
            int max = Math.max(i - intValue, 0);
            for (FieldName fieldName5 : arrayList) {
                NDArray nDArray4 = timeSeriesData.get(fieldName5);
                if (intValue > i) {
                    nDArray = nDArray4.get("..., {}:{}", Integer.valueOf(intValue - i), Integer.valueOf(intValue));
                } else if (intValue < i) {
                    Shape shape = nDArray4.getShape();
                    NDArray full = nDManager.full(shape.slice(0, shape.dimension() - 1).add(max), f, nDArray4.getDataType());
                    nDArray = intValue == 0 ? full : full.concat(nDArray4.get("..., :{}", Integer.valueOf(intValue)), -1);
                } else {
                    nDArray = nDArray4.get("..., :{}", Integer.valueOf(intValue));
                }
                timeSeriesData.setField(past(fieldName5), nDArray);
                if (intValue + i3 >= ((int) nDArray4.getShape().tail())) {
                    Shape shape2 = nDArray4.getShape();
                    nDArray2 = nDManager.create(shape2.slice(0, shape2.dimension() - 1).add(0));
                } else {
                    nDArray2 = nDArray4.get("..., {}:{}", Integer.valueOf(intValue + i3), Integer.valueOf(intValue + i3 + i2));
                }
                timeSeriesData.setField(future(fieldName5), nDArray2);
                timeSeriesData.remove(fieldName5);
            }
            NDArray zeros = nDManager.zeros(new Shape(i), nDArray3.getDataType());
            if (max > 0) {
                zeros.set(new NDIndex(":{}", Integer.valueOf(max)), (Number) 1);
            }
            if (z) {
                for (FieldName fieldName6 : arrayList) {
                    timeSeriesData.setField(past(fieldName6), timeSeriesData.get((TimeSeriesData) past(fieldName6)).transpose());
                    timeSeriesData.setField(future(fieldName6), timeSeriesData.get((TimeSeriesData) future(fieldName6)).transpose());
                }
            }
            timeSeriesData.setField(past(fieldName2), zeros);
            timeSeriesData.setForecastStartTime(timeSeriesData.getStartTime().plusDays(intValue + i3));
        }
    }

    public static void instanceSplit(NDManager nDManager, FieldName fieldName, FieldName fieldName2, FieldName fieldName3, FieldName fieldName4, InstanceSampler instanceSampler, int i, int i2, FieldName[] fieldNameArr, float f, TimeSeriesData timeSeriesData) {
        instanceSplit(nDManager, fieldName, fieldName2, fieldName3, fieldName4, instanceSampler, i, i2, 0, true, fieldNameArr, f, timeSeriesData);
    }

    private static String past(FieldName fieldName) {
        return "PAST_" + fieldName.name();
    }

    private static String future(FieldName fieldName) {
        return "FUTURE_" + fieldName.name();
    }
}
