package cc.fasttext;

import cc.fasttext.Args;
import cc.fasttext.Dictionary;
import cc.fasttext.io.FTInputStream;
import cc.fasttext.io.FTOutputStream;
import cc.fasttext.io.FormatUtils;
import cc.fasttext.io.IOStreams;
import cc.fasttext.io.PrintLogs;
import cc.fasttext.io.impl.LocalIOStreams;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.TreeMultimap;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.UncheckedIOException;
import java.lang.ref.Reference;
import java.lang.ref.SoftReference;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.Spliterators;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.Validate;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.util.FastMath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:cc/fasttext/FastText.class */
public class FastText {
    public static final int FASTTEXT_VERSION = 12;
    public static final int FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314;
    public static final boolean USE_PARALLEL_COMPUTATION = Boolean.parseBoolean(System.getProperty("parallel", "true"));
    static final int PARALLEL_THRESHOLD_FACTOR = Integer.parseInt(System.getProperty("parallel.factor", "100"));
    private static final Logger LOGGER = LoggerFactory.getLogger(FastText.class);
    public static final Factory DEFAULT_FACTORY = new Factory(new LocalIOStreams(), Well19937c::new, new SimpleLogger(), StandardCharsets.UTF_8);
    private static final double FIND_NN_THRESHOLD = 1.0E-8d;
    private final Args args;
    private final Dictionary dict;
    private final Model model;
    private final int version;
    private final IOStreams fs;
    private final PrintLogs logs;
    private final IntFunction<RandomGenerator> random;
    private Reference<Matrix> precomputedWordVectors;

    /* loaded from: input_file:cc/fasttext/FastText$Factory.class */
    public static class Factory {
        public static final Locale LOCALE = Locale.ENGLISH;
        public static final int BUFF_SIZE = 8192;
        private final IOStreams fs;
        private final PrintLogs logs;
        private final IntFunction<RandomGenerator> random;
        private final Charset charset;

        /* JADX INFO: Access modifiers changed from: protected */
        /* loaded from: input_file:cc/fasttext/FastText$Factory$Trainer.class */
        public class Trainer {
            private final String file;
            private final long size;
            private final Args args;
            private final Dictionary dictionary;
            private final Matrix input;
            private final Matrix output;
            private Instant start;
            private AtomicLong tokenCount;

            protected Trainer(Args args, String str, long j, Dictionary dictionary, Matrix matrix, Matrix matrix2) {
                this.args = (Args) Objects.requireNonNull(args, "Null args");
                this.file = (String) Objects.requireNonNull(str, "Null file");
                this.size = j;
                this.dictionary = (Dictionary) Objects.requireNonNull(dictionary, "Null dictionary");
                this.input = (Matrix) Objects.requireNonNull(matrix, "Null input matrix");
                this.output = (Matrix) Objects.requireNonNull(matrix2, "Null output matrix");
            }

            protected Dictionary.SeekableReader createReader() throws IOException {
                return this.dictionary.createReader(Factory.this.fs.openScrollable(this.file));
            }

            public Model train() throws IOException, ExecutionException, IllegalArgumentException {
                perform();
                Events.CREATE_RES_MODEL.start();
                try {
                    Model createModel = Factory.this.createModel(this.args, this.dictionary, this.input, this.output, 0);
                    Events.CREATE_RES_MODEL.end();
                    return createModel;
                } catch (Throwable th) {
                    Events.CREATE_RES_MODEL.end();
                    throw th;
                }
            }

            protected void perform() throws ExecutionException, IOException {
                this.start = Instant.now();
                this.tokenCount = new AtomicLong(0L);
                if (this.args.thread() <= 1) {
                    trainThread(0);
                    return;
                }
                ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.args.thread(), runnable -> {
                    Thread newThread = Executors.defaultThreadFactory().newThread(runnable);
                    newThread.setDaemon(true);
                    return newThread;
                });
                ExecutorCompletionService executorCompletionService = new ExecutorCompletionService(newFixedThreadPool);
                IntStream.range(0, this.args.thread()).forEach(i -> {
                    executorCompletionService.submit(() -> {
                        Thread.currentThread().setName("FT-TrainThread-" + i);
                        trainThread(i);
                        return null;
                    });
                });
                newFixedThreadPool.shutdown();
                int thread = this.args.thread();
                while (true) {
                    try {
                        try {
                            int i2 = thread;
                            thread--;
                            if (i2 <= 0) {
                                return;
                            } else {
                                executorCompletionService.take().get();
                            }
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                            newFixedThreadPool.shutdownNow();
                            return;
                        }
                    } finally {
                        newFixedThreadPool.shutdownNow();
                    }
                }
            }

