package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;
import org.apache.commons.compress.archivers.cpio.CpioConstants;

/* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/VGG.class */
public final class VGG {

    /* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/VGG$Builder.class */
    public static final class Builder {
        int numLayers = 11;
        int[][] convArch = {new int[]{1, 64}, new int[]{1, CpioConstants.C_IWUSR}, new int[]{2, 256}, new int[]{2, 512}, new int[]{2, 512}};

        public Builder setNumLayers(int i) {
            this.numLayers = i;
            return this;
        }

        public Builder setConvArch(int[][] iArr) {
            int i = 0;
            for (int[] iArr2 : iArr) {
                i += iArr2[0];
            }
            if (i != this.numLayers - 3) {
                throw new IllegalArgumentException("total sum of channels in the array should be equal to the ( numLayers - 3 )");
            }
            this.convArch = iArr;
            return this;
        }

        public Block build() {
            return VGG.vgg(this);
        }
    }

    private VGG() {
    }

    public static Block vgg(Builder builder) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        VGG vgg = new VGG();
        for (int[] iArr : builder.convArch) {
            sequentialBlock.add(vgg.vggBlock(iArr[0], iArr[1]));
        }
        sequentialBlock.add(Blocks.batchFlattenBlock()).add(Linear.builder().setUnits(4096L).build()).add(Activation::relu).add(Dropout.builder().optRate(0.5f).build()).add(Linear.builder().setUnits(4096L).build()).add(Activation::relu).add(Dropout.builder().optRate(0.5f).build()).add(Linear.builder().setUnits(10L).build());
        return sequentialBlock;
    }

    public SequentialBlock vggBlock(int i, int i2) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        for (int i3 = 0; i3 < i; i3++) {
            sequentialBlock.add(Conv2d.builder().setFilters(i2).setKernelShape(new Shape(3, 3)).optPadding(new Shape(1, 1)).build()).add(Activation::relu);
        }
        sequentialBlock.add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)));
        return sequentialBlock;
    }

    public static Builder builder() {
        return new Builder();
    }
}
