package com.datumbox.framework.machinelearning.classification;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.dataobjects.TypeInference;
import com.datumbox.common.persistentstorage.interfaces.BigMap;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclassifier;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.validation.ClassifierValidation;
import com.datumbox.framework.statistics.descriptivestatistics.Descriptives;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/machinelearning/classification/MaximumEntropy.class */
public class MaximumEntropy extends BaseMLclassifier<ModelParameters, TrainingParameters, ValidationMetrics> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/classification/MaximumEntropy$ModelParameters.class */
    public static class ModelParameters extends BaseMLclassifier.ModelParameters {

        @BigMap
        private Map<List<Object>, Double> lambdas;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<List<Object>, Double> getLambdas() {
            return this.lambdas;
        }

        protected void setLambdas(Map<List<Object>, Double> map) {
            this.lambdas = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/classification/MaximumEntropy$TrainingParameters.class */
    public static class TrainingParameters extends BaseMLclassifier.TrainingParameters {
        private int totalIterations = 100;

        public int getTotalIterations() {
            return this.totalIterations;
        }

        public void setTotalIterations(int i) {
            this.totalIterations = i;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/classification/MaximumEntropy$ValidationMetrics.class */
    public static class ValidationMetrics extends BaseMLclassifier.ValidationMetrics {
    }

    public MaximumEntropy(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class, new ClassifierValidation());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        Set<Object> classes = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getClasses();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            AssociativeArray associativeArray = new AssociativeArray();
            for (Object obj : classes) {
                associativeArray.put(obj, calculateClassScore(record.getX(), obj));
            }
            Object selectedClassFromClassScores = getSelectedClassFromClassScores(associativeArray);
            Descriptives.normalizeExp(associativeArray);
            dataset.set(next, new Record(record.getX(), record.getY(), selectedClassFromClassScores, associativeArray));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        int intValue = modelParameters.getN().intValue();
        Map<List<Object>, Double> lambdas = modelParameters.getLambdas();
        Set<Object> classes = modelParameters.getClasses();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            classes.add(dataset.get(it.next()).getY());
        }
        DatabaseConnector dbc = ((MLmodelKnowledgeBase) this.knowledgeBase).getDbc();
        Map<List<Object>, Double> bigMap = dbc.getBigMap("tmp_EpFj_observed", true);
        double d = 0.0d;
        double d2 = 1.0d / intValue;
        Iterator<Integer> it2 = dataset.iterator();
        while (it2.hasNext()) {
            Record record = dataset.get(it2.next());
            int i = 0;
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                Double d3 = TypeInference.toDouble(entry.getValue());
                if (d3 != null && d3.doubleValue() != 0.0d) {
                    Object key = entry.getKey();
                    for (Object obj : classes) {
                        List<Object> asList = Arrays.asList(key, obj);
                        Double d4 = bigMap.get(asList);
                        if (d4 == null) {
                            d4 = Double.valueOf(0.0d);
                            bigMap.put(asList, Double.valueOf(0.0d));
                            lambdas.put(asList, Double.valueOf(0.0d));
                        }
                        if (obj.equals(record.getY())) {
                            bigMap.put(asList, Double.valueOf(d4.doubleValue() + d2));
                        }
                    }
                    i++;
                }
            }
            if (i > d) {
                d = i;
            }
        }
        IIS(dataset, bigMap, d);
        dbc.dropBigMap("tmp_EpFj_observed", bigMap);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void IIS(Dataset dataset, Map<List<Object>, Double> map, double d) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        int totalIterations = ((TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters()).getTotalIterations();
        Set<Object> classes = modelParameters.getClasses();
        Map<List<Object>, Double> lambdas = modelParameters.getLambdas();
        int intValue = modelParameters.getN().intValue();
        DatabaseConnector dbc = ((MLmodelKnowledgeBase) this.knowledgeBase).getDbc();
        for (int i = 0; i < totalIterations; i++) {
            this.logger.debug("Iteration {}", Integer.valueOf(i));
            Map bigMap = dbc.getBigMap("tmp_EpFj_model", true);
            ArrayList<List<Object>> arrayList = new ArrayList();
            Iterator<Map.Entry<List<Object>, Double>> it = map.entrySet().iterator();
            while (it.hasNext()) {
                bigMap.put(it.next().getKey(), Double.valueOf(0.0d));
            }
            Iterator<Integer> it2 = dataset.iterator();
            while (it2.hasNext()) {
                Record record = dataset.get(it2.next());
                AssociativeArray associativeArray = new AssociativeArray();
                for (Object obj : classes) {
                    associativeArray.put(obj, Double.valueOf(calculateClassScore(record.getX(), obj).doubleValue()));
                }
                Descriptives.normalizeExp(associativeArray);
                for (Map.Entry<Object, Object> entry : associativeArray.entrySet()) {
                    Object key = entry.getKey();
                    double doubleValue = TypeInference.toDouble(entry.getValue()).doubleValue() / intValue;
                    for (Map.Entry<Object, Object> entry2 : record.getX().entrySet()) {
                        Double d2 = TypeInference.toDouble(entry2.getValue());
                        if (d2 != null && d2.doubleValue() != 0.0d) {
                            List asList = Arrays.asList(entry2.getKey(), key);
                            bigMap.put(asList, Double.valueOf(((Double) bigMap.get(asList)).doubleValue() + doubleValue));
                        }
                    }
                }
            }
            Double d3 = null;
            Double d4 = null;
            for (Map.Entry entry3 : bigMap.entrySet()) {
                List<Object> list = (List) entry3.getKey();
                Double d5 = map.get(list);
                Double d6 = (Double) entry3.getValue();
                if (Math.abs(d5.doubleValue() - d6.doubleValue()) > 1.0E-8d) {
                    if (d5.doubleValue() == 0.0d) {
                        lambdas.put(list, Double.valueOf(Double.NEGATIVE_INFINITY));
                        arrayList.add(list);
                    } else if (d6.doubleValue() == 0.0d) {
                        lambdas.put(list, Double.valueOf(Double.POSITIVE_INFINITY));
                        arrayList.add(list);
                    } else {
                        double doubleValue2 = lambdas.get(list).doubleValue() + (Math.log(d5.doubleValue() / d6.doubleValue()) / d);
                        lambdas.put(list, Double.valueOf(doubleValue2));
                        if (d3 == null || doubleValue2 < d3.doubleValue()) {
                            d3 = Double.valueOf(doubleValue2);
                        }
                        if (d4 == null || doubleValue2 > d4.doubleValue()) {
                            d4 = Double.valueOf(doubleValue2);
                        }
                    }
                }
            }
            if (!arrayList.isEmpty()) {
                for (List<Object> list2 : arrayList) {
                    Double d7 = lambdas.get(list2);
                    if (d7.doubleValue() == Double.NEGATIVE_INFINITY) {
                        lambdas.put(list2, d3);
                    } else if (d7.doubleValue() == Double.POSITIVE_INFINITY) {
                        lambdas.put(list2, d4);
                    } else {
                        lambdas.put(list2, Double.valueOf(0.0d));
                    }
                }
            }
            dbc.dropBigMap("tmp_EpFj_model", bigMap);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Double calculateClassScore(AssociativeArray associativeArray, Object obj) {
        Double d;
        double d2 = 0.0d;
        Map<List<Object>, Double> lambdas = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getLambdas();
        for (Map.Entry<Object, Object> entry : associativeArray.entrySet()) {
            Double d3 = TypeInference.toDouble(entry.getValue());
            if (d3 != null && d3.doubleValue() != 0.0d && (d = lambdas.get(Arrays.asList(entry.getKey(), obj))) != null) {
                d2 += d.doubleValue();
            }
        }
        return Double.valueOf(d2);
    }
}
