package com.datumbox.framework.machinelearning.regression;

import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.MapFunctions;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLregressor;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.interfaces.StepwiseCompatible;
import java.util.HashSet;
import java.util.Map;

/* loaded from: input_file:com/datumbox/framework/machinelearning/regression/StepwiseRegression.class */
public class StepwiseRegression extends BaseMLregressor<ModelParameters, TrainingParameters, BaseMLregressor.ValidationMetrics> {
    private transient BaseMLregressor mlregressor;

    /* loaded from: input_file:com/datumbox/framework/machinelearning/regression/StepwiseRegression$ModelParameters.class */
    public static class ModelParameters extends BaseMLregressor.ModelParameters {
        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/regression/StepwiseRegression$TrainingParameters.class */
    public static class TrainingParameters extends BaseMLregressor.TrainingParameters {
        private Integer maxIterations = null;
        private Double aout = Double.valueOf(0.05d);
        private Class<? extends BaseMLregressor> regressionClass;
        private BaseMLregressor.TrainingParameters regressionTrainingParameters;

        public Integer getMaxIterations() {
            return this.maxIterations;
        }

        public void setMaxIterations(Integer num) {
            this.maxIterations = num;
        }

        public Double getAout() {
            return this.aout;
        }

        public void setAout(Double d) {
            this.aout = d;
        }

        public Class<? extends BaseMLregressor> getRegressionClass() {
            return this.regressionClass;
        }

        public void setRegressionClass(Class<? extends BaseMLregressor> cls) {
            if (!StepwiseCompatible.class.isAssignableFrom(cls)) {
                throw new RuntimeException("The regression model is not Stepwise Compatible as it does not calculates the pvalues of the features.");
            }
            this.regressionClass = cls;
        }

        public BaseMLregressor.TrainingParameters getRegressionTrainingParameters() {
            return this.regressionTrainingParameters;
        }

        public void setRegressionTrainingParameters(BaseMLregressor.TrainingParameters trainingParameters) {
            this.regressionTrainingParameters = trainingParameters;
        }
    }

    public StepwiseRegression(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class, BaseMLregressor.ValidationMetrics.class, null);
        this.mlregressor = null;
    }

    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable, com.datumbox.common.objecttypes.Trainable
    public void erase() {
        loadRegressor();
        this.mlregressor.erase();
        this.mlregressor = null;
        super.erase();
    }

    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable, com.datumbox.common.objecttypes.Trainable
    public void close() {
        loadRegressor();
        this.mlregressor.close();
        this.mlregressor = null;
        super.close();
    }

    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public BaseMLregressor.ValidationMetrics kFoldCrossValidation(Dataset dataset, TrainingParameters trainingParameters, int i) {
        throw new UnsupportedOperationException("K-fold Cross Validation is not supported. Run it directly to the wrapped regressor.");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public BaseMLregressor.ValidationMetrics validateModel(Dataset dataset) {
        loadRegressor();
        return (BaseMLregressor.ValidationMetrics) this.mlregressor.validate(dataset);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        Integer maxIterations = trainingParameters.getMaxIterations();
        if (maxIterations == null) {
            maxIterations = Integer.MAX_VALUE;
        }
        double doubleValue = trainingParameters.getAout().doubleValue();
        Dataset copy = dataset.copy();
        for (int i = 0; i < maxIterations.intValue(); i++) {
            Map<Object, Double> runRegression = runRegression(copy);
            if (runRegression.isEmpty()) {
                break;
            }
            runRegression.remove(Dataset.constantColumnName);
            Map.Entry<Object, Double> selectMaxKeyValue = MapFunctions.selectMaxKeyValue(runRegression);
            if (selectMaxKeyValue.getValue().doubleValue() <= doubleValue) {
                break;
            }
            HashSet hashSet = new HashSet();
            hashSet.add(selectMaxKeyValue.getKey());
            copy.removeColumns(hashSet);
            if (copy.getVariableNumber() == 0) {
                break;
            }
        }
        this.mlregressor = (BaseMLregressor) BaseMLmodel.newInstance(trainingParameters.getRegressionClass(), this.dbName, ((MLmodelKnowledgeBase) this.knowledgeBase).getDbConf());
        this.mlregressor.fit(copy, (Dataset) trainingParameters.getRegressionTrainingParameters());
        copy.erase();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        loadRegressor();
        this.mlregressor.predict(dataset);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void loadRegressor() {
        if (this.mlregressor == null) {
            this.mlregressor = (BaseMLregressor) BaseMLmodel.newInstance(((TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters()).getRegressionClass(), this.dbName, ((MLmodelKnowledgeBase) this.knowledgeBase).getDbConf());
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<Object, Double> runRegression(Dataset dataset) {
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        this.mlregressor = (BaseMLregressor) BaseMLmodel.newInstance(trainingParameters.getRegressionClass(), this.dbName, ((MLmodelKnowledgeBase) this.knowledgeBase).getDbConf());
        this.mlregressor.fit(dataset, (Dataset) trainingParameters.getRegressionTrainingParameters());
        Map<Object, Double> featurePvalues = ((StepwiseCompatible) this.mlregressor).getFeaturePvalues();
        this.mlregressor.erase();
        return featurePvalues;
    }
}
