package ai.djl.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.Ec2Utils;
import ai.djl.util.RandomUtils;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;
import com.ibm.icu.text.PluralRules;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.Iterator;
import java.util.Map;
import java.util.Properties;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/engine/Engine.class */
public abstract class Engine {
    private static final Logger logger = LoggerFactory.getLogger(Engine.class);
    private static final Map<String, EngineProvider> ALL_ENGINES = new ConcurrentHashMap();
    private static final String DEFAULT_ENGINE = initEngine();
    private static final Pattern PATTERN = Pattern.compile("KEY|TOKEN|PASSWORD", 2);
    private Device defaultDevice;
    private Integer seed;

    private static synchronized String initEngine() {
        Iterator it2 = ServiceLoader.load(EngineProvider.class).iterator();
        while (it2.hasNext()) {
            registerEngine((EngineProvider) it2.next());
        }
        if (ALL_ENGINES.isEmpty()) {
            logger.debug("No engine found from EngineProvider");
            return null;
        }
        String str = Utils.getenv("DJL_DEFAULT_ENGINE", System.getProperty("ai.djl.default_engine"));
        if (str == null || str.isEmpty()) {
            int i = Integer.MAX_VALUE;
            for (EngineProvider engineProvider : ALL_ENGINES.values()) {
                if (engineProvider.getEngineRank() < i) {
                    str = engineProvider.getEngineName();
                    i = engineProvider.getEngineRank();
                }
            }
        } else if (!ALL_ENGINES.containsKey(str)) {
            throw new EngineException("Unknown default engine: " + str);
        }
        logger.debug("Found default engine: {}", str);
        Ec2Utils.callHome(str);
        return str;
    }

    public abstract Engine getAlternativeEngine();

    public abstract String getEngineName();

    public abstract int getRank();

    public static String getDefaultEngineName() {
        return System.getProperty("ai.djl.default_engine", DEFAULT_ENGINE);
    }

    public static Engine getInstance() {
        if (DEFAULT_ENGINE == null) {
            throw new EngineException("No deep learning engine found." + System.lineSeparator() + "Please refer to https://github.com/deepjavalibrary/djl/blob/master/docs/development/troubleshooting.md for more details.");
        }
        return getEngine(getDefaultEngineName());
    }

    public static boolean hasEngine(String str) {
        return ALL_ENGINES.containsKey(str);
    }

    public static void registerEngine(EngineProvider engineProvider) {
        logger.debug("Registering EngineProvider: {}", engineProvider.getEngineName());
        ALL_ENGINES.putIfAbsent(engineProvider.getEngineName(), engineProvider);
    }

    public static Set<String> getAllEngines() {
        return ALL_ENGINES.keySet();
    }

    public static Engine getEngine(String str) {
        EngineProvider engineProvider = ALL_ENGINES.get(str);
        if (engineProvider == null) {
            throw new IllegalArgumentException("Deep learning engine not found: " + str);
        }
        return engineProvider.getEngine();
    }

    public abstract String getVersion();

    public abstract boolean hasCapability(String str);

    public Device defaultDevice() {
        if (this.defaultDevice == null) {
            if (!hasCapability(StandardCapabilities.CUDA) || CudaUtils.getGpuCount() <= 0) {
                this.defaultDevice = Device.cpu();
            } else {
                this.defaultDevice = Device.gpu();
            }
        }
        return this.defaultDevice;
    }

    public Device[] getDevices() {
        return getDevices(Integer.MAX_VALUE);
    }

    public Device[] getDevices(int i) {
        int gpuCount = getGpuCount();
        if (i <= 0 || gpuCount <= 0) {
            return new Device[]{Device.cpu()};
        }
        int min = Math.min(i, gpuCount);
        Device[] deviceArr = new Device[min];
        for (int i2 = 0; i2 < min; i2++) {
            deviceArr[i2] = Device.gpu(i2);
        }
        return deviceArr;
    }

