package com.datumbox.framework.machinelearning.classification;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.dataobjects.TypeInference;
import com.datumbox.common.persistentstorage.interfaces.BigMap;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclassifier;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.validation.OrdinalRegressionValidation;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

/* loaded from: input_file:com/datumbox/framework/machinelearning/classification/OrdinalRegression.class */
public class OrdinalRegression extends BaseMLclassifier<ModelParameters, TrainingParameters, ValidationMetrics> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/classification/OrdinalRegression$ModelParameters.class */
    public static class ModelParameters extends BaseMLclassifier.ModelParameters {

        @BigMap
        private Map<Object, Double> weights;

        @BigMap
        private Map<Object, Double> thitas;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<Object, Double> getWeights() {
            return this.weights;
        }

        protected void setWeights(Map<Object, Double> map) {
            this.weights = map;
        }

        public Map<Object, Double> getThitas() {
            return this.thitas;
        }

        protected void setThitas(Map<Object, Double> map) {
            this.thitas = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/classification/OrdinalRegression$TrainingParameters.class */
    public static class TrainingParameters extends BaseMLclassifier.TrainingParameters {
        private int totalIterations = 100;
        private double learningRate = 0.1d;

        public int getTotalIterations() {
            return this.totalIterations;
        }

        public void setTotalIterations(int i) {
            this.totalIterations = i;
        }

        public double getLearningRate() {
            return this.learningRate;
        }

        public void setLearningRate(double d) {
            this.learningRate = d;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/classification/OrdinalRegression$ValidationMetrics.class */
    public static class ValidationMetrics extends BaseMLclassifier.ValidationMetrics {
        private double SSE = 0.0d;
        private double CountRSquare = 0.0d;

        public double getSSE() {
            return this.SSE;
        }

        public void setSSE(double d) {
            this.SSE = d;
        }

        public double getCountRSquare() {
            return this.CountRSquare;
        }

        public void setCountRSquare(double d) {
            this.CountRSquare = d;
        }
    }

    public OrdinalRegression(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class, new OrdinalRegressionValidation());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        Map<Object, Double> weights = modelParameters.getWeights();
        Map<Object, Double> thitas = modelParameters.getThitas();
        Map<Object, Object> previousThitaMappings = getPreviousThitaMappings();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            AssociativeArray hypothesisFunction = hypothesisFunction(record.getX(), previousThitaMappings, weights, thitas);
            dataset.set(next, new Record(record.getX(), record.getY(), getSelectedClassFromClassScores(hypothesisFunction), hypothesisFunction));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        Map<Object, Double> weights = modelParameters.getWeights();
        Map<Object, Double> thitas = modelParameters.getThitas();
        TreeSet treeSet = new TreeSet();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            treeSet.add(dataset.get(it.next()).getY());
        }
        Set<Object> classes = modelParameters.getClasses();
        classes.addAll(treeSet);
        Iterator<Object> it2 = dataset.getXDataTypes().keySet().iterator();
        while (it2.hasNext()) {
            weights.put(it2.next(), Double.valueOf(0.0d));
        }
        Iterator<Integer> it3 = dataset.iterator();
        while (it3.hasNext()) {
            thitas.put(dataset.get(it3.next()).getY(), Double.valueOf(0.0d));
        }
        Object obj = null;
        Iterator<Object> it4 = classes.iterator();
        while (it4.hasNext()) {
            obj = it4.next();
        }
        thitas.put(obj, Double.valueOf(Double.POSITIVE_INFINITY));
        Map<Object, Object> previousThitaMappings = getPreviousThitaMappings();
        double d = Double.POSITIVE_INFINITY;
        double learningRate = trainingParameters.getLearningRate();
        int totalIterations = trainingParameters.getTotalIterations();
        DatabaseConnector dbc = ((MLmodelKnowledgeBase) this.knowledgeBase).getDbc();
        for (int i = 0; i < totalIterations; i++) {
            this.logger.debug("Iteration {}", Integer.valueOf(i));
            Map<? extends Object, ? extends Double> bigMap = dbc.getBigMap("tmp_newThitas", true);
            Map<? extends Object, ? extends Double> bigMap2 = dbc.getBigMap("tmp_newWeights", true);
            bigMap.putAll(thitas);
            bigMap2.putAll(weights);
            batchGradientDescent(dataset, previousThitaMappings, bigMap2, bigMap, learningRate);
            double calculateError = calculateError(dataset, previousThitaMappings, bigMap2, bigMap);
            if (calculateError > d) {
                learningRate /= 2.0d;
            } else {
                learningRate *= 1.05d;
                d = calculateError;
                weights.clear();
                weights.putAll(bigMap2);
                thitas.clear();
                thitas.putAll(bigMap);
            }
            dbc.dropBigMap("tmp_newWeights", bigMap2);
            dbc.dropBigMap("tmp_newThitas", bigMap);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclassifier, com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public ValidationMetrics validateModel(Dataset dataset) {
        ValidationMetrics validationMetrics = (ValidationMetrics) super.validateModel(dataset);
        Map<Object, Object> previousThitaMappings = getPreviousThitaMappings();
        validationMetrics.setCountRSquare(validationMetrics.getAccuracy());
        validationMetrics.setSSE(calculateError(dataset, previousThitaMappings, ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getWeights(), ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getThitas()));
        return validationMetrics;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void batchGradientDescent(Dataset dataset, Map<Object, Object> map, Map<Object, Double> map2, Map<Object, Double> map3, double d) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        double intValue = (-d) / modelParameters.getN().intValue();
        Map<Object, Double> weights = modelParameters.getWeights();
        Map<Object, Double> thitas = modelParameters.getThitas();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Record record = dataset.get(it.next());
            Object y = record.getY();
            Object obj = map.get(y);
            double xTw = xTw(record.getX(), weights);
            double g = g(xTw - thitas.get(y).doubleValue());
            double g2 = obj != null ? g(thitas.get(obj).doubleValue() - xTw) : 0.0d;
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                Object key = entry.getKey();
                map2.put(key, Double.valueOf(map2.get(key).doubleValue() + (intValue * TypeInference.toDouble(entry.getValue()).doubleValue() * (g - g2))));
            }
            map3.put(y, Double.valueOf(map3.get(y).doubleValue() + (intValue * (-g))));
            if (obj != null) {
                map3.put(obj, Double.valueOf(map3.get(obj).doubleValue() + (intValue * g2)));
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private AssociativeArray hypothesisFunction(AssociativeArray associativeArray, Map<Object, Object> map, Map<Object, Double> map2, Map<Object, Double> map3) {
        AssociativeArray associativeArray2 = new AssociativeArray();
        double xTw = xTw(associativeArray, map2);
        for (Object obj : ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getClasses()) {
            Object obj2 = map.get(obj);
            if (obj2 != null) {
                associativeArray2.put(obj, Double.valueOf(g(map3.get(obj).doubleValue() - xTw) - g(map3.get(obj2).doubleValue() - xTw)));
            } else {
                associativeArray2.put(obj, Double.valueOf(g(map3.get(obj).doubleValue() - xTw)));
            }
        }
        return associativeArray2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double calculateError(Dataset dataset, Map<Object, Object> map, Map<Object, Double> map2, Map<Object, Double> map3) {
        double d = 0.0d;
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Record record = dataset.get(it.next());
            double xTw = xTw(record.getX(), map2);
            Object y = record.getY();
            Object obj = map.get(y);
            if (obj != null) {
                d += h(map3.get(obj).doubleValue() - xTw);
            }
            d += h(xTw - map3.get(y).doubleValue());
        }
        return d / ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getN().intValue();
    }

    private double h(double d) {
        if (d > 30.0d) {
            return d;
        }
        if (d < -30.0d) {
            return 0.0d;
        }
        return Math.log(1.0d + Math.exp(d));
    }

    private double g(double d) {
        if (d > 30.0d) {
            return 1.0d;
        }
        if (d < -30.0d) {
            return 0.0d;
        }
        return 1.0d / (1.0d + Math.exp(-d));
    }

    private double xTw(AssociativeArray associativeArray, Map<Object, Double> map) {
        Double d;
        double d2 = 0.0d;
        for (Map.Entry<Object, Object> entry : associativeArray.entrySet()) {
            Double d3 = TypeInference.toDouble(entry.getValue());
            if (d3 != null && d3.doubleValue() != 0.0d && (d = map.get(entry.getKey())) != null) {
                d2 += d3.doubleValue() * d.doubleValue();
            }
        }
        return d2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<Object, Object> getPreviousThitaMappings() {
        HashMap hashMap = new HashMap();
        Object obj = null;
        for (Object obj2 : ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getClasses()) {
            hashMap.put(obj2, obj);
            obj = obj2;
        }
        return hashMap;
    }
}
