package com.datumbox.framework.machinelearning.clustering;

import com.datumbox.common.dataobjects.MatrixDataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import java.util.Map;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/GaussianDPMM.class */
public class GaussianDPMM extends BaseDPMM<Cluster, ModelParameters, TrainingParameters, ValidationMetrics> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/GaussianDPMM$Cluster.class */
    public static class Cluster extends BaseDPMM.Cluster {
        private int dimensions;
        private int kappa0;
        private int nu0;
        private RealVector mu0;
        private RealMatrix psi0;
        private RealVector mean;
        private RealMatrix covariance;
        private RealMatrix meanError;
        private int meanDf;
        private transient RealVector xi_sum;
        private transient RealMatrix xi_square_sum;
        private transient Double cache_covariance_determinant;
        private transient RealMatrix cache_covariance_inverse;

        protected Cluster(Integer num) {
            super(num);
        }

        protected Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }

        protected int getKappa0() {
            return this.kappa0;
        }

        protected void setKappa0(int i) {
            this.kappa0 = i;
        }

        protected int getNu0() {
            return this.nu0;
        }

        protected void setNu0(int i) {
            this.nu0 = i;
        }

        protected RealVector getMu0() {
            return this.mu0;
        }

        protected void setMu0(RealVector realVector) {
            this.mu0 = realVector;
        }

        protected RealMatrix getPsi0() {
            return this.psi0;
        }

        protected void setPsi0(RealMatrix realMatrix) {
            this.psi0 = realMatrix;
        }

        protected int getDimensions() {
            return this.dimensions;
        }

        protected void setDimensions(int i) {
            this.dimensions = i;
        }

        protected RealMatrix getMeanError() {
            return this.meanError;
        }

        protected int getMeanDf() {
            return this.meanDf;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster
        protected double posteriorLogPdf(Record record) {
            RealVector subtract = MatrixDataset.parseRecord(record, this.featureIds).subtract(this.mean);
            if (this.cache_covariance_determinant == null || this.cache_covariance_inverse == null) {
                LUDecomposition lUDecomposition = new LUDecomposition(this.covariance);
                this.cache_covariance_determinant = Double.valueOf(lUDecomposition.getDeterminant());
                this.cache_covariance_inverse = lUDecomposition.getSolver().getInverse();
            }
            return ((-0.5d) * this.cache_covariance_inverse.preMultiply(subtract).dotProduct(subtract)) + Math.log(1.0d / (Math.pow(6.283185307179586d, this.dimensions / 2.0d) * Math.pow(this.cache_covariance_determinant.doubleValue(), 0.5d)));
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster
        protected void initializeClusterParameters() {
            if (this.nu0 < this.dimensions) {
                this.nu0 = this.dimensions;
            }
            if (this.mu0 == null) {
                this.mu0 = new ArrayRealVector(this.dimensions);
            }
            if (this.psi0 == null) {
                this.psi0 = MatrixUtils.createRealIdentityMatrix(this.dimensions);
            }
            this.mean = new ArrayRealVector(this.dimensions);
            this.covariance = MatrixUtils.createRealIdentityMatrix(this.dimensions);
            this.meanError = calculateMeanError(this.psi0, this.kappa0, this.nu0);
            this.meanDf = (this.nu0 - this.dimensions) + 1;
            this.cache_covariance_determinant = null;
            this.cache_covariance_inverse = null;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster, com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        protected boolean add(Integer num, Record record) {
            int size = this.recordIdSet.size();
            if (!this.recordIdSet.add(num)) {
                return false;
            }
            RealVector parseRecord = MatrixDataset.parseRecord(record, this.featureIds);
            if (size == 0) {
                this.xi_sum = parseRecord;
                this.xi_square_sum = parseRecord.outerProduct(parseRecord);
            } else {
                this.xi_sum = this.xi_sum.add(parseRecord);
                this.xi_square_sum = this.xi_square_sum.add(parseRecord.outerProduct(parseRecord));
            }
            updateClusterParameters();
            return true;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster, com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        protected boolean remove(Integer num, Record record) {
            if (!this.recordIdSet.remove(num)) {
                return false;
            }
            RealVector parseRecord = MatrixDataset.parseRecord(record, this.featureIds);
            this.xi_sum = this.xi_sum.subtract(parseRecord);
            this.xi_square_sum = this.xi_square_sum.subtract(parseRecord.outerProduct(parseRecord));
            updateClusterParameters();
            return true;
        }

        private RealMatrix calculateMeanError(RealMatrix realMatrix, int i, int i2) {
            return realMatrix.scalarMultiply(1.0d / (i * ((i2 - this.dimensions) + 1.0d)));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        public void clear() {
            super.clear();
            this.xi_sum = null;
            this.xi_square_sum = null;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM.Cluster
        protected void updateClusterParameters() {
            int size = this.recordIdSet.size();
            int i = this.kappa0 + size;
            int i2 = this.nu0 + size;
            RealVector mapDivide = this.xi_sum.mapDivide(size);
            RealVector subtract = mapDivide.subtract(this.mu0);
            RealMatrix add = this.psi0.add(this.xi_square_sum.subtract(mapDivide.outerProduct(mapDivide).scalarMultiply(size)).add(subtract.outerProduct(subtract).scalarMultiply((this.kappa0 * size) / i)));
            this.mean = this.mu0.mapMultiply(this.kappa0).add(mapDivide.mapMultiply(size)).mapDivide(i);
            this.covariance = add.scalarMultiply((i + 1.0d) / (i * ((i2 - this.dimensions) + 1.0d)));
            this.cache_covariance_determinant = null;
            this.cache_covariance_inverse = null;
            this.meanError = calculateMeanError(add, i, i2);
            this.meanDf = (i2 - this.dimensions) + 1;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/GaussianDPMM$ModelParameters.class */
    public static class ModelParameters extends BaseDPMM.ModelParameters<Cluster> {
        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.ModelParameters
        public Map<Integer, Cluster> getClusterList() {
            Map<Integer, Cluster> clusterList = super.getClusterList();
            Map<Object, Integer> featureIds = getFeatureIds();
            for (Cluster cluster : clusterList.values()) {
                if (cluster.getFeatureIds() == null) {
                    cluster.setFeatureIds(featureIds);
                }
            }
            return clusterList;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/GaussianDPMM$TrainingParameters.class */
    public static class TrainingParameters extends BaseDPMM.TrainingParameters {
        private int kappa0 = 0;
        private int nu0 = 1;
        private double[] mu0;
        private double[][] psi0;

        public int getKappa0() {
            return this.kappa0;
        }

        public void setKappa0(int i) {
            this.kappa0 = i;
        }

        public int getNu0() {
            return this.nu0;
        }

        public void setNu0(int i) {
            this.nu0 = i;
        }

        public double[] getMu0() {
            return this.mu0;
        }

        public void setMu0(double[] dArr) {
            this.mu0 = dArr;
        }

        public double[][] getPsi0() {
            return this.psi0;
        }

        public void setPsi0(double[][] dArr) {
            this.psi0 = dArr;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/GaussianDPMM$ValidationMetrics.class */
    public static class ValidationMetrics extends BaseDPMM.ValidationMetrics {
    }

    public GaussianDPMM(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.basemodels.BaseDPMM
    public Cluster createNewCluster(Integer num) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        Cluster cluster = new Cluster(num);
        cluster.setDimensions(modelParameters.getD().intValue());
        cluster.setFeatureIds(modelParameters.getFeatureIds());
        cluster.setKappa0(trainingParameters.getKappa0());
        cluster.setNu0(trainingParameters.getNu0());
        double[] mu0 = trainingParameters.getMu0();
        if (mu0 != null) {
            cluster.setMu0(new ArrayRealVector(mu0));
        }
        double[][] psi0 = trainingParameters.getPsi0();
        if (psi0 != null) {
            cluster.setPsi0(new BlockRealMatrix(psi0));
        }
        cluster.initializeClusterParameters();
        return cluster;
    }
}
