package com.datumbox.applications.datamodeling;

import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.datatransformation.DataTransformer;
import com.datumbox.framework.machinelearning.common.bases.featureselection.FeatureSelection;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
import com.datumbox.framework.machinelearning.common.bases.wrappers.BaseWrapper;

/* loaded from: input_file:com/datumbox/applications/datamodeling/Modeler.class */
public class Modeler extends BaseWrapper<ModelParameters, TrainingParameters> {

    /* loaded from: input_file:com/datumbox/applications/datamodeling/Modeler$ModelParameters.class */
    public static class ModelParameters extends BaseWrapper.ModelParameters {
        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }
    }

    /* loaded from: input_file:com/datumbox/applications/datamodeling/Modeler$TrainingParameters.class */
    public static class TrainingParameters extends BaseWrapper.TrainingParameters<DataTransformer, FeatureSelection, BaseMLmodel> {
    }

    public Modeler(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class);
    }

    public void predict(Dataset dataset) {
        this.logger.info("predict()");
        evaluateData(dataset, false);
    }

    public BaseMLmodel.ValidationMetrics validate(Dataset dataset) {
        this.logger.info("validate()");
        return evaluateData(dataset, true);
    }

    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        DatabaseConfiguration dbConf = this.knowledgeBase.getDbConf();
        Class<? extends DataTransformer> dataTransformerClass = trainingParameters.getDataTransformerClass();
        boolean z = dataTransformerClass != null;
        if (z) {
            this.dataTransformer = (DataTransformer) DataTransformer.newInstance(dataTransformerClass, this.dbName, dbConf);
            this.dataTransformer.fit_transform(dataset, trainingParameters.getDataTransformerTrainingParameters());
        }
        Class<? extends FeatureSelection> featureSelectionClass = trainingParameters.getFeatureSelectionClass();
        if (featureSelectionClass != null) {
            this.featureSelection = (FeatureSelection) FeatureSelection.newInstance(featureSelectionClass, this.dbName, dbConf);
            this.featureSelection.fit_transform(dataset, trainingParameters.getFeatureSelectionTrainingParameters());
        }
        this.mlmodel = (BaseMLmodel) BaseMLmodel.newInstance(trainingParameters.getMLmodelClass(), this.dbName, dbConf);
        this.mlmodel.fit(dataset, (Dataset) trainingParameters.getMLmodelTrainingParameters());
        if (z) {
            this.dataTransformer.denormalize(dataset);
        }
    }

    private BaseMLmodel.ValidationMetrics evaluateData(Dataset dataset, boolean z) {
        this.knowledgeBase.load();
        TrainingParameters trainingParameters = (TrainingParameters) this.knowledgeBase.getTrainingParameters();
        DatabaseConfiguration dbConf = this.knowledgeBase.getDbConf();
        Class<? extends DataTransformer> dataTransformerClass = trainingParameters.getDataTransformerClass();
        boolean z2 = dataTransformerClass != null;
        if (z2) {
            if (this.dataTransformer == null) {
                this.dataTransformer = (DataTransformer) DataTransformer.newInstance(dataTransformerClass, this.dbName, dbConf);
            }
            this.dataTransformer.transform(dataset);
        }
        Class<? extends FeatureSelection> featureSelectionClass = trainingParameters.getFeatureSelectionClass();
        if (featureSelectionClass != null) {
            if (this.featureSelection == null) {
                this.featureSelection = (FeatureSelection) FeatureSelection.newInstance(featureSelectionClass, this.dbName, dbConf);
            }
            this.featureSelection.transform(dataset);
        }
        if (this.mlmodel == null) {
            this.mlmodel = (BaseMLmodel) BaseMLmodel.newInstance(trainingParameters.getMLmodelClass(), this.dbName, dbConf);
        }
        BaseMLmodel.ValidationMetrics validationMetrics = null;
        if (z) {
            validationMetrics = this.mlmodel.validate(dataset);
        } else {
            this.mlmodel.predict(dataset);
        }
        if (z2) {
            this.dataTransformer.denormalize(dataset);
        }
        return validationMetrics;
    }
}
