package cc.fasttext;

import cc.fasttext.Args;
import cc.fasttext.FastText;
import cc.fasttext.io.FormatUtils;
import cc.fasttext.io.IOStreams;
import cc.fasttext.io.PrintLogs;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Scanner;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import java.util.function.DoubleConsumer;
import java.util.function.IntConsumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang.StringUtils;

/* loaded from: input_file:cc/fasttext/Main.class */
public class Main {
    private static FastText.Factory factory = FastText.DEFAULT_FACTORY;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/fasttext/Main$Usage.class */
    public enum Usage {
        COMMON("usage: {fasttext} <command> <args>\n\nThe commands supported by fasttext are:\n\n  supervised              train a supervised classifier\n  quantize                quantize a model to reduce the memory usage\n  test                    evaluate a supervised classifier\n  predict                 predict most likely labels\n  predict-prob            predict most likely labels with probabilities\n  skipgram                train a skipgram model\n  cbow                    train a cbow model\n  print-word-vectors      print word vectors given a trained model\n  print-sentence-vectors  print sentence vectors given a trained model\n  nn                      query for nearest neighbors\n  analogies               query for analogies\n"),
        TRAIN("usage: {fasttext} {supervised|skipgram|cbow} <args>"),
        QUANTIZE("usage: {fasttext} quantize <args>"),
        TEST("usage: {fasttext} test <model> <test-data> [<k>]\n\n  <model>      model filename\n  <test-data>  test data filename (if -, read from stdin)\n  <k>          (optional; 1 by default) predict top k labels\n"),
        PREDICT("usage: {fasttext} predict[-prob] <model> <test-data> [<k>]\n\n  <model>      model filename\n  <test-data>  test data filename (if -, read from stdin)\n  <k>          (optional; 1 by default) predict top k labels\n"),
        PRINT_WORD_VECTORS("usage: {fasttext} print-word-vectors <model>\n\n  <model>      model filename\n"),
        PRINT_SENTENCE_VECTORS("usage: {fasttext} print-sentence-vectors <model>\n\n  <model>      model filename\n"),
        PRINT_NGRAMS("usage: {fasttext} print-ngrams <model> <word>\n\n  <model>      model filename\n  <word>       word to print\n"),
        NN("usage: {fasttext} nn <model> <k>\n\n  <model>      model filename\n  <k>          (optional; 10 by default) predict top k labels\n"),
        ANALOGIES("usage: {fasttext} analogies <model> <k>\n\n  <model>      model filename\n  <k>          (optional; 10 by default) predict top k labels\n"),
        ARGS_BASIC_HELP("\nThe following arguments are mandatory:\n  -input              training file uri\n  -output             output file name\n\nThe following arguments are optional:\n  -verbose            verbosity level [integer]\n"),
        ARGS_DICTIONARY_HELP("\nThe following arguments for the dictionary are optional:\n  -minCount           minimal number of word occurrences [integer]\n  -minCountLabel      minimal number of label occurrences [integer]\n  -wordNgrams         max length of word ngram [integer]\n  -bucket             number of buckets [integer]\n  -minn               min length of char ngram [integer]\n  -maxn               max length of char ngram [integer]\n  -t                  sampling threshold [double]\n  -label              labels prefix [string]\n"),
        ARGS_TRAINING_HELP("\nThe following arguments for training are optional:\n  -lr                 learning rate [double]\n  -lrUpdateRate       change the rate of updates for the learning rate [integer]\n  -dim                size of word vectors [integer]\n  -ws                 size of the context window [integer]\n  -epoch              number of epochs [integer]\n  -neg                number of negatives sampled [integer]\n  -loss               loss function {ns|hs|softmax} [string]\n  -thread             number of threads [integer]\n  -pretrainedVectors  pretrained word vectors for supervised learning [file uri]\n  -saveOutput         whether output params should be saved [boolean]\n"),
        ARGS_QUANTIZATION_HELP("\nThe following arguments for quantization are optional:\n  -cutoff             number of words and ngrams to retain [integer]\n  -retrain            whether embeddings are finetuned if a cutoff is applied [boolean]\n  -qnorm              whether the norm is quantized separately [boolean]\n  -qout               whether the classifier is quantized [boolean\n  -dsub               size of each sub-vector [integer]\n"),
        ARGS(ARGS_BASIC_HELP.message + ARGS_DICTIONARY_HELP.message + ARGS_TRAINING_HELP.message + ARGS_QUANTIZATION_HELP.message);

