package com.datumbox.framework.machinelearning.common.bases.basemodels;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.DataTable2D;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.FlatDataList;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.MapFunctions;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseBoostingBagging.ModelParameters;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseBoostingBagging.TrainingParameters;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseBoostingBagging.ValidationMetrics;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclassifier;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.validation.ClassifierValidation;
import com.datumbox.framework.machinelearning.ensemblelearning.FixedCombinationRules;
import com.datumbox.framework.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.statistics.sampling.SRS;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseBoostingBagging.class */
public abstract class BaseBoostingBagging<MP extends ModelParameters, TP extends TrainingParameters, VM extends ValidationMetrics> extends BaseMLclassifier<MP, TP, VM> {
    private static final String DB_INDICATOR = "Cmp";
    private static final int maxNumberOfRetries = 2;

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseBoostingBagging$ModelParameters.class */
    public static abstract class ModelParameters extends BaseMLclassifier.ModelParameters {
        private List<Double> weakClassifierWeights;

        /* JADX INFO: Access modifiers changed from: protected */
        public ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
            this.weakClassifierWeights = new ArrayList();
        }

        public List<Double> getWeakClassifierWeights() {
            return this.weakClassifierWeights;
        }

        protected void setWeakClassifierWeights(List<Double> list) {
            this.weakClassifierWeights = list;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseBoostingBagging$Status.class */
    protected enum Status {
        NEXT,
        STOP,
        IGNORE
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseBoostingBagging$TrainingParameters.class */
    public static abstract class TrainingParameters extends BaseMLclassifier.TrainingParameters {
        private int maxWeakClassifiers = 5;
        private Class<? extends BaseMLclassifier> weakClassifierClass;
        private BaseMLclassifier.TrainingParameters weakClassifierTrainingParameters;

        public int getMaxWeakClassifiers() {
            return this.maxWeakClassifiers;
        }

        public void setMaxWeakClassifiers(int i) {
            this.maxWeakClassifiers = i;
        }

        public Class<? extends BaseMLclassifier> getWeakClassifierClass() {
            return this.weakClassifierClass;
        }

        public void setWeakClassifierClass(Class<? extends BaseMLclassifier> cls) {
            this.weakClassifierClass = cls;
        }

        public BaseMLclassifier.TrainingParameters getWeakClassifierTrainingParameters() {
            return this.weakClassifierTrainingParameters;
        }

        public void setWeakClassifierTrainingParameters(BaseMLclassifier.TrainingParameters trainingParameters) {
            this.weakClassifierTrainingParameters = trainingParameters;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseBoostingBagging$ValidationMetrics.class */
    public static abstract class ValidationMetrics extends BaseMLclassifier.ValidationMetrics {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseBoostingBagging(String str, DatabaseConfiguration databaseConfiguration, Class<MP> cls, Class<TP> cls2, Class<VM> cls3) {
        super(str, databaseConfiguration, cls, cls2, cls3, new ClassifierValidation());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        Class<? extends BaseMLclassifier> weakClassifierClass = ((TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters()).getWeakClassifierClass();
        List<Double> weakClassifierWeights = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getWeakClassifierWeights();
        DatabaseConnector dbc = ((MLmodelKnowledgeBase) this.knowledgeBase).getDbc();
        Map bigMap = dbc.getBigMap("tmp_recordDecisions", true);
        AssociativeArray associativeArray = new AssociativeArray(bigMap);
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            associativeArray.put(it.next(), new DataTable2D());
        }
        AssociativeArray associativeArray2 = new AssociativeArray();
        int size = weakClassifierWeights.size();
        for (int i = 0; i < size; i++) {
            BaseMLclassifier baseMLclassifier = (BaseMLclassifier) BaseMLmodel.newInstance(weakClassifierClass, this.dbName + ((MLmodelKnowledgeBase) this.knowledgeBase).getDbConf().getDBnameSeparator() + DB_INDICATOR + String.valueOf(i), ((MLmodelKnowledgeBase) this.knowledgeBase).getDbConf());
            baseMLclassifier.predict(dataset);
            baseMLclassifier.close();
            associativeArray2.put(Integer.valueOf(i), weakClassifierWeights.get(i));
            Iterator<Integer> it2 = dataset.iterator();
            while (it2.hasNext()) {
                Integer next = it2.next();
                ((DataTable2D) associativeArray.get(next)).put(Integer.valueOf(i), dataset.get(next).getYPredictedProbabilities());
            }
        }
        Iterator<Integer> it3 = dataset.iterator();
        while (it3.hasNext()) {
            Integer next2 = it3.next();
            Record record = dataset.get(next2);
            AssociativeArray weightedAverage = FixedCombinationRules.weightedAverage((DataTable2D) associativeArray.get(next2), associativeArray2);
            Descriptives.normalize(weightedAverage);
            dataset.set(next2, new Record(record.getX(), record.getY(), MapFunctions.selectMaxKeyValue(weightedAverage).getKey(), weightedAverage));
        }
        dbc.dropBigMap("tmp_recordDecisions", bigMap);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        int intValue = modelParameters.getN().intValue();
        Set<Object> classes = modelParameters.getClasses();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            classes.add(dataset.get(it.next()).getY());
        }
        AssociativeArray associativeArray = new AssociativeArray();
        Iterator<Integer> it2 = dataset.iterator();
        while (it2.hasNext()) {
            associativeArray.put(it2.next(), Double.valueOf(1.0d / intValue));
        }
        Class<? extends BaseMLclassifier> weakClassifierClass = trainingParameters.getWeakClassifierClass();
        BaseMLclassifier.TrainingParameters weakClassifierTrainingParameters = trainingParameters.getWeakClassifierTrainingParameters();
        int maxWeakClassifiers = trainingParameters.getMaxWeakClassifiers();
        int i = 0;
        int i2 = 0;
        while (i < maxWeakClassifiers) {
            this.logger.debug("Training Weak learner {}", Integer.valueOf(i));
            FlatDataList flatDataList = SRS.weightedSampling(associativeArray, intValue, true).toFlatDataList();
            Dataset generateNewSubset = dataset.generateNewSubset(flatDataList);
            BaseMLclassifier baseMLclassifier = (BaseMLclassifier) BaseMLmodel.newInstance(weakClassifierClass, this.dbName + ((MLmodelKnowledgeBase) this.knowledgeBase).getDbConf().getDBnameSeparator() + DB_INDICATOR + String.valueOf(i), ((MLmodelKnowledgeBase) this.knowledgeBase).getDbConf());
            baseMLclassifier.fit(generateNewSubset, (Dataset) weakClassifierTrainingParameters);
            generateNewSubset.erase();
            baseMLclassifier.predict(dataset);
            baseMLclassifier.close();
            Status updateObservationAndClassifierWeights = updateObservationAndClassifierWeights(dataset, associativeArray, flatDataList);
            if (updateObservationAndClassifierWeights == Status.STOP) {
                this.logger.debug("Skipping further training due to low error");
                return;
            }
            if (updateObservationAndClassifierWeights != Status.IGNORE) {
                if (updateObservationAndClassifierWeights == Status.NEXT) {
                    i2 = 0;
                }
                i++;
            } else if (i2 >= maxNumberOfRetries) {
                this.logger.debug("Too many retries, skipping further training");
                return;
            } else {
                this.logger.debug("Ignoring last weak learner due to high error");
                i2++;
            }
        }
    }

    protected abstract Status updateObservationAndClassifierWeights(Dataset dataset, AssociativeArray associativeArray, FlatDataList flatDataList);

    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable, com.datumbox.common.objecttypes.Trainable
    public void erase() {
        eraseWeakClassifiers();
        super.erase();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void eraseWeakClassifiers() {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        if (modelParameters == null) {
            return;
        }
        Class<? extends BaseMLclassifier> weakClassifierClass = trainingParameters.getWeakClassifierClass();
        int min = Math.min(modelParameters.getWeakClassifierWeights().size() + 1, trainingParameters.getMaxWeakClassifiers());
        for (int i = 0; i < min; i++) {
            ((BaseMLclassifier) BaseMLmodel.newInstance(weakClassifierClass, this.dbName + ((MLmodelKnowledgeBase) this.knowledgeBase).getDbConf().getDBnameSeparator() + DB_INDICATOR + String.valueOf(i), ((MLmodelKnowledgeBase) this.knowledgeBase).getDbConf())).erase();
        }
    }
}
