package ai.djl.timeseries.block;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:ai/djl/timeseries/block/FeatureEmbedder.class */
public class FeatureEmbedder extends AbstractBlock {
    private List<Integer> cardinalities;
    private List<Integer> embeddingDims;
    private List<FeatureEmbedding> embedders = new ArrayList();
    private int numFeatures;

    /* loaded from: input_file:ai/djl/timeseries/block/FeatureEmbedder$Builder.class */
    public static final class Builder {
        private List<Integer> cardinalities;
        private List<Integer> embeddingDims;

        public Builder setCardinalities(List<Integer> list) {
            this.cardinalities = list;
            return this;
        }

        public Builder setEmbeddingDims(List<Integer> list) {
            this.embeddingDims = list;
            return this;
        }

        public FeatureEmbedder build() {
            if (this.cardinalities.isEmpty()) {
                throw new IllegalArgumentException("Length of 'cardinalities' list must be greater than zero");
            }
            if (this.cardinalities.size() != this.embeddingDims.size()) {
                throw new IllegalArgumentException("Length of `cardinalities` and `embedding_dims` should match");
            }
            Iterator<Integer> it2 = this.cardinalities.iterator();
            while (it2.hasNext()) {
                if (it2.next().intValue() <= 0) {
                    throw new IllegalArgumentException("Elements of `cardinalities` should be > 0");
                }
            }
            Iterator<Integer> it3 = this.embeddingDims.iterator();
            while (it3.hasNext()) {
                if (it3.next().intValue() <= 0) {
                    throw new IllegalArgumentException("Elements of `embedding_dims` should be > 0");
                }
            }
            return new FeatureEmbedder(this);
        }
    }

    FeatureEmbedder(Builder builder) {
        this.cardinalities = builder.cardinalities;
        this.embeddingDims = builder.embeddingDims;
        this.numFeatures = this.cardinalities.size();
        for (int i = 0; i < this.cardinalities.size(); i++) {
            this.embedders.add(createEmbedding(i, this.cardinalities.get(i).intValue(), this.embeddingDims.get(i).intValue()));
        }
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        NDList split = this.numFeatures > 1 ? singletonOrThrow.split(this.numFeatures, singletonOrThrow.getShape().dimension() - 1) : new NDList(singletonOrThrow);
        NDList nDList2 = new NDList();
        for (int i = 0; i < this.numFeatures; i++) {
            nDList2.add(this.embedders.get(i).forward(parameterStore, new NDList(split.get(i).squeeze(-1)), z, pairList).singletonOrThrow());
        }
        return new NDList(NDArrays.concat(nDList2, -1));
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        Shape shape = shapeArr[0];
        Shape[] shapeArr2 = {shape.slice(0, shape.dimension() - 1)};
        long j = 0;
        Iterator<FeatureEmbedding> it2 = this.embedders.iterator();
        while (it2.hasNext()) {
            j += it2.next().getOutputShapes(shapeArr2)[0].tail();
        }
        return new Shape[]{shape.slice(0, shape.dimension() - 1).add(j)};
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        Iterator<FeatureEmbedding> it2 = this.embedders.iterator();
        while (it2.hasNext()) {
            it2.next().initialize(nDManager, dataType, shapeArr);
        }
    }

    private FeatureEmbedding createEmbedding(int i, int i2, int i3) {
        FeatureEmbedding build = FeatureEmbedding.builder().setNumEmbeddings(i2).setEmbeddingSize(i3).build();
        addChildBlock(String.format("cat_%d_embedding", Integer.valueOf(i)), (String) build);
        return build;
    }

    public static Builder builder() {
        return new Builder();
    }
}
