package com.datumbox.framework.machinelearning.common.bases.basemodels;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.dataobjects.TypeInference;
import com.datumbox.common.persistentstorage.interfaces.BigMap;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseNaiveBayes.ModelParameters;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseNaiveBayes.TrainingParameters;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseNaiveBayes.ValidationMetrics;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclassifier;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.validation.ClassifierValidation;
import com.datumbox.framework.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseNaiveBayes.class */
public abstract class BaseNaiveBayes<MP extends ModelParameters, TP extends TrainingParameters, VM extends ValidationMetrics> extends BaseMLclassifier<MP, TP, VM> {
    protected boolean isBinarized;

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseNaiveBayes$ModelParameters.class */
    public static abstract class ModelParameters extends BaseMLclassifier.ModelParameters {

        @BigMap
        private Map<Object, Double> logPriors;

        @BigMap
        private Map<List<Object>, Double> logLikelihoods;

        /* JADX INFO: Access modifiers changed from: protected */
        public ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<Object, Double> getLogPriors() {
            return this.logPriors;
        }

        protected void setLogPriors(Map<Object, Double> map) {
            this.logPriors = map;
        }

        public Map<List<Object>, Double> getLogLikelihoods() {
            return this.logLikelihoods;
        }

        protected void setLogLikelihoods(Map<List<Object>, Double> map) {
            this.logLikelihoods = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseNaiveBayes$TrainingParameters.class */
    public static abstract class TrainingParameters extends BaseMLclassifier.TrainingParameters {
        private boolean multiProbabilityWeighted = false;

        public boolean isMultiProbabilityWeighted() {
            return this.multiProbabilityWeighted;
        }

        public void setMultiProbabilityWeighted(boolean z) {
            this.multiProbabilityWeighted = z;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseNaiveBayes$ValidationMetrics.class */
    public static abstract class ValidationMetrics extends BaseMLclassifier.ValidationMetrics {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseNaiveBayes(String str, DatabaseConfiguration databaseConfiguration, Class<MP> cls, Class<TP> cls2, Class<VM> cls3) {
        super(str, databaseConfiguration, cls, cls2, cls3, new ClassifierValidation());
        this.isBinarized = false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        if (dataset.isEmpty()) {
            return;
        }
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        Map<List<Object>, Double> logLikelihoods = modelParameters.getLogLikelihoods();
        Map<Object, Double> logPriors = modelParameters.getLogPriors();
        Set<Object> classes = modelParameters.getClasses();
        Object next = classes.iterator().next();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next2 = it.next();
            Record record = dataset.get(next2);
            AssociativeArray associativeArray = new AssociativeArray(new HashMap(logPriors));
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                Object key = entry.getKey();
                if (logLikelihoods.containsKey(Arrays.asList(key, next))) {
                    AssociativeArray associativeArray2 = new AssociativeArray();
                    for (Object obj : classes) {
                        associativeArray2.put(obj, logLikelihoods.get(Arrays.asList(key, obj)));
                    }
                    Double d = TypeInference.toDouble(entry.getValue());
                    if ((!((TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters()).isMultiProbabilityWeighted() || this.isBinarized) && d.doubleValue() > 0.0d) {
                        d = Double.valueOf(1.0d);
                    }
                    for (Map.Entry<Object, Object> entry2 : associativeArray2.entrySet()) {
                        Object key2 = entry2.getKey();
                        associativeArray.put(key2, Double.valueOf(associativeArray.getDouble(key2).doubleValue() + (d.doubleValue() * TypeInference.toDouble(entry2.getValue()).doubleValue())));
                    }
                }
            }
            Object selectedClassFromClassScores = getSelectedClassFromClassScores(associativeArray);
            Descriptives.normalizeExp(associativeArray);
            dataset.set(next2, new Record(record.getX(), record.getY(), selectedClassFromClassScores, associativeArray));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    public void _fit(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        int intValue = modelParameters.getN().intValue();
        int intValue2 = modelParameters.getD().intValue();
        Map<List<Object>, Double> logLikelihoods = modelParameters.getLogLikelihoods();
        Map<Object, Double> logPriors = modelParameters.getLogPriors();
        Set<Object> classes = modelParameters.getClasses();
        AssociativeArray associativeArray = new AssociativeArray();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Object y = dataset.get(it.next()).getY();
            Double d = logPriors.get(y);
            if (d != null) {
                logPriors.put(y, Double.valueOf(d.doubleValue() + 1.0d));
            } else {
                classes.add(y);
                logPriors.put(y, Double.valueOf(1.0d));
                associativeArray.put(y, Double.valueOf(0.0d));
            }
        }
        Iterator<Integer> it2 = dataset.iterator();
        while (it2.hasNext()) {
            Record record = dataset.get(it2.next());
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                Object key = entry.getKey();
                Double d2 = TypeInference.toDouble(entry.getValue());
                if (this.isBinarized && d2.doubleValue() > 0.0d) {
                    d2 = Double.valueOf(1.0d);
                }
                for (Object obj : classes) {
                    List<Object> asList = Arrays.asList(key, obj);
                    Double d3 = logLikelihoods.get(asList);
                    if (d3 == null) {
                        d3 = Double.valueOf(0.0d);
                        logLikelihoods.put(asList, Double.valueOf(0.0d));
                    }
                    if (obj.equals(record.getY())) {
                        logLikelihoods.put(asList, Double.valueOf(d3.doubleValue() + d2.doubleValue()));
                        associativeArray.put(obj, Double.valueOf(associativeArray.getDouble(obj).doubleValue() + d2.doubleValue()));
                    }
                }
            }
        }
        for (Map.Entry<Object, Double> entry2 : logPriors.entrySet()) {
            logPriors.put(entry2.getKey(), Double.valueOf(Math.log(entry2.getValue().doubleValue() / intValue)));
        }
        for (Map.Entry<List<Object>, Double> entry3 : logLikelihoods.entrySet()) {
            Object obj2 = entry3.getKey().get(1);
            Double value = entry3.getValue();
            if (value == null) {
                value = Double.valueOf(0.0d);
            }
            logLikelihoods.put(entry3.getKey(), Double.valueOf(Math.log(Double.valueOf((value.doubleValue() + 1.0d) / (associativeArray.getDouble(obj2).doubleValue() + intValue2)).doubleValue())));
        }
    }
}
