package com.datumbox.framework.machinelearning.recommendersystem;

import com.datumbox.common.dataobjects.AssociativeArray;
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.dataobjects.FlatDataList;
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.dataobjects.TransposeDataList;
import com.datumbox.common.dataobjects.TypeInference;
import com.datumbox.common.persistentstorage.interfaces.BigMap;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.MapFunctions;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLrecommender;
import com.datumbox.framework.mathematics.distances.Distance;
import com.datumbox.framework.statistics.parametrics.relatedsamples.PearsonCorrelation;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/* loaded from: input_file:com/datumbox/framework/machinelearning/recommendersystem/CollaborativeFiltering.class */
public class CollaborativeFiltering extends BaseMLrecommender<ModelParameters, TrainingParameters> {

    /* loaded from: input_file:com/datumbox/framework/machinelearning/recommendersystem/CollaborativeFiltering$ModelParameters.class */
    public static class ModelParameters extends BaseMLrecommender.ModelParameters {

        @BigMap
        private Map<List<Object>, Double> similarities;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<List<Object>, Double> getSimilarities() {
            return this.similarities;
        }

        protected void setSimilarities(Map<List<Object>, Double> map) {
            this.similarities = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/machinelearning/recommendersystem/CollaborativeFiltering$TrainingParameters.class */
    public static class TrainingParameters extends BaseMLrecommender.TrainingParameters {
        private SimilarityMeasure similarityMethod = SimilarityMeasure.EUCLIDIAN;

        /* loaded from: input_file:com/datumbox/framework/machinelearning/recommendersystem/CollaborativeFiltering$TrainingParameters$SimilarityMeasure.class */
        public enum SimilarityMeasure {
            EUCLIDIAN,
            MANHATTAN,
            PEARSONS_CORRELATION
        }

        public SimilarityMeasure getSimilarityMethod() {
            return this.similarityMethod;
        }

        public void setSimilarityMethod(SimilarityMeasure similarityMeasure) {
            this.similarityMethod = similarityMeasure;
        }
    }

    public CollaborativeFiltering(String str, DatabaseConfiguration databaseConfiguration) {
        super(str, databaseConfiguration, ModelParameters.class, TrainingParameters.class);
    }

    @Override // com.datumbox.framework.machinelearning.common.bases.baseobjects.BaseTrainable
    protected void _fit(Dataset dataset) {
        Map<List<Object>, Double> similarities = ((ModelParameters) this.knowledgeBase.getModelParameters()).getSimilarities();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Record record = dataset.get(it.next());
            Iterator<Integer> it2 = dataset.iterator();
            while (it2.hasNext()) {
                Record record2 = dataset.get(it2.next());
                Object y = record.getY();
                Object y2 = record2.getY();
                if (!Objects.equals(y, y2)) {
                    List<Object> asList = Arrays.asList(y, y2);
                    if (!similarities.containsKey(asList)) {
                        double calculateSimilarity = calculateSimilarity(record, record2);
                        if (calculateSimilarity > 0.0d) {
                            similarities.put(asList, Double.valueOf(calculateSimilarity));
                            similarities.put(Arrays.asList(y2, y), Double.valueOf(calculateSimilarity));
                        }
                    }
                }
            }
        }
    }

    @Override // com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLrecommender
    protected void predictDataset(Dataset dataset) {
        Map<List<Object>, Double> similarities = ((ModelParameters) this.knowledgeBase.getModelParameters()).getSimilarities();
        Iterator<Integer> it = dataset.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Record record = dataset.get(next);
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                Object key = entry.getKey();
                Double d = TypeInference.toDouble(entry.getValue());
                for (Map.Entry<List<Object>, Double> entry2 : similarities.entrySet()) {
                    List<Object> key2 = entry2.getKey();
                    if (key2.get(0).equals(key)) {
                        Object obj = key2.get(1);
                        if (!record.getX().containsKey(obj)) {
                            Double d2 = TypeInference.toDouble(hashMap.get(obj));
                            Double d3 = (Double) hashMap2.get(obj);
                            if (d2 == null) {
                                d2 = Double.valueOf(0.0d);
                                d3 = Double.valueOf(0.0d);
                            }
                            Double value = entry2.getValue();
                            hashMap.put(obj, Double.valueOf(d2.doubleValue() + (value.doubleValue() * d.doubleValue())));
                            hashMap2.put(obj, Double.valueOf(d3.doubleValue() + value.doubleValue()));
                        }
                    }
                }
            }
            for (Map.Entry entry3 : hashMap.entrySet()) {
                Object key3 = entry3.getKey();
                hashMap.put(key3, Double.valueOf(TypeInference.toDouble(entry3.getValue()).doubleValue() / ((Double) hashMap2.get(key3)).doubleValue()));
            }
            if (!hashMap.isEmpty()) {
                Map sortNumberMapByValueDescending = MapFunctions.sortNumberMapByValueDescending(hashMap);
                dataset.set(next, new Record(record.getX(), record.getY(), sortNumberMapByValueDescending.keySet().iterator().next(), new AssociativeArray(sortNumberMapByValueDescending)));
            }
        }
    }

    private double calculateSimilarity(Record record, Record record2) {
        double calculateCorrelation;
        TrainingParameters.SimilarityMeasure similarityMethod = ((TrainingParameters) this.knowledgeBase.getTrainingParameters()).getSimilarityMethod();
        if (similarityMethod == TrainingParameters.SimilarityMeasure.EUCLIDIAN) {
            calculateCorrelation = 1.0d / (1.0d + Distance.euclidean(record.getX(), record2.getX()));
        } else if (similarityMethod == TrainingParameters.SimilarityMeasure.MANHATTAN) {
            calculateCorrelation = 1.0d / (1.0d + Distance.manhattan(record.getX(), record2.getX()));
        } else {
            if (similarityMethod != TrainingParameters.SimilarityMeasure.PEARSONS_CORRELATION) {
                throw new RuntimeException("Unsupported Distance method");
            }
            HashSet hashSet = new HashSet(record.getX().keySet());
            hashSet.addAll(record2.getX().keySet());
            FlatDataList flatDataList = new FlatDataList();
            FlatDataList flatDataList2 = new FlatDataList();
            for (Object obj : hashSet) {
                flatDataList.add(TypeInference.toDouble(record.getX().get(obj)));
                flatDataList2.add(TypeInference.toDouble(record2.getX().get(obj)));
            }
            TransposeDataList transposeDataList = new TransposeDataList();
            transposeDataList.put(1, flatDataList);
            transposeDataList.put(2, flatDataList2);
            calculateCorrelation = PearsonCorrelation.calculateCorrelation(transposeDataList);
        }
        return calculateCorrelation;
    }
}
