package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicmodelzoo.BasicModelZoo;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.modality.cv.zoo.ImageClassificationModelLoader;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.repository.Artifact;
import ai.djl.repository.Repository;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/ResNetModelLoader.class */
public class ResNetModelLoader extends ImageClassificationModelLoader {
    private static final String GROUP_ID = "ai.djl.zoo";
    private static final String ARTIFACT_ID = "resnet";
    private static final String VERSION = "0.0.2";

    public ResNetModelLoader(Repository repository) {
        super(repository, "ai.djl.zoo", ARTIFACT_ID, VERSION, new BasicModelZoo());
    }

    @Override // ai.djl.repository.zoo.BaseModelLoader
    protected Model createModel(String str, Device device, Artifact artifact, Map<String, Object> map, String str2) {
        Model newInstance = Model.newInstance(str, device, str2);
        newInstance.setBlock(resnetBlock(map));
        return newInstance;
    }

    private Block resnetBlock(Map<String, Object> map) {
        ResNetV1.Builder imageShape = ResNetV1.builder().setNumLayers((int) ((Double) map.get("numLayers")).doubleValue()).setOutSize((long) ((Double) map.get("outSize")).doubleValue()).setImageShape(new Shape(((List) map.get("imageShape")).stream().mapToLong((v0) -> {
            return v0.longValue();
        }).toArray()));
        if (map.containsKey("batchNormMomentum")) {
            imageShape.optBatchNormMomentum((float) ((Double) map.get("batchNormMomentum")).doubleValue());
        }
        return imageShape.build();
    }
}
