package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Locale;
import java.util.Optional;

/* loaded from: input_file:ai/djl/nn/core/Embedding.class */
public abstract class Embedding<T> extends AbstractBlock implements AbstractIndexedEmbedding<T> {
    private static final byte VERSION = 5;
    protected int embeddingSize;
    protected boolean sparseGrad;
    protected DataType dataType;
    protected int numItems;
    protected AbstractIndexedEmbedding<T> fallthroughEmbedding;
    protected Parameter embedding;

    /* loaded from: input_file:ai/djl/nn/core/Embedding$BaseBuilder.class */
    public static abstract class BaseBuilder<T, B extends BaseBuilder<T, B>> {
        protected Class<T> embeddingType;
        protected int embeddingSize;
        protected T defaultItem;
        protected AbstractIndexedEmbedding<T> fallthrough;
        protected boolean useDefault = true;
        protected boolean sparseGrad = true;
        protected DataType dataType = DataType.FLOAT32;

        public Class<T> getEmbeddingType() {
            return this.embeddingType;
        }

        protected abstract B setType(Class<T> cls);

        public B setEmbeddingSize(int i) {
            this.embeddingSize = i;
            return self();
        }

        public B optUseDefault(boolean z) {
            this.useDefault = z;
            return self();
        }

        public B optDefaultItem(T t) {
            this.defaultItem = t;
            return self();
        }

        public B optFallthrough(AbstractIndexedEmbedding<T> abstractIndexedEmbedding) {
            this.fallthrough = abstractIndexedEmbedding;
            return self();
        }

        public B optSparseGrad(boolean z) {
            this.sparseGrad = z;
            return self();
        }

        public B optDataType(DataType dataType) {
            this.dataType = dataType;
            return self();
        }

        protected abstract B self();
    }

    /* loaded from: input_file:ai/djl/nn/core/Embedding$DefaultEmbedding.class */
    protected class DefaultEmbedding implements AbstractIndexedEmbedding<T> {
        protected DefaultEmbedding() {
        }

        @Override // ai.djl.nn.core.AbstractIndexedEmbedding
        public byte[] encode(T t) throws IOException {
            return Embedding.this.encode(t);
        }

        @Override // ai.djl.nn.core.AbstractIndexedEmbedding
        public T decode(byte[] bArr) throws IOException {
            return Embedding.this.decode(bArr);
        }

        @Override // ai.djl.nn.core.AbstractEmbedding
        public boolean hasItem(T t) {
            return true;
        }

        @Override // ai.djl.nn.core.AbstractEmbedding
        public NDArray embed(NDManager nDManager, T[] tArr) {
            int length = tArr.length;
            NDArray nDArray = Embedding.this.embedding.getArray().get(0);
            nDArray.attach(nDManager);
            return nDArray.repeat(new Shape(length, Embedding.this.embeddingSize));
        }

        @Override // ai.djl.nn.core.AbstractIndexedEmbedding
        public long embed(T t) {
            return 0L;
        }

        @Override // ai.djl.nn.core.AbstractIndexedEmbedding
        public Optional<T> unembed(long j) {
            return Optional.empty();
        }
    }

    /* loaded from: input_file:ai/djl/nn/core/Embedding$DefaultItem.class */
    protected class DefaultItem implements AbstractIndexedEmbedding<T> {
        private T defaultItem;

        public DefaultItem(T t) {
            this.defaultItem = t;
        }

        @Override // ai.djl.nn.core.AbstractIndexedEmbedding
        public byte[] encode(T t) throws IOException {
            return Embedding.this.encode(t);
        }

        @Override // ai.djl.nn.core.AbstractIndexedEmbedding
        public T decode(byte[] bArr) throws IOException {
            return Embedding.this.decode(bArr);
        }

