package com.simiacryptus.text.gpt2;

import com.google.protobuf.InvalidProtocolBufferException;
import com.simiacryptus.ref.wrappers.RefString;
import com.simiacryptus.text.GraphModifier;
import com.simiacryptus.text.LanguageCodeModel;
import com.simiacryptus.text.TextGenerator;
import com.simiacryptus.util.Util;
import java.io.File;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.GraphOptions;
import org.tensorflow.framework.OptimizerOptions;

/* loaded from: input_file:com/simiacryptus/text/gpt2/GPT2Model.class */
public class GPT2Model implements LanguageCodeModel {
    protected static final Logger logger;
    public final String name;
    protected final byte[] graphDef;
    protected final ArrayList<Integer> code_history;
    protected final GraphModifier graphModifier;
    protected final GPT2Codec codec;
    public HashSet<String> loadedSubnets;
    public Graph graph;
    public Session session;
    protected int history_size;

    @Nullable
    protected Tensor<Float> tensor_state;
    private BiFunction<String, String, Boolean> filterFn;
    static final /* synthetic */ boolean $assertionsDisabled;

    public GPT2Model(String str, GraphModifier graphModifier, @Nonnull File file, GPT2Codec gPT2Codec) {
        this(str, loadModel(file), graphModifier, gPT2Codec);
    }

    public GPT2Model(String str, byte[] bArr, GraphModifier graphModifier, GPT2Codec gPT2Codec) {
        this(str, bArr, graphModifier, gPT2Codec, new Graph());
    }

    public GPT2Model(String str, byte[] bArr, GraphModifier graphModifier, GPT2Codec gPT2Codec, @Nonnull Graph graph) {
        this(str, bArr, graphModifier, gPT2Codec, graph, new Session(graph, ConfigProto.newBuilder().setGraphOptions(GraphOptions.newBuilder().setOptimizerOptions(OptimizerOptions.newBuilder().setDoConstantFolding(true).setDoFunctionInlining(true).setDoCommonSubexpressionElimination(true).build()).build()).setGpuOptions(GPUOptions.newBuilder().setForceGpuCompatible(true).setAllowGrowth(true).setPerProcessGpuMemoryFraction(0.5d).build()).build().toByteArray()));
    }

    public GPT2Model(String str, byte[] bArr, GraphModifier graphModifier, GPT2Codec gPT2Codec, Graph graph, Session session) {
        this.code_history = new ArrayList<>();
        this.history_size = 0;
        this.tensor_state = null;
        this.filterFn = (str2, str3) -> {
            return true;
        };
        this.name = str;
        this.graphDef = bArr;
        this.graphModifier = graphModifier;
        this.codec = gPT2Codec;
        this.graph = graph;
        this.session = session;
        this.loadedSubnets = new HashSet<>();
    }

    @Override // com.simiacryptus.text.LanguageCodeModel
    public BiFunction<String, String, Boolean> getFilterFn() {
        return this.filterFn;
    }

    public static byte[] loadModel(@Nonnull File file) {
        try {
            return FileUtils.readFileToByteArray(file);
        } catch (IOException e) {
            throw Util.throwException(e);
        }
    }

    @Nonnull
    public static Tensor<Float> copy(@Nonnull Tensor<Float> tensor) {
        FloatBuffer allocate = FloatBuffer.allocate(tensor.numElements());
        tensor.writeTo(allocate);
        allocate.flip();
        return Tensor.create(tensor.shape(), allocate);
    }

    @Override // com.simiacryptus.text.LanguageCodeModel
    @Nonnull
    public LanguageCodeModel copy() {
        GPT2Model gPT2Model = new GPT2Model(this.name, this.graphDef, this.graphModifier, this.codec, this.graph, this.session);
        if (null == this.tensor_state) {
            gPT2Model.tensor_state = null;
        } else {
            gPT2Model.tensor_state = copy(this.tensor_state);
        }
        gPT2Model.history_size = this.history_size;
        gPT2Model.loadedSubnets = this.loadedSubnets;
        gPT2Model.code_history.addAll(this.code_history);
        gPT2Model.filterFn = this.filterFn;
        return gPT2Model;
    }

