package com.datumbox.framework.machinelearning.clustering;

import com.datumbox.common.dataobjects.MatrixDataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.statistics.distributions.ContinuousDistributions;
import java.util.Map;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/MultinomialDPMM.class */
public class MultinomialDPMM extends BaseDPMM<Cluster, ModelParameters, TrainingParameters, ValidationMetrics> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/MultinomialDPMM$Cluster.class */
    public static class Cluster extends BaseDPMM.Cluster {
        private int dimensions;
        private double alphaWords;
        private RealVector wordCounts;
        private transient Double cache_wordcounts_plusalpha;

        protected Cluster(Integer num) {
            super(num);
        }

        protected double getAlphaWords() {
            return this.alphaWords;
        }

        protected void setAlphaWords(double d) {
            this.alphaWords = d;
        }

        protected int getDimensions() {
            return this.dimensions;
        }

        protected void setDimensions(int i) {
            this.dimensions = i;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster
        protected void initializeClusterParameters() {
            this.cache_wordcounts_plusalpha = null;
            this.wordCounts = new ArrayRealVector(this.dimensions);
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster
        protected double posteriorLogPdf(Record record) {
            RealVector parseRecord = MatrixDataset.parseRecord(record, this.featureIds);
            RealVector add = this.wordCounts.add(new ArrayRealVector(this.dimensions, this.alphaWords));
            if (this.cache_wordcounts_plusalpha == null) {
                this.cache_wordcounts_plusalpha = Double.valueOf(C(add));
            }
            return C(add.add(parseRecord)) - this.cache_wordcounts_plusalpha.doubleValue();
        }

        protected Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster, com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        protected boolean add(Integer num, Record record) {
            int size = this.recordIdSet.size();
            if (!this.recordIdSet.add(num)) {
                return false;
            }
            RealVector parseRecord = MatrixDataset.parseRecord(record, this.featureIds);
            if (size == 0) {
                this.wordCounts = parseRecord;
            } else {
                this.wordCounts = this.wordCounts.add(parseRecord);
            }
            updateClusterParameters();
            return true;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster, com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        protected boolean remove(Integer num, Record record) {
            if (!this.recordIdSet.remove(num)) {
                return false;
            }
            this.wordCounts = this.wordCounts.subtract(MatrixDataset.parseRecord(record, this.featureIds));
            updateClusterParameters();
            return true;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster
        protected void updateClusterParameters() {
            this.cache_wordcounts_plusalpha = null;
        }

        private double C(RealVector realVector) {
            double d = 0.0d;
            double d2 = 0.0d;
            int dimension = realVector.getDimension();
            for (int i = 0; i < dimension; i++) {
                double entry = realVector.getEntry(i);
                d += entry;
                d2 += ContinuousDistributions.LogGamma(entry);
            }
            return d2 - ContinuousDistributions.LogGamma(d);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/MultinomialDPMM$ModelParameters.class */
    public static class ModelParameters extends BaseDPMM.ModelParameters<Cluster> {
        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.ModelParameters
        public Map<Integer, Cluster> getClusterList() {
            Map<Integer, Cluster> clusterList = super.getClusterList();
            Map<Object, Integer> featureIds = getFeatureIds();
            for (Cluster cluster : clusterList.values()) {
                if (cluster.getFeatureIds() == null) {
                    cluster.setFeatureIds(featureIds);
                }
            }
            return clusterList;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/MultinomialDPMM$TrainingParameters.class */
    public static class TrainingParameters extends BaseDPMM.TrainingParameters {
        private double alphaWords = 50.0d;

        public double getAlphaWords() {
            return this.alphaWords;
        }

        public void setAlphaWords(double d) {
            this.alphaWords = d;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/MultinomialDPMM$ValidationMetrics.class */
    public static class ValidationMetrics extends BaseDPMM.ValidationMetrics {
    }

    public MultinomialDPMM(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM
    public Cluster createNewCluster(Integer num) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        Cluster cluster = new Cluster(num);
        cluster.setDimensions(modelParameters.getD().intValue());
        cluster.setFeatureIds(modelParameters.getFeatureIds());
        cluster.setAlphaWords(trainingParameters.getAlphaWords());
        cluster.initializeClusterParameters();
        return cluster;
    }
}
