package ai.djl.repository.zoo;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.nn.Block;
import ai.djl.translate.DefaultTranslatorFactory;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.io.IOException;
import java.net.MalformedURLException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/repository/zoo/Criteria.class */
public class Criteria<I, O> {
    private Application application;
    private Class<I> inputClass;
    private Class<O> outputClass;
    private String engine;
    private Device device;
    private String groupId;
    private String artifactId;
    private ModelZoo modelZoo;
    private Map<String, String> filters;
    private Map<String, Object> arguments;
    private Map<String, String> options;
    private TranslatorFactory factory;
    private Block block;
    private String modelName;
    private Progress progress;

    /* loaded from: input_file:ai/djl/repository/zoo/Criteria$Builder.class */
    public static final class Builder<I, O> {
        Application application;
        Class<I> inputClass;
        Class<O> outputClass;
        String engine;
        Device device;
        String groupId;
        String artifactId;
        ModelZoo modelZoo;
        Map<String, String> filters;
        Map<String, Object> arguments;
        Map<String, String> options;
        TranslatorFactory factory;
        Block block;
        String modelName;
        Progress progress;
        Translator<I, O> translator;

        Builder() {
            this.application = Application.UNDEFINED;
        }

        private Builder(Class<I> cls, Class<O> cls2, Builder<?, ?> builder) {
            this.inputClass = cls;
            this.outputClass = cls2;
            this.application = builder.application;
            this.engine = builder.engine;
            this.device = builder.device;
            this.groupId = builder.groupId;
            this.artifactId = builder.artifactId;
            this.modelZoo = builder.modelZoo;
            this.filters = builder.filters;
            this.arguments = builder.arguments;
            this.options = builder.options;
            this.factory = builder.factory;
            this.block = builder.block;
            this.modelName = builder.modelName;
            this.progress = builder.progress;
            this.translator = (Translator<I, O>) builder.translator;
        }

        public <P, Q> Builder<P, Q> setTypes(Class<P> cls, Class<Q> cls2) {
            return new Builder<>(cls, cls2, this);
        }

        public Builder<I, O> optApplication(Application application) {
            this.application = application;
            return this;
        }

        public Builder<I, O> optEngine(String str) {
            this.engine = str;
            return this;
        }

        public Builder<I, O> optDevice(Device device) {
            this.device = device;
            return this;
        }

        public Builder<I, O> optGroupId(String str) {
            this.groupId = str;
            return this;
        }

        public Builder<I, O> optArtifactId(String str) {
            if (str == null || !str.contains(":")) {
                this.artifactId = str;
            } else {
                String[] split = str.split(":", -1);
                this.groupId = split[0].isEmpty() ? null : split[0];
                this.artifactId = split[1].isEmpty() ? null : split[1];
            }
            return this;
        }

        public Builder<I, O> optModelUrls(String str) {
            if (str != null) {
                this.modelZoo = new DefaultModelZoo(str);
            }
            return this;
        }

        public Builder<I, O> optModelPath(Path path) {
            if (path != null) {
                try {
                    this.modelZoo = new DefaultModelZoo(path.toUri().toURL().toString());
                } catch (MalformedURLException e) {
                    throw new AssertionError("Invalid model path: " + path, e);
                }
            }
            return this;
        }

        public Builder<I, O> optModelZoo(ModelZoo modelZoo) {
            this.modelZoo = modelZoo;
            return this;
        }

        public Builder<I, O> optFilters(Map<String, String> map) {
            this.filters = map;
            return this;
        }

        public Builder<I, O> optFilter(String str, String str2) {
            if (this.filters == null) {
                this.filters = new HashMap();
            }
            this.filters.put(str, str2);
            return this;
        }

        public Builder<I, O> optBlock(Block block) {
            this.block = block;
            return this;
        }

        public Builder<I, O> optModelName(String str) {
            this.modelName = str;
            return this;
        }

        public Builder<I, O> optArguments(Map<String, Object> map) {
            this.arguments = map;
            return this;
        }

