package ai.djl.timeseries.evaluator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.timeseries.distribution.output.DistributionOutput;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.util.Pair;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:ai/djl/timeseries/evaluator/Rmsse.class */
public class Rmsse extends Evaluator {
    private DistributionOutput distributionOutput;
    private int axis;
    private Map<String, Float> totalLoss;

    public Rmsse(DistributionOutput distributionOutput) {
        this("RMSSE", 1, distributionOutput);
    }

    public Rmsse(String str, int i, DistributionOutput distributionOutput) {
        super(str);
        this.axis = i;
        this.distributionOutput = distributionOutput;
        this.totalLoss = new ConcurrentHashMap();
    }

    /* JADX WARN: Type inference failed for: r0v5, types: [ai.djl.timeseries.distribution.Distribution$DistributionBuilder] */
    protected Pair<Long, NDArray> evaluateHelper(NDList nDList, NDList nDList2) {
        NDArray head = nDList.head();
        NDArray mean = this.distributionOutput.distributionBuilder().setDistrArgs(nDList2).build().mean();
        checkLabelShapes(head, mean);
        NDArray mean2 = head.sub(mean).square().mean(new int[]{this.axis});
        NDArray mean3 = head.get(":, 1:", new Object[0]).sub(head.get(":, :-1", new Object[0])).square().mean(new int[]{this.axis});
        NDArray sqrt = mean2.div(mean3).sqrt();
        NDArray where = NDArrays.where(mean3.eq((Number) 0), sqrt.onesLike(), sqrt);
        return new Pair<>(Long.valueOf(where.countNonzero().getLong(new long[0])), where);
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        return evaluateHelper(nDList, nDList2).getValue();
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void addAccumulator(String str) {
        this.totalInstances.put(str, 0L);
        this.totalLoss.put(str, Float.valueOf(0.0f));
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void updateAccumulator(String str, NDList nDList, NDList nDList2) {
        updateAccumulators(new String[]{str}, nDList, nDList2);
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void updateAccumulators(String[] strArr, NDList nDList, NDList nDList2) {
        Pair<Long, NDArray> evaluateHelper = evaluateHelper(nDList, nDList2);
        for (String str : strArr) {
            this.totalInstances.compute(str, (str2, l) -> {
                return Long.valueOf(l.longValue() + ((Long) evaluateHelper.getKey()).longValue());
            });
            this.totalLoss.compute(str, (str3, f) -> {
                NDArray sum = ((NDArray) evaluateHelper.getValue()).sum();
                try {
                    Float valueOf = Float.valueOf(f.floatValue() + sum.getFloat(new long[0]));
                    if (sum != null) {
                        sum.close();
                    }
                    return valueOf;
                } catch (Throwable th) {
                    if (sum != null) {
                        try {
                            sum.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            });
        }
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void resetAccumulator(String str) {
        this.totalInstances.compute(str, (str2, l) -> {
            return 0L;
        });
        this.totalLoss.compute(str, (str3, f) -> {
            return Float.valueOf(0.0f);
        });
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public float getAccumulator(String str) {
        Long l = this.totalInstances.get(str);
        if (l == null || l.longValue() == 0) {
            return Float.NaN;
        }
        return this.totalLoss.get(str).floatValue() / ((float) this.totalInstances.get(str).longValue());
    }
}
