package com.datumbox.framework.machinelearning.common.bases.baseobjects;

import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.objecttypes.Trainable;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseModelParameters;
import com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainingParameters;
import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase;
import java.lang.reflect.InvocationTargetException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/baseobjects/BaseTrainable.class */
public abstract class BaseTrainable<MP extends BaseModelParameters, TP extends BaseTrainingParameters, KB extends KnowledgeBase<MP, TP>> implements Trainable<MP, TP> {
    protected final Logger logger;
    protected KB knowledgeBase;
    protected String dbName;

    public static <BT extends BaseTrainable> BT newInstance(Class<BT> cls, String str, DatabaseConfiguration databaseConfiguration) {
        try {
            return cls.getConstructor(String.class, DatabaseConfiguration.class).newInstance(str, databaseConfiguration);
        } catch (IllegalAccessException | IllegalArgumentException | InstantiationException | NoSuchMethodException | SecurityException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTrainable(String str, DatabaseConfiguration databaseConfiguration) {
        this.logger = LoggerFactory.getLogger(getClass());
        String simpleName = getClass().getSimpleName();
        String dBnameSeparator = databaseConfiguration.getDBnameSeparator();
        this.dbName = str.contains(new StringBuilder().append(simpleName).append(dBnameSeparator).toString()) ? str : str + dBnameSeparator + simpleName;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTrainable(String str, DatabaseConfiguration databaseConfiguration, Class<MP> cls, Class<TP> cls2) {
        this(str, databaseConfiguration);
        this.knowledgeBase = (KB) new KnowledgeBase(this.dbName, databaseConfiguration, cls, cls2);
    }

    @Override // com.datumbox.common.objecttypes.Trainable
    public MP getModelParameters() {
        return (MP) this.knowledgeBase.getModelParameters();
    }

    @Override // com.datumbox.common.objecttypes.Trainable
    public TP getTrainingParameters() {
        return (TP) this.knowledgeBase.getTrainingParameters();
    }

    @Override // com.datumbox.common.objecttypes.Trainable
    public void fit(Dataset dataset, TP tp) {
        this.logger.info("fit()");
        this.knowledgeBase.reinitialize();
        this.knowledgeBase.setTrainingParameters(tp);
        BaseModelParameters modelParameters = this.knowledgeBase.getModelParameters();
        modelParameters.setN(Integer.valueOf(dataset.getRecordNumber()));
        modelParameters.setD(Integer.valueOf(dataset.getVariableNumber()));
        _fit(dataset);
        this.logger.info("Saving model");
        this.knowledgeBase.save();
    }

    @Override // com.datumbox.common.objecttypes.Trainable
    public void erase() {
        this.knowledgeBase.erase();
    }

    @Override // com.datumbox.common.objecttypes.Trainable
    public void close() {
        this.knowledgeBase.close();
    }

    protected abstract void _fit(Dataset dataset);
}
