package ai.djl.basicmodelzoo.cv.object_detection.yolo;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.BlockList;
import ai.djl.nn.Blocks;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.compress.archivers.cpio.CpioConstants;

/* loaded from: input_file:ai/djl/basicmodelzoo/cv/object_detection/yolo/YOLOV3.class */
public final class YOLOV3 extends AbstractBlock {
    private static final byte VERSION = 1;
    private SequentialBlock darkNet53;
    private Block lastLayer0;
    private Block layer0Output;
    private Block lastLayer1Conv;
    private Block lastLayer1UpSample;
    private Block lastLayer1;
    private Block layer1Output;
    private Block lastLayer2Conv;
    private Block lastLayer2UpSample;
    private Block lastLayer2;
    private Block layer2Output;
    static final int[] REPEATS = {1, 2, 8, 8, 4};
    static final int[] FILTERS = {32, 64, CpioConstants.C_IWUSR, 256, 512, 1024};

    /* loaded from: input_file:ai/djl/basicmodelzoo/cv/object_detection/yolo/YOLOV3$Builder.class */
    public static final class Builder {
        int numClasses = 20;
        float batchNormMomentum = 0.9f;
        float leakyAlpha = 0.1f;
        int darkNetOutSize = 10;

        public Builder setNumClasses(int i) {
            this.numClasses = i;
            return this;
        }

        public Builder optBatchNormMomentum(float f) {
            this.batchNormMomentum = f;
            return this;
        }

        public Builder optLeakyAlpha(float f) {
            this.leakyAlpha = f;
            return this;
        }

        public Builder optDarkNetOutSize(int i) {
            this.darkNetOutSize = i;
            return this;
        }

        public Block build() {
            return new YOLOV3(this);
        }

        public Block buildDarkNet() {
            return new SequentialBlock().add(YOLOV3.darkNet53(this, false)).add(Pool.globalAvgPool2dBlock()).add(Linear.builder().setUnits(this.darkNetOutSize).build());
        }
    }

    private YOLOV3(Builder builder) {
        super((byte) 1);
        this.darkNet53 = (SequentialBlock) addChildBlock("darkNet53", (String) darkNet53(builder, true));
        this.lastLayer0 = addChildBlock("lastLayer0", (String) makeLastLayers(FILTERS[4], FILTERS[5], builder.batchNormMomentum, builder.leakyAlpha));
        this.layer0Output = addChildBlock("layer0Output", (String) makeOutputLayers(FILTERS[5], 3 * (builder.numClasses + 5), builder.batchNormMomentum, builder.leakyAlpha));
        this.lastLayer1Conv = addChildBlock("lastLayer1Conv", (String) convolutionBlock(256, 1, builder.batchNormMomentum, builder.leakyAlpha));
        this.lastLayer1UpSample = addChildBlock("lastLayer1UpSample", (String) upSampleBlockNearest());
        this.lastLayer1 = addChildBlock("lastLayer1", (String) makeLastLayers(FILTERS[3], FILTERS[4], builder.batchNormMomentum, builder.leakyAlpha));
        this.layer1Output = addChildBlock("layer1Output", (String) makeOutputLayers(FILTERS[4], 3 * (builder.numClasses + 5), builder.batchNormMomentum, builder.leakyAlpha));
        this.lastLayer2Conv = addChildBlock("lastLayer2Conv", (String) convolutionBlock(CpioConstants.C_IWUSR, 1, builder.batchNormMomentum, builder.leakyAlpha));
        this.lastLayer2UpSample = addChildBlock("lastLayer2UpSample", (String) upSampleBlockNearest());
        this.lastLayer2 = addChildBlock("lastLayer2", (String) makeLastLayers(FILTERS[2], FILTERS[3], builder.batchNormMomentum, builder.leakyAlpha));
        this.layer2Output = addChildBlock("layer2Output", (String) makeOutputLayers(FILTERS[3], 3 * (builder.numClasses + 5), builder.batchNormMomentum, builder.leakyAlpha));
    }

