package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.inference.streaming.StreamingBlock;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.function.Function;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/djl/nn/SequentialBlock.class */
public class SequentialBlock extends AbstractBlock implements StreamingBlock {
    private static final byte VERSION = 3;
    private boolean returnIntermediate;

    /* loaded from: input_file:ai/djl/nn/SequentialBlock$StreamIterator.class */
    private final class StreamIterator implements Iterator<NDList> {
        private int childIndex;
        private ParameterStore parameterStore;
        private NDList current;
        private boolean training;

        private StreamIterator(ParameterStore parameterStore, NDList nDList, boolean z) {
            this.parameterStore = parameterStore;
            this.current = nDList;
            this.training = z;
            this.childIndex = 0;
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.childIndex < SequentialBlock.this.children.size();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public NDList next() {
            BlockList blockList = SequentialBlock.this.children;
            int i = this.childIndex;
            this.childIndex = i + 1;
            this.current = blockList.get(i).getValue().forward(this.parameterStore, this.current, this.training);
            return this.current;
        }
    }

    public SequentialBlock() {
        super((byte) 3);
    }

    public SequentialBlock addAll(Block... blockArr) {
        addAll(Arrays.asList(blockArr));
        return this;
    }

    public SequentialBlock addAll(Collection<Block> collection) {
        collection.forEach(this::add);
        return this;
    }

    public SequentialBlock add(Block block) {
        if (block != null) {
            addChildBlock(block.getClass().getSimpleName(), (String) block);
        }
        return this;
    }

    public SequentialBlock add(Function<NDList, NDList> function) {
        add(new LambdaBlock(function));
        return this;
    }

    public SequentialBlock add(Function<NDList, NDList> function, String str) {
        add(new LambdaBlock(function, str));
        return this;
    }

    public SequentialBlock addSingleton(Function<NDArray, NDArray> function) {
        add(LambdaBlock.singleton(function));
        return this;
    }

    public SequentialBlock addSingleton(Function<NDArray, NDArray> function, String str) {
        add(LambdaBlock.singleton(function, str));
        return this;
    }

    public void removeLastBlock() {
        this.children.remove(this.children.size() - 1);
    }

    public void replaceLastBlock(Block block) {
        removeLastBlock();
        if (block != null) {
            add(block);
        }
    }

    public boolean isReturnIntermediate() {
        return this.returnIntermediate;
    }

    public SequentialBlock setReturnIntermediate(boolean z) {
        this.returnIntermediate = z;
        return this;
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        ArrayList arrayList = new ArrayList(this.children.size());
        NDList nDList2 = nDList;
        Iterator<Block> it2 = this.children.values().iterator();
        while (it2.hasNext()) {
            nDList2 = it2.next().forward(parameterStore, nDList2, z);
            arrayList.add(nDList2);
        }
        return this.returnIntermediate ? new NDList((Collection<NDArray>) arrayList.stream().flatMap((v0) -> {
            return v0.stream();
        }).collect(Collectors.toList())) : nDList2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ai.djl.nn.AbstractBaseBlock
    public NDList forwardInternal(ParameterStore parameterStore, NDList nDList, NDList nDList2, PairList<String, Object> pairList) {
        ArrayList arrayList = new ArrayList(this.children.size());
        NDList nDList3 = nDList;
        Iterator<Block> it2 = this.children.values().iterator();
        while (it2.hasNext()) {
            nDList3 = it2.next().forward(parameterStore, nDList3, nDList2, pairList);
            arrayList.add(nDList3);
        }
        return this.returnIntermediate ? new NDList((Collection<NDArray>) arrayList.stream().flatMap((v0) -> {
            return v0.stream();
        }).collect(Collectors.toList())) : nDList3;
    }

    @Override // ai.djl.inference.streaming.StreamingBlock
    public Iterator<NDList> forwardStreamIter(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        return new StreamIterator(parameterStore, nDList, z);
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        Shape[] shapeArr2 = shapeArr;
        DataType[] dataTypeArr = null;
        for (Block block : getChildren().values()) {
            block.initialize(nDManager, dataType, shapeArr2);
            shapeArr2 = block.getOutputShapes(shapeArr2, dataTypeArr);
            dataTypeArr = block.getOutputDataTypes();
        }
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        if (this.children.isEmpty()) {
            throw new IllegalArgumentException("The sequential block is empty");
        }
        ArrayList arrayList = new ArrayList(this.children.size());
        Shape[] shapeArr2 = shapeArr;
        Iterator<Block> it2 = this.children.values().iterator();
        while (it2.hasNext()) {
            shapeArr2 = it2.next().getOutputShapes(shapeArr2);
            arrayList.add(shapeArr2);
        }
        return this.returnIntermediate ? (Shape[]) arrayList.stream().flatMap((v0) -> {
            return Arrays.stream(v0);
        }).toArray(i -> {
            return new Shape[i];
        }) : shapeArr2;
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected void saveMetadata(DataOutputStream dataOutputStream) throws IOException {
        saveInputShapes(dataOutputStream);
        dataOutputStream.writeBoolean(this.returnIntermediate);
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    public void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b == this.version) {
            readInputShapes(dataInputStream);
            this.returnIntermediate = dataInputStream.readBoolean();
        } else {
            if (b != 2) {
                throw new MalformedModelException("Unsupported encoding version: " + ((int) b));
            }
            readInputShapes(dataInputStream);
        }
    }
}
