package cc.fasttext;

import cc.fasttext.io.FTInputStream;
import cc.fasttext.io.FTOutputStream;
import com.google.common.primitives.Bytes;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.RandomAdaptor;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:cc/fasttext/ProductQuantizer.class */
public class ProductQuantizer {
    private static final int NBITS = 8;
    private static final int KSUB = 256;
    private static final int MAX_POINTS_PER_CLUSTER = 256;
    private static final int MAX_POINTS = 65536;
    private static final int SEED = 1234;
    private static final int NITER = 25;
    private static final float EPS = 1.0E-7f;
    private int dim_;
    private int nsubq_;
    private int dsub_;
    private int lastdsub_;
    private List<Float> centroids_;
    private RandomGenerator rng;

    private ProductQuantizer(IntFunction<RandomGenerator> intFunction) {
        this.rng = intFunction.apply(SEED);
    }

    public ProductQuantizer(IntFunction<RandomGenerator> intFunction, int i, int i2) {
        this(intFunction);
        this.dim_ = i;
        this.nsubq_ = i / i2;
        this.dsub_ = i2;
        this.centroids_ = asFloatList(new float[i * 256]);
        this.lastdsub_ = this.dim_ % i2;
        if (this.lastdsub_ == 0) {
            this.lastdsub_ = this.dsub_;
        } else {
            this.nsubq_++;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public List<Float> getCentroids(int i, byte b) {
        int unsignedInt = Byte.toUnsignedInt(b);
        return shiftFloats(this.centroids_, i == this.nsubq_ - 1 ? (i * 256 * this.dsub_) + (unsignedInt * this.lastdsub_) : ((i * 256) + unsignedInt) * this.dsub_);
    }

    List<Float> getCentroids() {
        return this.centroids_;
    }

    private float distL2(List<Float> list, List<Float> list2, int i) {
        float f = 0.0f;
        for (int i2 = 0; i2 < i; i2++) {
            float f2 = getFloat(list, i2) - getFloat(list2, i2);
            f += f2 * f2;
        }
        return f;
    }

    private float assignCentroid(List<Float> list, List<Float> list2, List<Byte> list3, int i) {
        float distL2 = distL2(list, list2, i);
        list3.set(0, (byte) 0);
        for (int i2 = 1; i2 < 256; i2++) {
            list2 = shiftFloats(list2, i);
            float distL22 = distL2(list, list2, i);
            if (distL22 < distL2) {
                list3.set(0, Byte.valueOf((byte) i2));
                distL2 = distL22;
            }
        }
        return distL2;
    }

    private void eStep(float[] fArr, List<Float> list, List<Byte> list2, int i, int i2) {
        List<Float> asFloatList = asFloatList(fArr);
        for (int i3 = 0; i3 < i2; i3++) {
            assignCentroid(shiftFloats(asFloatList, i3 * i), list, shiftBytes(list2, i3), i);
        }
    }

    private void mStep(float[] fArr, List<Float> list, List<Byte> list2, int i, int i2) {
        int i3;
        List<Integer> asIntList = asIntList(new int[256]);
        IntStream.range(0, i * 256).forEach(i4 -> {
        });
        List<Float> asFloatList = asFloatList(fArr);
        for (int i5 = 0; i5 < i2; i5++) {
            int unsignedInt = Byte.toUnsignedInt(list2.get(i5).byteValue());
            List<Float> shiftFloats = shiftFloats(list, unsignedInt * i);
            for (int i6 = 0; i6 < i; i6++) {
                shiftFloats.set(i6, Float.valueOf(shiftFloats.get(i6).floatValue() + asFloatList.get(i6).floatValue()));
            }
            asIntList.set(unsignedInt, Integer.valueOf(asIntList.get(unsignedInt).intValue() + 1));
            asFloatList = shiftFloats(asFloatList, i);
        }
        List<Float> list3 = list;
        for (int i7 = 0; i7 < 256; i7++) {
            float intValue = asIntList.get(i7).intValue();
            if (intValue != 0.0f) {
                for (int i8 = 0; i8 < i; i8++) {
                    list3.set(i8, Float.valueOf(list3.get(i8).floatValue() / intValue));
                }
            }
            list3 = shiftFloats(list3, i);
        }
        UniformRealDistribution uniformRealDistribution = new UniformRealDistribution(this.rng, 0.0d, 1.0d);
        for (int i9 = 0; i9 < 256; i9++) {
            if (asIntList.get(i9).intValue() == 0) {
                int i10 = 0;
                while (true) {
                    i3 = i10;
                    if (uniformRealDistribution.sample() * (i2 - 256) < asIntList.get(i3).intValue() - 1) {
                        break;
                    } else {
                        i10 = (i3 + 1) % 256;
                    }
                }
                int i11 = i9 * i;
                int i12 = i3 * i;
                for (int i13 = 0; i13 < i; i13++) {
                    list.set(i13 + i11, list.get(i13 + i12));
                }
                for (int i14 = 0; i14 < i; i14++) {
                    float f = (((i14 % 2) * 2) - 1) * EPS;
                    list.set(i14 + i11, Float.valueOf(list.get(i14 + i11).floatValue() + f));
                    list.set(i14 + i12, Float.valueOf(list.get(i14 + i12).floatValue() - f));
                }
                asIntList.set(i9, Integer.valueOf(asIntList.get(i3).intValue() / 2));
                asIntList.set(i3, Integer.valueOf(asIntList.get(i3).intValue() - asIntList.get(i9).intValue()));
            }
        }
    }

    public void train(int i, float[] fArr) {
        if (i < 256) {
            throw new IllegalArgumentException("Matrix too small for quantization, must have > 256 rows");
        }
        List list = (List) LongStream.iterate(0L, j -> {
            return j + 1;
        }).limit(i).boxed().collect(Collectors.toList());
        int i2 = this.dsub_;
        int min = FastMath.min(i, MAX_POINTS);
        float[] fArr2 = new float[min * this.dsub_];
        for (int i3 = 0; i3 < this.nsubq_; i3++) {
            if (i3 == this.nsubq_ - 1) {
                i2 = this.lastdsub_;
            }
            if (min != i) {
                Collections.shuffle(list, new RandomAdaptor(this.rng));
            }
            for (int i4 = 0; i4 < min; i4++) {
                long longValue = (((Long) list.get(i4)).longValue() * this.dim_) + (i3 * this.dsub_);
                if (longValue > 2147483647L) {
                    throw new ArrayStoreException("Source start index too big : " + longValue);
                }
                int i5 = i4 * i2;
                try {
                    System.arraycopy(fArr, (int) longValue, fArr2, i5, i2);
                } catch (ArrayIndexOutOfBoundsException | ArrayStoreException e) {
                    throw new IllegalArgumentException("Can't copy arrays: data.length=" + fArr.length + ", src-pos=" + longValue + ", xslice.length=" + fArr2.length + ", dst-pos=" + i5, e);
                }
            }
            kmeans(fArr2, getCentroids(i3, (byte) 0), min, i2);
        }
    }

    private void kmeans(float[] fArr, List<Float> list, int i, int i2) {
        List list2 = (List) IntStream.iterate(0, i3 -> {
            return i3 + 1;
        }).limit(i).boxed().collect(Collectors.toList());
        Collections.shuffle(list2, new RandomAdaptor(this.rng));
        for (int i4 = 0; i4 < 256; i4++) {
            int i5 = i4 * i2;
            int intValue = ((Integer) list2.get(i4)).intValue() * i2;
            for (int i6 = 0; i6 < i2; i6++) {
                list.set(i6 + i5, Float.valueOf(fArr[intValue + i6]));
            }
        }
        List<Byte> asByteList = asByteList(new byte[i]);
        for (int i7 = 0; i7 < NITER; i7++) {
            eStep(fArr, list, asByteList, i2, i);
            mStep(fArr, list, asByteList, i2, i);
        }
    }

    private void computeCode(List<Float> list, List<Byte> list2) {
        int i = this.dsub_;
        for (int i2 = 0; i2 < this.nsubq_; i2++) {
            if (i2 == this.nsubq_ - 1) {
                i = this.lastdsub_;
            }
            assignCentroid(shiftFloats(list, i2 * this.dsub_), getCentroids(i2, (byte) 0), shiftBytes(list2, i2), i);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void computeCodes(float[] fArr, byte[] bArr, int i) {
        List<Float> asFloatList = asFloatList(fArr);
        List<Byte> asByteList = asByteList(bArr);
        for (int i2 = 0; i2 < i; i2++) {
            computeCode(shiftFloats(asFloatList, i2 * this.dim_), shiftBytes(asByteList, i2 * this.nsubq_));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float mulCode(Vector vector, byte[] bArr, int i, float f) {
        return mulCode(vector.data(), bArr, i) * f;
    }

    private float mulCode(float[] fArr, byte[] bArr, int i) {
        float f = 0.0f;
        int i2 = this.dsub_;
        List<Byte> shiftBytes = shiftBytes(asByteList(bArr), this.nsubq_ * i);
        for (int i3 = 0; i3 < this.nsubq_; i3++) {
            List<Float> centroids = getCentroids(i3, shiftBytes.get(i3).byteValue());
            if (i3 == this.nsubq_ - 1) {
                i2 = this.lastdsub_;
            }
            for (int i4 = 0; i4 < i2; i4++) {
                f += fArr[(i3 * this.dsub_) + i4] * centroids.get(i4).floatValue();
            }
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addCode(Vector vector, byte[] bArr, int i, float f) {
        addCode(vector.data(), bArr, i, f);
    }

    private void addCode(float[] fArr, byte[] bArr, int i, float f) {
        int i2 = this.dsub_;
        List<Byte> shiftBytes = shiftBytes(asByteList(bArr), this.nsubq_ * i);
        for (int i3 = 0; i3 < this.nsubq_; i3++) {
            List<Float> centroids = getCentroids(i3, shiftBytes.get(i3).byteValue());
            if (i3 == this.nsubq_ - 1) {
                i2 = this.lastdsub_;
            }
            for (int i4 = 0; i4 < i2; i4++) {
                int i5 = (i3 * this.dsub_) + i4;
                fArr[i5] = fArr[i5] + (f * centroids.get(i4).floatValue());
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void save(FTOutputStream fTOutputStream) throws IOException {
        fTOutputStream.writeInt(this.dim_);
        fTOutputStream.writeInt(this.nsubq_);
        fTOutputStream.writeInt(this.dsub_);
        fTOutputStream.writeInt(this.lastdsub_);
        Iterator<Float> it = this.centroids_.iterator();
        while (it.hasNext()) {
            fTOutputStream.writeFloat(it.next().floatValue());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static ProductQuantizer load(IntFunction<RandomGenerator> intFunction, FTInputStream fTInputStream) throws IOException {
        ProductQuantizer productQuantizer = new ProductQuantizer(intFunction);
        productQuantizer.dim_ = fTInputStream.readInt();
        productQuantizer.nsubq_ = fTInputStream.readInt();
        productQuantizer.dsub_ = fTInputStream.readInt();
        productQuantizer.lastdsub_ = fTInputStream.readInt();
        productQuantizer.centroids_ = asFloatList(new float[productQuantizer.dim_ * 256]);
        for (int i = 0; i < productQuantizer.centroids_.size(); i++) {
            productQuantizer.centroids_.set(i, Float.valueOf(fTInputStream.readFloat()));
        }
        return productQuantizer;
    }

    public static List<Byte> asByteList(byte... bArr) {
        return Bytes.asList(bArr);
    }

    public static List<Integer> asIntList(int... iArr) {
        return Ints.asList(iArr);
    }

    public static List<Float> asFloatList(float... fArr) {
        return Floats.asList(fArr);
    }

    private float getFloat(List<Float> list, int i) {
        if (i >= list.size()) {
            return Float.NaN;
        }
        return list.get(i).floatValue();
    }

    private List<Byte> shiftBytes(List<Byte> list, int i) {
        return shift(list, i);
    }

    private List<Float> shiftFloats(List<Float> list, int i) {
        return shift(list, i);
    }

    public static <T> List<T> shift(List<T> list, int i) {
        return i >= list.size() ? Collections.emptyList() : list.subList(i, list.size());
    }
}
