package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.timeseries.distribution.Distribution;
import ai.djl.timeseries.distribution.output.DistributionOutput;
import ai.djl.training.loss.Loss;

/* loaded from: input_file:ai/djl/timeseries/distribution/DistributionLoss.class */
public class DistributionLoss extends Loss {
    private DistributionOutput distrOutput;

    public DistributionLoss(String str, DistributionOutput distributionOutput) {
        super(str);
        this.distrOutput = distributionOutput;
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        Distribution.DistributionBuilder<?> distributionBuilder = this.distrOutput.distributionBuilder();
        distributionBuilder.setDistrArgs(nDList2);
        if (nDList2.contains("scale")) {
            distributionBuilder.optScale(nDList2.get("scale"));
        }
        if (nDList2.contains("loc")) {
            distributionBuilder.optLoc(nDList2.get("loc"));
        }
        NDArray mul = distributionBuilder.build().logProb(nDList.singletonOrThrow()).mul((Number) (-1));
        if (nDList2.contains("loss_weights")) {
            NDArray nDArray = nDList2.get("loss_weights");
            NDArray where = NDArrays.where(nDArray.neq((Number) 0), mul.mul(nDArray), mul.zerosLike());
            mul = where.sum().div(nDArray.sum().maximum(Float.valueOf(1.0f)));
        }
        return mul;
    }
}
