package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.pooling.Pool;
import com.sun.jna.Function;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.compress.archivers.cpio.CpioConstants;

/* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/GoogLeNet.class */
public final class GoogLeNet {

    /* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/GoogLeNet$Builder.class */
    public static final class Builder {
        long outSize = 10;

        Builder() {
        }

        public Builder setOutSize(long j) {
            this.outSize = j;
            return this;
        }

        public Block build() {
            return GoogLeNet.googLeNet(this);
        }
    }

    private GoogLeNet() {
    }

    public static Block googLeNet(Builder builder) {
        GoogLeNet googLeNet = new GoogLeNet();
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.add(Conv2d.builder().setKernelShape(new Shape(7, 7)).optPadding(new Shape(3, 3)).optStride(new Shape(2, 2)).setFilters(64).build()).add(Activation::relu).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2), new Shape(1, 1)));
        SequentialBlock sequentialBlock2 = new SequentialBlock();
        sequentialBlock2.add(Conv2d.builder().setFilters(64).setKernelShape(new Shape(1, 1)).build()).add(Activation::relu).add(Conv2d.builder().setFilters(192).setKernelShape(new Shape(3, 3)).optPadding(new Shape(1, 1)).build()).add(Activation::relu).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2), new Shape(1, 1)));
        SequentialBlock sequentialBlock3 = new SequentialBlock();
        sequentialBlock3.add(googLeNet.inceptionBlock(64, new int[]{96, CpioConstants.C_IWUSR}, new int[]{16, 32}, 32)).add(googLeNet.inceptionBlock(CpioConstants.C_IWUSR, new int[]{CpioConstants.C_IWUSR, 192}, new int[]{32, 96}, 64)).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2), new Shape(1, 1)));
        SequentialBlock sequentialBlock4 = new SequentialBlock();
        sequentialBlock4.add(googLeNet.inceptionBlock(192, new int[]{96, 208}, new int[]{16, 48}, 64)).add(googLeNet.inceptionBlock(160, new int[]{112, 224}, new int[]{24, 64}, 64)).add(googLeNet.inceptionBlock(CpioConstants.C_IWUSR, new int[]{CpioConstants.C_IWUSR, 256}, new int[]{24, 64}, 64)).add(googLeNet.inceptionBlock(112, new int[]{144, 288}, new int[]{32, 64}, 64)).add(googLeNet.inceptionBlock(256, new int[]{160, 320}, new int[]{32, CpioConstants.C_IWUSR}, CpioConstants.C_IWUSR)).add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2), new Shape(1, 1)));
        SequentialBlock sequentialBlock5 = new SequentialBlock();
        sequentialBlock5.add(googLeNet.inceptionBlock(256, new int[]{160, 320}, new int[]{32, CpioConstants.C_IWUSR}, CpioConstants.C_IWUSR)).add(googLeNet.inceptionBlock(Function.USE_VARARGS, new int[]{192, Function.USE_VARARGS}, new int[]{48, CpioConstants.C_IWUSR}, CpioConstants.C_IWUSR)).add(Pool.globalAvgPool2dBlock());
        return new SequentialBlock().addAll(sequentialBlock, sequentialBlock2, sequentialBlock3, sequentialBlock4, sequentialBlock5, Linear.builder().setUnits(builder.outSize).build());
    }

    public ParallelBlock inceptionBlock(int i, int[] iArr, int[] iArr2, int i2) {
        return new ParallelBlock(list -> {
            return new NDList(NDArrays.concat(new NDList((List) list.stream().map((v0) -> {
                return v0.head();
            }).collect(Collectors.toList())), 1));
        }, Arrays.asList(new SequentialBlock().add(Conv2d.builder().setFilters(i).setKernelShape(new Shape(1, 1)).build()).add(Activation::relu), new SequentialBlock().add(Conv2d.builder().setFilters(iArr[0]).setKernelShape(new Shape(1, 1)).build()).add(Activation::relu).add(Conv2d.builder().setFilters(iArr[1]).setKernelShape(new Shape(3, 3)).optPadding(new Shape(1, 1)).build()).add(Activation::relu), new SequentialBlock().add(Conv2d.builder().setFilters(iArr2[0]).setKernelShape(new Shape(1, 1)).build()).add(Activation::relu).add(Conv2d.builder().setFilters(iArr2[1]).setKernelShape(new Shape(5, 5)).optPadding(new Shape(2, 2)).build()).add(Activation::relu), new SequentialBlock().add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(1, 1), new Shape(1, 1))).add(Conv2d.builder().setFilters(i2).setKernelShape(new Shape(1, 1)).build()).add(Activation::relu)));
    }

    public static Builder builder() {
        return new Builder();
    }
}
