package com.datumbox.framework.machinelearning.common.bases.validation;

import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.FlatDataList;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.utilities.PHPfunctions;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel.ModelParameters;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel.TrainingParameters;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel.ValidationMetrics;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/validation/ModelValidation.class */
public abstract class ModelValidation<MP extends BaseMLmodel.ModelParameters, TP extends BaseMLmodel.TrainingParameters, VM extends BaseMLmodel.ValidationMetrics> {
    protected final Logger logger = LoggerFactory.getLogger(getClass());
    private static final String DB_INDICATOR = "Kfold";

    public VM kFoldCrossValidation(Dataset dataset, int i, String str, DatabaseConfiguration databaseConfiguration, Class<? extends BaseMLmodel> cls, TP tp) {
        int recordNumber = dataset.getRecordNumber();
        if (i <= 0 || recordNumber <= i) {
            throw new IllegalArgumentException("Invalid number of folds");
        }
        int i2 = recordNumber / i;
        Integer[] numArr = new Integer[recordNumber];
        int i3 = 0;
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            numArr[i3] = it.next();
            i3++;
        }
        PHPfunctions.shuffle(numArr);
        String str2 = str + databaseConfiguration.getDBnameSeparator() + DB_INDICATOR;
        LinkedList linkedList = new LinkedList();
        for (int i4 = 0; i4 < i; i4++) {
            this.logger.info("Kfold {}", Integer.valueOf(i4));
            FlatDataList flatDataList = new FlatDataList(new ArrayList(recordNumber - i2));
            FlatDataList flatDataList2 = new FlatDataList(new ArrayList(i2));
            for (int i5 = 0; i5 < recordNumber; i5++) {
                boolean z = false;
                if (i4 * i2 <= i5 && i5 < (i4 + 1) * i2) {
                    z = true;
                }
                if (z) {
                    flatDataList2.add(numArr[i5]);
                } else {
                    flatDataList.add(numArr[i5]);
                }
            }
            if (i == 1) {
                flatDataList = flatDataList2;
            }
            BaseMLmodel baseMLmodel = (BaseMLmodel) BaseMLmodel.newInstance(cls, str2 + (i4 + 1), databaseConfiguration);
            Dataset generateNewSubset = dataset.generateNewSubset(flatDataList);
            baseMLmodel.fit(generateNewSubset, (Dataset) tp);
            generateNewSubset.erase();
            Dataset generateNewSubset2 = dataset.generateNewSubset(flatDataList2);
            BaseMLmodel.ValidationMetrics validate = baseMLmodel.validate(generateNewSubset2);
            generateNewSubset2.erase();
            baseMLmodel.erase();
            linkedList.add(validate);
        }
        return calculateAverageValidationMetrics(linkedList);
    }

    protected abstract VM calculateAverageValidationMetrics(List<VM> list);
}