        public Builder<I, O> optArgument(String str, Object obj) {
            if (this.arguments == null) {
                this.arguments = new HashMap();
            }
            this.arguments.put(str, obj);
            return this;
        }

        public Builder<I, O> optOptions(Map<String, String> map) {
            this.options = map;
            return this;
        }

        public Builder<I, O> optOption(String str, String str2) {
            if (this.options == null) {
                this.options = new HashMap();
            }
            this.options.put(str, str2);
            return this;
        }

        public Builder<I, O> optTranslator(Translator<I, O> translator) {
            this.factory = null;
            this.translator = translator;
            return this;
        }

        public Builder<I, O> optTranslatorFactory(TranslatorFactory translatorFactory) {
            this.translator = null;
            this.factory = translatorFactory;
            return this;
        }

        public Builder<I, O> optProgress(Progress progress) {
            this.progress = progress;
            return this;
        }

        public Criteria<I, O> build() {
            if (this.factory == null && this.translator != null) {
                DefaultTranslatorFactory defaultTranslatorFactory = new DefaultTranslatorFactory();
                defaultTranslatorFactory.registerTranslator(this.inputClass, this.outputClass, this.translator);
                this.factory = defaultTranslatorFactory;
            }
            return new Criteria<>(this);
        }
    }

    Criteria(Builder<I, O> builder) {
        this.application = builder.application;
        this.inputClass = builder.inputClass;
        this.outputClass = builder.outputClass;
        this.engine = builder.engine;
        this.device = builder.device;
        this.groupId = builder.groupId;
        this.artifactId = builder.artifactId;
        this.modelZoo = builder.modelZoo;
        this.filters = builder.filters;
        this.arguments = builder.arguments;
        this.options = builder.options;
        this.factory = builder.factory;
        this.block = builder.block;
        this.modelName = builder.modelName;
        this.progress = builder.progress;
    }

    public ZooModel<I, O> loadModel() throws IOException, ModelNotFoundException, MalformedModelException {
        if (this.inputClass == null || this.outputClass == null) {
            throw new IllegalArgumentException("inputClass and outputClass are required.");
        }
        Logger logger = LoggerFactory.getLogger(ModelZoo.class);
        logger.debug("Loading model with {}", this);
        ArrayList<ModelZoo> arrayList = new ArrayList();
        if (this.modelZoo != null) {
            logger.debug("Searching model in specified model zoo: {}", this.modelZoo.getGroupId());
            if (this.groupId != null && !this.modelZoo.getGroupId().equals(this.groupId)) {
                throw new ModelNotFoundException("groupId conflict with ModelZoo criteria." + this.modelZoo.getGroupId() + " v.s. " + this.groupId);
            }
            Set<String> supportedEngines = this.modelZoo.getSupportedEngines();
            if (this.engine != null && !supportedEngines.contains(this.engine)) {
                throw new ModelNotFoundException("ModelZoo doesn't support specified engine: " + this.engine);
            }
            arrayList.add(this.modelZoo);
        } else {
            for (ModelZoo modelZoo : ModelZoo.listModelZoo()) {
                if (this.groupId == null || modelZoo.getGroupId().equals(this.groupId)) {
                    Set<String> supportedEngines2 = modelZoo.getSupportedEngines();
                    if (this.engine == null || supportedEngines2.contains(this.engine)) {
                        arrayList.add(modelZoo);
                    } else {
                        logger.debug("Ignore ModelZoo {} by engine: {}", modelZoo.getGroupId(), this.engine);
                    }
                } else {
                    logger.debug("Ignore ModelZoo {} by groupId: {}", modelZoo.getGroupId(), this.groupId);
                }
            }
        }
        ModelNotFoundException modelNotFoundException = null;
        for (ModelZoo modelZoo2 : arrayList) {
            String groupId = modelZoo2.getGroupId();
            for (ModelLoader modelLoader : modelZoo2.getModelLoaders()) {
                Application application = modelLoader.getApplication();
                String artifactId = modelLoader.getArtifactId();
                logger.debug("Checking ModelLoader: {}", modelLoader);
                if (this.artifactId != null && !this.artifactId.equals(artifactId)) {
                    logger.debug("artifactId mismatch for ModelLoader: {}:{}", groupId, artifactId);
                } else if (this.application == Application.UNDEFINED || application == Application.UNDEFINED || application.matches(this.application)) {
                    try {
                        return modelLoader.loadModel(this);
                    } catch (ModelNotFoundException e) {
                        modelNotFoundException = e;
                        logger.trace(JsonProperty.USE_DEFAULT_NAME, e);
                        logger.debug("{} for ModelLoader: {}:{}", new Object[]{e.getMessage(), groupId, artifactId});
                    }
                } else {
                    logger.debug("application mismatch for ModelLoader: {}:{}", groupId, artifactId);
                }
            }
        }
        throw new ModelNotFoundException("No model with the specified URI or the matching Input/Output type is found.", modelNotFoundException);
    }

