package cc.fasttext;

import cc.fasttext.Args;
import com.google.common.collect.TreeMultimap;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.AtomicDouble;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang.Validate;
import org.apache.commons.math3.random.RandomAdaptor;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:cc/fasttext/Model.class */
public class Model {
    private static final int SIGMOID_TABLE_SIZE = 512;
    private static final int MAX_SIGMOID = 8;
    private static final int LOG_TABLE_SIZE = 512;
    private static final int NEGATIVE_TABLE_SIZE = 10000000;
    private static final Comparator<Float> HEAP_PROBABILITY_COMPARATOR = Comparator.reverseOrder();
    private static final Comparator<Integer> HEAP_LABEL_COMPARATOR = Comparator.reverseOrder();
    private static final int PARALLEL_SIZE_THRESHOLD = Integer.parseInt(System.getProperty("parallel.model.threshold", String.valueOf(FastText.PARALLEL_THRESHOLD_FACTOR * 100)));
    private QMatrix qwi_;
    private QMatrix qwo_;
    private RandomGenerator rng;
    private Matrix wi_;
    private Matrix wo_;
    private Vector hidden_;
    private Vector output_;
    private Vector grad_;
    private int osz_;
    private float loss_;
    private long nexamples_;
    private float[] t_sigmoid;
    private float[] t_log;
    private List<Integer> negatives;
    private int negpos;
    private List<List<Integer>> paths;
    private List<List<Boolean>> codes;
    private List<Node> tree;
    private final Args.ModelName model;
    private final Args.LossName loss;
    private final int dim;
    private final int neg;
    private final boolean qout;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/fasttext/Model$Node.class */
    public class Node {
        int parent;
        int left;
        int right;
        long count;
        boolean binary;

        private Node() {
        }
    }

    public Model(Matrix matrix, Matrix matrix2, Args args, RandomGenerator randomGenerator) {
        this(matrix, matrix2, args.model(), args.loss(), args.dim(), args.neg(), args.qout(), randomGenerator);
    }

