package cc.fasttext;

import cc.fasttext.io.FTInputStream;
import cc.fasttext.io.FTOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang.Validate;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:cc/fasttext/Matrix.class */
public class Matrix {
    private static final int PARALLEL_SIZE_THRESHOLD = Integer.parseInt(System.getProperty("parallel.matrix.threshold", String.valueOf(FastText.PARALLEL_THRESHOLD_FACTOR * 100)));
    private float[][] data;
    protected int m;
    protected int n;

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix() {
    }

    public Matrix(int i, int i2) {
        Validate.isTrue(i > 0, "Wrong m-size: " + i);
        Validate.isTrue(i2 > 0, "Wrong n-size: " + i2);
        this.m = i;
        this.n = i2;
        this.data = new float[i][i2];
    }

    public Matrix copy() {
        Matrix matrix = new Matrix(this.m, this.n);
        for (int i = 0; i < this.m; i++) {
            System.arraycopy(this.data[i], 0, matrix.data[i], 0, this.n);
        }
        return matrix;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float[] flatData() {
        float[] fArr = new float[this.m * this.n];
        for (int i = 0; i < this.m; i++) {
            System.arraycopy(this.data[i], 0, fArr, i * this.n, this.n);
        }
        return fArr;
    }

    float[][] data() {
        return this.data;
    }

    public List<Vector> getData() {
        return Collections.unmodifiableList((List) Arrays.stream(this.data).map(Vector::new).collect(Collectors.toList()));
    }

    public boolean isEmpty() {
        return this.m == 0 || this.n == 0;
    }

    public int getM() {
        return this.m;
    }

    public int getN() {
        return this.n;
    }

    public long size() {
        return this.n * this.m;
    }

    public float get(int i, int i2) {
        validateMIndex(i);
        validateNIndex(i2);
        return at(i, i2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float at(int i, int i2) {
        return this.data[i][i2];
    }

    public void set(int i, int i2, float f) {
        validateMIndex(i);
        validateNIndex(i2);
        put(i, i2, f);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void put(int i, int i2, float f) {
        this.data[i][i2] = f;
    }

    public void compute(int i, int i2, DoubleUnaryOperator doubleUnaryOperator) {
        Objects.requireNonNull(doubleUnaryOperator, "Null operator");
        this.data[i][i2] = (float) doubleUnaryOperator.applyAsDouble(this.data[i][i2]);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void validateMIndex(int i) {
        Validate.isTrue(i >= 0 && i < this.m, "First index (" + i + ") is out of range [0, " + this.m + ")");
    }

    void validateNIndex(int i) {
        Validate.isTrue(i >= 0 && i < this.n, "Second index (" + i + ") is out of range [0, " + this.n + ")");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void validateNVector(Vector vector) {
        Validate.isTrue(((Vector) Objects.requireNonNull(vector, "Null vector")).size() == this.n, "Wrong vector size: " + vector.size() + " (!= " + this.n + ")");
    }

    void validateMVector(Vector vector) {
        Validate.isTrue(((Vector) Objects.requireNonNull(vector, "Null vector")).size() == this.m, "Wrong vector size: " + vector.size() + " (!= " + this.m + ")");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean isQuant() {
        return false;
    }

    public void uniform(RandomGenerator randomGenerator, float f) {
        UniformRealDistribution uniformRealDistribution = new UniformRealDistribution(randomGenerator, -f, f);
        for (int i = 0; i < this.m; i++) {
            for (int i2 = 0; i2 < this.n; i2++) {
                this.data[i][i2] = (float) uniformRealDistribution.sample();
            }
        }
    }

    public float dotRow(Vector vector, int i) {
        float f;
        validateMIndex(i);
        validateNVector(vector);
        if (!FastText.USE_PARALLEL_COMPUTATION || this.n <= PARALLEL_SIZE_THRESHOLD) {
            f = 0.0f;
            for (int i2 = 0; i2 < this.n; i2++) {
                f += this.data[i][i2] * vector.get(i2);
            }
        } else {
            f = (float) IntStream.range(0, this.n).parallel().mapToDouble(i3 -> {
                return this.data[i][i3] * vector.get(i3);
            }).sum();
        }
        if (Float.isNaN(f)) {
            throw new IllegalStateException("Encountered NaN.");
        }
        return f;
    }

    public void addRow(Vector vector, int i, float f) {
        validateMIndex(i);
        validateNVector(vector);
        if (FastText.USE_PARALLEL_COMPUTATION && this.n > PARALLEL_SIZE_THRESHOLD) {
            IntStream.range(0, this.n).parallel().forEach(i2 -> {
                float[] fArr = this.data[i];
                fArr[i2] = fArr[i2] + (f * vector.get(i2));
            });
            return;
        }
        for (int i3 = 0; i3 < this.n; i3++) {
            float[] fArr = this.data[i];
            int i4 = i3;
            fArr[i4] = fArr[i4] + (f * vector.get(i3));
        }
    }

    protected void multiplyRow(Vector vector, int i, int i2) {
        rowOp(vector, i, i2, (d, d2) -> {
            return d * d2;
        });
    }

    public void multiplyRow(Vector vector) {
        multiplyRow(vector, 0, -1);
    }

    protected void divideRow(Vector vector, int i, int i2) {
        rowOp(vector, i, i2, (d, d2) -> {
            return d / d2;
        });
    }

    public void divideRow(Vector vector) {
        divideRow(vector, 0, -1);
    }

    protected void rowOp(Vector vector, int i, int i2, DoubleBinaryOperator doubleBinaryOperator) {
        if (i2 == -1) {
            i2 = this.m;
        }
        Validate.isTrue(i2 <= vector.size());
        Validate.isTrue(i2 >= i);
        if (FastText.USE_PARALLEL_COMPUTATION && i2 - i > PARALLEL_SIZE_THRESHOLD) {
            IntStream.range(i, i2).parallel().forEach(i3 -> {
                vectorOp(vector, i3, doubleBinaryOperator, i);
            });
            return;
        }
        for (int i4 = i; i4 < i2; i4++) {
            vectorOp(vector, i4, doubleBinaryOperator, i);
        }
    }

    private void vectorOp(Vector vector, int i, DoubleBinaryOperator doubleBinaryOperator, int i2) {
        float f = vector.get(i - i2);
        if (f == 0.0f) {
            return;
        }
        if (FastText.USE_PARALLEL_COMPUTATION && this.n > PARALLEL_SIZE_THRESHOLD) {
            IntStream.range(0, this.n).parallel().forEach(i3 -> {
                this.data[i][i3] = (float) doubleBinaryOperator.applyAsDouble(this.data[i][i3], f);
            });
            return;
        }
        for (int i4 = 0; i4 < this.n; i4++) {
            this.data[i][i4] = (float) doubleBinaryOperator.applyAsDouble(this.data[i][i4], f);
        }
    }

    private float l2NormRow(int i) {
        float f;
        if (!FastText.USE_PARALLEL_COMPUTATION || this.n <= PARALLEL_SIZE_THRESHOLD) {
            f = 0.0f;
            for (int i2 = 0; i2 < this.n; i2++) {
                float at = at(i, i2);
                f += at * at;
            }
        } else {
            f = (float) IntStream.range(0, this.n).parallel().mapToDouble(i3 -> {
                return this.data[i][i3] * this.data[i][i3];
            }).sum();
        }
        if (Float.isNaN(f)) {
            throw new IllegalStateException("Encountered NaN.");
        }
        return (float) FastMath.sqrt(f);
    }

    public Vector l2NormRow() {
        Vector vector = new Vector(this.m);
        IntStream range = IntStream.range(0, this.m);
        if (FastText.USE_PARALLEL_COMPUTATION && this.m > PARALLEL_SIZE_THRESHOLD) {
            range = range.parallel();
        }
        range.forEach(i -> {
            vector.set(i, l2NormRow(i));
        });
        return vector;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void save(FTOutputStream fTOutputStream) throws IOException {
        fTOutputStream.writeLong(this.m);
        fTOutputStream.writeLong(this.n);
        for (int i = 0; i < this.m; i++) {
            for (int i2 = 0; i2 < this.n; i2++) {
                fTOutputStream.writeFloat(this.data[i][i2]);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Matrix load(FTInputStream fTInputStream) throws IOException {
        Matrix matrix = new Matrix((int) fTInputStream.readLong(), (int) fTInputStream.readLong());
        for (int i = 0; i < matrix.m; i++) {
            for (int i2 = 0; i2 < matrix.n; i2++) {
                matrix.data[i][i2] = fTInputStream.readFloat();
            }
        }
        return matrix;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Matrix empty() {
        return new Matrix();
    }

    public String toString() {
        return String.format("%s[m=%d, n=%d]", getClass().getSimpleName(), Integer.valueOf(this.m), Integer.valueOf(this.n));
    }
}