        private final String message;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:cc/fasttext/Main$Usage$WrongInputException.class */
        public static class WrongInputException extends IllegalArgumentException {
            WrongInputException(String str) {
                super(str);
            }
        }

        Usage(String str) {
            this.message = str;
        }

        public String getMessage() {
            return this.message.replace("{fasttext}", "java -jar fasttext.jar");
        }

        public IllegalArgumentException toException() {
            return createException(getMessage());
        }

        public IllegalArgumentException toException(String str) {
            return createException(str + "\n" + getMessage());
        }

        public IllegalArgumentException toException(String str, Usage usage) {
            return createException(str + "\n" + getMessage() + "\n" + usage.getMessage());
        }

        private static IllegalArgumentException createException(String str) {
            return new WrongInputException(str);
        }
    }

    public static void setFileSystem(IOStreams iOStreams) {
        factory = factory.setFileSystem(iOStreams);
    }

    public static IOStreams fileSystem() {
        return factory.getFileSystem();
    }

    public static void test(String[] strArr) throws IOException, IllegalArgumentException {
        int i = 1;
        if (strArr.length == 4) {
            i = Integer.parseInt(strArr[3]);
        } else if (strArr.length != 3) {
            throw Usage.TEST.toException();
        }
        FastText loadModel = loadModel(strArr[1]);
        String str = strArr[2];
        System.out.println(("-".equals(str) ? loadModel.test(System.in, i) : loadModel.test(str, i)).toString());
    }