            protected void trainThread(int i) throws IOException {
                Dictionary.SeekableReader createReader = createReader();
                Throwable th = null;
                try {
                    try {
                        long thread = (i * this.size) / this.args.thread();
                        Events.FILE_SEEK.start();
                        createReader.seek(thread);
                        Events.FILE_SEEK.end();
                        Model createModel = Factory.this.createModel(this.args, this.dictionary, this.input, this.output, i);
                        long epoch = this.args.epoch() * this.dictionary.ntokens();
                        long j = 0;
                        ArrayList arrayList = new ArrayList();
                        ArrayList arrayList2 = new ArrayList();
                        while (this.tokenCount.longValue() < epoch) {
                            float floatValue = this.tokenCount.floatValue() / ((float) epoch);
                            float lr = (float) (this.args.lr() * (1.0f - floatValue));
                            if (Args.ModelName.SUP == this.args.model()) {
                                Events.DIC_GET_LINE.start();
                                j += this.dictionary.getLine(createReader, arrayList, arrayList2);
                                Events.DIC_GET_LINE.end();
                                Events.TRAIN_CALC.start();
                                supervised(createModel, lr, arrayList, arrayList2);
                                Events.TRAIN_CALC.end();
                            } else if (Args.ModelName.CBOW == this.args.model()) {
                                Events.DIC_GET_LINE.start();
                                j += this.dictionary.getLine(createReader, arrayList, createModel.random());
                                Events.DIC_GET_LINE.end();
                                Events.TRAIN_CALC.start();
                                cbow(createModel, lr, arrayList);
                                Events.TRAIN_CALC.end();
                            } else if (Args.ModelName.SG == this.args.model()) {
                                Events.DIC_GET_LINE.start();
                                j += this.dictionary.getLine(createReader, arrayList, createModel.random());
                                Events.DIC_GET_LINE.end();
                                Events.TRAIN_CALC.start();
                                skipgram(createModel, lr, arrayList);
                                Events.TRAIN_CALC.end();
                            }
                            if (j > this.args.lrUpdateRate()) {
                                this.tokenCount.addAndGet(j);
                                j = 0;
                                if (i == 0 && Factory.this.logs.isDebugEnabled()) {
                                    Factory.this.logs.debug(progressMessage(floatValue, createModel.getLoss()), new Object[0]);
                                }
                            }
                        }
                        if (createReader != null) {
                            if (0 != 0) {
                                try {
                                    createReader.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                createReader.close();
                            }
                        }
                        if (Factory.this.logs.isInfoEnabled() && i == 0) {
                            Factory.this.logs.infoln(progressMessage(1.0f, createModel.getLoss()), new Object[0]);
                        }
                    } catch (Throwable th3) {
                        th = th3;
                        throw th3;
                    }
                } catch (Throwable th4) {
                    if (createReader != null) {
                        if (th != null) {
                            try {
                                createReader.close();
                            } catch (Throwable th5) {
                                th.addSuppressed(th5);
                            }
                        } else {
                            createReader.close();
                        }
                    }
                    throw th4;
                }
            }

            protected String progressMessage(float f, float f2) {
                float between = ((float) ChronoUnit.NANOS.between(this.start, Instant.now())) / 1.0E9f;
                float f3 = ((float) this.tokenCount.get()) / between;
                int thread = (int) (((between / f) * (1.0f - f)) / this.args.thread());
                int i = thread / 3600;
                return String.format(Factory.LOCALE, "\rProgress: %.1f%%  words/sec/thread: %.0f  lr: %.6f  loss: %.6f  eta: %dh%dm ", Float.valueOf(100.0f * f), Float.valueOf(f3), Float.valueOf((float) (this.args.lr() * (1.0f - f))), Float.valueOf(f2), Integer.valueOf(i), Integer.valueOf((thread - (i * 3600)) / 60));
            }

            protected void supervised(Model model, float f, List<Integer> list, List<Integer> list2) {
                if (list2.isEmpty() || list.isEmpty()) {
                    return;
                }
                int sample = new UniformIntegerDistribution(model.random(), 0, list2.size() - 1).sample();
                Events.MODEL_UPDATE.start();
                model.update(list, list2.get(sample).intValue(), f);
                Events.MODEL_UPDATE.end();
            }

            protected void cbow(Model model, float f, List<Integer> list) {
                int i;
                UniformIntegerDistribution uniformIntegerDistribution = new UniformIntegerDistribution(model.random(), 1, this.args.ws());
                for (int i2 = 0; i2 < list.size(); i2++) {
                    ArrayList arrayList = new ArrayList();
                    int sample = uniformIntegerDistribution.sample();
                    for (int i3 = -sample; i3 <= sample; i3++) {
                        if (i3 != 0 && (i = i2 + i3) >= 0 && i < list.size()) {
                            arrayList.addAll(this.dictionary.getSubwords(list.get(i).intValue()));
                        }
                    }
                    Events.MODEL_UPDATE.start();
                    model.update(arrayList, list.get(i2).intValue(), f);
                    Events.MODEL_UPDATE.end();
                }
            }

            protected void skipgram(Model model, float f, List<Integer> list) {
                int i;
                UniformIntegerDistribution uniformIntegerDistribution = new UniformIntegerDistribution(model.random(), 1, this.args.ws());
                for (int i2 = 0; i2 < list.size(); i2++) {
                    int sample = uniformIntegerDistribution.sample();
                    List<Integer> subwords = this.dictionary.getSubwords(list.get(i2).intValue());
                    for (int i3 = -sample; i3 <= sample; i3++) {
                        if (i3 != 0 && (i = i2 + i3) >= 0 && i < list.size()) {
                            Events.MODEL_UPDATE.start();
                            model.update(subwords, list.get(i).intValue(), f);
                            Events.MODEL_UPDATE.end();
                        }
                    }
                }
            }
        }

