package com.datumbox.framework.machinelearning.common.bases.basemodels;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.BigMap;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.MapFunctions;
import com.datumbox.common.utilities.PHPfunctions;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.ModelParameters;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.TrainingParameters;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.ValidationMetrics;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.validation.ClustererValidation;
import com.datumbox.framework.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.statistics.sampling.SRS;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseDPMM.class */
public abstract class BaseDPMM<CL extends Cluster, MP extends ModelParameters, TP extends TrainingParameters, VM extends ValidationMetrics> extends BaseMLclusterer<CL, MP, TP, VM> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseDPMM$Cluster.class */
    public static abstract class Cluster extends BaseMLclusterer.Cluster {
        protected transient Map<Object, Integer> featureIds;

        /* JADX INFO: Access modifiers changed from: protected */
        public Cluster(Integer num) {
            super(num);
        }

        protected void setClusterId(Integer num) {
            this.clusterId = num;
        }

        protected abstract void updateClusterParameters();

        protected abstract void initializeClusterParameters();

        protected abstract double posteriorLogPdf(Record record);

        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        protected abstract boolean add(Integer num, Record record);

        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        protected abstract boolean remove(Integer num, Record record);
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseDPMM$ModelParameters.class */
    public static abstract class ModelParameters<CL extends Cluster> extends BaseMLclusterer.ModelParameters<CL> {
        private int totalIterations;

        @BigMap
        private Map<Object, Integer> featureIds;

        /* JADX INFO: Access modifiers changed from: protected */
        public ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public int getTotalIterations() {
            return this.totalIterations;
        }

        public void setTotalIterations(int i) {
            this.totalIterations = i;
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        public void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseDPMM$TrainingParameters.class */
    public static abstract class TrainingParameters extends BaseMLclusterer.TrainingParameters {
        private double alpha;
        private int maxIterations = 1000;
        private Initialization initializationMethod = Initialization.ONE_CLUSTER_PER_RECORD;

        /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseDPMM$TrainingParameters$Initialization.class */
        public enum Initialization {
            ONE_CLUSTER_PER_RECORD,
            RANDOM_ASSIGNMENT
        }

        public double getAlpha() {
            return this.alpha;
        }

        public void setAlpha(double d) {
            this.alpha = d;
        }

        public int getMaxIterations() {
            return this.maxIterations;
        }

        public void setMaxIterations(int i) {
            this.maxIterations = i;
        }

        public Initialization getInitializationMethod() {
            return this.initializationMethod;
        }

        public void setInitializationMethod(Initialization initialization) {
            this.initializationMethod = initialization;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/basemodels/BaseDPMM$ValidationMetrics.class */
    public static abstract class ValidationMetrics extends BaseMLclusterer.ValidationMetrics {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseDPMM(String str, DatabaseConfiguration databaseConfiguration, Class<MP> cls, Class<TP> cls2, Class<VM> cls3) {
        super(str, databaseConfiguration, cls, cls2, cls3, new ClustererValidation());
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        Set<Object> goldStandardClasses = modelParameters.getGoldStandardClasses();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        int i = 0;
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Record record = dataset.get(it.next());
            Object y = record.getY();
            if (y != null) {
                goldStandardClasses.add(y);
            }
            Iterator<Map.Entry<Object, Object>> it2 = record.getX().entrySet().iterator();
            while (it2.hasNext()) {
                Object key = it2.next().getKey();
                if (!featureIds.containsKey(key)) {
                    int i2 = i;
                    i++;
                    featureIds.put(key, Integer.valueOf(i2));
                }
            }
        }
        modelParameters.setTotalIterations(collapsedGibbsSampling(dataset));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private int collapsedGibbsSampling(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        Map<Integer, CL> hashMap = new HashMap<>(modelParameters.getClusterList());
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        double alpha = trainingParameters.getAlpha();
        Integer valueOf = Integer.valueOf(hashMap.size());
        if (trainingParameters.getInitializationMethod() == TrainingParameters.Initialization.ONE_CLUSTER_PER_RECORD) {
            Iterator<Integer> it = dataset.iterator();
            while (it.hasNext()) {
                Integer next = it.next();
                Record record = dataset.get(next);
                CL createNewCluster = createNewCluster(valueOf);
                Record record2 = new Record(record.getX(), record.getY(), valueOf, record.getYPredictedProbabilities());
                dataset.set(next, record2);
                createNewCluster.add(next, record2);
                hashMap.put(valueOf, createNewCluster);
                valueOf = Integer.valueOf(valueOf.intValue() + 1);
            }
        } else {
            int max = (int) (Math.max(alpha, 1.0d) * Math.log(dataset.getRecordNumber()));
            if (max <= 0) {
                max = 1;
            }
            for (int i = 0; i < max; i++) {
                hashMap.put(valueOf, createNewCluster(valueOf));
                valueOf = Integer.valueOf(valueOf.intValue() + 1);
            }
            int intValue = valueOf.intValue();
            Iterator<Integer> it2 = dataset.iterator();
            while (it2.hasNext()) {
                Integer next2 = it2.next();
                Record record3 = dataset.get(next2);
                int mt_rand = PHPfunctions.mt_rand(0, intValue - 1);
                Record record4 = new Record(record3.getX(), record3.getY(), Integer.valueOf(mt_rand), record3.getYPredictedProbabilities());
                dataset.set(next2, record4);
                hashMap.get(Integer.valueOf(mt_rand)).add(next2, record4);
            }
        }
        int size = hashMap.size();
        int maxIterations = trainingParameters.getMaxIterations();
        boolean z = false;
        int i2 = 0;
        while (i2 < maxIterations && !z) {
            this.logger.debug("Iteration {}", Integer.valueOf(i2));
            z = true;
            Iterator<Integer> it3 = dataset.iterator();
            while (it3.hasNext()) {
                Integer next3 = it3.next();
                Record record5 = dataset.get(next3);
                Integer num = (Integer) record5.getYPredicted();
                CL cl = hashMap.get(num);
                cl.remove(next3, record5);
                if (cl.size() == 0) {
                    hashMap.remove(num);
                }
                AssociativeArray clusterProbabilities = clusterProbabilities(record5, size, hashMap);
                CL createNewCluster2 = createNewCluster(valueOf);
                clusterProbabilities.put(valueOf, Double.valueOf(createNewCluster2.posteriorLogPdf(record5) + Math.log(alpha / ((alpha + size) - 1.0d))));
                Descriptives.normalizeExp(clusterProbabilities);
                Integer num2 = (Integer) SRS.weightedSampling(clusterProbabilities, 1, true).iterator().next();
                if (Objects.equals(num2, valueOf)) {
                    Record record6 = new Record(record5.getX(), record5.getY(), valueOf, record5.getYPredictedProbabilities());
                    dataset.set(next3, record6);
                    createNewCluster2.add(next3, record6);
                    hashMap.put(valueOf, createNewCluster2);
                    z = false;
                    valueOf = Integer.valueOf(valueOf.intValue() + 1);
                } else {
                    Record record7 = new Record(record5.getX(), record5.getY(), num2, record5.getYPredictedProbabilities());
                    dataset.set(next3, record7);
                    hashMap.get(num2).add(next3, record7);
                    if (z && !Objects.equals(num, num2)) {
                        z = false;
                    }
                }
            }
            i2++;
        }
        Map<Integer, CL> clusterList = modelParameters.getClusterList();
        Integer valueOf2 = Integer.valueOf(clusterList.size());
        for (CL cl2 : hashMap.values()) {
            cl2.setClusterId(valueOf2);
            clusterList.put(valueOf2, cl2);
            valueOf2 = Integer.valueOf(valueOf2.intValue() + 1);
        }
        return i2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private AssociativeArray clusterProbabilities(Record record, int i, Map<Integer, CL> map) {
        AssociativeArray associativeArray = new AssociativeArray();
        double alpha = ((TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters()).getAlpha();
        for (CL cl : map.values()) {
            associativeArray.put(cl.getClusterId(), Double.valueOf(cl.posteriorLogPdf(record) + Math.log(cl.size() / ((alpha + i) - 1.0d))));
        }
        return associativeArray;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        if (dataset.isEmpty()) {
            return;
        }
        Map<Integer, CL> clusterList = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getClusterList();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            AssociativeArray associativeArray = new AssociativeArray();
            for (Cluster cluster : clusterList.values()) {
                associativeArray.put(cluster.getClusterId(), Double.valueOf(cluster.posteriorLogPdf(record)));
            }
            Descriptives.normalizeExp(associativeArray);
            dataset.set(next, new Record(record.getX(), record.getY(), getSelectedClusterFromScores(associativeArray), associativeArray));
        }
    }

    private Object getSelectedClusterFromScores(AssociativeArray associativeArray) {
        return MapFunctions.selectMaxKeyValue(associativeArray).getKey();
    }

    protected abstract CL createNewCluster(Integer num);
}
