package ai.djl.training.listener;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import java.util.Map;

/* loaded from: input_file:ai/djl/training/listener/TrainingListener.class */
public interface TrainingListener {

    /* loaded from: input_file:ai/djl/training/listener/TrainingListener$BatchData.class */
    public static class BatchData {
        private Batch batch;
        private Map<Device, NDList> labels;
        private Map<Device, NDList> predictions;

        public BatchData(Batch batch, Map<Device, NDList> map, Map<Device, NDList> map2) {
            this.batch = batch;
            this.labels = map;
            this.predictions = map2;
        }

        public Batch getBatch() {
            return this.batch;
        }

        public Map<Device, NDList> getLabels() {
            return this.labels;
        }

        public Map<Device, NDList> getPredictions() {
            return this.predictions;
        }
    }

    /* loaded from: input_file:ai/djl/training/listener/TrainingListener$Defaults.class */
    public interface Defaults {
        static TrainingListener[] basic() {
            return new TrainingListener[]{new EpochTrainingListener(), new EvaluatorTrainingListener(), new DivergenceCheckTrainingListener()};
        }

        static TrainingListener[] logging() {
            return new TrainingListener[]{new EpochTrainingListener(), new EvaluatorTrainingListener(), new DivergenceCheckTrainingListener(), new LoggingTrainingListener()};
        }

        static TrainingListener[] logging(int i) {
            return new TrainingListener[]{new EpochTrainingListener(), new EvaluatorTrainingListener(), new DivergenceCheckTrainingListener(), new LoggingTrainingListener(i)};
        }

        static TrainingListener[] logging(String str) {
            if (str == null) {
                throw new IllegalArgumentException("The output directory can't be null");
            }
            return new TrainingListener[]{new EpochTrainingListener(), new MemoryTrainingListener(str), new EvaluatorTrainingListener(), new DivergenceCheckTrainingListener(), new LoggingTrainingListener(), new TimeMeasureTrainingListener(str)};
        }
    }

    void onEpoch(Trainer trainer);

    void onTrainingBatch(Trainer trainer, BatchData batchData);

    void onValidationBatch(Trainer trainer, BatchData batchData);

    void onTrainingBegin(Trainer trainer);

    void onTrainingEnd(Trainer trainer);
}