    private Model(Matrix matrix, Matrix matrix2, Args.ModelName modelName, Args.LossName lossName, int i, int i2, boolean z, RandomGenerator randomGenerator) {
        this.model = modelName;
        this.loss = lossName;
        this.dim = i;
        this.neg = i2;
        this.qout = z;
        this.hidden_ = new Vector(i);
        this.output_ = new Vector(matrix2.getM());
        this.grad_ = new Vector(i);
        this.rng = randomGenerator;
        this.wi_ = matrix;
        this.wo_ = matrix2;
        this.osz_ = matrix2.getM();
        this.negpos = 0;
        this.loss_ = 0.0f;
        this.nexamples_ = 1L;
        initSigmoid();
        initLog();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Model setQuantizePointer(QMatrix qMatrix, QMatrix qMatrix2) {
        this.qwi_ = qMatrix;
        this.qwo_ = qMatrix2;
        if (this.qout) {
            this.osz_ = this.qwo_.getM();
        }
        return this;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RandomGenerator random() {
        return this.rng;
    }

    public Matrix input() {
        return this.wi_;
    }

    public Matrix output() {
        return this.wo_;
    }

    public QMatrix qinput() {
        return this.qwi_;
    }

    public QMatrix qoutput() {
        return this.qwo_;
    }

    public boolean isQuant() {
        return (this.qwi_ == null || this.qwi_.isEmpty()) ? false : true;
    }

    private float binaryLogistic(int i, boolean z, float f) {
        float sigmoid = sigmoid(this.wo_.dotRow(this.hidden_, i));
        float f2 = f * ((z ? 1 : 0) - sigmoid);
        this.grad_.addRow(this.wo_, i, f2);
        this.wo_.addRow(this.hidden_, i, f2);
        return z ? -log(sigmoid) : -log(1.0f - sigmoid);
    }

    private float negativeSampling(int i, float f) {
        float f2;
        float binaryLogistic;
        this.grad_.clear();
        float f3 = 0.0f;
        for (int i2 = 0; i2 <= this.neg; i2++) {
            if (i2 == 0) {
                f2 = f3;
                binaryLogistic = binaryLogistic(i, true, f);
            } else {
                f2 = f3;
                binaryLogistic = binaryLogistic(getNegative(i), false, f);
            }
            f3 = f2 + binaryLogistic;
        }
        return f3;
    }

    private float hierarchicalSoftmax(int i, float f) {
        float f2 = 0.0f;
        this.grad_.clear();
        List<Boolean> list = this.codes.get(i);
        List<Integer> list2 = this.paths.get(i);
        for (int i2 = 0; i2 < list2.size(); i2++) {
            f2 += binaryLogistic(list2.get(i2).intValue(), list.get(i2).booleanValue(), f);
        }
        return f2;
    }

    private void computeOutputSoftmax(Vector vector, Vector vector2) {
        if (isQuant() && this.qout) {
            vector2.mul(this.qwo_, vector);
        } else {
            vector2.mul(this.wo_, vector);
        }
        if (FastText.USE_PARALLEL_COMPUTATION && this.osz_ > PARALLEL_SIZE_THRESHOLD) {
            IntStream parallel = IntStream.range(0, this.osz_).parallel();
            vector2.getClass();
            double orElseThrow = parallel.mapToDouble(vector2::get).max().orElseThrow(() -> {
                return new IllegalStateException("Can't calc max");
            });
            AtomicDouble atomicDouble = new AtomicDouble();
            IntStream.range(0, this.osz_).parallel().forEach(i -> {
                double exp = FastMath.exp(vector2.get(i) - orElseThrow);
                vector2.set(i, (float) exp);
                atomicDouble.addAndGet(exp);
            });
            IntStream.range(0, this.osz_).parallel().forEach(i2 -> {
                vector2.set(i2, vector2.get(i2) / atomicDouble.floatValue());
            });
            return;
        }
        float f = vector2.get(0);
        float f2 = 0.0f;
        for (int i3 = 0; i3 < this.osz_; i3++) {
            f = FastMath.max(vector2.get(i3), f);
        }
        for (int i4 = 0; i4 < this.osz_; i4++) {
            vector2.set(i4, (float) FastMath.exp(vector2.get(i4) - f));
            f2 += vector2.get(i4);
        }
        for (int i5 = 0; i5 < this.osz_; i5++) {
            vector2.set(i5, vector2.get(i5) / f2);
        }
    }

    private void computeOutputSoftmax() {
        computeOutputSoftmax(this.hidden_, this.output_);
    }

    private float softmax(int i, float f) {
        this.grad_.clear();
        computeOutputSoftmax();
        IntStream.range(0, this.osz_).forEach(i2 -> {
            float f2 = f * ((i2 == i ? 1.0f : 0.0f) - this.output_.get(i2));
            this.grad_.addRow(this.wo_, i2, f2);
            this.wo_.addRow(this.hidden_, i2, f2);
        });
        return -log(this.output_.get(i));
    }

    private void computeHidden(List<Integer> list, Vector vector) {
        Validate.isTrue(vector.size() == this.dim, "Wrong size of hidden vector: " + vector.size() + "!=" + this.dim);
        vector.clear();
        list.forEach(num -> {
            if (isQuant()) {
                vector.addRow(this.qwi_, num.intValue());
            } else {
                vector.addRow(this.wi_, num.intValue());
            }
        });
        vector.mul(1.0f / list.size());
    }

    public TreeMultimap<Float, Integer> predict(List<Integer> list, int i, Vector vector, Vector vector2) {
        if (i <= 0) {
            throw new IllegalArgumentException("k needs to be 1 or higher!");
        }
        if (!Args.ModelName.SUP.equals(this.model)) {
            throw new IllegalArgumentException("Model needs to be supervised for prediction!");
        }
        TreeMultimap<Float, Integer> create = TreeMultimap.create(HEAP_PROBABILITY_COMPARATOR, HEAP_LABEL_COMPARATOR);
        computeHidden(list, vector);
        if (Args.LossName.HS == this.loss) {
            dfs(i, (2 * this.osz_) - 2, 0.0f, create, vector);
        } else {
            findKBest(i, create, vector, vector2);
        }
        return create;
    }

    public TreeMultimap<Float, Integer> predict(List<Integer> list, int i) {
        return predict(list, i, this.hidden_, this.output_);
    }

    private void findKBest(int i, TreeMultimap<Float, Integer> treeMultimap, Vector vector, Vector vector2) {
        computeOutputSoftmax(vector, vector2);
        for (int i2 = 0; i2 < this.osz_; i2++) {
            float stdLog = stdLog(vector2.get(i2));
            if (treeMultimap.size() != i || stdLog >= ((Float) treeMultimap.asMap().firstKey()).floatValue()) {
                put(treeMultimap, i, Float.valueOf(stdLog), Integer.valueOf(i2));
            }
        }
    }

    private <K, V> void put(TreeMultimap<K, V> treeMultimap, int i, K k, V v) {
        treeMultimap.put(k, v);
        if (treeMultimap.size() > i) {
            treeMultimap.get(treeMultimap.asMap().lastKey()).pollLast();
        }
    }

    private void dfs(int i, int i2, float f, TreeMultimap<Float, Integer> treeMultimap, Vector vector) {
        if (treeMultimap.size() != i || f >= ((Float) treeMultimap.asMap().firstKey()).floatValue()) {
            if (this.tree.get(i2).left == -1 && this.tree.get(i2).right == -1) {
                put(treeMultimap, i, Float.valueOf(f), Integer.valueOf(i2));
                return;
            }
            float exp = (float) (1.0d / (1.0d + FastMath.exp(-((isQuant() && this.qout) ? this.qwo_.dotRow(vector, i2 - this.osz_) : this.wo_.dotRow(vector, i2 - this.osz_)))));
            dfs(i, this.tree.get(i2).left, f + stdLog(1.0f - exp), treeMultimap, vector);
            dfs(i, this.tree.get(i2).right, f + stdLog(exp), treeMultimap, vector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void update(List<Integer> list, int i, float f) {
        Map emptyMap;
        Validate.isTrue(i >= 0);
        Validate.isTrue(i < this.osz_);
        if (list.isEmpty()) {
            return;
        }
        Events.MODEL_COMPUTE_HIDDEN.start();
        if (!FastText.USE_PARALLEL_COMPUTATION || list.size() <= PARALLEL_SIZE_THRESHOLD) {
            emptyMap = Collections.emptyMap();
            computeHidden(list, this.hidden_);
        } else {
            emptyMap = (Map) list.parallelStream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
            this.hidden_.clear();
            emptyMap.entrySet().parallelStream().forEach(entry -> {
                this.hidden_.addRow(isQuant() ? this.qwi_ : this.wi_, ((Integer) entry.getKey()).intValue(), (float) ((Long) entry.getValue()).longValue());
            });
            this.hidden_.mul(1.0f / list.size());
        }
        Events.MODEL_COMPUTE_HIDDEN.end();
        Events.MODEL_LOSS_CALC.start();
        if (Args.LossName.NS == this.loss) {
            this.loss_ += negativeSampling(i, f);
        } else if (Args.LossName.HS == this.loss) {
            this.loss_ += hierarchicalSoftmax(i, f);
        } else {
            this.loss_ += softmax(i, f);
        }
        Events.MODEL_LOSS_CALC.end();
        this.nexamples_++;
        Events.MODEL_GRAD_MUL.start();
        if (Args.ModelName.SUP == this.model) {
            this.grad_.mul(1.0f / list.size());
        }
        Events.MODEL_GRAD_MUL.end();
        Events.MODEL_INPUT_ADD_ROW.start();
        if (emptyMap.isEmpty()) {
            list.forEach(num -> {
                this.wi_.addRow(this.grad_, num.intValue(), 1.0f);
            });
        } else {
            emptyMap.entrySet().parallelStream().forEach(entry2 -> {
                this.wi_.addRow(this.grad_, ((Integer) entry2.getKey()).intValue(), (float) ((Long) entry2.getValue()).longValue());
            });
        }
        Events.MODEL_INPUT_ADD_ROW.end();
    }

    public void setTargetCounts(List<Long> list) {
        Validate.isTrue(list.size() == this.osz_);
        if (Args.LossName.NS == this.loss) {
            initTableNegatives(list);
        }
        if (Args.LossName.HS == this.loss) {
            buildTree(list);
        }
    }

    private void initTableNegatives(List<Long> list) {
        this.negatives = new ArrayList(list.size());
        if (!FastText.USE_PARALLEL_COMPUTATION || list.size() <= PARALLEL_SIZE_THRESHOLD) {
            double d = 0.0d;
            Iterator<Long> it = list.iterator();
            while (it.hasNext()) {
                d += FastMath.sqrt(it.next().longValue());
            }
            for (int i = 0; i < list.size(); i++) {
                double sqrt = (FastMath.sqrt(list.get(i).longValue()) * 1.0E7d) / d;
                for (int i2 = 0; i2 < sqrt; i2++) {
                    this.negatives.add(Integer.valueOf(i));
                }
            }
        } else {
            List synchronizedList = Collections.synchronizedList(this.negatives);
            double sum = list.parallelStream().mapToDouble((v0) -> {
                return FastMath.sqrt(v0);
            }).sum();
            IntStream.range(0, list.size()).parallel().forEach(i3 -> {
                double sqrt2 = (FastMath.sqrt(((Long) list.get(i3)).longValue()) * 1.0E7d) / sum;
                for (int i3 = 0; i3 < sqrt2; i3++) {
                    synchronizedList.add(Integer.valueOf(i3));
                }
            });
        }
        Collections.shuffle(this.negatives, new RandomAdaptor(this.rng));
    }

    private int getNegative(int i) {
        int intValue;
        do {
            intValue = this.negatives.get(this.negpos).intValue();
            this.negpos = (this.negpos + 1) % this.negatives.size();
        } while (i == intValue);
        return intValue;
    }

    private void buildTree(List<Long> list) {
        this.paths = new ArrayList(this.osz_);
        this.codes = new ArrayList(this.osz_);
        this.tree = new ArrayList((2 * this.osz_) - 1);
        for (int i = 0; i < (2 * this.osz_) - 1; i++) {
            Node node = new Node();
            node.parent = -1;
            node.left = -1;
            node.right = -1;
            node.count = 1000000000000000L;
            node.binary = false;
            this.tree.add(i, node);
        }
        for (int i2 = 0; i2 < this.osz_; i2++) {
            this.tree.get(i2).count = list.get(i2).longValue();
        }
        int i3 = this.osz_ - 1;
        int i4 = this.osz_;
        for (int i5 = this.osz_; i5 < (2 * this.osz_) - 1; i5++) {
            int[] iArr = new int[2];
            for (int i6 = 0; i6 < 2; i6++) {
                if (i3 < 0 || this.tree.get(i3).count >= this.tree.get(i4).count) {
                    int i7 = i4;
                    i4++;
                    iArr[i6] = i7;
                } else {
                    int i8 = i3;
                    i3--;
                    iArr[i6] = i8;
                }
            }
            this.tree.get(i5).left = iArr[0];
            this.tree.get(i5).right = iArr[1];
            this.tree.get(i5).count = this.tree.get(iArr[0]).count + this.tree.get(iArr[1]).count;
            this.tree.get(iArr[0]).parent = i5;
            this.tree.get(iArr[1]).parent = i5;
            this.tree.get(iArr[1]).binary = true;
        }
        for (int i9 = 0; i9 < this.osz_; i9++) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            int i10 = i9;
            while (true) {
                int i11 = i10;
                if (this.tree.get(i11).parent != -1) {
                    arrayList.add(Integer.valueOf(this.tree.get(i11).parent - this.osz_));
                    arrayList2.add(Boolean.valueOf(this.tree.get(i11).binary));
                    i10 = this.tree.get(i11).parent;
                }
            }
            this.paths.add(arrayList);
            this.codes.add(arrayList2);
        }
    }

    public float getLoss() {
        return this.loss_ / ((float) this.nexamples_);
    }

    private void initSigmoid() {
        this.t_sigmoid = new float[513];
        for (int i = 0; i < 513; i++) {
            this.t_sigmoid[i] = (float) (1.0d / (1.0d + FastMath.exp(-((((i * 2.0f) * 8.0f) / 512.0f) - 8.0f))));
        }
    }

    private void initLog() {
        this.t_log = new float[513];
        for (int i = 0; i < 513; i++) {
            this.t_log[i] = (float) FastMath.log((i + 1.0E-5f) / 512.0f);
        }
    }

    private float log(float f) {
        if (f > 1.0d) {
            return 0.0f;
        }
        return this.t_log[Ints.checkedCast(f * 512.0f)];
    }

    private float stdLog(float f) {
        return (float) FastMath.log(f + 1.0E-5d);
    }

    private float sigmoid(float f) {
        if (f < -8.0f) {
            return 0.0f;
        }
        if (f > 8.0f) {
            return 1.0f;
        }
        return this.t_sigmoid[Ints.checkedCast((((f + 8.0f) * 512.0f) / 8.0f) / 2.0f)];
    }
}
