package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;
import java.util.Arrays;
import org.apache.commons.compress.archivers.cpio.CpioConstants;

/* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/SqueezeNet.class */
public final class SqueezeNet {
    private SqueezeNet() {
    }

    static Block fire(int i, int i2, int i3) {
        return new SequentialBlock().add(new SequentialBlock().add(Conv2d.builder().setFilters(i).setKernelShape(new Shape(1, 1)).build()).add(Activation::relu)).add(new ParallelBlock(list -> {
            return new NDList(NDArrays.concat(((NDList) list.get(0)).addAll((NDList) list.get(1)), 1));
        }, Arrays.asList(new SequentialBlock().add(Conv2d.builder().setFilters(i2).setKernelShape(new Shape(1, 1)).build()).add(Activation::relu), new SequentialBlock().add(Conv2d.builder().setFilters(i3).setKernelShape(new Shape(3, 3)).optPadding(new Shape(1, 1)).build()).add(Activation::relu))));
    }

    public static Block squeezenet(int i) {
        return new SequentialBlock().add(Conv2d.builder().setFilters(64).setKernelShape(new Shape(3, 3)).optStride(new Shape(2, 2)).build()).add(Activation::relu).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2), new Shape(0, 0), true)).add(fire(16, 64, 64)).add(fire(16, 64, 64)).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2), new Shape(0, 0), true)).add(fire(32, CpioConstants.C_IWUSR, CpioConstants.C_IWUSR)).add(fire(32, CpioConstants.C_IWUSR, CpioConstants.C_IWUSR)).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2), new Shape(0, 0), true)).add(fire(48, 192, 192)).add(fire(48, 192, 192)).add(fire(64, 256, 256)).add(fire(64, 256, 256)).add(Dropout.builder().optRate(0.5f).build()).add(Conv2d.builder().setFilters(i).setKernelShape(new Shape(1, 1)).build()).add(Activation::relu).add(Pool.globalAvgPool2dBlock()).add(Blocks.batchFlattenBlock());
    }
}
