package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;
import com.sun.jna.Function;

/* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/AlexNet.class */
public final class AlexNet {

    /* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/AlexNet$Builder.class */
    public static final class Builder {
        float dropOutRate = 0.5f;
        int numLayers = 7;
        int[] numChannels = {96, 256, Function.USE_VARARGS, Function.USE_VARARGS, 256, 4096, 4096};

        Builder() {
        }

        public Builder setDropOutRate(float f) {
            this.dropOutRate = f;
            return this;
        }

        public Builder setNumChannels(int[] iArr) {
            if (iArr.length != this.numLayers) {
                throw new IllegalArgumentException("number of channels should be equal to " + this.numLayers);
            }
            this.numChannels = iArr;
            return this;
        }

        public Block build() {
            return AlexNet.alexNet(this);
        }
    }

    private AlexNet() {
    }

    public static Block alexNet(Builder builder) {
        return new SequentialBlock().add(Conv2d.builder().setKernelShape(new Shape(11, 11)).optStride(new Shape(4, 4)).setFilters(builder.numChannels[0]).build()).add(Activation::relu).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2))).add(Conv2d.builder().setKernelShape(new Shape(5, 5)).optPadding(new Shape(2, 2)).setFilters(builder.numChannels[1]).build()).add(Activation::relu).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2))).add(Conv2d.builder().setKernelShape(new Shape(3, 3)).optPadding(new Shape(1, 1)).setFilters(builder.numChannels[2]).build()).add(Activation::relu).add(Conv2d.builder().setKernelShape(new Shape(3, 3)).optPadding(new Shape(1, 1)).setFilters(builder.numChannels[3]).build()).add(Activation::relu).add(Conv2d.builder().setKernelShape(new Shape(3, 3)).optPadding(new Shape(1, 1)).setFilters(builder.numChannels[4]).build()).add(Activation::relu).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2))).add(Blocks.batchFlattenBlock()).add(Linear.builder().setUnits(builder.numChannels[5]).build()).add(Activation::relu).add(Dropout.builder().optRate(builder.dropOutRate).build()).add(Linear.builder().setUnits(builder.numChannels[6]).build()).add(Activation::relu).add(Dropout.builder().optRate(builder.dropOutRate).build()).add(Linear.builder().setUnits(10L).build());
    }

    public static Builder builder() {
        return new Builder();
    }
}
