package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import java.util.ArrayList;
import java.util.Arrays;

/* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/MobileNetV2.class */
public final class MobileNetV2 {
    public static final int FILTERLENGTH = 9;
    public static final int REPEATLENGTH = 9;
    public static final int STRIDELENGTH = 9;
    public static final int MULTILENGTH = 7;

    /* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/MobileNetV2$Builder.class */
    public static final class Builder {
        float batchNormMomentum = 0.9f;
        long outSize = 10;
        int[] repeatTimes = {1, 1, 2, 3, 4, 3, 3, 1, 1};
        int[] filters = {32, 16, 24, 32, 64, 96, 160, 320, 1280};
        int[] strides = {2, 1, 2, 2, 2, 1, 2, 1, 1};
        int[] multiTimes = {1, 6, 6, 6, 6, 6, 6};

        Builder() {
        }

        public Builder optBatchNormMomentum(float f) {
            this.batchNormMomentum = f;
            return this;
        }

        public Builder setOutSize(long j) {
            this.outSize = j;
            return this;
        }

        public Builder optFilters(int[] iArr) {
            if (iArr.length != 9) {
                throw new IllegalArgumentException(String.format("optFilters requires filters of length %d, but was given filters of length %d instead", 9, Integer.valueOf(iArr.length)));
            }
            this.filters = iArr;
            return this;
        }

        public Builder optRepeatTimes(int[] iArr) {
            if (iArr.length != 9) {
                throw new IllegalArgumentException(String.format("optRepeatTimes requires repeatTimes of length %d, but was given repeatTimes of length %d instead", 9, Integer.valueOf(iArr.length)));
            }
            this.repeatTimes = iArr;
            return this;
        }

        public Builder optStrides(int[] iArr) {
            if (iArr.length != 9) {
                throw new IllegalArgumentException(String.format("optStrides requires strides of length %d, but was given strides of length %d instead", 9, Integer.valueOf(iArr.length)));
            }
            this.strides = iArr;
            return this;
        }

        public Builder optMultiTimes(int[] iArr) {
            if (iArr.length != 7) {
                throw new IllegalArgumentException(String.format("optMultiTimes requires multiTimes of length %d, but was given multiTimes of length %d instead", 7, Integer.valueOf(iArr.length)));
            }
            this.multiTimes = iArr;
            return this;
        }

        public Block build() {
            return MobileNetV2.mobilenetV2(this);
        }
    }

    private MobileNetV2() {
    }

    /* JADX WARN: Type inference failed for: r1v10, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    /* JADX WARN: Type inference failed for: r1v29, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    /* JADX WARN: Type inference failed for: r1v42, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    public static Block linearBottleNeck(int i, int i2, int i3, int i4, float f) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.add(Conv2d.builder().setFilters(i * i4).setKernelShape(new Shape(1, 1)).optBias(false).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(f).build()).add(Activation.relu6Block()).add(Conv2d.builder().setKernelShape(new Shape(3, 3)).setFilters(i * i4).optStride(new Shape(i3, i3)).optPadding(new Shape(1, 1)).optGroups(i * i4).optBias(false).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(f).build()).add(Activation.relu6Block()).add(Conv2d.builder().setFilters(i2).setKernelShape(new Shape(1, 1)).optBias(false).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(f).build());
        return (i3 == 1 && i == i2) ? new ParallelBlock(list -> {
            return new NDList(NDArrays.add(((NDList) list.get(0)).singletonOrThrow(), ((NDList) list.get(1)).singletonOrThrow()));
        }, Arrays.asList(sequentialBlock, Blocks.identityBlock())) : sequentialBlock;
    }

    public static Block makeStage(int i, int i2, int i3, int i4, int i5, float f) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.add(linearBottleNeck(i2, i3, i4, i5, f));
        for (int i6 = 0; i6 < i - 1; i6++) {
            sequentialBlock.add(linearBottleNeck(i3, i3, 1, i5, f));
        }
        return sequentialBlock;
    }

    /* JADX WARN: Type inference failed for: r1v35, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    /* JADX WARN: Type inference failed for: r1v54, types: [ai.djl.nn.norm.BatchNorm$BaseBuilder] */
    public static Block mobilenetV2(Builder builder) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        SequentialBlock sequentialBlock2 = new SequentialBlock();
        for (int i = 0; i < builder.repeatTimes[0]; i++) {
            sequentialBlock2.add(Conv2d.builder().setKernelShape(new Shape(1, 1)).setFilters(builder.filters[0]).optStride(new Shape(builder.strides[0], builder.strides[0])).optBias(false).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.relu6Block());
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < 7; i2++) {
            arrayList.add(makeStage(builder.repeatTimes[i2 + 1], builder.filters[i2], builder.filters[i2 + 1], builder.strides[i2 + 1], builder.multiTimes[i2], builder.batchNormMomentum));
        }
        SequentialBlock sequentialBlock3 = new SequentialBlock();
        for (int i3 = 0; i3 < builder.repeatTimes[8]; i3++) {
            sequentialBlock3.add(Conv2d.builder().setKernelShape(new Shape(1, 1)).setFilters(builder.filters[8]).optStride(new Shape(builder.strides[8], builder.strides[8])).optBias(false).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.relu6Block());
        }
        return sequentialBlock.add(sequentialBlock2).addAll(arrayList).add(sequentialBlock3).add(Pool.globalAvgPool2dBlock()).addSingleton(nDArray -> {
            return nDArray.reshape(nDArray.getShape().get(0), builder.filters[8], 1, 1);
        }).add(Conv2d.builder().setKernelShape(new Shape(1, 1)).setFilters((int) builder.outSize).build()).addSingleton(nDArray2 -> {
            return nDArray2.reshape(nDArray2.getShape().get(0), builder.outSize);
        });
    }

    public static Builder builder() {
        return new Builder();
    }
}
