package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import java.util.Arrays;
import org.apache.commons.compress.archivers.cpio.CpioConstants;

/* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/ResNetV1.class */
public final class ResNetV1 {

    /* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/ResNetV1$Builder.class */
    public static final class Builder {
        int numLayers;
        int numStages;
        long outSize;
        float batchNormMomentum = 0.9f;
        Shape imageShape;
        boolean bottleneck;
        int[] units;
        int[] filters;

        Builder() {
        }

        public Builder setNumLayers(int i) {
            this.numLayers = i;
            return this;
        }

        public Builder setOutSize(long j) {
            this.outSize = j;
            return this;
        }

        public Builder optBatchNormMomentum(float f) {
            this.batchNormMomentum = f;
            return this;
        }

        public Builder setImageShape(Shape shape) {
            this.imageShape = shape;
            return this;
        }

        public SequentialBlock build() {
            int i;
            if (this.imageShape == null) {
                throw new IllegalArgumentException("Must set imageShape");
            }
            if (this.imageShape.get(1) <= 28) {
                this.numStages = 3;
                if ((this.numLayers - 2) % 9 == 0 && this.numLayers >= 164) {
                    i = (this.numLayers - 2) / 9;
                    this.filters = new int[]{16, 64, CpioConstants.C_IWUSR, 256};
                    this.bottleneck = true;
                } else {
                    if ((this.numLayers - 2) % 6 != 0 || this.numLayers >= 164) {
                        throw new IllegalArgumentException("no experiments done on num_layers " + this.numLayers + ", you can do it yourself");
                    }
                    i = (this.numLayers - 2) / 6;
                    this.filters = new int[]{16, 16, 32, 64};
                    this.bottleneck = false;
                }
                this.units = new int[this.numStages];
                for (int i2 = 0; i2 < this.numStages; i2++) {
                    this.units[i2] = i;
                }
            } else {
                this.numStages = 4;
                if (this.numLayers >= 50) {
                    this.filters = new int[]{64, 256, 512, 1024, 2048};
                    this.bottleneck = true;
                } else {
                    this.filters = new int[]{64, 64, CpioConstants.C_IWUSR, 256, 512};
                    this.bottleneck = true;
                }
                if (this.numLayers == 18) {
                    this.units = new int[]{2, 2, 2, 2};
                } else if (this.numLayers == 34) {
                    this.units = new int[]{3, 4, 6, 3};
                } else if (this.numLayers == 50) {
                    this.units = new int[]{3, 4, 6, 3};
                } else if (this.numLayers == 101) {
                    this.units = new int[]{3, 4, 23, 3};
                } else if (this.numLayers == 152) {
                    this.units = new int[]{3, 8, 36, 3};
                } else if (this.numLayers == 200) {
                    this.units = new int[]{3, 24, 36, 3};
                } else {
                    if (this.numLayers != 269) {
                        throw new IllegalArgumentException("no experiments done on num_layers " + this.numLayers + ", you can do it yourself");
                    }
                    this.units = new int[]{3, 30, 48, 8};
                }
            }
            return ResNetV1.resnet(this);
        }
    }

    private ResNetV1() {
    }

    /* JADX WARN: Type inference failed for: r1v100, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    /* JADX WARN: Type inference failed for: r1v14, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    /* JADX WARN: Type inference failed for: r1v31, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    /* JADX WARN: Type inference failed for: r1v48, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    /* JADX WARN: Type inference failed for: r1v66, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    /* JADX WARN: Type inference failed for: r1v83, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    public static Block residualUnit(int i, Shape shape, boolean z, boolean z2, float f) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        if (z2) {
            sequentialBlock.add(Conv2d.builder().setKernelShape(new Shape(1, 1)).setFilters(i / 4).optStride(shape).optPadding(new Shape(0, 0)).optBias(true).build()).add(BatchNorm.builder().optEpsilon(1.0E-5f).optMomentum(f).build()).add(Activation::relu).add(Conv2d.builder().setKernelShape(new Shape(3, 3)).setFilters(i / 4).optStride(new Shape(1, 1)).optPadding(new Shape(1, 1)).optBias(false).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(f).build()).add(Activation::relu).add(Conv2d.builder().setKernelShape(new Shape(1, 1)).setFilters(i).optStride(new Shape(1, 1)).optPadding(new Shape(0, 0)).optBias(true).build()).add(BatchNorm.builder().optEpsilon(1.0E-5f).optMomentum(f).build());
        } else {
            sequentialBlock.add(Conv2d.builder().setKernelShape(new Shape(3, 3)).setFilters(i).optStride(shape).optPadding(new Shape(1, 1)).optBias(false).build()).add(BatchNorm.builder().optEpsilon(1.0E-5f).optMomentum(f).build()).add(Activation::relu).add(Conv2d.builder().setKernelShape(new Shape(3, 3)).setFilters(i).optStride(new Shape(1, 1)).optPadding(new Shape(1, 1)).optBias(false).build()).add(BatchNorm.builder().optEpsilon(1.0E-5f).optMomentum(f).build());
        }
        SequentialBlock sequentialBlock2 = new SequentialBlock();
        if (z) {
            sequentialBlock2.add(Blocks.identityBlock());
        } else {
            sequentialBlock2.add(Conv2d.builder().setKernelShape(new Shape(1, 1)).setFilters(i).optStride(shape).optPadding(new Shape(0, 0)).optBias(false).build()).add(BatchNorm.builder().optEpsilon(1.0E-5f).optMomentum(f).build());
        }
        return new ParallelBlock(list -> {
            return new NDList(((NDList) list.get(0)).singletonOrThrow().add(((NDList) list.get(1)).singletonOrThrow()).getNDArrayInternal().relu());
        }, Arrays.asList(sequentialBlock, sequentialBlock2));
    }

    /* JADX WARN: Type inference failed for: r1v16, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    public static SequentialBlock resnet(Builder builder) {
        int length = builder.units.length;
        long j = builder.imageShape.get(1);
        SequentialBlock sequentialBlock = new SequentialBlock();
        if (j <= 32) {
            sequentialBlock.add(Conv2d.builder().setKernelShape(new Shape(3, 3)).setFilters(builder.filters[0]).optStride(new Shape(1, 1)).optPadding(new Shape(1, 1)).optBias(false).build());
        } else {
            sequentialBlock.add(Conv2d.builder().setKernelShape(new Shape(7, 7)).setFilters(builder.filters[0]).optStride(new Shape(2, 2)).optPadding(new Shape(3, 3)).optBias(false).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.reluBlock()).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2), new Shape(1, 1)));
        }
        Shape shape = new Shape(1, 1);
        for (int i = 0; i < length; i++) {
            sequentialBlock.add(residualUnit(builder.filters[i + 1], shape, false, builder.bottleneck, builder.batchNormMomentum));
            for (int i2 = 0; i2 < builder.units[i] - 1; i2++) {
                sequentialBlock.add(residualUnit(builder.filters[i + 1], new Shape(1, 1), true, builder.bottleneck, builder.batchNormMomentum));
            }
            if (i == 0) {
                shape = new Shape(2, 2);
            }
        }
        return sequentialBlock.add(Pool.globalAvgPool2dBlock()).add(Blocks.batchFlattenBlock()).add(Linear.builder().setUnits(builder.outSize).build()).add(Blocks.batchFlattenBlock());
    }

    public static Builder builder() {
        return new Builder();
    }
}
