package ai.djl.timeseries;

import ai.djl.ndarray.NDArray;
import java.time.LocalDateTime;

/* loaded from: input_file:ai/djl/timeseries/SampleForecast.class */
public class SampleForecast extends Forecast {
    private NDArray samples;
    private int numSamples;

    public SampleForecast(NDArray nDArray, LocalDateTime localDateTime, String str) {
        super(localDateTime, (int) nDArray.getShape().get(1), str);
        this.samples = nDArray;
        this.numSamples = (int) nDArray.getShape().head();
    }

    public NDArray getSortedSamples() {
        return this.samples.sort(0);
    }

    public int getNumSamples() {
        return this.numSamples;
    }

    @Override // ai.djl.timeseries.Forecast
    public NDArray quantile(float f) {
        return getSortedSamples().get("{}, :", Integer.valueOf(Math.round((this.numSamples - 1) * f)));
    }

    public SampleForecast copyDim(int i) {
        NDArray nDArray;
        if (this.samples.getShape().dimension() == 2) {
            nDArray = this.samples;
        } else {
            int i2 = (int) this.samples.getShape().get(2);
            if (i >= i2) {
                throw new IllegalArgumentException(String.format("must set 0 <= dim < target_dim, but got dim=%d, target_dim=%d", Integer.valueOf(i), Integer.valueOf(i2)));
            }
            nDArray = this.samples.get(":, :, {}", Integer.valueOf(i));
        }
        return new SampleForecast(nDArray, this.startDate, this.freq);
    }

    @Override // ai.djl.timeseries.Forecast
    public NDArray mean() {
        return this.samples.mean(new int[]{0});
    }
}
