package com.datumbox.framework.machinelearning.clustering;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.MapFunctions;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import com.datumbox.framework.machinelearning.common.validation.ClustererValidation;
import com.datumbox.framework.mathematics.distances.Distance;
import com.datumbox.framework.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/HierarchicalAgglomerative.class */
public class HierarchicalAgglomerative extends BaseMLclusterer<Cluster, ModelParameters, TrainingParameters, ValidationMetrics> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/HierarchicalAgglomerative$Cluster.class */
    public static class Cluster extends BaseMLclusterer.Cluster {
        private Record centroid;
        private boolean active;
        private final transient AssociativeArray xi_sum;

        protected Cluster(int i) {
            super(Integer.valueOf(i));
            this.active = true;
            this.centroid = new Record(new AssociativeArray(), null);
            this.xi_sum = new AssociativeArray();
        }

        public Record getCentroid() {
            return this.centroid;
        }

        protected boolean merge(Cluster cluster) {
            this.xi_sum.addValues(cluster.xi_sum);
            return this.recordIdSet.addAll(cluster.recordIdSet);
        }

        protected boolean updateClusterParameters() {
            boolean z = false;
            int size = this.recordIdSet.size();
            AssociativeArray associativeArray = new AssociativeArray();
            associativeArray.addValues(this.xi_sum);
            if (size > 0) {
                associativeArray.multiplyValues(1.0d / size);
            }
            if (!this.centroid.getX().equals(associativeArray)) {
                z = true;
                this.centroid = new Record(associativeArray, this.centroid.getY());
            }
            return z;
        }

        protected boolean isActive() {
            return this.active;
        }

        protected void setActive(boolean z) {
            this.active = z;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        protected boolean add(Integer num, Record record) {
            boolean add = this.recordIdSet.add(num);
            if (add) {
                this.xi_sum.addValues(record.getX());
            }
            return add;
        }

        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        protected boolean remove(Integer num, Record record) {
            throw new UnsupportedOperationException();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster
        public void clear() {
            super.clear();
            this.xi_sum.clear();
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/HierarchicalAgglomerative$ModelParameters.class */
    public static class ModelParameters extends BaseMLclusterer.ModelParameters<Cluster> {
        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/HierarchicalAgglomerative$TrainingParameters.class */
    public static class TrainingParameters extends BaseMLclusterer.TrainingParameters {
        private Linkage linkageMethod = Linkage.COMPLETE;
        private Distance distanceMethod = Distance.EUCLIDIAN;
        private double maxDistanceThreshold = Double.MAX_VALUE;
        private double minClustersThreshold = 2.0d;

        /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/HierarchicalAgglomerative$TrainingParameters$Distance.class */
        public enum Distance {
            EUCLIDIAN,
            MANHATTAN,
            MAXIMUM
        }

        /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/HierarchicalAgglomerative$TrainingParameters$Linkage.class */
        public enum Linkage {
            AVERAGE,
            SINGLE,
            COMPLETE
        }

        public Linkage getLinkageMethod() {
            return this.linkageMethod;
        }

        public void setLinkageMethod(Linkage linkage) {
            this.linkageMethod = linkage;
        }

        public Distance getDistanceMethod() {
            return this.distanceMethod;
        }

        public void setDistanceMethod(Distance distance) {
            this.distanceMethod = distance;
        }

        public double getMaxDistanceThreshold() {
            return this.maxDistanceThreshold;
        }

        public void setMaxDistanceThreshold(double d) {
            this.maxDistanceThreshold = d;
        }

        public double getMinClustersThreshold() {
            return this.minClustersThreshold;
        }

        public void setMinClustersThreshold(double d) {
            this.minClustersThreshold = d;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/clustering/HierarchicalAgglomerative$ValidationMetrics.class */
    public static class ValidationMetrics extends BaseMLclusterer.ValidationMetrics {
    }

    public HierarchicalAgglomerative(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class, new ClustererValidation());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public void predictDataset(Dataset dataset) {
        if (dataset.isEmpty()) {
            return;
        }
        Map<Integer, Cluster> clusterList = ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getClusterList();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            AssociativeArray associativeArray = new AssociativeArray();
            for (Cluster cluster : clusterList.values()) {
                associativeArray.put(cluster.getClusterId(), Double.valueOf(calculateDistance(record, cluster.getCentroid())));
            }
            Descriptives.normalize(associativeArray);
            dataset.set(next, new Record(record.getX(), record.getY(), getSelectedClusterFromDistances(associativeArray), associativeArray));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        Set<Object> goldStandardClasses = modelParameters.getGoldStandardClasses();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Object y = dataset.get(it.next()).getY();
            if (y != null) {
                goldStandardClasses.add(y);
            }
        }
        calculateClusters(dataset);
        modelParameters.getClusterList();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double calculateDistance(Record record, Record record2) {
        double maximum;
        TrainingParameters.Distance distanceMethod = ((TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters()).getDistanceMethod();
        if (distanceMethod == TrainingParameters.Distance.EUCLIDIAN) {
            maximum = Distance.euclidean(record.getX(), record2.getX());
        } else if (distanceMethod == TrainingParameters.Distance.MANHATTAN) {
            maximum = Distance.manhattan(record.getX(), record2.getX());
        } else {
            if (distanceMethod != TrainingParameters.Distance.MAXIMUM) {
                throw new RuntimeException("Unsupported Distance method");
            }
            maximum = Distance.maximum(record.getX(), record2.getX());
        }
        return maximum;
    }

    private Object getSelectedClusterFromDistances(AssociativeArray associativeArray) {
        return MapFunctions.selectMinKeyValue(associativeArray).getKey();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void calculateClusters(Dataset dataset) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        Map<Integer, Cluster> clusterList = modelParameters.getClusterList();
        DatabaseConnector dbc = ((MLmodelKnowledgeBase) this.knowledgeBase).getDbc();
        Map<List<Object>, Double> bigMap = dbc.getBigMap("tmp_distanceArray", true);
        Map<Integer, Integer> bigMap2 = dbc.getBigMap("tmp_minClusterDistanceId", true);
        Integer num = 0;
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            Cluster cluster = new Cluster(num.intValue());
            cluster.add(next, record);
            cluster.setActive(true);
            cluster.updateClusterParameters();
            clusterList.put(num, cluster);
            num = Integer.valueOf(num.intValue() + 1);
        }
        for (Map.Entry<Integer, Cluster> entry : clusterList.entrySet()) {
            Integer key = entry.getKey();
            Cluster value = entry.getValue();
            for (Map.Entry<Integer, Cluster> entry2 : clusterList.entrySet()) {
                Integer key2 = entry2.getKey();
                double calculateDistance = Objects.equals(key, key2) ? Double.MAX_VALUE : calculateDistance(value.getCentroid(), entry2.getValue().getCentroid());
                bigMap.put(Arrays.asList(key, key2), Double.valueOf(calculateDistance));
                bigMap.put(Arrays.asList(key2, key), Double.valueOf(calculateDistance));
                Integer num2 = bigMap2.get(key);
                if (num2 == null || calculateDistance < bigMap.get(Arrays.asList(key, num2)).doubleValue()) {
                    bigMap2.put(key, key2);
                }
            }
        }
        boolean z = true;
        while (z) {
            z = mergeClosest(bigMap2, bigMap);
            int i = 0;
            Iterator<Cluster> it2 = clusterList.values().iterator();
            while (it2.hasNext()) {
                if (it2.next().isActive()) {
                    i++;
                }
            }
            if (i <= trainingParameters.getMinClustersThreshold()) {
                z = false;
            }
        }
        Iterator<Map.Entry<Integer, Cluster>> it3 = clusterList.entrySet().iterator();
        while (it3.hasNext()) {
            Cluster value2 = it3.next().getValue();
            if (value2.isActive()) {
                value2.updateClusterParameters();
            } else {
                it3.remove();
            }
        }
        dbc.dropBigMap("tmp_distanceArray", bigMap);
        dbc.dropBigMap("tmp_minClusterDistanceId", bigMap2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private boolean mergeClosest(Map<Integer, Integer> map, Map<List<Object>, Double> map2) {
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getTrainingParameters();
        Map<Integer, Cluster> clusterList = modelParameters.getClusterList();
        Integer num = null;
        double d = Double.MAX_VALUE;
        for (Map.Entry<Integer, Cluster> entry : clusterList.entrySet()) {
            Integer key = entry.getKey();
            if (entry.getValue().isActive()) {
                double doubleValue = map2.get(Arrays.asList(key, map.get(key))).doubleValue();
                if (doubleValue < d) {
                    num = key;
                    d = doubleValue;
                }
            }
        }
        if (d >= trainingParameters.getMaxDistanceThreshold()) {
            return false;
        }
        Integer num2 = map.get(num);
        Cluster cluster = clusterList.get(num);
        Cluster cluster2 = clusterList.get(num2);
        double size = cluster.size();
        double size2 = cluster2.size();
        cluster.merge(cluster2);
        cluster2.setActive(false);
        TrainingParameters.Linkage linkageMethod = trainingParameters.getLinkageMethod();
        Iterator<Map.Entry<Integer, Cluster>> it = clusterList.entrySet().iterator();
        while (it.hasNext()) {
            Cluster value = it.next().getValue();
            if (value.isActive()) {
                double min = Objects.equals(cluster.getClusterId(), value.getClusterId()) ? Double.MAX_VALUE : linkageMethod == TrainingParameters.Linkage.SINGLE ? Math.min(map2.get(Arrays.asList(cluster.getClusterId(), value.getClusterId())).doubleValue(), map2.get(Arrays.asList(cluster2.getClusterId(), value.getClusterId())).doubleValue()) : linkageMethod == TrainingParameters.Linkage.COMPLETE ? Math.max(map2.get(Arrays.asList(cluster.getClusterId(), value.getClusterId())).doubleValue(), map2.get(Arrays.asList(cluster2.getClusterId(), value.getClusterId())).doubleValue()) : linkageMethod == TrainingParameters.Linkage.AVERAGE ? ((map2.get(Arrays.asList(cluster.getClusterId(), value.getClusterId())).doubleValue() * size) + (map2.get(Arrays.asList(cluster2.getClusterId(), value.getClusterId())).doubleValue() * size2)) / (size + size2) : calculateDistance(cluster.getCentroid(), value.getCentroid());
                map2.put(Arrays.asList(cluster.getClusterId(), value.getClusterId()), Double.valueOf(min));
                map2.put(Arrays.asList(value.getClusterId(), cluster.getClusterId()), Double.valueOf(min));
            }
        }
        for (Map.Entry<Integer, Cluster> entry2 : clusterList.entrySet()) {
            Integer key2 = entry2.getKey();
            if (entry2.getValue().isActive()) {
                Integer num3 = map.get(key2);
                if (Objects.equals(num3, cluster.getClusterId()) || Objects.equals(num3, cluster2.getClusterId())) {
                    Integer num4 = key2;
                    for (Map.Entry<Integer, Cluster> entry3 : clusterList.entrySet()) {
                        Integer key3 = entry3.getKey();
                        if (entry3.getValue().isActive() && map2.get(Arrays.asList(key2, key3)).doubleValue() < map2.get(Arrays.asList(key2, num4)).doubleValue()) {
                            num4 = key3;
                        }
                    }
                    map.put(key2, num4);
                }
            }
        }
        return true;
    }
}