    public static void predict(String[] strArr) throws IOException, IllegalArgumentException {
        int i = 1;
        if (strArr.length == 4) {
            i = Integer.parseInt(strArr[3]);
        } else if (strArr.length != 3) {
            throw Usage.PREDICT.toException();
        }
        boolean equalsIgnoreCase = "predict-prob".equalsIgnoreCase(strArr[0]);
        FastText loadModel = loadModel(strArr[1]);
        String str = strArr[2];
        Stream<Map<String, Float>> predict = "-".equals(str) ? loadModel.predict(System.in, i) : loadModel.predict(str, i);
        Throwable th = null;
        try {
            try {
                Stream<R> map = predict.map(map2 -> {
                    return (String) map2.entrySet().stream().map(entry -> {
                        String str2 = (String) entry.getKey();
                        if (equalsIgnoreCase) {
                            str2 = str2 + " " + FormatUtils.toString(((Float) entry.getValue()).floatValue(), 6);
                        }
                        return str2;
                    }).collect(Collectors.joining(" "));
                });
                PrintStream printStream = System.out;
                printStream.getClass();
                map.forEach(printStream::println);
                if (predict != null) {
                    if (0 == 0) {
                        predict.close();
                        return;
                    }
                    try {
                        predict.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (predict != null) {
                if (th != null) {
                    try {
                        predict.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    predict.close();
                }
            }
            throw th4;
        }
    }

    public static void printWordVectors(String[] strArr) throws IOException, IllegalArgumentException {
        if (strArr.length != 2) {
            throw Usage.PRINT_WORD_VECTORS.toException();
        }
        FastText loadModel = loadModel(strArr[1]);
        Scanner scanner = new Scanner(System.in);
        while (scanner.hasNextLine()) {
            String nextLine = scanner.nextLine();
            System.out.println(nextLine + " " + loadModel.getWordVector(nextLine));
        }
    }

    public static void printSentenceVectors(String[] strArr) throws IOException, IllegalArgumentException {
        if (strArr.length != 2) {
            throw Usage.PRINT_SENTENCE_VECTORS.toException();
        }
        FastText loadModel = loadModel(strArr[1]);
        Scanner scanner = new Scanner(System.in);
        while (scanner.hasNextLine()) {
            System.out.println(loadModel.getSentenceVector(scanner.nextLine()));
        }
    }

    public static void printNgrams(String[] strArr) throws IOException, IllegalArgumentException {
        if (strArr.length != 3) {
            throw Usage.PRINT_NGRAMS.toException();
        }
        loadModel(strArr[1]).ngramVectors(strArr[2]).forEach((str, vector) -> {
            System.out.println(str + " " + vector);
        });
    }

    public static void nn(String[] strArr) throws IOException, IllegalArgumentException {
        int i = 10;
        if (strArr.length == 3) {
            i = Integer.parseInt(strArr[2]);
        } else if (strArr.length != 2) {
            throw Usage.NN.toException();
        }
        FastText loadModel = loadModel(strArr[1]);
        loadModel.getPrecomputedWordVectors();
        Scanner scanner = new Scanner(System.in);
        PrintStream printStream = System.out;
        while (true) {
            printStream.println("Query word?");
            try {
                loadModel.nn(i, scanner.next()).forEach((str, f) -> {
                    printStream.println(str + " " + FormatUtils.toString(f.floatValue()));
                });
            } catch (NoSuchElementException e) {
                return;
            }
        }
    }

    public static void analogies(String[] strArr) throws IOException, IllegalArgumentException {
        int i = 10;
        if (strArr.length == 3) {
            i = Integer.parseInt(strArr[2]);
        } else if (strArr.length != 2) {
            throw Usage.ANALOGIES.toException();
        }
        FastText loadModel = loadModel(strArr[1]);
        loadModel.getPrecomputedWordVectors();
        Scanner scanner = new Scanner(System.in);
        PrintStream printStream = System.out;
        while (true) {
            printStream.println("Query triplet (A - B + C)?");
            ArrayList arrayList = new ArrayList();
            while (arrayList.size() < 3) {
                try {
                    arrayList.add(scanner.next());
                } catch (NoSuchElementException e) {
                    return;
                }
            }
            loadModel.analogies(i, (String) arrayList.get(0), (String) arrayList.get(1), (String) arrayList.get(2)).forEach((str, f) -> {
                printStream.println(str + " " + FormatUtils.toString(f.floatValue()));
            });
        }
    }

    public static void train(String[] strArr) throws IOException, ExecutionException, IllegalArgumentException {
        if (strArr.length == 0) {
            throw Usage.TRAIN.toException("Empty args specified.", Usage.ARGS);
        }
        Map<String, String> map = toMap(strArr);
        Args.ModelName fromName = Args.ModelName.fromName(strArr[0]);
        String str = map.get("-input");
        if (StringUtils.isEmpty(str)) {
            throw Usage.TRAIN.toException("Empty -input", Usage.ARGS);
        }
        if (!fileSystem().canRead(str)) {
            throw Usage.TRAIN.toException("Wrong -input: can't read " + str, Usage.ARGS);
        }
        String str2 = map.get("-output");
        if (StringUtils.isEmpty(str2)) {
            throw Usage.TRAIN.toException("Empty -output", Usage.ARGS);
        }
        String str3 = null;
        if (map.containsKey("-saveOutput")) {
            str3 = str2 + ".output";
        }
        String str4 = str2 + ".bin";
        String str5 = str2 + ".vec";
        if (Stream.of((Object[]) new String[]{str4, str5, str3}).filter((v0) -> {
            return Objects.nonNull(v0);
        }).anyMatch(str6 -> {
            return !fileSystem().canWrite(str6);
        })) {
            throw Usage.TRAIN.toException("Wrong -output: can't write model " + str, Usage.ARGS);
        }
        String str7 = map.get("-pretrainedVectors");
        if (!StringUtils.isEmpty(str7) && !fileSystem().canRead(str7)) {
            throw Usage.TRAIN.toException("Wrong -pretrainedVectors: can't read " + str7, Usage.ARGS);
        }
        FastText train = factory.setLogs(createStdErrLogger(parseVerbose(map, Usage.TRAIN))).train(parseArgs(fromName, map), str, str7);
        train.saveModel(str4);
        train.saveVectors(str5);
        if (str3 == null) {
            return;
        }
        train.saveOutput(str3);
    }

    public static void quantize(String[] strArr) throws IOException, ExecutionException, IllegalArgumentException {
        if (strArr.length == 0) {
            throw Usage.QUANTIZE.toException("Empty args specified.", Usage.ARGS);
        }
        Map<String, String> map = toMap(strArr);
        String str = map.get("-output");
        if (StringUtils.isEmpty(str)) {
            throw Usage.QUANTIZE.toException("No model (-output)", Usage.ARGS);
        }
        String str2 = str + ".bin";
        if (!fileSystem().canRead(str2)) {
            throw Usage.QUANTIZE.toException("Wrong -output: can't read file " + str2, Usage.ARGS);
        }
        String str3 = null;
        if (map.containsKey("-retrain")) {
            str3 = map.get("-input");
            if (StringUtils.isEmpty(str3)) {
                throw Usage.QUANTIZE.toException("Wrong args: -input is required if -retrain specified.", Usage.ARGS);
            }
            if (!fileSystem().canRead(str3)) {
                throw Usage.QUANTIZE.toException("Wrong -input: can't read file " + str3, Usage.ARGS);
            }
        }
        String str4 = str + ".ftz";
        String str5 = str + ".vec";
        if (!fileSystem().canWrite(str4) || !fileSystem().canWrite(str5)) {
            throw Usage.QUANTIZE.toException("Wrong -output: can't write model " + str, Usage.ARGS);
        }
        if (map.containsKey("-saveOutput")) {
            throw Usage.QUANTIZE.toException("Option -saveOutput is not supported for quantized models", Usage.ARGS);
        }
        FastText quantize = factory.setLogs(createStdErrLogger(parseVerbose(map, Usage.QUANTIZE))).load(str2).quantize(parseArgs(Args.ModelName.SUP, map), str3);
        quantize.saveModel(str4);
        quantize.saveVectors(str5);
    }

    public static void main(String... strArr) {
        try {
            run(strArr);
        } catch (Usage.WrongInputException e) {
            System.out.print(e.getMessage());
        } catch (Exception e2) {
            e2.printStackTrace();
            System.exit(1);
        }
    }

    public static void run(String... strArr) throws Exception {
        if (strArr.length < 1) {
            throw Usage.COMMON.toException();
        }
        String str = strArr[0];
        if ("skipgram".equalsIgnoreCase(str) || "cbow".equalsIgnoreCase(str) || "supervised".equalsIgnoreCase(str)) {
            train(strArr);
            return;
        }
        if ("quantize".equalsIgnoreCase(str)) {
            quantize(strArr);
            return;
        }
        if ("test".equalsIgnoreCase(str)) {
            test(strArr);
            return;
        }
        if ("print-word-vectors".equalsIgnoreCase(str)) {
            printWordVectors(strArr);
            return;
        }
        if ("print-sentence-vectors".equalsIgnoreCase(str)) {
            printSentenceVectors(strArr);
            return;
        }
        if ("print-ngrams".equalsIgnoreCase(str)) {
            printNgrams(strArr);
            return;
        }
        if ("nn".equalsIgnoreCase(str)) {
            nn(strArr);
            return;
        }
        if ("analogies".equalsIgnoreCase(str)) {
            analogies(strArr);
        } else {
            if (!"predict".equalsIgnoreCase(str) && !"predict-prob".equalsIgnoreCase(str)) {
                throw Usage.COMMON.toException();
            }
            predict(strArr);
        }
    }

    private static FastText loadModel(String str) throws IOException {
        return factory.setLogs(createStdErrLogger(PrintLogs.Level.INFO)).load(str);
    }

    public static PrintLogs createStdErrLogger(PrintLogs.Level level) {
        return level.createLogger(System.err);
    }

    private static PrintLogs.Level parseVerbose(Map<String, String> map, Usage usage) {
        if (!map.containsKey("-verbose")) {
            return PrintLogs.Level.ALL;
        }
        try {
            return PrintLogs.Level.at(Integer.parseInt(map.get("-verbose")));
        } catch (NumberFormatException e) {
            throw usage.toException(e.getMessage());
        }
    }

    public static Args parseArgs(Args.ModelName modelName, Map<String, String> map) throws IllegalArgumentException {
        Args.Builder model = new Args.Builder().setModel(modelName);
        model.getClass();
        putIntegerArg(map, "-lrUpdateRate", model::setLRUpdateRate);
        model.getClass();
        putIntegerArg(map, "-dim", model::setDim);
        model.getClass();
        putIntegerArg(map, "-ws", model::setWS);
        model.getClass();
        putIntegerArg(map, "-epoch", model::setEpoch);
        model.getClass();
        putIntegerArg(map, "-minCount", model::setMinCount);
        model.getClass();
        putIntegerArg(map, "-minCountLabel", model::setMinCountLabel);
        model.getClass();
        putIntegerArg(map, "-neg", model::setNeg);
        model.getClass();
        putIntegerArg(map, "-wordNgrams", model::setWordNgrams);
        model.getClass();
        putIntegerArg(map, "-bucket", model::setBucket);
        model.getClass();
        putIntegerArg(map, "-minn", model::setMinN);
        model.getClass();
        putIntegerArg(map, "-maxn", model::setMaxN);
        model.getClass();
        putIntegerArg(map, "-thread", model::setThread);
        model.getClass();
        putIntegerArg(map, "-cutoff", model::setCutOff);
        model.getClass();
        putIntegerArg(map, "-dsub", model::setDSub);
        model.getClass();
        putDoubleArg(map, "-lr", model::setLR);
        model.getClass();
        putDoubleArg(map, "-t", model::setSamplingThreshold);
        model.getClass();
        putBooleanArg(map, "-qnorm", (v1) -> {
            r2.setQNorm(v1);
        });
        model.getClass();
        putBooleanArg(map, "-qout", (v1) -> {
            r2.setQOut(v1);
        });
        model.getClass();
        putStringArg(map, "-label", model::setLabel);
        if (map.containsKey("-loss")) {
            model.setLossName(Args.LossName.fromName(map.get("-loss")));
        }
        return model.build();
    }

    public static Map<String, String> toMap(String... strArr) throws IllegalArgumentException {
        String bool;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        int i = 0;
        while (i < strArr.length) {
            String str = strArr[i];
            String str2 = null;
            if (str.startsWith("-")) {
                if (i == strArr.length - 1 || strArr[i + 1].startsWith("-")) {
                    bool = Boolean.TRUE.toString();
                } else {
                    i++;
                    bool = strArr[i];
                }
                str2 = bool;
            }
            linkedHashMap.put(str, str2);
            i++;
        }
        if (linkedHashMap.containsKey("-h")) {
            throw Usage.ARGS.toException("Here is the help! Usage:");
        }
        return linkedHashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static void putStringArg(Map<String, String> map, String str, Consumer<String> consumer) {
        if (map.containsKey(str)) {
            consumer.accept(Objects.requireNonNull(map.get(str), "Null value for " + str));
        }
    }

    private static void putIntegerArg(Map<String, String> map, String str, IntConsumer intConsumer) {
        if (map.containsKey(str)) {
            try {
                intConsumer.accept(Integer.parseInt((String) Objects.requireNonNull(map.get(str), "Null int value for " + str)));
            } catch (NumberFormatException e) {
                throw Usage.ARGS.toException("Wrong value for " + str + ": " + e.getMessage());
            }
        }
    }

    private static void putDoubleArg(Map<String, String> map, String str, DoubleConsumer doubleConsumer) {
        if (map.containsKey(str)) {
            try {
                doubleConsumer.accept(Double.parseDouble((String) Objects.requireNonNull(map.get(str), "Null double value for " + str)));
            } catch (NumberFormatException e) {
                throw Usage.ARGS.toException("Wrong value for " + str + ": " + e.getMessage());
            }
        }
    }

    private static void putBooleanArg(Map<String, String> map, String str, Consumer<Boolean> consumer) {
        if (map.containsKey(str)) {
            consumer.accept(Boolean.valueOf(Boolean.parseBoolean((String) Objects.requireNonNull(map.get(str), "Null value for " + str))));
        }
    }
}
