package com.simiacryptus.text.gpt2;

import com.simiacryptus.text.LanguageCodeModel;
import com.simiacryptus.text.SumModel;
import com.simiacryptus.text.TextGenerator;
import com.simiacryptus.util.Util;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.ZipFile;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;

/* loaded from: input_file:com/simiacryptus/text/gpt2/GPT2Util.class */
public class GPT2Util {
    private static final String MODEL_URL_BASE = System.getProperty("GPT2_MODEL_URL", "https://s3-us-west-2.amazonaws.com/simiacryptus/gpt2/");
    private static final TextGenerator prototype = get345M().setVerbose(false);

    @Nonnull
    public static TextGenerator get345M() {
        return new TextGenerator(50257, getModel_345M(), getCodec_345M());
    }

    @Nonnull
    public static GPT2Codec getCodec_345M() {
        return new GPT2Codec(getEncoderFile_345M(), 50257);
    }

    @Nonnull
    public static File getEncoderFile_345M() {
        return loadZippedInternetFile(MODEL_URL_BASE + "encoder_345M.zip", "encoder_345M.json");
    }

    @Nonnull
    public static File getGraphFile_345M() {
        return loadRawInternetFile(MODEL_URL_BASE, "345M.pb");
    }

    @Nonnull
    public static GPT2Model getModel_345M() {
        return getModel_345M(getGraphFile_345M());
    }

    @Nonnull
    public static TextGenerator getTextGenerator() {
        return prototype.copy();
    }

    @Nonnull
    public static File loadZippedInternetFile(@Nonnull String str, @Nonnull String str2) {
        File file;
        File file2 = new File(str2);
        if (new File(file2.getName()).exists()) {
            file = new File(file2.getName());
        } else {
            try {
                ZipFile zipFile = new ZipFile(Util.cacheFile(new URI(str)));
                Throwable th = null;
                try {
                    try {
                        byte[] byteArray = IOUtils.toByteArray(zipFile.getInputStream(zipFile.getEntry(str2)));
                        file = new File(file2.getName());
                        FileUtils.writeByteArrayToFile(file, byteArray);
                        if (zipFile != null) {
                            if (0 != 0) {
                                try {
                                    zipFile.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                zipFile.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            } catch (Exception e) {
                throw Util.throwException(e);
            }
        }
        return file;
    }

    @Nonnull
    public static File loadRawInternetFile(String str, @Nonnull String str2) {
        File cacheFile;
        File file = new File(str2);
        if (new File(file.getName()).exists()) {
            cacheFile = new File(file.getName());
        } else {
            try {
                cacheFile = Util.cacheFile(new URI(str + str2));
            } catch (Exception e) {
                throw Util.throwException(e);
            }
        }
        return cacheFile;
    }

    @Nonnull
    public static GPT2Model getModel_345M(@Nonnull File file) {
        return getModel_345M("345M", file);
    }

    @Nonnull
    public static GPT2Model getModel_345M(String str, @Nonnull File file) {
        return new GPT2Model(str, new GPT2Edit_345M(), file, getCodec_345M());
    }

    @Nonnull
    public static TextGenerator getTextGenerator(@Nonnull TextGenerator textGenerator, @Nullable String str, @Nullable URI uri) throws IOException, NoSuchAlgorithmException, KeyManagementException {
        TreeSet treeSet = null == uri ? null : new TreeSet((Collection) Arrays.stream(FileUtils.readFileToString(Util.cacheFile(uri), "UTF-8").split("\\s+")).map(str2 -> {
            return str2.trim().toLowerCase();
        }).collect(Collectors.toSet()));
        textGenerator.getModel().setFilterFn((str3, str4) -> {
            String str3;
            String str4;
            if (null != str && !str.isEmpty() && str4.matches(".*[^" + str + "].*")) {
                return false;
            }
            String[] split = str4.split("[^\\w]+");
            if (null != treeSet && !treeSet.isEmpty()) {
                for (int i = 0; i < split.length; i++) {
                    String lowerCase = split[i].toLowerCase();
                    if (!lowerCase.isEmpty()) {
                        if (i < split.length - 1 && !treeSet.contains(lowerCase)) {
                            return false;
                        }
                        if (!treeSet.contains(lowerCase) && ((null == (str3 = (String) treeSet.floor(lowerCase)) || !str3.startsWith(lowerCase)) && (null == (str4 = (String) treeSet.ceiling(lowerCase)) || !str4.startsWith(lowerCase)))) {
                            return false;
                        }
                    }
                }
            }
            return true;
        });
        return textGenerator;
    }

    @Nonnull
    public static TextGenerator getTextGenerator(String... strArr) {
        return getTextGenerator(get345M().setVerbose(false), strArr);
    }

    @Nonnull
    public static TextGenerator getTextGenerator(@Nonnull TextGenerator textGenerator, String... strArr) {
        return getTextGenerator(textGenerator, new ArrayList(), strArr);
    }

    @Nonnull
    public static TextGenerator getTextGenerator(@Nonnull TextGenerator textGenerator, @Nonnull List<LanguageCodeModel> list, @Nonnull String... strArr) {
        textGenerator.setModel(new SumModel((LanguageCodeModel[]) Stream.concat(Arrays.stream(strArr).map(str -> {
            TextGenerator copy = textGenerator.copy();
            copy.feed(str);
            return copy.getModel();
        }), list.stream()).toArray(i -> {
            return new LanguageCodeModel[i];
        })));
        return textGenerator;
    }

    @Nonnull
    protected static TextGenerator getTextGenerator(String str, URI uri) throws IOException, NoSuchAlgorithmException, KeyManagementException {
        return getTextGenerator(getTextGenerator(), str, uri);
    }
}