    public static Block upSampleBlockNearest() {
        return new SequentialBlock().addSingleton(nDArray -> {
            return nDArray.transpose(0, 2, 3, 1);
        }).addSingleton(nDArray2 -> {
            return NDImageUtils.resize(nDArray2, (int) (nDArray2.getShape().get(1) * 2), (int) (nDArray2.getShape().get(2) * 2), Image.Interpolation.NEAREST);
        }).addSingleton(nDArray3 -> {
            return nDArray3.transpose(0, 3, 1, 2);
        });
    }

    /* JADX WARN: Type inference failed for: r1v10, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    public static Block convolutionBlock(int i, int i2, float f, float f2) {
        int i3 = 0;
        if (i2 > 0) {
            i3 = (i2 - 1) >> 1;
        }
        return new SequentialBlock().add(Conv2d.builder().setFilters(i).setKernelShape(new Shape(i2, i2)).optPadding(new Shape(i3, i3)).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(f).build()).add(Activation.leakyReluBlock(f2));
    }

    public static Block makeLastLayers(int i, int i2, float f, float f2) {
        return new SequentialBlock().add(convolutionBlock(i, 1, f, f2)).add(convolutionBlock(i2, 3, f, f2)).add(convolutionBlock(i, 1, f, f2)).add(convolutionBlock(i2, 3, f, f2)).add(convolutionBlock(i, 1, f, f2));
    }

    public static Block makeOutputLayers(int i, int i2, float f, float f2) {
        return new SequentialBlock().add(convolutionBlock(i, 3, f, f2)).add(Conv2d.builder().setFilters(i2).setKernelShape(new Shape(1, 1)).build());
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList forward = this.darkNet53.forward(parameterStore, nDList, z);
        NDArray nDArray = forward.get(7);
        NDArray nDArray2 = forward.get(6);
        NDArray nDArray3 = forward.get(5);
        NDList forward2 = this.lastLayer0.forward(parameterStore, new NDList(nDArray), z);
        NDList forward3 = this.layer0Output.forward(parameterStore, forward2, z);
        NDList forward4 = this.lastLayer1.forward(parameterStore, new NDList(this.lastLayer1UpSample.forward(parameterStore, this.lastLayer1Conv.forward(parameterStore, forward2, z), z).singletonOrThrow().concat(nDArray2, 1)), z);
        return new NDList(forward3.singletonOrThrow(), this.layer1Output.forward(parameterStore, forward4, z).singletonOrThrow(), this.layer2Output.forward(parameterStore, this.lastLayer2.forward(parameterStore, new NDList(this.lastLayer2UpSample.forward(parameterStore, this.lastLayer2Conv.forward(parameterStore, forward4, z), z).singletonOrThrow().concat(nDArray3, 1)), z), z).singletonOrThrow());
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        Shape[] shapeArr2 = shapeArr;
        Shape[] shapeArr3 = new Shape[3];
        Shape[] shapeArr4 = new Shape[8];
        int i = 0;
        for (String str : this.children.keys()) {
            Block block = this.children.get((BlockList) str);
            if (str.contains("darkNet")) {
                shapeArr4 = block.getOutputShapes(shapeArr2);
                shapeArr2 = new Shape[]{shapeArr4[7]};
            } else if (str.contains("lastLayer")) {
                if ("05lastLayer1UpSample".equals(str)) {
                    Shape[] outputShapes = block.getOutputShapes(shapeArr2);
                    shapeArr2 = new Shape[]{new Shape(outputShapes[0].get(0), outputShapes[0].get(1) + shapeArr4[6].get(1), outputShapes[0].get(2), outputShapes[0].get(3))};
                } else if ("09lastLayer2UpSample".equals(str)) {
                    Shape[] outputShapes2 = block.getOutputShapes(shapeArr2);
                    shapeArr2 = new Shape[]{new Shape(outputShapes2[0].get(0), outputShapes2[0].get(1) + shapeArr4[5].get(1), outputShapes2[0].get(2), outputShapes2[0].get(3))};
                } else {
                    shapeArr2 = block.getOutputShapes(shapeArr2);
                }
            } else if (str.contains("Output")) {
                int i2 = i;
                i++;
                shapeArr3[i2] = block.getOutputShapes(shapeArr2)[0];
            } else {
                shapeArr2 = block.getOutputShapes(shapeArr2);
            }
        }
        return shapeArr3;
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        Shape[] shapeArr2 = shapeArr;
        Shape[] shapeArr3 = new Shape[8];
        for (String str : this.children.keys()) {
            Block block = this.children.get((BlockList) str);
            block.initialize(nDManager, dataType, shapeArr2);
            if (str.contains("darkNet")) {
                shapeArr3 = block.getOutputShapes(shapeArr2);
                shapeArr2 = new Shape[]{shapeArr3[7]};
            } else if (str.contains("lastLayer")) {
                if ("05lastLayer1UpSample".equals(str)) {
                    Shape[] outputShapes = block.getOutputShapes(shapeArr2);
                    shapeArr2 = new Shape[]{new Shape(outputShapes[0].get(0), outputShapes[0].get(1) + shapeArr3[6].get(1), outputShapes[0].get(2), outputShapes[0].get(3))};
                } else if ("09lastLayer2UpSample".equals(str)) {
                    Shape[] outputShapes2 = block.getOutputShapes(shapeArr2);
                    shapeArr2 = new Shape[]{new Shape(outputShapes2[0].get(0), outputShapes2[0].get(1) + shapeArr3[5].get(1), outputShapes2[0].get(2), outputShapes2[0].get(3))};
                } else {
                    shapeArr2 = block.getOutputShapes(shapeArr2);
                }
            } else if (str.contains("Output")) {
                block.getOutputShapes(shapeArr2);
            } else {
                shapeArr2 = block.getOutputShapes(shapeArr2);
            }
        }
    }

    /* JADX WARN: Type inference failed for: r1v22, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    /* JADX WARN: Type inference failed for: r1v8, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    public static Block basicBlock(int i, float f, float f2) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.add(Conv2d.builder().setFilters(i / 2).setKernelShape(new Shape(1, 1)).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(f).build()).add(Activation.leakyReluBlock(f2)).add(Conv2d.builder().setFilters(i).setKernelShape(new Shape(3, 3)).optPadding(new Shape(1, 1)).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(f).build()).add(Activation.leakyReluBlock(f2));
        return new ParallelBlock(list -> {
            return new NDList(NDArrays.add(((NDList) list.get(0)).singletonOrThrow(), ((NDList) list.get(1)).singletonOrThrow()));
        }, Arrays.asList(sequentialBlock, Blocks.identityBlock()));
    }

    /* JADX WARN: Type inference failed for: r1v13, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    public static Block makeLayer(int i, int i2, float f, float f2) {
        ArrayList arrayList = new ArrayList();
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.add(Conv2d.builder().setFilters(i).setKernelShape(new Shape(3, 3)).optStride(new Shape(2, 2)).optPadding(new Shape(1, 1)).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(f).build()).add(Activation.leakyReluBlock(f2));
        for (int i3 = 0; i3 < i2; i3++) {
            arrayList.add(basicBlock(i, f, f2));
        }
        return new SequentialBlock().add(sequentialBlock).addAll(arrayList);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Type inference failed for: r1v11, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    public static SequentialBlock darkNet53(Builder builder, boolean z) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.setReturnIntermediate(z);
        sequentialBlock.add(Conv2d.builder().setFilters(FILTERS[0]).optPadding(new Shape(1, 1)).setKernelShape(new Shape(3, 3)).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.leakyReluBlock(builder.leakyAlpha)).add(makeLayer(FILTERS[1], REPEATS[0], builder.batchNormMomentum, builder.leakyAlpha)).add(makeLayer(FILTERS[2], REPEATS[1], builder.batchNormMomentum, builder.leakyAlpha)).add(makeLayer(FILTERS[3], REPEATS[2], builder.batchNormMomentum, builder.leakyAlpha)).add(makeLayer(FILTERS[4], REPEATS[3], builder.batchNormMomentum, builder.leakyAlpha)).add(makeLayer(FILTERS[5], REPEATS[4], builder.batchNormMomentum, builder.leakyAlpha));
        return sequentialBlock;
    }

    public static Builder builder() {
        return new Builder();
    }
}