        public Factory(IOStreams iOStreams, IntFunction<RandomGenerator> intFunction, PrintLogs printLogs, Charset charset) {
            this.fs = (IOStreams) Objects.requireNonNull(iOStreams, "Null io-factory.");
            this.random = (IntFunction) Objects.requireNonNull(intFunction, "Null random-factory.");
            this.logs = (PrintLogs) Objects.requireNonNull(printLogs, "Null logs.");
            this.charset = (Charset) Objects.requireNonNull(charset, "Null charset.");
        }

        public Factory setFileSystem(IOStreams iOStreams) {
            return new Factory(iOStreams, this.random, this.logs, this.charset);
        }

        public Factory setLogs(PrintLogs printLogs) {
            return new Factory(this.fs, this.random, printLogs, this.charset);
        }

        public Factory setRandom(IntFunction<RandomGenerator> intFunction) {
            return new Factory(this.fs, intFunction, this.logs, this.charset);
        }

        public IOStreams getFileSystem() {
            return this.fs;
        }

        public PrintLogs getLogs() {
            return this.logs;
        }

        public IntFunction<RandomGenerator> getRandom() {
            return this.random;
        }

        public Charset getCharset() {
            return this.charset;
        }

        public FastText load(String str) throws IOException, IllegalArgumentException {
            if (!this.fs.canRead((String) Objects.requireNonNull(str, "Null file ref specified."))) {
                throw new IllegalArgumentException("Model file cannot be opened for loading: <" + str + Dictionary.EOW);
            }
            try {
                InputStream openInput = this.fs.openInput(str);
                Throwable th = null;
                try {
                    try {
                        this.logs.debug("Load model %s ... ", str);
                        FastText load = load(openInput);
                        this.logs.debugln("done.", new Object[0]);
                        if (openInput != null) {
                            if (0 != 0) {
                                try {
                                    openInput.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                openInput.close();
                            }
                        }
                        return load;
                    } finally {
                    }
                } finally {
                }
            } catch (Exception e) {
                this.logs.infoln("error: %s", e);
                throw e;
            }
        }

        public FastText load(InputStream inputStream) throws IOException, IllegalArgumentException {
            QMatrix empty;
            Matrix load;
            QMatrix empty2;
            Matrix load2;
            FTInputStream fTInputStream = new FTInputStream(new BufferedInputStream(inputStream));
            if (793712314 != fTInputStream.readInt()) {
                throw new IllegalArgumentException("Model file has wrong format!");
            }
            int readInt = fTInputStream.readInt();
            if (readInt > 12) {
                throw new IllegalArgumentException("Model file has wrong format!");
            }
            Args load3 = Args.load(fTInputStream);
            if (readInt == 11 && load3.model() == Args.ModelName.SUP) {
                load3 = new Args.Builder().copy(load3).setMaxN(0).build();
            }
            Dictionary load4 = Dictionary.load(load3, this.charset, fTInputStream);
            boolean readBoolean = fTInputStream.readBoolean();
            if (readBoolean) {
                empty = QMatrix.load(this.random, fTInputStream);
                load = Matrix.empty();
            } else {
                empty = QMatrix.empty();
                load = Matrix.load(fTInputStream);
            }
            if (!readBoolean && load4.isPruned()) {
                throw new IllegalArgumentException("Invalid model file.\nPlease download the updated model from www.fasttext.cc.\nSee issue #332 on Github for more information.\n");
            }
            Args build = new Args.Builder().copy(load3).setQOut(fTInputStream.readBoolean()).build();
            if (readBoolean && build.qout()) {
                empty2 = QMatrix.load(this.random, fTInputStream);
                load2 = Matrix.empty();
            } else {
                empty2 = QMatrix.empty();
                load2 = Matrix.load(fTInputStream);
            }
            return createFastText(build, load4, createModel(build, load4, load, load2, 0).setQuantizePointer(empty, empty2), readInt);
        }

        protected Matrix loadInput(Args args, Dictionary dictionary, String str) throws IOException, IllegalArgumentException {
            if (!this.fs.canRead(str)) {
                throw new IllegalArgumentException("Pre-trained vectors file cannot be opened!");
            }
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(this.fs.openInput(str), this.charset));
            Throwable th = null;
            try {
                String readLine = bufferedReader.readLine();
                if (!readLine.matches("\\d+\\s+\\d+")) {
                    throw new IllegalArgumentException("Wrong pre-trained vectors file: first line should contain 'n dim' pair");
                }
                int parseInt = Integer.parseInt(readLine.split("\\s+")[0]);
                int parseInt2 = Integer.parseInt(readLine.split("\\s+")[1]);
                if (parseInt2 != args.dim()) {
                    throw new IllegalArgumentException("Dimension of pretrained vectors does not match -dim option: found " + parseInt2 + ", expected " + args.dim());
                }
                Matrix matrix = new Matrix(parseInt, parseInt2);
                ArrayList arrayList = new ArrayList(parseInt);
                for (int i = 0; i < parseInt; i++) {
                    String readLine2 = bufferedReader.readLine();
                    if (!StringUtils.isEmpty(readLine2)) {
                        String[] split = readLine2.split("\\s+");
                        if (split.length != 0) {
                            String str2 = split[0];
                            if (!StringUtils.isEmpty(str2)) {
                                List list = (List) Arrays.stream(split).skip(1L).limit(parseInt2 + 1).map(Float::parseFloat).collect(Collectors.toList());
                                if (list.size() < parseInt2) {
                                    throw new IllegalArgumentException("Wrong numbers in the line: " + list.size() + ". Expected " + parseInt2);
                                }
                                arrayList.add(str2);
                                dictionary.add(str2);
                                for (int i2 = 0; i2 < list.size(); i2++) {
                                    matrix.set(i, i2, ((Float) list.get(i2)).floatValue());
                                }
                            }
                        }
                    }
                    throw new IllegalArgumentException("Wrong line: " + readLine2);
                }
                dictionary.threshold(1L, 0L);
                Matrix matrix2 = new Matrix(dictionary.nwords() + args.bucket(), args.dim());
                matrix2.uniform(this.random.apply(1), 1.0f / args.dim());
                for (int i3 = 0; i3 < parseInt; i3++) {
                    int id = dictionary.getId((String) arrayList.get(i3));
                    if (id >= 0 && id < dictionary.nwords()) {
                        for (int i4 = 0; i4 < parseInt2; i4++) {
                            matrix2.set(id, i4, matrix.get(i3, i4));
                        }
                    }
                }
                return matrix2;
            } finally {
                if (bufferedReader != null) {
                    if (0 != 0) {
                        try {
                            bufferedReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedReader.close();
                    }
                }
            }
        }

        protected Dictionary readDictionary(Args args, String str) throws IOException {
            InputStream openInput = this.fs.openInput(str);
            Throwable th = null;
            try {
                try {
                    Dictionary read = Dictionary.read(openInput, args, this.charset, this.logs);
                    if (openInput != null) {
                        if (0 != 0) {
                            try {
                                openInput.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            openInput.close();
                        }
                    }
                    return read;
                } finally {
                }
            } catch (Throwable th3) {
                if (openInput != null) {
                    if (th != null) {
                        try {
                            openInput.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        openInput.close();
                    }
                }
                throw th3;
            }
        }

        protected Matrix createInput(Args args, Dictionary dictionary) {
            Matrix matrix = new Matrix(dictionary.nwords() + args.bucket(), args.dim());
            matrix.uniform(this.random.apply(1), 1.0f / args.dim());
            return matrix;
        }

        protected Matrix createOutput(Args args, Dictionary dictionary) {
            return Args.ModelName.SUP.equals(args.model()) ? new Matrix(dictionary.nlabels(), args.dim()) : new Matrix(dictionary.nwords(), args.dim());
        }

        protected Trainer newTrainer(Args args, String str, String str2) throws IOException {
            if (!this.fs.canRead((String) Objects.requireNonNull(str, "Null data file specified"))) {
                throw new IllegalArgumentException("Input file cannot be opened: " + str);
            }
            Events.GET_FILE_SIZE.start();
            long size = this.fs.size(str);
            Events.GET_FILE_SIZE.end();
            Events.READ_DICT.start();
            Dictionary readDictionary = readDictionary(args, str);
            Events.READ_DICT.end();
            Events.IN_MATRIX_CREATE.start();
            Matrix createInput = str2 == null ? createInput(args, readDictionary) : loadInput(args, readDictionary, str2);
            Events.IN_MATRIX_CREATE.end();
            Events.OUT_MATRIX_CREATE.start();
            Matrix createOutput = createOutput(args, readDictionary);
            Events.OUT_MATRIX_CREATE.end();
            return new Trainer(args, str, size, readDictionary, createInput, createOutput);
        }

        protected Trainer newTrainer(Args args, String str, Dictionary dictionary, Matrix matrix, Matrix matrix2) throws IOException {
            if (this.fs.canRead((String) Objects.requireNonNull(str, "Null data file specified"))) {
                return new Trainer(args, str, this.fs.size(str), dictionary, matrix, matrix2);
            }
            throw new IllegalArgumentException("Input file cannot be opened: " + str);
        }

        public FastText train(Args args, String str, String str2) throws IOException, ExecutionException {
            Events.TRAIN.start();
            try {
                Trainer newTrainer = newTrainer(args, str, str2);
                FastText createFastText = createFastText(args, newTrainer.dictionary, newTrainer.train(), 12);
                Events.TRAIN.end();
                return createFastText;
            } catch (Throwable th) {
                Events.TRAIN.end();
                throw th;
            }
        }

        protected Model createModel(Args args, Dictionary dictionary, Matrix matrix, Matrix matrix2, int i) {
            Model model = new Model(matrix, matrix2, args, this.random.apply(i));
            if (Args.ModelName.SUP.equals(args.model())) {
                model.setTargetCounts(dictionary.getCounts(Dictionary.EntryType.LABEL));
            } else {
                model.setTargetCounts(dictionary.getCounts(Dictionary.EntryType.WORD));
            }
            return model;
        }

        protected FastText createFastText(Args args, Dictionary dictionary, Model model, int i) {
            return new FastText(args, dictionary, model, i, this.fs, this.logs, this.random);
        }
    }

    /* loaded from: input_file:cc/fasttext/FastText$SimpleLogger.class */
    private static class SimpleLogger implements PrintLogs {
        private SimpleLogger() {
        }

        @Override // cc.fasttext.io.PrintLogs
        public boolean isTraceEnabled() {
            return FastText.LOGGER.isTraceEnabled();
        }

        @Override // cc.fasttext.io.PrintLogs
        public boolean isDebugEnabled() {
            return FastText.LOGGER.isDebugEnabled();
        }

        @Override // cc.fasttext.io.PrintLogs
        public boolean isInfoEnabled() {
            return FastText.LOGGER.isInfoEnabled();
        }

        @Override // cc.fasttext.io.PrintLogs
        public void trace(String str, Object... objArr) {
            String format;
            if (!isTraceEnabled() || (format = format(str, objArr)) == null) {
                return;
            }
            FastText.LOGGER.trace(format);
        }

        @Override // cc.fasttext.io.PrintLogs
        public void debug(String str, Object... objArr) {
            String format;
            if (!isDebugEnabled() || (format = format(str, objArr)) == null) {
                return;
            }
            FastText.LOGGER.debug(format);
        }

        @Override // cc.fasttext.io.PrintLogs
        public void info(String str, Object... objArr) {
            String format;
            if (!isInfoEnabled() || (format = format(str, objArr)) == null) {
                return;
            }
            FastText.LOGGER.info(format);
        }

        private static String format(String str, Object... objArr) {
            if (StringUtils.isEmpty(str)) {
                return null;
            }
            String nonHyphenatedLine = FormatUtils.toNonHyphenatedLine(str);
            return objArr.length == 0 ? nonHyphenatedLine : String.format(Factory.LOCALE, nonHyphenatedLine, objArr);
        }
    }

    /* loaded from: input_file:cc/fasttext/FastText$TestInfo.class */
    public class TestInfo {
        private final double precision;
        private final int examples;
        private final int labels;
        private final int k;

        private TestInfo(int i, double d, int i2, int i3) {
            this.k = i;
            this.precision = d;
            this.examples = i2;
            this.labels = i3;
        }

        public double getPrecision() {
            return this.precision;
        }

        public int getNExamples() {
            return this.examples;
        }

        public int getNLabels() {
            return this.labels;
        }

        public int getK() {
            return this.k;
        }

        public String toString() {
            return String.format(Factory.LOCALE, "N\t%d%nP@%d: %.3f%nR@%d: %.3f%nNumber of examples: %d%n", Integer.valueOf(this.examples), Integer.valueOf(this.k), Double.valueOf(this.precision / (this.k * this.examples)), Integer.valueOf(this.k), Double.valueOf(this.precision / this.labels), Integer.valueOf(this.examples));
        }
    }

    private FastText(Args args, Dictionary dictionary, Model model, int i, IOStreams iOStreams, PrintLogs printLogs, IntFunction<RandomGenerator> intFunction) {
        this.args = args;
        this.dict = dictionary;
        this.model = model;
        this.version = i;
        this.fs = iOStreams;
        this.logs = printLogs;
        this.random = intFunction;
    }

    public static FastText train(Args args, String str) throws IOException, ExecutionException {
        return train(args, str, null);
    }

    public static FastText train(Args args, String str, String str2) throws IOException, ExecutionException {
        return DEFAULT_FACTORY.train(args, str, str2);
    }

    public static FastText load(String str) throws IOException, IllegalArgumentException {
        return DEFAULT_FACTORY.load(str);
    }

    public Args getArgs() {
        return this.args;
    }

    public Dictionary getDictionary() {
        return this.dict;
    }

    public Model getModel() {
        return this.model;
    }

    public int getVersion() {
        return this.version;
    }

    protected Factory toFactory() {
        return new Factory(this.fs, this.random, this.logs, this.dict.charset());
    }

    public Vector getWordVector(String str) {
        Vector vector = new Vector(this.args.dim());
        List<Integer> subwords = this.dict.getSubwords(str);
        Iterator<Integer> it = subwords.iterator();
        while (it.hasNext()) {
            addInputVector(vector, it.next().intValue());
        }
        if (subwords.size() > 0) {
            vector.mul(1.0f / subwords.size());
        }
        return vector;
    }

    public Vector getSentenceVector(String str) {
        String str2 = str + "\n";
        Vector vector = new Vector(this.args.dim());
        if (Args.ModelName.SUP.equals(this.args.model())) {
            List<Integer> line = this.dict.getLine(str2);
            if (line.isEmpty()) {
                return vector;
            }
            Iterator<Integer> it = line.iterator();
            while (it.hasNext()) {
                addInputVector(vector, it.next().intValue());
            }
            vector.mul(1.0f / line.size());
            return vector;
        }
        int i = 0;
        for (String str3 : str2.split("\\s+")) {
            Vector wordVector = getWordVector(str3);
            float norm = wordVector.norm();
            if (norm > 0.0f) {
                wordVector.mul(1.0f / norm);
                vector.addVector(wordVector);
                i++;
            }
        }
        if (i > 0) {
            vector.mul(1.0f / i);
        }
        return vector;
    }

    private Matrix computeWordVectors() {
        this.logs.info("Pre-computing word vectors... ", new Object[0]);
        Matrix matrix = new Matrix(this.dict.nwords(), this.args.dim());
        for (int i = 0; i < this.dict.nwords(); i++) {
            Vector wordVector = getWordVector(this.dict.getWord(i));
            float norm = wordVector.norm();
            if (norm > 0.0f) {
                matrix.addRow(wordVector, i, 1.0f / norm);
            }
        }
        this.logs.infoln("done.", new Object[0]);
        return matrix;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Matrix getPrecomputedWordVectors() {
        Matrix matrix;
        if (this.precomputedWordVectors != null && (matrix = this.precomputedWordVectors.get()) != null) {
            return matrix;
        }
        Matrix computeWordVectors = computeWordVectors();
        this.precomputedWordVectors = new SoftReference(computeWordVectors);
        return computeWordVectors;
    }

    private Multimap<Float, String> findNN(Matrix matrix, Vector vector, int i, Set<String> set) {
        float norm = vector.norm();
        if (FastMath.abs(norm) < FIND_NN_THRESHOLD) {
            norm = 1.0f;
        }
        TreeMultimap create = TreeMultimap.create(Comparator.reverseOrder(), Comparator.reverseOrder());
        TreeMultimap create2 = TreeMultimap.create(Comparator.reverseOrder(), Comparator.reverseOrder());
        for (int i2 = 0; i2 < this.dict.nwords(); i2++) {
            create.put(Float.valueOf(matrix.dotRow(vector, i2) / norm), this.dict.getWord(i2));
        }
        int i3 = 0;
        while (i3 < i && create.size() > 0) {
            Float f = (Float) create.asMap().firstKey();
            String str = (String) create.get(f).first();
            if (!set.contains(str)) {
                create2.put(f, str);
                i3++;
            }
            create.remove(f, str);
        }
        return create2;
    }

    public Multimap<String, Float> nn(int i, String str) throws IllegalArgumentException {
        Validate.notEmpty(str, "Empty query word");
        Validate.isTrue(i > 0, "Not positive factor");
        Matrix precomputedWordVectors = getPrecomputedWordVectors();
        HashSet hashSet = new HashSet();
        hashSet.add(str);
        return Multimaps.invertFrom(findNN(precomputedWordVectors, getWordVector(str), i, hashSet), ArrayListMultimap.create());
    }

    public Multimap<String, Float> analogies(int i, String str, String str2, String str3) {
        Validate.notEmpty(str, "Empty first query word");
        Validate.notEmpty(str2, "Empty second query word");
        Validate.notEmpty(str3, "Empty third query word");
        Validate.isTrue(i > 0, "Not positive factor");
        Matrix precomputedWordVectors = getPrecomputedWordVectors();
        HashSet hashSet = new HashSet();
        hashSet.add(str);
        Vector vector = new Vector(this.args.dim());
        vector.addVector(getWordVector(str), 1.0f);
        hashSet.add(str2);
        vector.addVector(getWordVector(str2), -1.0f);
        hashSet.add(str3);
        vector.addVector(getWordVector(str3), 1.0f);
        return Multimaps.invertFrom(findNN(precomputedWordVectors, vector, i, hashSet), ArrayListMultimap.create());
    }

    public Multimap<String, Vector> ngramVectors(String str) {
        Validate.notEmpty(str, "Empty word");
        return Multimaps.transformValues(this.dict.getSubwordsMap(str), num -> {
            Vector vector = new Vector(this.args.dim());
            if (num != null && num.intValue() >= 0) {
                addInputVector(vector, num.intValue());
            }
            return vector;
        });
    }

    private void addInputVector(Vector vector, int i) {
        if (this.model.isQuant()) {
            vector.addRow(this.model.qinput(), i);
        } else {
            vector.addRow(this.model.input(), i);
        }
    }

    public void saveVectors(String str) throws IOException, IllegalArgumentException {
        int nwords = this.dict.nwords();
        Dictionary dictionary = this.dict;
        dictionary.getClass();
        writeVectors("vectors", str, nwords, dictionary::getWord, i -> {
            return getWordVector(this.dict.getWord(i));
        });
    }

    public void saveOutput(String str) throws IOException, IllegalArgumentException, IllegalStateException {
        if (getModel().isQuant()) {
            throw new IllegalStateException("Saving output is not supported for quantized models.");
        }
        writeVectors("output", str, Args.ModelName.SUP.equals(this.args.model()) ? this.dict.nlabels() : this.dict.nwords(), i -> {
            return Args.ModelName.SUP.equals(this.args.model()) ? this.dict.getLabel(i) : this.dict.getWord(i);
        }, i2 -> {
            Vector vector = new Vector(this.args.dim());
            vector.addRow(this.model.output(), i2);
            return vector;
        });
    }

    private void writeVectors(String str, String str2, int i, IntFunction<String> intFunction, IntFunction<Vector> intFunction2) throws IOException, IllegalArgumentException {
        if (!this.fs.canWrite(str2)) {
            throw new IllegalArgumentException("Can't write to " + str2);
        }
        this.logs.infoln("Saving %s to %s", str, str2);
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(this.fs.createOutput(str2), this.dict.charset()));
        Throwable th = null;
        try {
            try {
                bufferedWriter.write(i + " " + this.args.dim() + "\n");
                for (int i2 = 0; i2 < i; i2++) {
                    bufferedWriter.write(intFunction.apply(i2));
                    bufferedWriter.write(" ");
                    bufferedWriter.write(intFunction2.apply(i2).toString());
                    bufferedWriter.write("\n");
                }
                if (bufferedWriter != null) {
                    if (0 == 0) {
                        bufferedWriter.close();
                        return;
                    }
                    try {
                        bufferedWriter.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (bufferedWriter != null) {
                if (th != null) {
                    try {
                        bufferedWriter.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    bufferedWriter.close();
                }
            }
            throw th4;
        }
    }

    public void saveModel(String str) throws IOException, IllegalArgumentException {
        Events.SAVE_BIN.start();
        if (!this.fs.canWrite(str)) {
            throw new IllegalArgumentException("Can't write to " + str);
        }
        this.logs.infoln("Saving model to %s", str);
        FTOutputStream fTOutputStream = new FTOutputStream(new BufferedOutputStream(this.fs.createOutput(str)));
        Throwable th = null;
        try {
            signModel(fTOutputStream);
            this.args.save(fTOutputStream);
            this.dict.save(fTOutputStream);
            boolean isQuant = this.model.isQuant();
            fTOutputStream.writeBoolean(isQuant);
            if (isQuant) {
                this.model.qinput().save(fTOutputStream);
            } else {
                this.model.input().save(fTOutputStream);
            }
            fTOutputStream.writeBoolean(this.args.qout());
            if (isQuant && this.args.qout()) {
                this.model.qoutput().save(fTOutputStream);
            } else {
                this.model.output().save(fTOutputStream);
            }
            Events.SAVE_BIN.end();
        } finally {
            if (fTOutputStream != null) {
                if (0 != 0) {
                    try {
                        fTOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    fTOutputStream.close();
                }
            }
        }
    }

    private static void signModel(FTOutputStream fTOutputStream) throws IOException {
        fTOutputStream.writeInt(FASTTEXT_FILEFORMAT_MAGIC_INT32);
        fTOutputStream.writeInt(12);
    }

    public TestInfo test(InputStream inputStream, int i) throws IOException {
        Objects.requireNonNull(inputStream, "Null input");
        Validate.isTrue(i > 0, "Not positive factor");
        int i2 = 0;
        int i3 = 0;
        double d = 0.0d;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Dictionary.SeekableReader createReader = this.dict.createReader(inputStream);
        while (!createReader.isEnd() && this.dict.getLine(createReader, arrayList, arrayList2) != 0) {
            if (!arrayList2.isEmpty() && !arrayList.isEmpty()) {
                Stream stream = this.model.predict(arrayList, i).values().stream();
                arrayList2.getClass();
                d += stream.filter((v1) -> {
                    return r2.contains(v1);
                }).count();
                i2++;
                i3 += arrayList2.size();
            }
        }
        return new TestInfo(i, d, i2, i3);
    }

    public TestInfo test(String str, int i) throws IOException {
        if (!this.fs.canRead(str)) {
            throw new IllegalArgumentException("Can't read file " + str);
        }
        InputStream openInput = this.fs.openInput(str);
        Throwable th = null;
        try {
            try {
                TestInfo test = test(openInput, i);
                if (openInput != null) {
                    if (0 != 0) {
                        try {
                            openInput.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        openInput.close();
                    }
                }
                return test;
            } finally {
            }
        } catch (Throwable th3) {
            if (openInput != null) {
                if (th != null) {
                    try {
                        openInput.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    openInput.close();
                }
            }
            throw th3;
        }
    }

    private static int compareLabels(String str, String str2, String str3) {
        String replace = str2.replace(str, "");
        if (replace.matches("\\d+")) {
            String replace2 = str3.replace(str, "");
            if (replace2.matches("\\d+")) {
                return Integer.compare(Integer.valueOf(replace).intValue(), Integer.valueOf(replace2).intValue());
            }
        }
        return str2.compareTo(str3);
    }

    private static <K, V> Map<K, V> toStandardMap(Multimap<K, V> multimap, Function<Map.Entry<K, V>, V> function) throws IllegalStateException {
        return (Map) multimap.entries().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, function, (obj, obj2) -> {
            throw new IllegalStateException("Duplicate label");
        }, LinkedHashMap::new));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<String, Float> toProbabilityMap(Multimap<String, Float> multimap) throws IllegalStateException {
        return toStandardMap(multimap, entry -> {
            return Float.valueOf((float) FastMath.exp(((Float) entry.getValue()).floatValue()));
        });
    }

    public Stream<Map<String, Float>> predict(InputStream inputStream, final int i) {
        Objects.requireNonNull(inputStream, "Null input");
        Validate.isTrue(i > 0, "Not positive factor");
        final Dictionary.SeekableReader createReader = this.dict.createReader(inputStream);
        return StreamSupport.stream(Spliterators.spliteratorUnknownSize(new Iterator<Map<String, Float>>() { // from class: cc.fasttext.FastText.1
            @Override // java.util.Iterator
            public boolean hasNext() {
                return !createReader.isEnd();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Map<String, Float> next() {
                if (!(!createReader.isEnd())) {
                    throw new NoSuchElementException();
                }
                try {
                    return FastText.toProbabilityMap(FastText.this.predict(createReader, i));
                } catch (IOException e) {
                    throw new UncheckedIOException(e);
                }
            }
        }, 0), false).filter(map -> {
            return !map.isEmpty();
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Multimap<String, Float> predict(Dictionary.SeekableReader seekableReader, int i) throws IOException {
        ArrayList arrayList = new ArrayList();
        this.dict.getLine(seekableReader, arrayList, new ArrayList());
        if (arrayList.isEmpty()) {
            return ImmutableListMultimap.of();
        }
        TreeMultimap<Float, Integer> predict = this.model.predict(arrayList, i, new Vector(this.args.dim()), new Vector(this.dict.nlabels()));
        TreeMultimap create = TreeMultimap.create((str, str2) -> {
            return compareLabels(this.args.label(), str, str2);
        }, predict.keySet().comparator());
        predict.forEach((f, num) -> {
            create.put(this.dict.getLabel(num.intValue()), f);
        });
        return create;
    }

    public Stream<Map<String, Float>> predict(String str, int i) throws IOException, IllegalArgumentException {
        if (!this.fs.canRead(str)) {
            throw new IllegalArgumentException("Can't read file " + str);
        }
        InputStream openInput = this.fs.openInput(str);
        return (Stream) predict(openInput, i).onClose(() -> {
            try {
                openInput.close();
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        });
    }

    public Map<String, Float> predictLine(String str, int i) throws IllegalStateException, IllegalArgumentException {
        Validate.notEmpty(str, "Null line specified.");
        Validate.isTrue(i > 0, "Negative or zero factor");
        List<Integer> line = this.dict.getLine(str);
        if (line.isEmpty()) {
            return Collections.emptyMap();
        }
        TreeMultimap<Float, Integer> predict = this.model.predict(line, i, new Vector(this.args.dim()), new Vector(this.dict.nlabels()));
        TreeMultimap create = TreeMultimap.create((str2, str3) -> {
            return compareLabels(this.args.label(), str2, str3);
        }, predict.keySet().comparator());
        predict.forEach((f, num) -> {
            create.put(this.dict.getLabel(num.intValue()), f);
        });
        return toProbabilityMap(create);
    }

    private List<Integer> selectEmbeddings(int i) {
        Vector l2NormRow = this.model.input().l2NormRow();
        List list = (List) IntStream.iterate(0, i2 -> {
            return i2 + 1;
        }).limit(this.model.input().getM()).boxed().collect(Collectors.toList());
        int id = this.dict.getId(Dictionary.EOS);
        list.sort((num, num2) -> {
            return id == num.intValue() || (id != num2.intValue() && (l2NormRow.get(num.intValue()) > l2NormRow.get(num2.intValue()) ? 1 : (l2NormRow.get(num.intValue()) == l2NormRow.get(num2.intValue()) ? 0 : -1)) > 0) ? -1 : 1;
        });
        return list.subList(0, i);
    }

    public FastText quantize(Args args, String str) throws IOException, ExecutionException, IllegalStateException, IllegalArgumentException {
        Matrix copy;
        if (this.model.isQuant()) {
            throw new IllegalStateException("Already quantized.");
        }
        if (!Args.ModelName.SUP.equals(this.args.model())) {
            throw new IllegalArgumentException("For now we only support quantization of supervised models");
        }
        Args build = new Args.Builder().copy(this.args).setQOut(args.qout()).setCutOff(args.cutoff()).setQNorm(args.qnorm()).setDSub(args.dsub()).build();
        Dictionary copy2 = this.dict.copy();
        Matrix copy3 = this.model.output().copy();
        Factory factory = toFactory();
        if (build.cutoff() <= 0 || build.cutoff() >= this.model.input().getM()) {
            copy = this.model.input().copy();
        } else {
            List<Integer> prune = copy2.prune(selectEmbeddings(build.cutoff()));
            copy = new Matrix(prune.size(), build.dim());
            for (int i = 0; i < prune.size(); i++) {
                for (int i2 = 0; i2 < build.dim(); i2++) {
                    copy.put(i, i2, this.model.input().at(prune.get(i).intValue(), i2));
                }
            }
            if (!StringUtils.isEmpty(str)) {
                build = new Args.Builder().copy(build).setEpoch(args.epoch()).setLR(args.lr()).setThread(args.thread()).build();
                this.logs.traceln("Start retraining ...", new Object[0]);
                Model train = factory.newTrainer(build, str, copy2, copy, copy3).train();
                copy = train.input();
                copy3 = train.output();
            }
        }
        return factory.createFastText(build, copy2, factory.createModel(build, copy2, copy, copy3, 0).setQuantizePointer(new QMatrix(copy, this.random, build.dsub(), build.qnorm()), build.qout() ? new QMatrix(copy3, this.random, 2, build.qnorm()) : QMatrix.empty()), 12);
    }
}