    public Application getApplication() {
        return this.application;
    }

    public Class<I> getInputClass() {
        return this.inputClass;
    }

    public Class<O> getOutputClass() {
        return this.outputClass;
    }

    public String getEngine() {
        return this.engine;
    }

    public Device getDevice() {
        return this.device;
    }

    public String getGroupId() {
        return this.groupId;
    }

    public String getArtifactId() {
        return this.artifactId;
    }

    public ModelZoo getModelZoo() {
        return this.modelZoo;
    }

    public Map<String, String> getFilters() {
        return this.filters;
    }

    public Map<String, Object> getArguments() {
        return this.arguments;
    }

    public Map<String, String> getOptions() {
        return this.options;
    }

    public TranslatorFactory getTranslatorFactory() {
        return this.factory;
    }

    public Block getBlock() {
        return this.block;
    }

    public String getModelName() {
        return this.modelName;
    }

    public Progress getProgress() {
        return this.progress;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(128);
        sb.append("Criteria:\n");
        if (this.application != null) {
            sb.append("\tApplication: ").append(this.application).append('\n');
        }
        sb.append("\tInput: ").append(this.inputClass);
        sb.append("\n\tOutput: ").append(this.outputClass).append('\n');
        if (this.engine != null) {
            sb.append("\tEngine: ").append(this.engine).append('\n');
        }
        if (this.modelZoo != null) {
            sb.append("\tModelZoo: ").append(this.modelZoo.getGroupId()).append('\n');
        }
        if (this.groupId != null) {
            sb.append("\tGroupID: ").append(this.groupId).append('\n');
        }
        if (this.artifactId != null) {
            sb.append("\tArtifactId: ").append(this.artifactId).append('\n');
        }
        if (this.filters != null) {
            sb.append("\tFilter: ").append(JsonUtils.GSON.toJson(this.filters)).append('\n');
        }
        if (this.arguments != null) {
            sb.append("\tArguments: ").append(JsonUtils.builder().excludeFieldsWithoutExposeAnnotation().create().toJson(this.arguments)).append('\n');
        }
        if (this.options != null) {
            sb.append("\tOptions: ").append(JsonUtils.GSON.toJson(this.options)).append('\n');
        }
        if (this.factory == null) {
            sb.append("\tNo translator supplied\n");
        }
        return sb.toString();
    }

    public Builder<I, O> toBuilder() {
        return builder().setTypes(this.inputClass, this.outputClass).optApplication(this.application).optEngine(this.engine).optDevice(this.device).optGroupId(this.groupId).optArtifactId(this.artifactId).optModelZoo(this.modelZoo).optFilters(this.filters).optArguments(this.arguments).optOptions(this.options).optTranslatorFactory(this.factory).optBlock(this.block).optModelName(this.modelName).optProgress(this.progress);
    }

    public static Builder<?, ?> builder() {
        return new Builder<>();
    }
}
