package com.simiacryptus.text;

import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.ProtocolStringList;
import com.simiacryptus.text.gpt2.GPT2Codec;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperationBuilder;
import org.tensorflow.Output;
import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;

/* loaded from: input_file:com/simiacryptus/text/GraphModifier.class */
public abstract class GraphModifier {
    protected static final Logger logger = LoggerFactory.getLogger(GPT2Codec.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.simiacryptus.text.GraphModifier$1, reason: invalid class name */
    /* loaded from: input_file:com/simiacryptus/text/GraphModifier$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tensorflow$framework$DataType;
        static final /* synthetic */ int[] $SwitchMap$org$tensorflow$framework$AttrValue$ValueCase = new int[AttrValue.ValueCase.values().length];

        static {
            try {
                $SwitchMap$org$tensorflow$framework$AttrValue$ValueCase[AttrValue.ValueCase.TENSOR.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$AttrValue$ValueCase[AttrValue.ValueCase.SHAPE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$AttrValue$ValueCase[AttrValue.ValueCase.TYPE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$AttrValue$ValueCase[AttrValue.ValueCase.I.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$AttrValue$ValueCase[AttrValue.ValueCase.B.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            $SwitchMap$org$tensorflow$framework$DataType = new int[DataType.values().length];
            try {
                $SwitchMap$org$tensorflow$framework$DataType[DataType.DT_FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$DataType[DataType.DT_INT32.ordinal()] = 2;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    @Nonnull
    public abstract HashSet<String> getDeletes_Init();

    public static void importGraphDef(@Nonnull Graph graph, @Nonnull GraphDef graphDef) {
        HashSet hashSet = new HashSet();
        graph.operations().forEachRemaining(operation -> {
            hashSet.add(operation.name());
        });
        while (true) {
            List list = (List) graphDef.getNodeList().stream().filter(nodeDef -> {
                return !hashSet.contains(nodeDef.getName());
            }).filter(nodeDef2 -> {
                ProtocolStringList inputList = nodeDef2.getInputList();
                return inputList.isEmpty() || inputList.stream().allMatch(str -> {
                    return hashSet.contains(str.split(":")[0]);
                });
            }).collect(Collectors.toList());
            if (list.isEmpty()) {
                graphDef.getNodeList().stream().filter(nodeDef3 -> {
                    return !hashSet.contains(nodeDef3.getName());
                }).forEach(nodeDef4 -> {
                    logger.warn("Remaining Node: " + nodeDef4.toString());
                });
                return;
            }
            list.forEach(nodeDef5 -> {
                hashSet.add(nodeDef5.getName());
                if (graph.operation(nodeDef5.getName()) == null) {
                    try {
                        logger.debug("Adding node to graph: " + nodeDef5.getName() + " <= [" + ((String) nodeDef5.getInputList().stream().reduce((str, str2) -> {
                            return str + "," + str2;
                        }).orElse("")) + "]");
                        GraphOperationBuilder opBuilder = graph.opBuilder(nodeDef5.getOp(), nodeDef5.getName());
                        opBuilder.setDevice(nodeDef5.getDevice());
                        nodeDef5.getAttrMap().forEach((str3, attrValue) -> {
                            Class cls;
                            switch (AnonymousClass1.$SwitchMap$org$tensorflow$framework$AttrValue$ValueCase[attrValue.getValueCase().ordinal()]) {
                                case 1:
                                    TensorProto tensor = attrValue.getTensor();
                                    long[] array = tensor.getTensorShape().getDimList().stream().mapToLong(dim -> {
                                        return dim.getSize();
                                    }).toArray();
                                    switch (AnonymousClass1.$SwitchMap$org$tensorflow$framework$DataType[tensor.getDtype().ordinal()]) {
                                        case 1:
                                            cls = Float.class;
                                            break;
                                        case 2:
                                            cls = Integer.class;
                                            break;
                                        default:
                                            throw new RuntimeException(tensor.getDtype().toString());
                                    }
                                    if (null != tensor.getTensorContent() && !tensor.getTensorContent().isEmpty()) {
                                        opBuilder.setAttr(str3, Tensor.create(cls, array, tensor.getTensorContent().asReadOnlyByteBuffer()));
                                        return;
                                    } else {
                                        if (0 >= tensor.getIntValCount()) {
                                            throw new RuntimeException(tensor.toString());
                                        }
                                        opBuilder.setAttr(str3, Tensor.create(array, IntBuffer.wrap(tensor.getIntValList().stream().mapToInt(num -> {
                                            return num.intValue();
                                        }).toArray())));
                                        return;
                                    }
                                case 2:
                                    long[] array2 = attrValue.getShape().getDimList().stream().mapToLong(dim2 -> {
                                        return dim2.getSize();
                                    }).toArray();
                                    opBuilder.setAttr(str3, Shape.make(array2[0], Arrays.copyOfRange(array2, 1, array2.length)));
                                    return;
                                case 3:
                                    opBuilder.setAttr(str3, org.tensorflow.DataType.valueOf(attrValue.getType().name().split("_")[1]));
                                    return;
                                case 4:
                                    opBuilder.setAttr(str3, attrValue.getI());
                                    return;
                                case 5:
                                    opBuilder.setAttr(str3, attrValue.getB());
                                    return;
                                default:
                                    throw new RuntimeException(str3 + " = " + attrValue.toString());
                            }
                        });
                        Output[] outputArr = (Output[]) nodeDef5.getInputList().stream().map(str4 -> {
                            String[] split = str4.split(":");
                            return graph.operation(split[0]).output(1 == split.length ? 0 : Integer.parseInt(split[1]));
                        }).toArray(i -> {
                            return new Output[i];
                        });
                        if (nodeDef5.getOp().equals("Pack")) {
                            opBuilder.addInputList(outputArr);
                        } else if (nodeDef5.getOp().equals("ConcatV2")) {
                            opBuilder.addInputList(new Output[]{outputArr[0], outputArr[1]});
                            opBuilder.addInput(outputArr[2]);
                            opBuilder.addControlInput(outputArr[2].op());
                        } else if (nodeDef5.getOp().equals("StridedSlice")) {
                            for (int i2 = 0; i2 < outputArr.length; i2++) {
                                if (i2 == 0) {
                                    opBuilder.addInput(outputArr[i2]);
                                } else {
                                    opBuilder.addInput(outputArr[i2]);
                                    opBuilder.addControlInput(outputArr[i2].op());
                                }
                            }
                        } else if (outputArr.length > 1) {
                            for (Output output : outputArr) {
                                opBuilder.addInput(output);
                            }
                        } else if (outputArr.length > 0) {
                            opBuilder.addInput(outputArr[0]);
                        }
                        try {
                            opBuilder.build();
                        } catch (Throwable th) {
                            throw new RuntimeException("Error processing " + nodeDef5.toString(), th);
                        }
                    } catch (RuntimeException e) {
                        throw e;
                    } catch (Throwable th2) {
                        throw new RuntimeException("Error processing " + nodeDef5.toString(), th2);
                    }
                }
            });
        }
    }

    @Nonnull
    public static ByteBuffer edit(@Nonnull ByteBuffer byteBuffer, @Nonnull Consumer<IntBuffer> consumer) {
        ByteBuffer copy = copy(byteBuffer);
        consumer.accept(copy.asIntBuffer());
        return copy;
    }

    @Nonnull
    public static ByteBuffer copy(@Nonnull ByteBuffer byteBuffer) {
        ByteBuffer allocate = ByteBuffer.allocate(byteBuffer.capacity());
        allocate.put(byteBuffer);
        allocate.clear();
        return allocate;
    }

    @Nonnull
    public static TensorProto tensor1(int[] iArr, @Nonnull int... iArr2) {
        TensorProto.Builder dtype = TensorProto.newBuilder().setTensorShape(shape(iArr)).setDtype(DataType.DT_INT32);
        Arrays.stream(iArr2).forEach(i -> {
            dtype.addIntVal(i);
        });
        return dtype.build();
    }

    @Nonnull
    public static TensorProto tensor2(int[] iArr, @Nonnull int... iArr2) {
        TensorProto.Builder tensorShape = TensorProto.newBuilder().setTensorShape(shape(iArr));
        byte[] bArr = new byte[iArr2.length * 4];
        IntBuffer asIntBuffer = ByteBuffer.wrap(bArr).asIntBuffer();
        for (int i : iArr2) {
            asIntBuffer.put(Integer.reverseBytes(i));
        }
        tensorShape.setTensorContent(ByteString.copyFrom(bArr)).setDtype(DataType.DT_INT32);
        return tensorShape.build();
    }

    @Nonnull
    public static TensorShapeProto shape(@Nonnull int... iArr) {
        TensorShapeProto.Builder newBuilder = TensorShapeProto.newBuilder();
        Arrays.stream(iArr).mapToObj(i -> {
            return TensorShapeProto.Dim.newBuilder().setSize(i).build();
        }).forEach(dim -> {
            newBuilder.addDim(dim);
        });
        return newBuilder.build();
    }

    @Nonnull
    public GraphDef edit(@Nonnull GraphDef graphDef, String str, boolean z) throws InvalidProtocolBufferException {
        GraphDef parseFrom = GraphDef.parseFrom(graphDef.toByteArray());
        GraphDef.Builder newBuilder = GraphDef.newBuilder();
        HashSet<String> deletes_Init = getDeletes_Init();
        HashSet<String> hashSet = new HashSet<>();
        for (int i = 0; i < parseFrom.getNodeCount(); i++) {
            NodeDef node = parseFrom.getNode(i);
            if (deletes_Init.contains(node.getName())) {
                logger.debug("Omit Node: " + node.getName());
            } else {
                NodeDef.Builder edit = edit(node.toBuilder());
                if (null != edit) {
                    logger.debug("Edit Node: " + node.getName());
                    newBuilder.addNode(edit.build());
                    hashSet.add(node.getName());
                } else {
                    newBuilder.addNode(node);
                }
            }
        }
        addNodes(nodeDef -> {
            newBuilder.addNode(nodeDef);
            hashSet.add(nodeDef.getName());
        });
        return prefixRewrite(newBuilder.build(), hashSet, str, z);
    }

    @Nullable
    public abstract NodeDef.Builder edit(NodeDef.Builder builder);

    protected abstract void addNodes(Consumer<NodeDef> consumer);

    @Nonnull
    protected GraphDef prefixRewrite(@Nonnull GraphDef graphDef, @Nonnull HashSet<String> hashSet, String str, boolean z) {
        NodeDef.Builder builder;
        while (true) {
            List list = (List) graphDef.getNodeList().stream().filter(nodeDef -> {
                return !hashSet.contains(nodeDef.getName());
            }).filter(nodeDef2 -> {
                return nodeDef2.getInputList().stream().filter(str2 -> {
                    return hashSet.contains(str2.split(":")[0]);
                }).findAny().isPresent();
            }).map(nodeDef3 -> {
                return nodeDef3.getName();
            }).collect(Collectors.toList());
            if (list.isEmpty()) {
                break;
            }
            Iterator it = list.iterator();
            while (it.hasNext()) {
                logger.debug("Item touched by rename: " + ((String) it.next()));
            }
            hashSet.addAll(list);
        }
        GraphDef.Builder newBuilder = GraphDef.newBuilder();
        for (NodeDef nodeDef4 : graphDef.getNodeList()) {
            if (hashSet.contains(nodeDef4.getName())) {
                builder = nodeDef4.toBuilder();
                builder.setName(str + nodeDef4.getName());
            } else {
                builder = null;
            }
            ArrayList arrayList = new ArrayList((Collection) nodeDef4.getInputList());
            if (arrayList.stream().filter(str2 -> {
                return hashSet.contains(str2.split(":")[0]);
            }).findAny().isPresent()) {
                if (null == builder) {
                    builder = nodeDef4.toBuilder();
                }
                builder.clearInput();
                Iterator it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    String str3 = (String) it2.next();
                    if (hashSet.contains(str3.split(":")[0])) {
                        logger.debug(nodeDef4.getName() + " [ " + str3 + " ] += " + str);
                        builder.addInput(str + str3);
                    } else {
                        builder.addInput(str3);
                    }
                }
            }
            if (null != builder) {
                logger.debug("Edit in renaming: " + builder.getName());
                newBuilder.addNode(builder.build());
            } else if (z) {
                newBuilder.addNode(nodeDef4);
            }
        }
        return newBuilder.build();
    }
}
