package com.datumbox.framework.machinelearning.classification;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.dataobjects.TypeInference;
import com.datumbox.common.persistentstorage.interfaces.BigMap;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.RandomGenerator;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclassifier;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.validation.ClassifierValidation;
import com.datumbox.framework.statistics.descriptivestatistics.Descriptives;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;

/* loaded from: input_file:com/datumbox/framework/machinelearning/classification/SupportVectorMachine.class */
public class SupportVectorMachine extends BaseMLclassifier<ModelParameters, TrainingParameters, ValidationMetrics> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/classification/SupportVectorMachine$ModelParameters.class */
    public static class ModelParameters extends BaseMLclassifier.ModelParameters {

        @BigMap
        private Map<Object, Integer> featureIds;
        private Map<Object, Integer> classIds;
        private svm_model svmModel;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
            this.classIds = new HashMap();
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }

        public svm_model getSvmModel() {
            return this.svmModel;
        }

        protected void setSvmModel(svm_model svm_modelVar) {
            this.svmModel = svm_modelVar;
        }

        public Map<Object, Integer> getClassIds() {
            return this.classIds;
        }

        protected void setClassIds(Map<Object, Integer> map) {
            this.classIds = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/classification/SupportVectorMachine$TrainingParameters.class */
    public static class TrainingParameters extends BaseMLclassifier.TrainingParameters {
        private svm_parameter svmParameter = new svm_parameter();

        public TrainingParameters() {
            this.svmParameter.svm_type = 0;
            this.svmParameter.kernel_type = 0;
            this.svmParameter.degree = 3;
            this.svmParameter.gamma = 0.0d;
            this.svmParameter.coef0 = 0.0d;
            this.svmParameter.nu = 0.5d;
            this.svmParameter.cache_size = 100.0d;
            this.svmParameter.C = 1.0d;
            this.svmParameter.eps = 0.001d;
            this.svmParameter.p = 0.1d;
            this.svmParameter.shrinking = 1;
            this.svmParameter.probability = 1;
            this.svmParameter.nr_weight = 0;
            this.svmParameter.weight_label = new int[0];
            this.svmParameter.weight = new double[0];
        }

        public svm_parameter getSvmParameter() {
            return this.svmParameter;
        }

        public void setSvmParameter(svm_parameter svm_parameterVar) {
            this.svmParameter = svm_parameterVar;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/classification/SupportVectorMachine$ValidationMetrics.class */
    public static class ValidationMetrics extends BaseMLclassifier.ValidationMetrics {
    }

    public SupportVectorMachine(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class, new ClassifierValidation());
        svm.rand.setSeed(RandomGenerator.getThreadLocalRandom().nextLong());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            AssociativeArray calculateClassScores = calculateClassScores(record.getX());
            Object selectedClassFromClassScores = getSelectedClassFromClassScores(calculateClassScores);
            Descriptives.normalize(calculateClassScores);
            dataset.set(next, new Record(record.getX(), record.getY(), selectedClassFromClassScores, calculateClassScores));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        ((TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters()).getSvmParameter().probability = 1;
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        Map<Object, Integer> classIds = modelParameters.getClassIds();
        Set<Object> classes = modelParameters.getClasses();
        int i = 0;
        int i2 = 0;
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Record record = dataset.get(it.next());
            Object y = record.getY();
            if (!classIds.containsKey(y)) {
                int i3 = i;
                i++;
                classIds.put(y, Integer.valueOf(i3));
                classes.add(y);
            }
            Iterator<Map.Entry<Object, Object>> it2 = record.getX().entrySet().iterator();
            while (it2.hasNext()) {
                Object key = it2.next().getKey();
                if (!featureIds.containsKey(key)) {
                    int i4 = i2;
                    i2++;
                    featureIds.put(key, Integer.valueOf(i4));
                }
            }
        }
        LibSVMTrainer(dataset);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void LibSVMTrainer(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        Map<Object, Integer> classIds = modelParameters.getClassIds();
        int intValue = modelParameters.getN().intValue();
        int size = featureIds.size();
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = intValue;
        svm_problemVar.y = new double[intValue];
        svm_problemVar.x = new svm_node[intValue][size];
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            svm_problemVar.y[next.intValue()] = classIds.get(record.getY()).intValue();
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                int intValue2 = featureIds.get(entry.getKey()).intValue();
                Double d = TypeInference.toDouble(entry.getValue());
                if (d == null) {
                    d = Double.valueOf(0.0d);
                }
                svm_node svm_nodeVar = new svm_node();
                svm_nodeVar.index = intValue2 + 1;
                svm_nodeVar.value = d.doubleValue();
                svm_problemVar.x[next.intValue()][intValue2] = svm_nodeVar;
            }
            for (int i = 0; i < size; i++) {
                if (svm_problemVar.x[next.intValue()][i] == null) {
                    svm_node svm_nodeVar2 = new svm_node();
                    svm_nodeVar2.index = i + 1;
                    svm_nodeVar2.value = 0.0d;
                    svm_problemVar.x[next.intValue()][i] = svm_nodeVar2;
                }
            }
        }
        svm_parameter svmParameter = ((TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters()).getSvmParameter();
        svm.svm_set_print_string_function(str -> {
            this.logger.debug(str);
        });
        modelParameters.setSvmModel(svm.svm_train(svm_problemVar, svmParameter));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private AssociativeArray calculateClassScores(AssociativeArray associativeArray) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        Map<Object, Integer> classIds = modelParameters.getClassIds();
        svm_model svmModel = modelParameters.getSvmModel();
        int size = featureIds.size();
        int intValue = modelParameters.getC().intValue();
        svm_node[] svm_nodeVarArr = new svm_node[size];
        for (Map.Entry<Object, Object> entry : associativeArray.entrySet()) {
            Integer num = featureIds.get(entry.getKey());
            if (num != null) {
                Double d = TypeInference.toDouble(entry.getValue());
                if (d == null) {
                    d = Double.valueOf(0.0d);
                }
                svm_node svm_nodeVar = new svm_node();
                svm_nodeVar.index = num.intValue() + 1;
                svm_nodeVar.value = d.doubleValue();
                svm_nodeVarArr[num.intValue()] = svm_nodeVar;
            }
        }
        for (int i = 0; i < size; i++) {
            if (svm_nodeVarArr[i] == null) {
                svm_node svm_nodeVar2 = new svm_node();
                svm_nodeVar2.index = i + 1;
                svm_nodeVar2.value = 0.0d;
                svm_nodeVarArr[i] = svm_nodeVar2;
            }
        }
        svm.svm_get_labels(svmModel, new int[intValue]);
        double[] dArr = new double[intValue];
        svm.svm_predict_probability(svmModel, svm_nodeVarArr, dArr);
        AssociativeArray associativeArray2 = new AssociativeArray();
        for (Map.Entry<Object, Integer> entry2 : classIds.entrySet()) {
            associativeArray2.put(entry2.getKey(), Double.valueOf(dArr[entry2.getValue().intValue()]));
        }
        return associativeArray2;
    }
}
