package com.datumbox.framework.machinelearning.clustering;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.dataobjects.TypeInference;
import com.datumbox.common.persistentstorage.interfaces.BigMap;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.MapFunctions;
import com.datumbox.common.utilities.PHPfunctions;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.validation.ClustererValidation;
import com.datumbox.framework.mathematics.distances.Distance;
import com.datumbox.framework.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.statistics.sampling.SRS;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/Kmeans.class */
public class Kmeans extends BaseMLclusterer<Cluster, ModelParameters, TrainingParameters, ValidationMetrics> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/Kmeans$Cluster.class */
    public static class Cluster extends BaseMLclusterer.Cluster {
        private Record centroid;
        private final transient AssociativeArray xi_sum;

        protected Cluster(int i) {
            super(Integer.valueOf(i));
            this.centroid = new Record(new AssociativeArray(), null);
            this.xi_sum = new AssociativeArray();
        }

        public Record getCentroid() {
            return this.centroid;
        }

        protected boolean updateClusterParameters() {
            boolean z = false;
            int size = this.recordIdSet.size();
            AssociativeArray associativeArray = new AssociativeArray();
            associativeArray.addValues(this.xi_sum);
            if (size > 0) {
                associativeArray.multiplyValues(1.0d / size);
            }
            if (!this.centroid.getX().equals(associativeArray)) {
                z = true;
                this.centroid = new Record(associativeArray, this.centroid.getY());
            }
            return z;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        protected boolean add(Integer num, Record record) {
            boolean add = this.recordIdSet.add(num);
            if (add) {
                this.xi_sum.addValues(record.getX());
            }
            return add;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        protected boolean remove(Integer num, Record record) {
            throw new UnsupportedOperationException();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        public void clear() {
            super.clear();
            this.xi_sum.clear();
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/Kmeans$ModelParameters.class */
    public static class ModelParameters extends BaseMLclusterer.ModelParameters<Cluster> {
        private int totalIterations;

        @BigMap
        private Map<Object, Double> featureWeights;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public int getTotalIterations() {
            return this.totalIterations;
        }

        protected void setTotalIterations(int i) {
            this.totalIterations = i;
        }

        public Map<Object, Double> getFeatureWeights() {
            return this.featureWeights;
        }

        protected void setFeatureWeights(Map<Object, Double> map) {
            this.featureWeights = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/Kmeans$TrainingParameters.class */
    public static class TrainingParameters extends BaseMLclusterer.TrainingParameters {
        private int k = 2;
        private Initialization initializationMethod = Initialization.PLUS_PLUS;
        private Distance distanceMethod = Distance.EUCLIDIAN;
        private int maxIterations = 200;
        private double subsetFurthestFirstcValue = 2.0d;
        private double categoricalGamaMultiplier = 1.0d;
        private boolean weighted = false;

        /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/Kmeans$TrainingParameters$Distance.class */
        public enum Distance {
            EUCLIDIAN,
            MANHATTAN
        }

        /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/Kmeans$TrainingParameters$Initialization.class */
        public enum Initialization {
            FORGY,
            RANDOM_PARTITION,
            SET_FIRST_K,
            FURTHEST_FIRST,
            SUBSET_FURTHEST_FIRST,
            PLUS_PLUS
        }

        public int getK() {
            return this.k;
        }

        public void setK(int i) {
            this.k = i;
        }

        public Initialization getInitializationMethod() {
            return this.initializationMethod;
        }

        public void setInitializationMethod(Initialization initialization) {
            this.initializationMethod = initialization;
        }

        public Distance getDistanceMethod() {
            return this.distanceMethod;
        }

        public void setDistanceMethod(Distance distance) {
            this.distanceMethod = distance;
        }

        public int getMaxIterations() {
            return this.maxIterations;
        }

        public void setMaxIterations(int i) {
            this.maxIterations = i;
        }

        public double getSubsetFurthestFirstcValue() {
            return this.subsetFurthestFirstcValue;
        }

        public void setSubsetFurthestFirstcValue(double d) {
            this.subsetFurthestFirstcValue = d;
        }

        public double getCategoricalGamaMultiplier() {
            return this.categoricalGamaMultiplier;
        }

        public void setCategoricalGamaMultiplier(double d) {
            this.categoricalGamaMultiplier = d;
        }

        public boolean isWeighted() {
            return this.weighted;
        }

        public void setWeighted(boolean z) {
            this.weighted = z;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/Kmeans$ValidationMetrics.class */
    public static class ValidationMetrics extends BaseMLclusterer.ValidationMetrics {
    }

    public Kmeans(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class, new ClustererValidation());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        if (dataset.isEmpty()) {
            return;
        }
        Map<Integer, Cluster> clusterList = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getClusterList();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            AssociativeArray associativeArray = new AssociativeArray();
            for (Cluster cluster : clusterList.values()) {
                associativeArray.put(cluster.getClusterId(), Double.valueOf(calculateDistance(record, cluster.getCentroid())));
            }
            Descriptives.normalize(associativeArray);
            dataset.set(next, new Record(record.getX(), record.getY(), getSelectedClusterFromDistances(associativeArray), associativeArray));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        Set<Object> goldStandardClasses = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getGoldStandardClasses();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Object y = dataset.get(it.next()).getY();
            if (y != null) {
                goldStandardClasses.add(y);
            }
        }
        calculateFeatureWeights(dataset);
        initializeClusters(dataset);
        calculateClusters(dataset);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void calculateFeatureWeights(Dataset dataset) {
        double doubleValue;
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        Map<Object, TypeInference.DataType> xDataTypes = dataset.getXDataTypes();
        Map<Object, Double> featureWeights = modelParameters.getFeatureWeights();
        if (!trainingParameters.isWeighted()) {
            double categoricalGamaMultiplier = trainingParameters.getCategoricalGamaMultiplier();
            Iterator<Integer> it = dataset.iterator();
            while (it.hasNext()) {
                for (Object obj : dataset.get(it.next()).getX().keySet()) {
                    featureWeights.put(obj, Double.valueOf(xDataTypes.get(obj) != TypeInference.DataType.NUMERICAL ? categoricalGamaMultiplier : 1.0d));
                }
            }
            return;
        }
        int intValue = modelParameters.getN().intValue();
        DatabaseConnector dbc = ((MLmodelKnowledgeBase) this.knowledgeBase).getDbc();
        Map bigMap = dbc.getBigMap("tmp_categoricalFrequencies", true);
        Map bigMap2 = dbc.getBigMap("tmp_varianceSumX", true);
        Map bigMap3 = dbc.getBigMap("tmp_varianceSumXsquare", true);
        Iterator<Integer> it2 = dataset.iterator();
        while (it2.hasNext()) {
            for (Map.Entry<Object, Object> entry : dataset.get(it2.next()).getX().entrySet()) {
                Double d = TypeInference.toDouble(entry.getValue());
                if (d != null && d.doubleValue() != 0.0d) {
                    Object key = entry.getKey();
                    if (xDataTypes.get(key) != TypeInference.DataType.NUMERICAL) {
                        Double d2 = (Double) bigMap.get(key);
                        if (d2 == null) {
                            d2 = Double.valueOf(0.0d);
                        }
                        bigMap.put(key, Double.valueOf(d2.doubleValue() + 1.0d));
                    } else {
                        Double d3 = (Double) bigMap2.get(key);
                        Double d4 = (Double) bigMap3.get(key);
                        if (d3 == null) {
                            d3 = Double.valueOf(0.0d);
                            d4 = Double.valueOf(0.0d);
                        }
                        bigMap2.put(key, Double.valueOf(d3.doubleValue() + d.doubleValue()));
                        bigMap3.put(key, Double.valueOf(d4.doubleValue() + (d.doubleValue() * d.doubleValue())));
                    }
                }
            }
        }
        double categoricalGamaMultiplier2 = trainingParameters.getCategoricalGamaMultiplier();
        for (Map.Entry<Object, TypeInference.DataType> entry2 : xDataTypes.entrySet()) {
            Object key2 = entry2.getKey();
            TypeInference.DataType value = entry2.getValue();
            if (value != TypeInference.DataType.NUMERICAL) {
                double doubleValue2 = ((Double) bigMap.get(key2)).doubleValue() / intValue;
                doubleValue = 1.0d - (doubleValue2 * doubleValue2);
            } else {
                double doubleValue3 = ((Double) bigMap2.get(key2)).doubleValue() / intValue;
                doubleValue = 2.0d * ((((Double) bigMap3.get(key2)).doubleValue() / intValue) - (doubleValue3 * doubleValue3));
            }
            if (doubleValue > 0.0d) {
                doubleValue = 1.0d / doubleValue;
            }
            if (value != TypeInference.DataType.NUMERICAL) {
                doubleValue *= categoricalGamaMultiplier2;
            }
            featureWeights.put(key2, Double.valueOf(doubleValue));
        }
        dbc.dropBigMap("tmp_categoricalFrequencies", bigMap);
        dbc.dropBigMap("tmp_varianceSumX", bigMap);
        dbc.dropBigMap("tmp_varianceSumXsquare", bigMap);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double calculateDistance(Record record, Record record2) {
        double manhattanWeighted;
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        Map<Object, Double> featureWeights = modelParameters.getFeatureWeights();
        TrainingParameters.Distance distanceMethod = trainingParameters.getDistanceMethod();
        if (distanceMethod == TrainingParameters.Distance.EUCLIDIAN) {
            manhattanWeighted = Distance.euclideanWeighted(record.getX(), record2.getX(), featureWeights);
        } else {
            if (distanceMethod != TrainingParameters.Distance.MANHATTAN) {
                throw new RuntimeException("Unsupported Distance method");
            }
            manhattanWeighted = Distance.manhattanWeighted(record.getX(), record2.getX(), featureWeights);
        }
        return manhattanWeighted;
    }

    private Object getSelectedClusterFromDistances(AssociativeArray associativeArray) {
        return MapFunctions.selectMinKeyValue(associativeArray).getKey();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void initializeClusters(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        int k = trainingParameters.getK();
        TrainingParameters.Initialization initializationMethod = trainingParameters.getInitializationMethod();
        Map<Integer, Cluster> clusterList = modelParameters.getClusterList();
        if (initializationMethod == TrainingParameters.Initialization.SET_FIRST_K || initializationMethod == TrainingParameters.Initialization.FORGY) {
            int i = 0;
            Iterator<Integer> it = dataset.iterator();
            while (it.hasNext()) {
                Integer next = it.next();
                Record record = dataset.get(next);
                if (i >= k) {
                    return;
                }
                Integer valueOf = Integer.valueOf(i);
                Cluster cluster = new Cluster(valueOf.intValue());
                cluster.add(next, record);
                cluster.updateClusterParameters();
                clusterList.put(valueOf, cluster);
                i++;
            }
            return;
        }
        if (initializationMethod == TrainingParameters.Initialization.RANDOM_PARTITION) {
            int i2 = 0;
            Iterator<Integer> it2 = dataset.iterator();
            while (it2.hasNext()) {
                Integer next2 = it2.next();
                Record record2 = dataset.get(next2);
                Integer valueOf2 = Integer.valueOf(i2 % k);
                Cluster cluster2 = clusterList.get(valueOf2);
                if (cluster2 == null) {
                    cluster2 = new Cluster(valueOf2.intValue());
                    clusterList.put(valueOf2, cluster2);
                }
                cluster2.add(next2, record2);
                i2++;
            }
            Iterator<Cluster> it3 = clusterList.values().iterator();
            while (it3.hasNext()) {
                it3.next().updateClusterParameters();
            }
            return;
        }
        if (initializationMethod != TrainingParameters.Initialization.FURTHEST_FIRST && initializationMethod != TrainingParameters.Initialization.SUBSET_FURTHEST_FIRST) {
            if (initializationMethod == TrainingParameters.Initialization.PLUS_PLUS) {
                DatabaseConnector dbc = ((MLmodelKnowledgeBase) this.knowledgeBase).getDbc();
                HashSet hashSet = new HashSet();
                for (int i3 = 0; i3 < k; i3++) {
                    Map bigMap = dbc.getBigMap("tmp_minClusterDistance", true);
                    AssociativeArray associativeArray = new AssociativeArray(bigMap);
                    Iterator<Integer> it4 = dataset.iterator();
                    while (it4.hasNext()) {
                        Integer next3 = it4.next();
                        Record record3 = dataset.get(next3);
                        if (!hashSet.contains(next3)) {
                            double d = 1.0d;
                            if (clusterList.size() > 0) {
                                d = Double.MAX_VALUE;
                                Iterator<Cluster> it5 = clusterList.values().iterator();
                                while (it5.hasNext()) {
                                    double calculateDistance = calculateDistance(record3, it5.next().getCentroid());
                                    if (calculateDistance < d) {
                                        d = calculateDistance;
                                    }
                                }
                            }
                            associativeArray.put(next3, Double.valueOf(d));
                        }
                    }
                    Descriptives.normalize(associativeArray);
                    Integer num = (Integer) SRS.weightedSampling(associativeArray, 1, true).iterator().next();
                    dbc.dropBigMap("tmp_minClusterDistance", bigMap);
                    hashSet.add(num);
                    Integer valueOf3 = Integer.valueOf(clusterList.size());
                    Cluster cluster3 = new Cluster(valueOf3.intValue());
                    cluster3.add(num, dataset.get(num));
                    cluster3.updateClusterParameters();
                    clusterList.put(valueOf3, cluster3);
                }
                return;
            }
            return;
        }
        int intValue = modelParameters.getN().intValue();
        if (initializationMethod == TrainingParameters.Initialization.SUBSET_FURTHEST_FIRST) {
            intValue = (int) Math.max(Math.ceil(trainingParameters.getSubsetFurthestFirstcValue() * k * PHPfunctions.log(k, 2.0d)), k);
        }
        HashSet hashSet2 = new HashSet();
        for (int i4 = 0; i4 < k; i4++) {
            Integer num2 = null;
            double d2 = 0.0d;
            int i5 = 0;
            Iterator<Integer> it6 = dataset.iterator();
            while (it6.hasNext()) {
                Integer next4 = it6.next();
                Record record4 = dataset.get(next4);
                if (i5 > intValue) {
                    break;
                }
                if (!hashSet2.contains(next4)) {
                    double d3 = Double.MAX_VALUE;
                    Iterator<Cluster> it7 = clusterList.values().iterator();
                    while (it7.hasNext()) {
                        double calculateDistance2 = calculateDistance(record4, it7.next().getCentroid());
                        if (calculateDistance2 < d3) {
                            d3 = calculateDistance2;
                        }
                    }
                    if (d3 > d2) {
                        d2 = d3;
                        num2 = next4;
                    }
                    i5++;
                }
            }
            hashSet2.add(num2);
            Integer valueOf4 = Integer.valueOf(clusterList.size());
            Cluster cluster4 = new Cluster(valueOf4.intValue());
            cluster4.add(num2, dataset.get(num2));
            cluster4.updateClusterParameters();
            clusterList.put(valueOf4, cluster4);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void calculateClusters(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        Map<Integer, Cluster> clusterList = modelParameters.getClusterList();
        int maxIterations = trainingParameters.getMaxIterations();
        modelParameters.setTotalIterations(maxIterations);
        AssociativeArray associativeArray = new AssociativeArray();
        for (int i = 0; i < maxIterations; i++) {
            this.logger.debug("Iteration {}", Integer.valueOf(i));
            Iterator<Cluster> it = clusterList.values().iterator();
            while (it.hasNext()) {
                it.next().clear();
            }
            Iterator<Integer> it2 = dataset.iterator();
            while (it2.hasNext()) {
                Integer next = it2.next();
                Record record = dataset.get(next);
                for (Cluster cluster : clusterList.values()) {
                    associativeArray.put(cluster, Double.valueOf(calculateDistance(record, cluster.getCentroid())));
                }
                ((Cluster) getSelectedClusterFromDistances(associativeArray)).add(next, record);
                associativeArray.clear();
            }
            boolean z = false;
            Iterator<Cluster> it3 = clusterList.values().iterator();
            while (it3.hasNext()) {
                z |= it3.next().updateClusterParameters();
            }
            if (!z) {
                modelParameters.setTotalIterations(i);
                return;
            }
        }
    }
}