    public int getGpuCount() {
        if (hasCapability(StandardCapabilities.CUDA)) {
            return CudaUtils.getGpuCount();
        }
        return 0;
    }

    public SymbolBlock newSymbolBlock(NDManager nDManager) {
        throw new UnsupportedOperationException("Not supported.");
    }

    public abstract Model newModel(String str, Device device);

    public abstract NDManager newBaseManager();

    public abstract NDManager newBaseManager(Device device);

    public GradientCollector newGradientCollector() {
        throw new UnsupportedOperationException("Not supported.");
    }

    public ParameterServer newParameterServer(Optimizer optimizer) {
        return new LocalParameterServer(optimizer);
    }

    public void setRandomSeed(int i) {
        this.seed = Integer.valueOf(i);
        RandomUtils.RANDOM.setSeed(i);
    }

    public Integer getSeed() {
        return this.seed;
    }

    public static String getDjlVersion() {
        String specificationVersion = Engine.class.getPackage().getSpecificationVersion();
        if (specificationVersion != null) {
            return specificationVersion;
        }
        try {
            InputStream resourceAsStream = Engine.class.getResourceAsStream("api.properties");
            try {
                Properties properties = new Properties();
                properties.load(resourceAsStream);
                String property = properties.getProperty("djl_version");
                if (resourceAsStream != null) {
                    resourceAsStream.close();
                }
                return property;
            } finally {
            }
        } catch (IOException e) {
            throw new AssertionError("Failed to open api.properties", e);
        }
    }

    public String toString() {
        return getEngineName() + ':' + getVersion();
    }

    public static void debugEnvironment() {
        System.out.println("----------- System Properties -----------");
        System.getProperties().forEach((obj, obj2) -> {
            print((String) obj, obj2);
        });
        System.out.println();
        System.out.println("--------- Environment Variables ---------");
        Utils.getenv().forEach((v0, v1) -> {
            print(v0, v1);
        });
        System.out.println();
        System.out.println("-------------- Directories --------------");
        try {
            System.out.println("temp directory: " + Paths.get(System.getProperty("java.io.tmpdir"), new String[0]));
            Files.delete(Files.createTempFile("test", ".tmp", new FileAttribute[0]));
            System.out.println("DJL cache directory: " + Utils.getCacheDir().toAbsolutePath());
            Path engineCacheDir = Utils.getEngineCacheDir();
            System.out.println("Engine cache directory: " + engineCacheDir.toAbsolutePath());
            Files.createDirectories(engineCacheDir, new FileAttribute[0]);
            if (!Files.isWritable(engineCacheDir)) {
                System.out.println("Engine cache directory is not writable!!!");
            }
        } catch (Throwable th) {
            th.printStackTrace(System.out);
        }
        System.out.println();
        System.out.println("------------------ CUDA -----------------");
        int gpuCount = CudaUtils.getGpuCount();
        System.out.println("GPU Count: " + gpuCount);
        if (gpuCount > 0) {
            System.out.println("CUDA: " + CudaUtils.getCudaVersionString());
            System.out.println("ARCH: " + CudaUtils.getComputeCapability(0));
        }
        for (int i = 0; i < gpuCount; i++) {
            System.out.println("GPU(" + i + ") memory used: " + CudaUtils.getGpuMemory(Device.gpu(i)).getCommitted() + " bytes");
        }
        System.out.println();
        System.out.println("----------------- Engines ---------------");
        System.out.println("DJL version: " + getDjlVersion());
        System.out.println("Default Engine: " + getInstance());
        System.out.println("Default Device: " + getInstance().defaultDevice());
        for (EngineProvider engineProvider : ALL_ENGINES.values()) {
            System.out.println(engineProvider.getEngineName() + PluralRules.KEYWORD_RULE_SEPARATOR + engineProvider.getEngineRank());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void print(String str, Object obj) {
        if (PATTERN.matcher(str).find()) {
            obj = "*********";
        }
        System.out.println(str + PluralRules.KEYWORD_RULE_SEPARATOR + obj);
    }
}