        @Override // ai.djl.nn.core.AbstractEmbedding
        public boolean hasItem(T t) {
            return true;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // ai.djl.nn.core.AbstractEmbedding
        public NDArray embed(NDManager nDManager, T[] tArr) {
            Object[] objArr = new Object[tArr.length];
            Arrays.fill(objArr, this.defaultItem);
            return Embedding.this.embed(nDManager, objArr);
        }

        @Override // ai.djl.nn.core.AbstractIndexedEmbedding
        public long embed(T t) {
            return 0L;
        }

        @Override // ai.djl.nn.core.AbstractIndexedEmbedding
        public Optional<T> unembed(long j) {
            return Optional.of(this.defaultItem);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Embedding(BaseBuilder<T, ?> baseBuilder) {
        super((byte) 5);
        this.embeddingSize = baseBuilder.embeddingSize;
        this.sparseGrad = baseBuilder.sparseGrad;
        this.dataType = baseBuilder.dataType;
        this.embedding = addParameter((Embedding<T>) new Parameter("embedding", this, ParameterType.WEIGHT, true, this.sparseGrad ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE), shapeArr -> {
            return new Shape(this.numItems, this.embeddingSize);
        });
        if (baseBuilder.fallthrough != null && baseBuilder.defaultItem != null) {
            throw new IllegalArgumentException("You can not specify both a fallthrough and a defaultItem");
        }
        if (baseBuilder.fallthrough != null) {
            this.fallthroughEmbedding = baseBuilder.fallthrough;
        } else if (baseBuilder.defaultItem != null) {
            this.fallthroughEmbedding = new DefaultItem(baseBuilder.defaultItem);
        } else if (baseBuilder.useDefault) {
            this.fallthroughEmbedding = new DefaultEmbedding();
        }
        this.numItems = 1;
        this.inputShapes = new Shape[]{new Shape(-1)};
    }

    public Embedding(NDArray nDArray) {
        this(nDArray, true);
    }

    public Embedding(NDArray nDArray, boolean z) {
        super((byte) 5);
        this.embeddingSize = Math.toIntExact(nDArray.getShape().get(1));
        this.sparseGrad = z;
        this.dataType = nDArray.getDataType();
        this.embedding = addParameter((Embedding<T>) new Parameter("embedding", this, ParameterType.WEIGHT, true, z ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE), shapeArr -> {
            return new Shape(this.numItems, this.embeddingSize);
        });
        this.embedding.setArray(nDArray);
        this.numItems = Math.toIntExact(nDArray.getShape().size(0));
        this.inputShapes = new Shape[]{new Shape(-1)};
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[]{shapeArr[0].addAll(new Shape(this.embeddingSize))};
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList opInputs = opInputs(parameterStore, nDList, z);
        NDList embedding = opInputs.head().getNDArrayInternal().embedding(opInputs, this.numItems, this.embeddingSize, this.sparseGrad, this.dataType, pairList);
        if (nDList.head().getShape().dimension() == 0) {
            embedding = new NDList(embedding.singletonOrThrow().reshape(this.embeddingSize));
        }
        return embedding;
    }

    @Override // ai.djl.nn.AbstractBlock, ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeByte(5);
        saveInputShapes(dataOutputStream);
        dataOutputStream.writeBoolean(this.sparseGrad);
        dataOutputStream.writeUTF(this.dataType.toString());
        this.embedding.save(dataOutputStream);
    }

    @Override // ai.djl.nn.AbstractBlock, ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        byte readByte = dataInputStream.readByte();
        if (readByte >= 3) {
            readInputShapes(dataInputStream);
            r19 = readByte == 3 ? !dataInputStream.readBoolean() : false;
            this.sparseGrad = dataInputStream.readBoolean();
            this.dataType = DataType.valueOf(dataInputStream.readUTF().toUpperCase(Locale.ENGLISH));
            if (readByte == 3 || readByte == 4) {
                int readInt = dataInputStream.readInt();
                for (int i = 1; i <= readInt; i++) {
                    byte[] bArr = new byte[dataInputStream.readInt()];
                    if (dataInputStream.read(bArr) != bArr.length) {
                        throw new MalformedModelException("Model data is malformed");
                    }
                    dataInputStream.readInt();
                }
            }
        } else if (readByte == 2) {
            readInputShapes(dataInputStream);
            r19 = true;
        } else if (readByte != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
        this.embedding.load(nDManager, dataInputStream);
        this.numItems = (int) this.embedding.getArray().getShape().get(0);
        this.embeddingSize = (int) this.embedding.getArray().getShape().get(1);
        if (r19) {
            this.numItems++;
            this.embedding.setArray(NDArrays.concat(new NDList(nDManager.zeros(new Shape(1, this.embeddingSize)), this.embedding.getArray())));
        }
    }

    private NDList opInputs(ParameterStore parameterStore, NDList nDList, boolean z) {
        NDArray head = nDList.head();
        Device device = head.getDevice();
        NDList nDList2 = new NDList(2);
        if (head.getShape().dimension() == 0) {
            nDList2.add(head.reshape(1));
        } else {
            nDList2.add(head);
        }
        nDList2.add(parameterStore.getValue(this.embedding, device, z));
        return nDList2;
    }

    @Override // ai.djl.nn.core.AbstractEmbedding
    public NDArray embed(NDManager nDManager, T[] tArr) {
        return nDManager.create(Arrays.stream(tArr).mapToLong(this::embed).toArray());
    }
}
