package com.datumbox.framework.machinelearning.regression;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.dataobjects.TypeInference;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseLinearRegression;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:com/datumbox/framework/machinelearning/regression/NLMS.class */
public class NLMS extends BaseLinearRegression<ModelParameters, TrainingParameters, ValidationMetrics> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/regression/NLMS$ModelParameters.class */
    public static class ModelParameters extends BaseLinearRegression.ModelParameters {
        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/regression/NLMS$TrainingParameters.class */
    public static class TrainingParameters extends BaseLinearRegression.TrainingParameters {
        private int totalIterations = 1000;
        private double learningRate = 0.1d;

        public int getTotalIterations() {
            return this.totalIterations;
        }

        public void setTotalIterations(int i) {
            this.totalIterations = i;
        }

        public double getLearningRate() {
            return this.learningRate;
        }

        public void setLearningRate(double d) {
            this.learningRate = d;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/regression/NLMS$ValidationMetrics.class */
    public static class ValidationMetrics extends BaseLinearRegression.ValidationMetrics {
    }

    public NLMS(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        Map<Object, Double> thitas = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getThitas();
        thitas.put(Dataset.constantColumnName, Double.valueOf(0.0d));
        Iterator<Object> it = dataset.getXDataTypes().keySet().iterator();
        while (it.hasNext()) {
            thitas.put(it.next(), Double.valueOf(0.0d));
        }
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        double d = Double.POSITIVE_INFINITY;
        double learningRate = trainingParameters.getLearningRate();
        int totalIterations = trainingParameters.getTotalIterations();
        DatabaseConnector dbc = ((MLmodelKnowledgeBase) this.knowledgeBase).getDbc();
        for (int i = 0; i < totalIterations; i++) {
            this.logger.debug("Iteration {}", Integer.valueOf(i));
            Map<? extends Object, ? extends Double> bigMap = dbc.getBigMap("tmp_newThitas", true);
            bigMap.putAll(thitas);
            batchGradientDescent(dataset, bigMap, learningRate);
            double calculateError = calculateError(dataset, bigMap);
            if (calculateError > d) {
                learningRate /= 2.0d;
            } else {
                learningRate *= 1.05d;
                d = calculateError;
                thitas.clear();
                thitas.putAll(bigMap);
            }
            dbc.dropBigMap("tmp_newThitas", bigMap);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        Map<Object, Double> thitas = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getThitas();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            dataset.set(next, new Record(record.getX(), record.getY(), Double.valueOf(hypothesisFunction(record.getX(), thitas)), record.getYPredictedProbabilities()));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void batchGradientDescent(Dataset dataset, Map<Object, Double> map, double d) {
        double intValue = d / r0.getN().intValue();
        Map<Object, Double> thitas = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getThitas();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Record record = dataset.get(it.next());
            double doubleValue = intValue * (TypeInference.toDouble(record.getY()).doubleValue() - hypothesisFunction(record.getX(), thitas));
            map.put(Dataset.constantColumnName, Double.valueOf(map.get(Dataset.constantColumnName).doubleValue() + doubleValue));
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                Object key = entry.getKey();
                Double d2 = map.get(key);
                if (d2 != null) {
                    map.put(key, Double.valueOf(d2.doubleValue() + (doubleValue * TypeInference.toDouble(entry.getValue()).doubleValue())));
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void stochasticGradientDescent(Dataset dataset, Map<Object, Double> map, double d) {
        double intValue = d / ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getN().intValue();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Record record = dataset.get(it.next());
            double doubleValue = intValue * (TypeInference.toDouble(record.getY()).doubleValue() - hypothesisFunction(record.getX(), map));
            map.put(Dataset.constantColumnName, Double.valueOf(map.get(Dataset.constantColumnName).doubleValue() + doubleValue));
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                Object key = entry.getKey();
                Double d2 = map.get(key);
                if (d2 != null) {
                    map.put(key, Double.valueOf(d2.doubleValue() + (doubleValue * TypeInference.toDouble(entry.getValue()).doubleValue())));
                }
            }
        }
    }

    private double calculateError(Dataset dataset, Map<Object, Double> map) {
        double d = 0.0d;
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            double hypothesisFunction = hypothesisFunction(record.getX(), map);
            dataset.set(next, new Record(record.getX(), record.getY(), Double.valueOf(hypothesisFunction), record.getYPredictedProbabilities()));
            d += Math.pow(TypeInference.toDouble(record.getY()).doubleValue() - hypothesisFunction, 2.0d);
        }
        return d;
    }

    private double hypothesisFunction(AssociativeArray associativeArray, Map<Object, Double> map) {
        double doubleValue = map.get(Dataset.constantColumnName).doubleValue();
        for (Map.Entry<Object, Object> entry : associativeArray.entrySet()) {
            Double d = map.get(entry.getKey());
            if (d != null) {
                doubleValue += d.doubleValue() * TypeInference.toDouble(entry.getValue()).doubleValue();
            }
        }
        return doubleValue;
    }
}
