package org.apache.mahout.cf.taste.example.kddcup.track1.svd;

import java.util.Collection;
import java.util.Random;
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.common.RandomUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.class */
public class ParallelArraysSGDFactorizer implements Factorizer {
    public static final double DEFAULT_LEARNING_RATE = 0.005d;
    public static final double DEFAULT_PREVENT_OVERFITTING = 0.02d;
    public static final double DEFAULT_RANDOM_NOISE = 0.005d;
    private final int numFeatures;
    private final int numIterations;
    private final float minPreference;
    private final float maxPreference;
    private final Random random;
    private final double learningRate;
    private final double preventOverfitting;
    private final FastByIDMap<Integer> userIDMapping;
    private final FastByIDMap<Integer> itemIDMapping;
    private final double[][] userFeatures;
    private final double[][] itemFeatures;
    private final int[] userIndexes;
    private final int[] itemIndexes;
    private final float[] values;
    private final double defaultValue;
    private final double interval;
    private final double[] cachedEstimates;
    private static final Logger log = LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class);

    public ParallelArraysSGDFactorizer(DataModel dataModel, int i, int i2) {
        this(new DataModelFactorizablePreferences(dataModel), i, i2, 0.005d, 0.02d, 0.005d);
    }

    public ParallelArraysSGDFactorizer(DataModel dataModel, int i, int i2, double d, double d2, double d3) {
        this(new DataModelFactorizablePreferences(dataModel), i, i2, d, d2, d3);
    }

    public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int i, int i2) {
        this(factorizablePreferences, i, i2, 0.005d, 0.02d, 0.005d);
    }

    public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int i, int i2, double d, double d2, double d3) {
        this.numFeatures = i;
        this.numIterations = i2;
        this.minPreference = factorizablePreferences.getMinPreference();
        this.maxPreference = factorizablePreferences.getMaxPreference();
        this.random = RandomUtils.getRandom();
        this.learningRate = d;
        this.preventOverfitting = d2;
        int numUsers = factorizablePreferences.numUsers();
        int numItems = factorizablePreferences.numItems();
        int numPreferences = factorizablePreferences.numPreferences();
        log.info("Mapping {} users...", Integer.valueOf(numUsers));
        this.userIDMapping = new FastByIDMap<>(numUsers);
        int i3 = 0;
        LongPrimitiveIterator userIDs = factorizablePreferences.getUserIDs();
        while (userIDs.hasNext()) {
            int i4 = i3;
            i3++;
            this.userIDMapping.put(userIDs.nextLong(), Integer.valueOf(i4));
        }
        log.info("Mapping {} items", Integer.valueOf(numItems));
        this.itemIDMapping = new FastByIDMap<>(numItems);
        int i5 = 0;
        LongPrimitiveIterator itemIDs = factorizablePreferences.getItemIDs();
        while (itemIDs.hasNext()) {
            int i6 = i5;
            i5++;
            this.itemIDMapping.put(itemIDs.nextLong(), Integer.valueOf(i6));
        }
        this.userIndexes = new int[numPreferences];
        this.itemIndexes = new int[numPreferences];
        this.values = new float[numPreferences];
        this.cachedEstimates = new double[numPreferences];
        int i7 = 0;
        log.info("Loading {} preferences into memory", Integer.valueOf(numPreferences));
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        for (Preference preference : factorizablePreferences.getPreferences()) {
            this.userIndexes[i7] = ((Integer) this.userIDMapping.get(preference.getUserID())).intValue();
            this.itemIndexes[i7] = ((Integer) this.itemIDMapping.get(preference.getItemID())).intValue();
            this.values[i7] = preference.getValue();
            this.cachedEstimates[i7] = 0.0d;
            fullRunningAverage.addDatum(preference.getValue());
            i7++;
            if (i7 % 1000000 == 0) {
                log.info("Processed {} preferences", Integer.valueOf(i7));
            }
        }
        log.info("Processed {} preferences, done.", Integer.valueOf(i7));
        double average = fullRunningAverage.getAverage();
        log.info("Average preference value is {}", Double.valueOf(average));
        double maxPreference = factorizablePreferences.getMaxPreference() - factorizablePreferences.getMinPreference();
        this.defaultValue = Math.sqrt((average - (maxPreference * 0.1d)) / i);
        this.interval = (maxPreference * 0.1d) / i;
        this.userFeatures = new double[numUsers][i];
        this.itemFeatures = new double[numItems][i];
        log.info("Initializing feature vectors...");
        for (int i8 = 0; i8 < i; i8++) {
            for (int i9 = 0; i9 < numUsers; i9++) {
                this.userFeatures[i9][i8] = this.defaultValue + ((this.random.nextDouble() - 0.5d) * this.interval * d3);
            }
            for (int i10 = 0; i10 < numItems; i10++) {
                this.itemFeatures[i10][i8] = this.defaultValue + ((this.random.nextDouble() - 0.5d) * this.interval * d3);
            }
        }
    }

    public Factorization factorize() throws TasteException {
        for (int i = 0; i < this.numFeatures; i++) {
            log.info("Shuffling preferences...");
            shufflePreferences();
            log.info("Starting training of feature {} ...", Integer.valueOf(i));
            for (int i2 = 0; i2 < this.numIterations; i2++) {
                if (i2 == this.numIterations - 1) {
                    log.info("Finished training feature {} with RMSE {}", Integer.valueOf(i), Double.valueOf(trainingIterationWithRmse(i)));
                } else {
                    trainingIteration(i);
                }
            }
            if (i < this.numFeatures - 1) {
                log.info("Updating cache...");
                for (int i3 = 0; i3 < this.userIndexes.length; i3++) {
                    this.cachedEstimates[i3] = estimate(this.userIndexes[i3], this.itemIndexes[i3], i, this.cachedEstimates[i3], false);
                }
            }
        }
        log.info("Factorization done");
        return new Factorization(this.userIDMapping, this.itemIDMapping, this.userFeatures, this.itemFeatures);
    }

    private void trainingIteration(int i) {
        for (int i2 = 0; i2 < this.userIndexes.length; i2++) {
            train(this.userIndexes[i2], this.itemIndexes[i2], i, this.values[i2], this.cachedEstimates[i2]);
        }
    }

    private double trainingIterationWithRmse(int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.userIndexes.length; i2++) {
            double train = train(this.userIndexes[i2], this.itemIndexes[i2], i, this.values[i2], this.cachedEstimates[i2]);
            d += train * train;
        }
        return Math.sqrt(d / this.userIndexes.length);
    }

    private double estimate(int i, int i2, int i3, double d, boolean z) {
        double d2 = d + (this.userFeatures[i][i3] * this.itemFeatures[i2][i3]);
        if (z) {
            d2 += ((this.numFeatures - i3) - 1) * (this.defaultValue + this.interval) * (this.defaultValue + this.interval);
            if (d2 > this.maxPreference) {
                d2 = this.maxPreference;
            } else if (d2 < this.minPreference) {
                d2 = this.minPreference;
            }
        }
        return d2;
    }

    public double train(int i, int i2, int i3, double d, double d2) {
        double estimate = d - estimate(i, i2, i3, d2, true);
        double[] dArr = this.userFeatures[i];
        double[] dArr2 = this.itemFeatures[i2];
        dArr[i3] = dArr[i3] + (this.learningRate * ((estimate * dArr2[i3]) - (this.preventOverfitting * dArr[i3])));
        dArr2[i3] = dArr2[i3] + (this.learningRate * ((estimate * dArr[i3]) - (this.preventOverfitting * dArr2[i3])));
        return estimate;
    }

    protected void shufflePreferences() {
        for (int length = this.userIndexes.length - 1; length > 0; length--) {
            swapPreferences(length, this.random.nextInt(length + 1));
        }
    }

    private void swapPreferences(int i, int i2) {
        int i3 = this.userIndexes[i];
        int i4 = this.itemIndexes[i];
        float f = this.values[i];
        double d = this.cachedEstimates[i];
        this.userIndexes[i] = this.userIndexes[i2];
        this.itemIndexes[i] = this.itemIndexes[i2];
        this.values[i] = this.values[i2];
        this.cachedEstimates[i] = this.cachedEstimates[i2];
        this.userIndexes[i2] = i3;
        this.itemIndexes[i2] = i4;
        this.values[i2] = f;
        this.cachedEstimates[i2] = d;
    }

    public void refresh(Collection<Refreshable> collection) {
    }
}
