package ai.djl.basicdataset.tabular;

import ai.djl.Model;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDList;
import ai.djl.translate.ExpansionTranslatorFactory;
import ai.djl.translate.PostProcessor;
import ai.djl.translate.PreProcessor;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.lang.reflect.Type;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

/* loaded from: input_file:ai/djl/basicdataset/tabular/TabularTranslatorFactory.class */
public class TabularTranslatorFactory extends ExpansionTranslatorFactory<ListFeatures, TabularResults> {

    /* loaded from: input_file:ai/djl/basicdataset/tabular/TabularTranslatorFactory$ClassificationsTabularPostProcessor.class */
    static final class ClassificationsTabularPostProcessor implements PostProcessor<Classifications> {
        private PostProcessor<TabularResults> postProcessor;

        ClassificationsTabularPostProcessor(PostProcessor<TabularResults> postProcessor) {
            this.postProcessor = postProcessor;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.translate.PostProcessor
        public Classifications processOutput(TranslatorContext translatorContext, NDList nDList) throws Exception {
            TabularResults processOutput = this.postProcessor.processOutput(translatorContext, nDList);
            if (processOutput.size() != 1) {
                throw new IllegalStateException("The ClassificationsTabularPostProcessor expected the model to produce one output, but instead it produced " + processOutput.size());
            }
            Object result = processOutput.getFeature(0).getResult();
            if (result instanceof Classifications) {
                return (Classifications) result;
            }
            throw new IllegalStateException("The ClassificationsTabularPostProcessor expected the model to produce a Classifications, but instead it produced " + result.getClass().getName());
        }
    }

    /* loaded from: input_file:ai/djl/basicdataset/tabular/TabularTranslatorFactory$MapPreProcessor.class */
    static final class MapPreProcessor implements PreProcessor<MapFeatures> {
        private TabularTranslator preProcessor;

        MapPreProcessor(PreProcessor<ListFeatures> preProcessor) {
            if (!(preProcessor instanceof TabularTranslator)) {
                throw new IllegalArgumentException("The MapPreProcessor for the TabularTranslatorFactory expects a TabularTranslator, but received " + preProcessor.getClass().getName());
            }
            this.preProcessor = (TabularTranslator) preProcessor;
        }

        @Override // ai.djl.translate.PreProcessor
        public NDList processInput(TranslatorContext translatorContext, MapFeatures mapFeatures) throws Exception {
            ListFeatures listFeatures = new ListFeatures(this.preProcessor.getFeatures().size());
            for (Feature feature : this.preProcessor.getFeatures()) {
                if (!mapFeatures.containsKey(feature.getName())) {
                    throw new IllegalArgumentException("The input to the TabularTranslator is missing the feature: " + feature.getName());
                }
                listFeatures.add(mapFeatures.get(feature.getName()));
            }
            return this.preProcessor.processInput(translatorContext, listFeatures);
        }
    }

    /* loaded from: input_file:ai/djl/basicdataset/tabular/TabularTranslatorFactory$RegressionTabularPostProcessor.class */
    static final class RegressionTabularPostProcessor implements PostProcessor<Float> {
        private PostProcessor<TabularResults> postProcessor;

        RegressionTabularPostProcessor(PostProcessor<TabularResults> postProcessor) {
            this.postProcessor = postProcessor;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.translate.PostProcessor
        public Float processOutput(TranslatorContext translatorContext, NDList nDList) throws Exception {
            TabularResults processOutput = this.postProcessor.processOutput(translatorContext, nDList);
            if (processOutput.size() != 1) {
                throw new IllegalStateException("The RegressionTabularPostProcessor expected the model to produce one output, but instead it produced " + processOutput.size());
            }
            Object result = processOutput.getFeature(0).getResult();
            if (result instanceof Float) {
                return (Float) result;
            }
            throw new IllegalStateException("The RegressionTabularPostProcessor expected the model to produce a float, but instead it produced " + result.getClass().getName());
        }
    }

    @Override // ai.djl.translate.ExpansionTranslatorFactory
    protected Translator<ListFeatures, TabularResults> buildBaseTranslator(Model model, Map<String, ?> map) {
        return new TabularTranslator(model, map);
    }

    @Override // ai.djl.translate.ExpansionTranslatorFactory
    public Class<ListFeatures> getBaseInputType() {
        return ListFeatures.class;
    }

    @Override // ai.djl.translate.ExpansionTranslatorFactory
    public Class<TabularResults> getBaseOutputType() {
        return TabularResults.class;
    }

    @Override // ai.djl.translate.ExpansionTranslatorFactory
    protected Map<Type, Function<PreProcessor<ListFeatures>, PreProcessor<?>>> getPreprocessorExpansions() {
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        concurrentHashMap.put(MapFeatures.class, MapPreProcessor::new);
        return concurrentHashMap;
    }

    @Override // ai.djl.translate.ExpansionTranslatorFactory
    protected Map<Type, Function<PostProcessor<TabularResults>, PostProcessor<?>>> getPostprocessorExpansions() {
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        concurrentHashMap.put(Classifications.class, ClassificationsTabularPostProcessor::new);
        concurrentHashMap.put(Float.class, RegressionTabularPostProcessor::new);
        return concurrentHashMap;
    }
}
