package ai.djl.timeseries.block;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.core.Embedding;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

/* loaded from: input_file:ai/djl/timeseries/block/FeatureEmbedding.class */
public final class FeatureEmbedding extends AbstractBlock {
    private static final String EMBEDDING_PARAM_NAME = "embedding";
    private int embeddingSize;
    private int numEmbeddings;
    private Parameter embedding;

    /* loaded from: input_file:ai/djl/timeseries/block/FeatureEmbedding$Builder.class */
    public static final class Builder {
        private int embeddingSize;
        private int numEmbeddings;

        public Builder setEmbeddingSize(int i) {
            this.embeddingSize = i;
            return this;
        }

        public Builder setNumEmbeddings(int i) {
            this.numEmbeddings = i;
            return this;
        }

        public FeatureEmbedding build() {
            if (this.numEmbeddings <= 0) {
                throw new IllegalArgumentException("You must specify the dictionary Size for the embedding.");
            }
            if (this.embeddingSize == 0) {
                throw new IllegalArgumentException("You must specify the embedding size");
            }
            return new FeatureEmbedding(this);
        }
    }

    FeatureEmbedding(Builder builder) {
        this.embeddingSize = builder.embeddingSize;
        this.numEmbeddings = builder.numEmbeddings;
        this.embedding = addParameter(Parameter.builder().setName(EMBEDDING_PARAM_NAME).setType(Parameter.Type.WEIGHT).optShape(new Shape(this.numEmbeddings, this.embeddingSize)).build());
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        return Embedding.embedding(singletonOrThrow, parameterStore.getValue(this.embedding, singletonOrThrow.getDevice(), z), SparseFormat.DENSE);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return new Shape[]{shapeArr[0].addAll(new Shape(this.embeddingSize))};
    }

    public static Builder builder() {
        return new Builder();
    }
}
