package com.datumbox.framework.machinelearning.regression;

import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.MatrixDataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.BigMap;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.PHPfunctions;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseLinearRegression;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.interfaces.StepwiseCompatible;
import com.datumbox.framework.statistics.distributions.ContinuousDistributions;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:com/datumbox/framework/machinelearning/regression/MatrixLinearRegression.class */
public class MatrixLinearRegression extends BaseLinearRegression<ModelParameters, TrainingParameters, ValidationMetrics> implements StepwiseCompatible {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/regression/MatrixLinearRegression$ModelParameters.class */
    public static class ModelParameters extends BaseLinearRegression.ModelParameters {

        @BigMap
        private Map<Object, Integer> featureIds;
        private Map<Object, Double> featurePvalues;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }

        public Map<Object, Double> getFeaturePvalues() {
            return this.featurePvalues;
        }

        protected void setFeaturePvalues(Map<Object, Double> map) {
            this.featurePvalues = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/regression/MatrixLinearRegression$TrainingParameters.class */
    public static class TrainingParameters extends BaseLinearRegression.TrainingParameters {
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/regression/MatrixLinearRegression$ValidationMetrics.class */
    public static class ValidationMetrics extends BaseLinearRegression.ValidationMetrics {
    }

    public MatrixLinearRegression(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.interfaces.StepwiseCompatible
    public Map<Object, Double> getFeaturePvalues() {
        return ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getFeaturePvalues();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        int intValue = modelParameters.getN().intValue();
        int intValue2 = modelParameters.getD().intValue();
        Map<Object, Double> thitas = modelParameters.getThitas();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        MatrixDataset newInstance = MatrixDataset.newInstance(dataset, true, featureIds);
        RealVector y = newInstance.getY();
        RealMatrix x = newInstance.getX();
        RealMatrix transpose = x.transpose();
        RealMatrix inverse = new LUDecomposition(transpose.multiply(x)).getSolver().getInverse();
        RealVector operate = inverse.multiply(transpose).operate(y);
        thitas.put(Dataset.constantColumnName, Double.valueOf(operate.getEntry(0)));
        for (Map.Entry<Object, Integer> entry : featureIds.entrySet()) {
            thitas.put(entry.getKey(), Double.valueOf(operate.getEntry(entry.getValue().intValue())));
        }
        double d = 0.0d;
        for (double d2 : x.operate(operate).subtract(y).toArray()) {
            d += d2 * d2;
        }
        RealMatrix scalarMultiply = inverse.scalarMultiply(d / (intValue - (intValue2 + 1)));
        Map array_flip = PHPfunctions.array_flip(featureIds);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < intValue2 + 1; i++) {
            double entry2 = scalarMultiply.getEntry(i, i);
            Object obj = array_flip.get(Integer.valueOf(i));
            if (entry2 <= 0.0d) {
                hashMap.put(obj, Double.valueOf(0.0d));
            } else {
                hashMap.put(obj, Double.valueOf(1.0d - ContinuousDistributions.StudentsCdf(operate.getEntry(i) / Math.sqrt(entry2), intValue - (intValue2 + 1))));
            }
        }
        modelParameters.setFeaturePvalues(hashMap);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        int intValue = modelParameters.getD().intValue() + 1;
        Map<Object, Double> thitas = modelParameters.getThitas();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        ArrayRealVector arrayRealVector = new ArrayRealVector(intValue);
        for (Map.Entry<Object, Double> entry : thitas.entrySet()) {
            arrayRealVector.setEntry(featureIds.get(entry.getKey()).intValue(), entry.getValue().doubleValue());
        }
        RealVector operate = MatrixDataset.parseDataset(dataset, featureIds).getX().operate(arrayRealVector);
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            dataset.set(next, new Record(record.getX(), record.getY(), Double.valueOf(operate.getEntry(next.intValue())), record.getYPredictedProbabilities()));
        }
    }
}
