package com.simiacryptus.text.gpt2;

import com.google.gson.GsonBuilder;
import com.google.gson.JsonObject;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefStringBuilder;
import com.simiacryptus.util.Util;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/simiacryptus/text/gpt2/GPT2Codec.class */
public class GPT2Codec {
    protected static final Logger logger = LoggerFactory.getLogger(GPT2Codec.class);
    protected final TreeMap<String, Integer> encoder;

    @Nonnull
    protected final TreeMap<Integer, String> decoder;
    private final int vocabSize;

    public GPT2Codec(TreeMap<String, Integer> treeMap, int i) {
        this.encoder = treeMap;
        this.vocabSize = i;
        this.decoder = buildDecoder(this.encoder);
    }

    public GPT2Codec(@Nonnull File file, int i) {
        this(loadEncoder(file), i);
    }

    @Nonnull
    public static Function<String, String> getCharacterTransformer() {
        Map<Character, Character> byteEncoder = byteEncoder();
        return str -> {
            char[] charArray = str.toCharArray();
            for (int i = 0; i < charArray.length; i++) {
                charArray[i] = ((Character) byteEncoder.getOrDefault(Character.valueOf(charArray[i]), Character.valueOf(charArray[i]))).charValue();
            }
            return new String(charArray);
        };
    }

    public int getVocabSize() {
        return this.vocabSize;
    }

    @Nonnull
    public static TreeMap<Integer, String> buildDecoder(@Nonnull TreeMap<String, Integer> treeMap) {
        return new TreeMap<>((Map) treeMap.entrySet().stream().collect(Collectors.toMap(entry -> {
            return (Integer) entry.getValue();
        }, entry2 -> {
            return (String) entry2.getKey();
        })));
    }

    @Nonnull
    public static TreeMap<String, Integer> loadEncoder(@Nonnull File file) {
        try {
            return toMap(FileUtils.readFileToString(file, "UTF-8"), getCharacterTransformer());
        } catch (IOException e) {
            throw Util.throwException(e);
        }
    }

    @Nonnull
    public static TreeMap<String, Integer> toMap(String str, @Nonnull Function<String, String> function) {
        JsonObject jsonObject = (JsonObject) new GsonBuilder().create().fromJson(str, JsonObject.class);
        return new TreeMap<>((Map) jsonObject.keySet().stream().collect(Collectors.toMap(function, str2 -> {
            return Integer.valueOf(jsonObject.get(str2).getAsInt());
        }, (num, num2) -> {
            return num;
        })));
    }

    @Nonnull
    public static Map<Character, Character> byteEncoder() {
        try {
            HashMap hashMap = new HashMap();
            for (int i = 0; i < 256; i++) {
                hashMap.put(Character.valueOf((char) (i + 256)), Character.valueOf((char) i));
            }
            for (char c = '!'; c < '~'; c = (char) (c + 1)) {
                hashMap.put(Character.valueOf(c), Character.valueOf(c));
            }
            for (char c2 = 161; c2 < 172; c2 = (char) (c2 + 1)) {
                hashMap.put(Character.valueOf(c2), Character.valueOf(c2));
            }
            for (char c3 = 174; c3 < 255; c3 = (char) (c3 + 1)) {
                hashMap.put(Character.valueOf(c3), Character.valueOf(c3));
            }
            return hashMap;
        } catch (Throwable th) {
            throw Util.throwException(th);
        }
    }

    public String decode(@Nonnull Integer... numArr) {
        return (String) Arrays.stream(numArr).map(num -> {
            return (String) this.decoder.getOrDefault(num, "<Not Found: " + num + ">");
        }).reduce((str, str2) -> {
            return str + str2;
        }).orElseGet(() -> {
            return "";
        });
    }

    @Nonnull
    public List<Integer> encode(@Nullable String str) {
        ArrayList arrayList = new ArrayList();
        if (null != str && !str.isEmpty()) {
            RefStringBuilder refStringBuilder = new RefStringBuilder(str);
            while (refStringBuilder.length() > 0) {
                Optional<String> lookup = lookup(refStringBuilder.toString());
                if (lookup.isPresent()) {
                    String str2 = (String) RefUtil.get(lookup);
                    refStringBuilder.delete(0, str2.length());
                    arrayList.add(this.encoder.get(str2));
                } else {
                    refStringBuilder.delete(0, 1);
                }
            }
        }
        return arrayList;
    }

    protected Optional<String> lookup(@Nullable String str) {
        if (null == str || str.isEmpty()) {
            return Optional.empty();
        }
        String ceilingKey = this.encoder.ceilingKey(str);
        String floorKey = this.encoder.floorKey(str);
        if (null != ceilingKey && !str.startsWith(ceilingKey)) {
            ceilingKey = null;
        }
        if (null != floorKey && !str.startsWith(floorKey)) {
            floorKey = null;
        }
        return (null == ceilingKey && null == floorKey) ? Optional.empty() : (null == ceilingKey || null == floorKey) ? null != ceilingKey ? Optional.of(ceilingKey) : Optional.of(floorKey) : floorKey.length() < ceilingKey.length() ? Optional.of(ceilingKey) : Optional.of(floorKey);
    }
}
