package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.function.Predicate;

/* loaded from: input_file:ai/djl/nn/AbstractBlock.class */
public abstract class AbstractBlock implements Block {
    protected Shape[] inputShapes;
    protected byte version;
    protected List<String> inputNames = Collections.emptyList();
    protected BlockList children = new BlockList();
    protected LinkedHashMap<String, Parameter> parameters = new LinkedHashMap<>();

    public AbstractBlock(byte b) {
        this.version = b;
    }

    @Override // ai.djl.nn.Block
    public final NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDManager manager = parameterStore.getManager();
        if (!isInitialized()) {
            initialize(manager, DataType.FLOAT32, nDList.getShapes());
        }
        return forwardInternal(parameterStore, nDList, z, pairList);
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, NDList nDList2, PairList<String, Object> pairList) {
        NDManager manager = parameterStore.getManager();
        if (!isInitialized()) {
            initialize(manager, DataType.FLOAT32, nDList.getShapes());
        }
        return forwardInternal(parameterStore, nDList, nDList2, pairList);
    }

    protected abstract NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList);

    /* JADX INFO: Access modifiers changed from: protected */
    public NDList forwardInternal(ParameterStore parameterStore, NDList nDList, NDList nDList2, PairList<String, Object> pairList) {
        return forwardInternal(parameterStore, nDList, true, pairList);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final <B extends Block> B addChildBlock(String str, B b) {
        this.children.add(String.format(Locale.ENGLISH, "%02d%s", Integer.valueOf(this.children.size() + 1), str), b);
        return b;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final <P extends Parameter> P addParameter(P p) {
        this.parameters.put(p.getName(), p);
        return p;
    }

    @Override // ai.djl.nn.Block
    public BlockList getChildren() {
        BlockList blockList = new BlockList(this.children.size());
        Iterator<Pair<String, Block>> it = this.children.iterator();
        while (it.hasNext()) {
            blockList.add(it.next());
        }
        return blockList;
    }

    @Override // ai.djl.nn.Block
    public PairList<String, Shape> describeInput() {
        if (isInitialized()) {
            return new PairList<>(this.inputNames, Arrays.asList(this.inputShapes));
        }
        throw new IllegalStateException("Parameter of this block are not initialised,please call model.newTrainer and trainer.initialize");
    }

    @Override // ai.djl.nn.Block
    public void setInitializer(Initializer initializer, Parameter.Type type) {
        setInitializer(initializer, parameter -> {
            return parameter.getType().equals(type);
        });
    }

    @Override // ai.djl.nn.Block
    public void setInitializer(Initializer initializer, String str) {
        Parameter parameter = this.parameters.get(str);
        if (parameter == null) {
            throw new IllegalArgumentException("Could not find parameter " + str);
        }
        parameter.setInitializer(initializer);
    }

    @Override // ai.djl.nn.Block
    public void setInitializer(Initializer initializer, Predicate<Parameter> predicate) {
        for (Parameter parameter : getParameters().values()) {
            if (predicate.test(parameter)) {
                parameter.setInitializer(initializer);
            }
        }
    }

    @Override // ai.djl.nn.Block
    public void initialize(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        beforeInitialize(shapeArr);
        if (!isInitialized()) {
            prepare(shapeArr);
        }
        Iterator<Parameter> it = this.parameters.values().iterator();
        while (it.hasNext()) {
            it.next().initialize(nDManager, dataType);
        }
        initializeChildBlocks(nDManager, dataType, shapeArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void beforeInitialize(Shape... shapeArr) {
        if (this.inputNames.isEmpty()) {
            this.inputNames = new ArrayList();
            for (int i = 0; i < shapeArr.length; i++) {
                this.inputNames.add("data" + i);
            }
        }
        this.inputShapes = shapeArr;
    }

    protected void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        if (!this.children.isEmpty()) {
            throw new IllegalStateException(getClass().getSimpleName() + " has child blocks but initializeChildBlocks is not overwritten.");
        }
    }

    @Override // ai.djl.nn.Block
    public ParameterList getParameters() {
        ParameterList directParameters = getDirectParameters();
        Iterator<Pair<String, Block>> it = getChildren().iterator();
        while (it.hasNext()) {
            Pair<String, Block> next = it.next();
            Iterator<Pair<String, Parameter>> it2 = next.getValue().getParameters().iterator();
            while (it2.hasNext()) {
                Pair<String, Parameter> next2 = it2.next();
                directParameters.add(next.getKey() + "_" + next2.getKey(), next2.getValue());
            }
        }
        return directParameters;
    }

    @Override // ai.djl.nn.Block
    public ParameterList getDirectParameters() {
        return new ParameterList(this.parameters);
    }

    protected void prepare(Shape[] shapeArr) {
    }

    @Override // ai.djl.nn.Block
    public boolean isInitialized() {
        Iterator<Parameter> it = getParameters().values().iterator();
        while (it.hasNext()) {
            if (!it.next().isInitialized()) {
                return false;
            }
        }
        return true;
    }

    @Override // ai.djl.nn.Block
    public void clear() {
        getParameters().forEach(pair -> {
            ((Parameter) pair.getValue()).close();
        });
    }

    @Override // ai.djl.nn.Block
    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    @Override // ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.write(this.version);
        saveMetadata(dataOutputStream);
        Iterator<Parameter> it = this.parameters.values().iterator();
        while (it.hasNext()) {
            it.next().save(dataOutputStream);
        }
        Iterator<Block> it2 = this.children.values().iterator();
        while (it2.hasNext()) {
            it2.next().saveParameters(dataOutputStream);
        }
    }

    @Override // ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        loadMetadata(dataInputStream.readByte(), dataInputStream);
        Iterator<Parameter> it = this.parameters.values().iterator();
        while (it.hasNext()) {
            it.next().load(nDManager, dataInputStream);
        }
        Iterator<Block> it2 = this.children.values().iterator();
        while (it2.hasNext()) {
            it2.next().loadParameters(nDManager, dataInputStream);
        }
    }

    protected void saveMetadata(DataOutputStream dataOutputStream) throws IOException {
        saveInputShapes(dataOutputStream);
    }

    protected void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b != this.version) {
            throw new MalformedModelException("Cannot load parameters for " + getClass().getCanonicalName() + ", expected version " + ((int) this.version) + ", got " + ((int) b) + ".");
        }
        readInputShapes(dataInputStream);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void saveInputShapes(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeInt(this.inputShapes.length);
        for (Shape shape : this.inputShapes) {
            dataOutputStream.write(shape.getEncoded());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void readInputShapes(DataInputStream dataInputStream) throws IOException {
        int readInt = dataInputStream.readInt();
        Shape[] shapeArr = new Shape[readInt];
        for (int i = 0; i < readInt; i++) {
            shapeArr[i] = Shape.decode(dataInputStream);
        }
        if (this.inputShapes == null) {
            this.inputShapes = shapeArr;
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        String simpleName = getClass().getSimpleName();
        if (simpleName.endsWith("Block")) {
            simpleName = simpleName.substring(0, simpleName.length() - 5);
        }
        sb.append(simpleName).append('(');
        if (isInitialized()) {
            PairList<String, Shape> describeInput = describeInput();
            appendShape(sb, (Shape[]) describeInput.values().toArray(new Shape[0]));
            sb.append(" -> ");
            appendShape(sb, getOutputShapes((Shape[]) describeInput.values().toArray(new Shape[0])));
        } else {
            sb.append("Uninitialized");
        }
        sb.append(')');
        return sb.toString();
    }

    private void appendShape(StringBuilder sb, Shape[] shapeArr) {
        boolean z = true;
        for (Shape shape : shapeArr) {
            if (z) {
                z = false;
            } else {
                sb.append(", ");
            }
            long[] shape2 = shape.getShape();
            int length = shape2.length;
            if (length == 0) {
                sb.append("()");
            } else {
                int i = 0;
                if (shape2[0] == -1) {
                    length--;
                    i = 1;
                }
                if (length == 0) {
                    sb.append("()");
                } else if (length == 1) {
                    sb.append(shape2[i]);
                } else {
                    sb.append('(');
                    for (int i2 = i; i2 < shape2.length; i2++) {
                        if (i2 > i) {
                            sb.append(", ");
                        }
                        sb.append(shape2[i2]);
                    }
                    sb.append(')');
                }
            }
        }
    }
}