    @Nonnull
    public float[] logitsToProbabilities(@Nonnull float[] fArr) {
        String decode = this.codec.decode((Integer[]) this.code_history.stream().toArray(i -> {
            return new Integer[i];
        }));
        int[] array = Arrays.stream(TextGenerator.sortedIndices(fArr, Integer.MAX_VALUE)).filter(i2 -> {
            if (i2 == fArr.length - 1) {
                return true;
            }
            String decode2 = this.codec.decode(Integer.valueOf(i2));
            if ($assertionsDisabled || getFilterFn() != null) {
                return getFilterFn().apply(decode, decode2).booleanValue();
            }
            throw new AssertionError();
        }).toArray();
        double[] array2 = IntStream.range(0, array.length).mapToDouble(i3 -> {
            return fArr[array[i3]];
        }).toArray();
        if (!$assertionsDisabled && 1 >= array2.length) {
            throw new AssertionError("input.length() = " + array2.length);
        }
        double max = DoubleStream.of(array2).filter(d -> {
            return Double.isFinite(d);
        }).summaryStatistics().getMax();
        double[] array3 = Arrays.stream(array2).map(d2 -> {
            double exp = Math.exp(d2 - max);
            if (Double.isFinite(exp)) {
                return exp;
            }
            return 0.0d;
        }).toArray();
        double sum = 0.0d < Arrays.stream(array3).sum() ? Arrays.stream(array3).sum() : 1.0d;
        if (!$assertionsDisabled && !Double.isFinite(sum)) {
            throw new AssertionError();
        }
        double[] array4 = Arrays.stream(array3).map(d3 -> {
            return d3 / sum;
        }).toArray();
        for (int i4 = 0; i4 < fArr.length; i4++) {
            fArr[i4] = 0.0f;
        }
        if (!$assertionsDisabled && array4 == null) {
            throw new AssertionError();
        }
        IntStream.range(0, array4.length).forEach(i5 -> {
            fArr[array[i5]] = (float) array4[i5];
        });
        return fArr;
    }

    @Override // com.simiacryptus.text.LanguageCodeModel
    @Nonnull
    public synchronized LanguageCodeModel clear() {
        logger.debug("Reset Language Model State");
        if (null != this.tensor_state) {
            this.tensor_state.close();
        }
        this.tensor_state = null;
        this.history_size = 0;
        this.code_history.clear();
        return this;
    }

    @Override // com.simiacryptus.text.LanguageCodeModel
    @Nonnull
    public synchronized float[] eval(int i) {
        String str;
        logger.debug(RefString.format("Eval %d", new Object[]{Integer.valueOf(i)}));
        try {
            if (!this.loadedSubnets.contains("")) {
                this.loadedSubnets.add("");
                this.graph.importGraphDef(this.graphDef);
            }
            if (null == this.tensor_state) {
                str = "init/";
                if (!this.loadedSubnets.contains(str)) {
                    GraphModifier.importGraphDef(this.graph, this.graphModifier.edit(GraphDef.parseFrom(this.graphDef), str, false));
                    this.loadedSubnets.add(str);
                }
            } else {
                str = "";
            }
            this.code_history.add(Integer.valueOf(i));
            return 0 == this.history_size ? eval(str, i) : eval(str, this.code_history.subList(this.code_history.size() - 1, this.code_history.size()).stream().mapToInt(num -> {
                return num.intValue();
            }).toArray());
        } catch (InvalidProtocolBufferException e) {
            throw Util.throwException(e);
        }
    }

    @Nonnull
    public synchronized float[] eval(String str, @Nonnull int... iArr) {
        float[] logitsToProbabilities;
        synchronized (this.session) {
            logger.debug(RefString.format("Eval(%s,%s)", new Object[]{this.session, Arrays.toString(iArr)}));
            Tensor create = Tensor.create(new long[]{1, iArr.length}, IntBuffer.wrap(iArr));
            Session.Runner feed = this.session.runner().feed("input_X", create);
            if (null != this.tensor_state) {
                feed = feed.feed(str + "input_past", this.tensor_state);
            }
            logger.debug("Input Codes: " + Arrays.toString(iArr));
            logger.debug("Input State: " + (this.tensor_state == null ? "null" : Arrays.toString(this.tensor_state.shape())));
            Tensor<Float> tensor = this.tensor_state;
            List run = feed.fetch(str + "output/strided_slice_1").fetch(0 == this.history_size ? str + "model/stack" : str + "output/concat").run();
            Tensor expect = ((Tensor) run.get(0)).expect(Float.class);
            Tensor<Float> expect2 = ((Tensor) run.get(1)).expect(Float.class);
            logger.debug("Output Logits: " + Arrays.toString(expect.shape()));
            logger.debug("Output State: " + Arrays.toString(expect2.shape()));
            if (null == this.tensor_state) {
                this.history_size = (int) expect2.shape()[4];
                this.tensor_state = expect2;
            } else {
                this.history_size++;
                this.tensor_state.close();
                this.tensor_state = expect2;
            }
            float[] fArr = new float[expect.numElements()];
            expect.writeTo(FloatBuffer.wrap(fArr));
            expect.close();
            if (null != tensor) {
                tensor.close();
            }
            create.close();
            logitsToProbabilities = logitsToProbabilities(fArr);
        }
        return logitsToProbabilities;
    }

    @Override // com.simiacryptus.text.LanguageCodeModel
    @Nonnull
    public LanguageCodeModel setFilterFn(BiFunction<String, String, Boolean> biFunction) {
        this.filterFn = biFunction;
        return this;
    }

    @Override // com.simiacryptus.text.LanguageCodeModel
    @Nullable
    public Tensor<?> state() {
        return this.tensor_state;
    }

    static {
        $assertionsDisabled = !GPT2Model.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(GPT2Model.class);
    }
}
