package com.datumbox.framework.machinelearning.common.bases.mlmodels;

import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.objecttypes.Learnable;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.ModelParameters;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.TrainingParameters;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.ValidationMetrics;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
import com.datumbox.framework.machinelearning.common.bases.validation.ModelValidation;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/mlmodels/BaseMLclusterer.class */
public abstract class BaseMLclusterer<CL extends Cluster, MP extends ModelParameters, TP extends TrainingParameters, VM extends ValidationMetrics> extends BaseMLmodel<MP, TP, VM> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/mlmodels/BaseMLclusterer$Cluster.class */
    public static abstract class Cluster implements Learnable, Iterable<Integer> {
        protected Integer clusterId;
        protected Set<Integer> recordIdSet = new HashSet();
        protected Object labelY;

        /* JADX INFO: Access modifiers changed from: protected */
        public Cluster(Integer num) {
            this.clusterId = num;
        }

        public Integer getClusterId() {
            return this.clusterId;
        }

        public Set<Integer> getRecordIdSet() {
            return Collections.unmodifiableSet(this.recordIdSet);
        }

        public Object getLabelY() {
            return this.labelY;
        }

        protected void setLabelY(Object obj) {
            this.labelY = obj;
        }

        public int size() {
            if (this.recordIdSet == null) {
                return 0;
            }
            return this.recordIdSet.size();
        }

        @Override // java.lang.Iterable
        public Iterator<Integer> iterator() {
            return new Iterator<Integer>() { // from class: com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLclusterer.Cluster.1
                private final Iterator<Integer> it;

                {
                    this.it = Cluster.this.recordIdSet.iterator();
                }

                @Override // java.util.Iterator
                public boolean hasNext() {
                    return this.it.hasNext();
                }

                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.Iterator
                public Integer next() {
                    return this.it.next();
                }

                @Override // java.util.Iterator
                public void remove() {
                    throw new UnsupportedOperationException();
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void clear() {
            if (this.recordIdSet != null) {
                this.recordIdSet.clear();
            }
        }

        public int hashCode() {
            return (89 * 7) + this.clusterId.intValue();
        }

        public boolean equals(Object obj) {
            return obj != null && getClass() == obj.getClass() && Objects.equals(this.clusterId, ((Cluster) obj).clusterId);
        }

        protected abstract boolean add(Integer num, Record record);

        protected abstract boolean remove(Integer num, Record record);
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/mlmodels/BaseMLclusterer$ModelParameters.class */
    public static abstract class ModelParameters<CL extends Cluster> extends BaseMLmodel.ModelParameters {
        private Set<Object> goldStandardClasses;
        private Map<Integer, CL> clusterList;

        /* JADX INFO: Access modifiers changed from: protected */
        public ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
            this.goldStandardClasses = new LinkedHashSet();
            this.clusterList = new HashMap();
        }

        public Integer getC() {
            return Integer.valueOf(this.clusterList.size());
        }

        public Set<Object> getGoldStandardClasses() {
            return this.goldStandardClasses;
        }

        protected void setGoldStandardClasses(Set<Object> set) {
            this.goldStandardClasses = set;
        }

        public Map<Integer, CL> getClusterList() {
            return this.clusterList;
        }

        protected void setClusterList(Map<Integer, CL> map) {
            this.clusterList = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/mlmodels/BaseMLclusterer$TrainingParameters.class */
    public static abstract class TrainingParameters extends BaseMLmodel.TrainingParameters {
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/common/bases/mlmodels/BaseMLclusterer$ValidationMetrics.class */
    public static abstract class ValidationMetrics extends BaseMLmodel.ValidationMetrics {
        private Double purity = null;
        private Double NMI = null;

        public Double getPurity() {
            return this.purity;
        }

        public void setPurity(Double d) {
            this.purity = d;
        }

        public Double getNMI() {
            return this.NMI;
        }

        public void setNMI(Double d) {
            this.NMI = d;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseMLclusterer(String str, DatabaseConfiguration databaseConfiguration, Class<MP> cls, Class<TP> cls2, Class<VM> cls3, ModelValidation<MP, TP, VM> modelValidation) {
        super(str, databaseConfiguration, cls, cls2, cls3, modelValidation);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel
    public VM validateModel(Dataset dataset) {
        predictDataset(dataset);
        int recordNumber = dataset.getRecordNumber();
        ModelParameters modelParameters = (ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters();
        Map<Integer, CL> clusterList = modelParameters.getClusterList();
        Set<Object> goldStandardClasses = modelParameters.getGoldStandardClasses();
        VM vm = (VM) ((MLmodelKnowledgeBase) this.knowledgeBase).getEmptyValidationMetricsObject();
        if (goldStandardClasses.isEmpty()) {
            return vm;
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        Iterator<CL> it = clusterList.values().iterator();
        while (it.hasNext()) {
            Integer clusterId = it.next().getClusterId();
            hashMap2.put(clusterId, Double.valueOf(0.0d));
            for (Object obj : modelParameters.getGoldStandardClasses()) {
                hashMap.put(Arrays.asList(clusterId, obj), Double.valueOf(0.0d));
                hashMap3.put(obj, Double.valueOf(0.0d));
            }
        }
        Iterator<Integer> it2 = dataset.iterator();
        while (it2.hasNext()) {
            Record record = dataset.get(it2.next());
            Integer num = (Integer) record.getYPredicted();
            Object y = record.getY();
            List asList = Arrays.asList(num, y);
            hashMap.put(asList, Double.valueOf(((Double) hashMap.get(asList)).doubleValue() + 1.0d));
            hashMap2.put(num, Double.valueOf(((Double) hashMap2.get(num)).doubleValue() + 1.0d));
            hashMap3.put(y, Double.valueOf(((Double) hashMap3.get(y)).doubleValue() + 1.0d));
        }
        double log = Math.log(recordNumber);
        double d = 0.0d;
        double d2 = 0.0d;
        for (CL cl : clusterList.values()) {
            Integer clusterId2 = cl.getClusterId();
            double d3 = Double.NEGATIVE_INFINITY;
            for (Object obj2 : modelParameters.getGoldStandardClasses()) {
                double doubleValue = ((Double) hashMap.get(Arrays.asList(clusterId2, obj2))).doubleValue();
                if (doubleValue > d3) {
                    d3 = doubleValue;
                    cl.setLabelY(obj2);
                }
                if (doubleValue > 0.0d) {
                    d2 += (doubleValue / recordNumber) * (((Math.log(doubleValue) - Math.log(((Double) hashMap3.get(obj2)).doubleValue())) - Math.log(((Double) hashMap2.get(clusterId2)).doubleValue())) + log);
                }
            }
            d += d3;
        }
        vm.setPurity(Double.valueOf(d / recordNumber));
        double d4 = 0.0d;
        for (Double d5 : hashMap2.values()) {
            d4 -= (d5.doubleValue() / recordNumber) * (Math.log(d5.doubleValue()) - log);
        }
        double d6 = 0.0d;
        for (Double d7 : hashMap2.values()) {
            d6 -= (d7.doubleValue() / recordNumber) * (Math.log(d7.doubleValue()) - log);
        }
        vm.setNMI(Double.valueOf(d2 / ((d4 + d6) / 2.0d)));
        return vm;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Map<Integer, CL> getClusters() {
        if (this.knowledgeBase == 0) {
            return null;
        }
        return ((ModelParameters) ((MLmodelKnowledgeBase) this.knowledgeBase).getModelParameters()).getClusterList();
    }
}
