package com.datumbox.framework.machinelearning.common.bases.mlmodels;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.MapFunctions;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclassifier.ModelParameters;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclassifier.TrainingParameters;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclassifier.ValidationMetrics;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
import com.datumbox.framework.machinelearning.common.bases.validation.ModelValidation;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.enums.SensitivityRates;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/mlmodels/BaseMLclassifier.class */
public abstract class BaseMLclassifier<MP extends ModelParameters, TP extends TrainingParameters, VM extends ValidationMetrics> extends BaseMLmodel<MP, TP, VM> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/mlmodels/BaseMLclassifier$ModelParameters.class */
    public static abstract class ModelParameters extends BaseMLmodel.ModelParameters {
        private Set<Object> classes;

        /* JADX INFO: Access modifiers changed from: protected */
        public ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
            this.classes = new LinkedHashSet();
        }

        public Integer getC() {
            return Integer.valueOf(this.classes.size());
        }

        public Set<Object> getClasses() {
            return this.classes;
        }

        protected void setClasses(Set<Object> set) {
            this.classes = set;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/mlmodels/BaseMLclassifier$TrainingParameters.class */
    public static abstract class TrainingParameters extends BaseMLmodel.TrainingParameters {
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/mlmodels/BaseMLclassifier$ValidationMetrics.class */
    public static abstract class ValidationMetrics extends BaseMLmodel.ValidationMetrics {
        private double accuracy = 0.0d;
        private double macroPrecision = 0.0d;
        private double macroRecall = 0.0d;
        private double macroF1 = 0.0d;
        private Map<Object, Double> microPrecision = new HashMap();
        private Map<Object, Double> microRecall = new HashMap();
        private Map<Object, Double> microF1 = new HashMap();
        private Map<List<Object>, Double> ContingencyTable = new HashMap();

        public double getAccuracy() {
            return this.accuracy;
        }

        public void setAccuracy(double d) {
            this.accuracy = d;
        }

        public double getMacroPrecision() {
            return this.macroPrecision;
        }

        public void setMacroPrecision(double d) {
            this.macroPrecision = d;
        }

        public double getMacroRecall() {
            return this.macroRecall;
        }

        public void setMacroRecall(double d) {
            this.macroRecall = d;
        }

        public double getMacroF1() {
            return this.macroF1;
        }

        public void setMacroF1(double d) {
            this.macroF1 = d;
        }

        public Map<Object, Double> getMicroPrecision() {
            return this.microPrecision;
        }

        public void setMicroPrecision(Map<Object, Double> map) {
            this.microPrecision = map;
        }

        public Map<Object, Double> getMicroRecall() {
            return this.microRecall;
        }

        public void setMicroRecall(Map<Object, Double> map) {
            this.microRecall = map;
        }

        public Map<Object, Double> getMicroF1() {
            return this.microF1;
        }

        public void setMicroF1(Map<Object, Double> map) {
            this.microF1 = map;
        }

        public Map<List<Object>, Double> getContingencyTable() {
            return this.ContingencyTable;
        }

        public void setContingencyTable(Map<List<Object>, Double> map) {
            this.ContingencyTable = map;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseMLclassifier(String str, DatabaseConfiguration databaseConfiguration, Class<MP> cls, Class<TP> cls2, Class<VM> cls3, ModelValidation<MP, TP, VM> modelValidation) {
        super(str, databaseConfiguration, cls, cls2, cls3, modelValidation);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public VM validateModel(Dataset dataset) {
        predictDataset(dataset);
        Set<Object> classes = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getClasses();
        VM vm = (VM) ((MLmodelKnowledgeBase) this.knowledgeBase).getEmptyValidationMetricsObject();
        Map<List<Object>, Double> contingencyTable = vm.getContingencyTable();
        for (Object obj : classes) {
            contingencyTable.put(Arrays.asList(obj, SensitivityRates.TP), Double.valueOf(0.0d));
            contingencyTable.put(Arrays.asList(obj, SensitivityRates.FP), Double.valueOf(0.0d));
            contingencyTable.put(Arrays.asList(obj, SensitivityRates.TN), Double.valueOf(0.0d));
            contingencyTable.put(Arrays.asList(obj, SensitivityRates.FN), Double.valueOf(0.0d));
        }
        int recordNumber = dataset.getRecordNumber();
        int size = classes.size();
        int i = 0;
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Record record = dataset.get(it.next());
            if (record.getYPredicted().equals(record.getY())) {
                i++;
                for (Object obj2 : classes) {
                    if (obj2.equals(record.getYPredicted())) {
                        List<Object> asList = Arrays.asList(obj2, SensitivityRates.TP);
                        contingencyTable.put(asList, Double.valueOf(contingencyTable.get(asList).doubleValue() + 1.0d));
                    } else {
                        List<Object> asList2 = Arrays.asList(obj2, SensitivityRates.TN);
                        contingencyTable.put(asList2, Double.valueOf(contingencyTable.get(asList2).doubleValue() + 1.0d));
                    }
                }
            } else {
                for (Object obj3 : classes) {
                    if (obj3.equals(record.getYPredicted())) {
                        List<Object> asList3 = Arrays.asList(obj3, SensitivityRates.FP);
                        contingencyTable.put(asList3, Double.valueOf(contingencyTable.get(asList3).doubleValue() + 1.0d));
                    } else if (obj3.equals(record.getY())) {
                        List<Object> asList4 = Arrays.asList(obj3, SensitivityRates.FN);
                        contingencyTable.put(asList4, Double.valueOf(contingencyTable.get(asList4).doubleValue() + 1.0d));
                    } else {
                        List<Object> asList5 = Arrays.asList(obj3, SensitivityRates.TN);
                        contingencyTable.put(asList5, Double.valueOf(contingencyTable.get(asList5).doubleValue() + 1.0d));
                    }
                }
            }
        }
        vm.setAccuracy(i / recordNumber);
        for (Object obj4 : classes) {
            double doubleValue = contingencyTable.get(Arrays.asList(obj4, SensitivityRates.TP)).doubleValue();
            double doubleValue2 = contingencyTable.get(Arrays.asList(obj4, SensitivityRates.FP)).doubleValue();
            double doubleValue3 = contingencyTable.get(Arrays.asList(obj4, SensitivityRates.FN)).doubleValue();
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            if (doubleValue > 0.0d) {
                d = doubleValue / (doubleValue + doubleValue2);
                d2 = doubleValue / (doubleValue + doubleValue3);
                d3 = ((2.0d * d) * d2) / (d + d2);
            } else if (doubleValue == 0.0d && doubleValue2 == 0.0d && doubleValue3 == 0.0d) {
                d = 1.0d;
                d2 = 1.0d;
                d3 = 1.0d;
            }
            vm.getMicroPrecision().put(obj4, Double.valueOf(d));
            vm.getMicroRecall().put(obj4, Double.valueOf(d2));
            vm.getMicroF1().put(obj4, Double.valueOf(d3));
            vm.setMacroPrecision(vm.getMacroPrecision() + (d / size));
            vm.setMacroRecall(vm.getMacroRecall() + (d2 / size));
            vm.setMacroF1(vm.getMacroF1() + (d3 / size));
        }
        return vm;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Object getSelectedClassFromClassScores(AssociativeArray associativeArray) {
        return MapFunctions.selectMaxKeyValue(associativeArray).getKey();
    }
}
