package com.simiacryptus.text;

import com.simiacryptus.ref.wrappers.RefString;
import com.simiacryptus.text.gpt2.GPT2Codec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.IntStream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/simiacryptus/text/TextGenerator.class */
public class TextGenerator {
    protected static final Logger logger;
    protected final int vocabularySize;
    protected final GPT2Codec codec;
    protected boolean verbose = false;
    protected int choicesToLog = 10;

    @Nonnull
    List<Integer> codes = new ArrayList();

    @Nullable
    float[] nextSelections;
    private LanguageCodeModel model;
    static final /* synthetic */ boolean $assertionsDisabled;

    public TextGenerator(int i, LanguageCodeModel languageCodeModel, GPT2Codec gPT2Codec) {
        setModel(languageCodeModel);
        this.vocabularySize = i;
        this.codec = gPT2Codec;
    }

    public int getChoicesToLog() {
        return this.choicesToLog;
    }

    @Nonnull
    public TextGenerator setChoicesToLog(int i) {
        this.choicesToLog = i;
        return this;
    }

    public LanguageCodeModel getModel() {
        return this.model;
    }

    @Nonnull
    public TextGenerator setModel(LanguageCodeModel languageCodeModel) {
        if (this.model == languageCodeModel) {
            return this;
        }
        if (null != this.model) {
            this.model.clear();
        }
        this.model = languageCodeModel;
        return this;
    }

    public String getText() {
        return this.codec.decode((Integer[]) this.codes.toArray(new Integer[0]));
    }

    public int getVocabularySize() {
        return this.vocabularySize;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    @Nonnull
    public TextGenerator setVerbose(boolean z) {
        this.verbose = z;
        return this;
    }

    public static int[] sortedIndices(@Nonnull float[] fArr, int i) {
        return IntStream.range(0, fArr.length).mapToObj(i2 -> {
            return Integer.valueOf(i2);
        }).sorted(Comparator.comparing(num -> {
            return Float.valueOf(-fArr[num.intValue()]);
        })).limit(i).mapToInt(num2 -> {
            return num2.intValue();
        }).toArray();
    }

    @Nonnull
    public TextGenerator copy() {
        TextGenerator textGenerator = new TextGenerator(this.vocabularySize, getModel().copy(), this.codec);
        textGenerator.codes.addAll(this.codes);
        textGenerator.verbose = this.verbose;
        textGenerator.choicesToLog = this.choicesToLog;
        textGenerator.nextSelections = null == this.nextSelections ? null : Arrays.copyOf(this.nextSelections, this.nextSelections.length);
        return textGenerator;
    }

    @Nonnull
    public String generateText(@Nonnull Predicate<String> predicate) {
        return generateText(predicate, (String) null);
    }

    @Nonnull
    public String generateText(int i) {
        return generateText(i, (String) null);
    }

    @Nonnull
    public String generateText(@Nonnull Predicate<String> predicate, String str) {
        reset();
        feed(str);
        generate(predicate);
        return getText();
    }

    @Nonnull
    public String generateText(int i, String str) {
        reset();
        feed(str);
        generate(i);
        return getText();
    }

    /* JADX WARN: Code restructure failed: missing block: B:11:0x0042, code lost:
    
        throw new java.lang.AssertionError();
     */
    /* JADX WARN: Code restructure failed: missing block: B:30:0x010e, code lost:
    
        return r11.codec.decode((java.lang.Integer[]) r0.toArray(new java.lang.Integer[0]));
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public java.lang.String generate(@javax.annotation.Nonnull java.util.function.Predicate<java.lang.String> r12) {
        /*
            Method dump skipped, instructions count: 271
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: com.simiacryptus.text.TextGenerator.generate(java.util.function.Predicate):java.lang.String");
    }

    public void generate(int i) {
        init();
        for (int i2 = 0; i2 < i; i2++) {
            try {
                if (!$assertionsDisabled && this.nextSelections == null) {
                    throw new AssertionError();
                }
                int select = select(this.nextSelections);
                if (isVerbose()) {
                    if (i2 != 0) {
                        log(this.nextSelections, this.codec, getChoicesToLog());
                    }
                    logger.info(RefString.format("Selected New Text: '%s'", new Object[]{this.codec.decode(Integer.valueOf(select))}));
                }
                if (select == getVocabularySize() - 1) {
                    break;
                }
                this.codes.add(Integer.valueOf(select));
                this.nextSelections = getModel().eval(select);
            } catch (Throwable th) {
                logger.warn("Error generating text", th);
                return;
            }
        }
    }

    @Nonnull
    public TextGenerator init() {
        if (this.nextSelections == null) {
            feed("");
        }
        return this;
    }

    public double feed(String str) {
        double d = 0.0d;
        ArrayList<Integer> arrayList = new ArrayList();
        arrayList.addAll(this.codec.encode(str));
        if (arrayList.isEmpty()) {
            arrayList.add(Integer.valueOf(getVocabularySize() - 1));
        }
        for (Integer num : arrayList) {
            if (null != this.nextSelections) {
                float f = this.nextSelections[num.intValue()];
                d += f != 0.0f ? -Math.log(f) : Math.log(getVocabularySize());
            }
            this.codes.add(num);
            this.nextSelections = getModel().eval(num.intValue());
            if (isVerbose()) {
                logger.info(RefString.format("Feed token: '%s'", new Object[]{this.codec.decode(num)}));
                if (!$assertionsDisabled && this.nextSelections == null) {
                    throw new AssertionError();
                }
                log(this.nextSelections, this.codec, getChoicesToLog());
            }
        }
        return d / Math.log(2.0d);
    }

    @Nonnull
    public TextGenerator reset() {
        this.codes.clear();
        getModel().clear();
        return this;
    }

    protected int select(@Nonnull float[] fArr) {
        double random = Math.random() * 1.0d;
        double d = random;
        int i = 0;
        int[] sortedIndices = sortedIndices(fArr, fArr.length);
        while (i < sortedIndices.length && d > fArr[sortedIndices[i]]) {
            int i2 = i;
            i++;
            d -= fArr[sortedIndices[i2]];
        }
        int i3 = sortedIndices[i];
        logger.debug(RefString.format("Chose #%s with fate %s", new Object[]{Integer.valueOf(i3), Double.valueOf(random)}));
        return i3;
    }

    protected void log(@Nonnull float[] fArr, @Nonnull GPT2Codec gPT2Codec, int i) {
        Arrays.stream(sortedIndices(fArr, i)).forEach(i2 -> {
            logger.info(RefString.format("\t#%d %.4f%% '%s'", new Object[]{Integer.valueOf(i2), Float.valueOf(fArr[i2] * 100.0f), gPT2Codec.decode(Integer.valueOf(i2))}));
        });
    }

    static {
        $assertionsDisabled = !TextGenerator.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(TextGenerator.class);
    }
}
