package ai.djl.timeseries.transform.feature;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import com.ibm.icu.text.DateFormat;
import java.time.Duration;
import java.time.LocalDateTime;
import java.time.Period;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiFunction;

/* loaded from: input_file:ai/djl/timeseries/transform/feature/Feature.class */
public final class Feature {
    private Feature() {
    }

    public static void addObservedValuesIndicator(NDManager nDManager, FieldName fieldName, FieldName fieldName2, TimeSeriesData timeSeriesData) {
        NDArray nDArray = timeSeriesData.get(fieldName);
        timeSeriesData.setField(fieldName, dummyValueImputation(nDManager, nDArray, 0.0f));
        timeSeriesData.setField(fieldName2, nDArray.isNaN().logicalNot().toType(nDArray.getDataType(), false));
    }

    public static void addTimeFeature(NDManager nDManager, FieldName fieldName, FieldName fieldName2, FieldName fieldName3, List<BiFunction<NDManager, List<LocalDateTime>, NDArray>> list, int i, String str, TimeSeriesData timeSeriesData) {
        addTimeFeature(nDManager, fieldName, fieldName2, fieldName3, list, i, str, timeSeriesData, false);
    }

    public static void addTimeFeature(NDManager nDManager, FieldName fieldName, FieldName fieldName2, FieldName fieldName3, List<BiFunction<NDManager, List<LocalDateTime>, NDArray>> list, int i, String str, TimeSeriesData timeSeriesData, boolean z) {
        if (list.isEmpty()) {
            timeSeriesData.setField(fieldName3, (NDArray) null);
        }
        LocalDateTime startTime = timeSeriesData.getStartTime();
        int targetTransformationLength = targetTransformationLength(timeSeriesData.get(fieldName2), i, z);
        StringBuilder sb = new StringBuilder();
        sb.append(str);
        if (!str.matches("\\d+.*")) {
            sb.insert(0, 1);
        }
        sb.insert(0, "P");
        String sb2 = sb.toString();
        TemporalAmount parse = (str.endsWith(DateFormat.HOUR24) || str.endsWith("T") || str.endsWith("S")) ? Duration.parse(sb2) : Period.parse(sb2);
        ArrayList arrayList = new ArrayList();
        LocalDateTime localDateTime = startTime;
        for (int i2 = 0; i2 < targetTransformationLength; i2++) {
            arrayList.add(localDateTime);
            localDateTime = localDateTime.plus(parse);
        }
        NDList nDList = new NDList(list.size());
        Iterator<BiFunction<NDManager, List<LocalDateTime>, NDArray>> it2 = list.iterator();
        while (it2.hasNext()) {
            nDList.add(it2.next().apply(nDManager, arrayList));
        }
        timeSeriesData.setField(fieldName3, NDArrays.stack(nDList));
    }

    public static TimeSeriesData addAgeFeature(NDManager nDManager, FieldName fieldName, FieldName fieldName2, int i, boolean z, TimeSeriesData timeSeriesData) {
        return addAgeFeature(nDManager, fieldName, fieldName2, i, z, timeSeriesData, false);
    }

    public static TimeSeriesData addAgeFeature(NDManager nDManager, FieldName fieldName, FieldName fieldName2, int i, boolean z, TimeSeriesData timeSeriesData, boolean z2) {
        NDArray nDArray = timeSeriesData.get(fieldName);
        int targetTransformationLength = targetTransformationLength(nDArray, i, z2);
        NDArray arange = nDManager.arange(0, targetTransformationLength, 1, nDArray.getDataType());
        if (z) {
            arange = arange.add(Float.valueOf(2.0f)).log10();
        }
        timeSeriesData.setField(fieldName2, arange.reshape(new Shape(1, targetTransformationLength)));
        return timeSeriesData;
    }

    public static void addAgeFeature(NDManager nDManager, FieldName fieldName, FieldName fieldName2, int i, TimeSeriesData timeSeriesData) {
        addAgeFeature(nDManager, fieldName, fieldName2, i, true, timeSeriesData);
    }

    private static int targetTransformationLength(NDArray nDArray, int i, boolean z) {
        return ((int) nDArray.getShape().tail()) + (z ? 0 : i);
    }

    private static NDArray dummyValueImputation(NDManager nDManager, NDArray nDArray, float f) {
        return NDArrays.where(nDArray.isNaN(), nDManager.full(nDArray.getShape(), f), nDArray);
    }
}
