package ai.djl.util.cuda;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.util.Utils;
import com.sun.jna.Native;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.management.MemoryUsage;
import java.util.ArrayList;
import java.util.Locale;
import java.util.regex.Pattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/util/cuda/CudaUtils.class */
public final class CudaUtils {
    private static final Logger logger = LoggerFactory.getLogger(CudaUtils.class);
    private static final CudaLibrary LIB = loadLibrary();
    private static String[] gpuInfo;

    private CudaUtils() {
    }

    public static boolean hasCuda() {
        return getGpuCount() > 0;
    }

    public static int getGpuCount() {
        if (Boolean.getBoolean("ai.djl.util.cuda.fork")) {
            if (gpuInfo == null) {
                gpuInfo = execute(-1);
            }
            return Integer.parseInt(gpuInfo[0]);
        }
        if (LIB == null) {
            return 0;
        }
        int[] iArr = new int[1];
        int cudaGetDeviceCount = LIB.cudaGetDeviceCount(iArr);
        switch (cudaGetDeviceCount) {
            case 0:
                return iArr[0];
            case CudaLibrary.INITIALIZATION_ERROR /* 3 */:
            case CudaLibrary.INSUFFICIENT_DRIVER /* 35 */:
            case CudaLibrary.ERROR_NOT_PERMITTED /* 800 */:
            default:
                logger.warn("Failed to detect GPU count: {} ({})", LIB.cudaGetErrorString(cudaGetDeviceCount), Integer.valueOf(cudaGetDeviceCount));
                return 0;
            case CudaLibrary.ERROR_NO_DEVICE /* 100 */:
                logger.debug("No GPU device found: {} ({})", LIB.cudaGetErrorString(cudaGetDeviceCount), Integer.valueOf(cudaGetDeviceCount));
                return 0;
        }
    }

    public static int getCudaVersion() {
        if (!Boolean.getBoolean("ai.djl.util.cuda.fork")) {
            if (LIB == null) {
                throw new IllegalStateException("No cuda library is loaded.");
            }
            int[] iArr = new int[1];
            checkCall(LIB.cudaRuntimeGetVersion(iArr));
            return iArr[0];
        }
        if (gpuInfo == null) {
            gpuInfo = execute(-1);
        }
        int parseInt = Integer.parseInt(gpuInfo[1]);
        if (parseInt == -1) {
            throw new IllegalArgumentException("No cuda device found.");
        }
        return parseInt;
    }

    public static String getCudaVersionString() {
        int cudaVersion = getCudaVersion();
        return String.format(Locale.ROOT, "%02d", Integer.valueOf(cudaVersion / 1000)) + ((cudaVersion / 10) % 10);
    }

    public static String getComputeCapability(int i) {
        if (Boolean.getBoolean("ai.djl.util.cuda.fork")) {
            String[] execute = execute(i);
            if (execute.length != 3) {
                throw new IllegalArgumentException(execute[0]);
            }
            return execute[0];
        }
        if (LIB == null) {
            throw new IllegalStateException("No cuda library is loaded.");
        }
        int[] iArr = new int[1];
        int[] iArr2 = new int[1];
        checkCall(LIB.cudaDeviceGetAttribute(iArr, 75, i));
        checkCall(LIB.cudaDeviceGetAttribute(iArr2, 76, i));
        return String.valueOf(iArr[0]) + iArr2[0];
    }

    public static MemoryUsage getGpuMemory(Device device) {
        if (!device.isGpu()) {
            throw new IllegalArgumentException("Only GPU device is allowed.");
        }
        if (Boolean.getBoolean("ai.djl.util.cuda.fork")) {
            String[] execute = execute(device.getDeviceId());
            if (execute.length != 3) {
                throw new IllegalArgumentException(execute[0]);
            }
            long parseLong = Long.parseLong(execute[1]);
            long parseLong2 = Long.parseLong(execute[2]);
            return new MemoryUsage(-1L, parseLong2, parseLong2, parseLong);
        }
        if (LIB == null) {
            throw new IllegalStateException("No GPU device detected.");
        }
        int[] iArr = new int[1];
        checkCall(LIB.cudaGetDevice(iArr));
        checkCall(LIB.cudaSetDevice(device.getDeviceId()));
        long[] jArr = new long[1];
        long[] jArr2 = new long[1];
        checkCall(LIB.cudaMemGetInfo(jArr, jArr2));
        checkCall(LIB.cudaSetDevice(iArr[0]));
        long j = jArr2[0] - jArr[0];
        return new MemoryUsage(-1L, j, j, jArr2[0]);
    }

    public static void main(String[] strArr) {
        int gpuCount = getGpuCount();
        if (strArr.length == 0) {
            if (gpuCount <= 0) {
                System.out.println("0,-1");
                return;
            } else {
                System.out.println(gpuCount + "," + getCudaVersion());
                return;
            }
        }
        try {
            int parseInt = Integer.parseInt(strArr[0]);
            if (parseInt < 0 || parseInt >= gpuCount) {
                System.out.println("Invalid device: " + parseInt);
                return;
            }
            MemoryUsage gpuMemory = getGpuMemory(Device.gpu(parseInt));
            System.out.println(getComputeCapability(parseInt) + ',' + gpuMemory.getMax() + ',' + gpuMemory.getUsed());
        } catch (NumberFormatException e) {
            System.out.println("Invalid device: " + strArr[0]);
        }
    }

    private static CudaLibrary loadLibrary() {
        try {
            if (Boolean.getBoolean("ai.djl.util.cuda.fork")) {
                return null;
            }
            if (!System.getProperty("os.name").startsWith("Win")) {
                return (CudaLibrary) Native.load("cudart", CudaLibrary.class);
            }
            String str = Utils.getenv("PATH");
            if (str == null) {
                return null;
            }
            Pattern compile = Pattern.compile("cudart64_\\d+\\.dll");
            String str2 = Utils.getenv("CUDA_PATH");
            for (String str3 : str2 == null ? str.split(";") : (str2 + "\\bin\\;" + str).split(";")) {
                File[] listFiles = new File(str3).listFiles(file -> {
                    return compile.matcher(file.getName()).matches();
                });
                if (listFiles != null && listFiles.length > 0) {
                    String name = listFiles[0].getName();
                    String substring = name.substring(0, name.length() - 4);
                    logger.debug("Found cudart: {}", listFiles[0].getAbsolutePath());
                    return (CudaLibrary) Native.load(substring, CudaLibrary.class);
                }
            }
            logger.debug("No cudart library found in path.");
            return null;
        } catch (SecurityException e) {
            logger.warn("Access denied during loading cudart library.");
            logger.trace("", e);
            return null;
        } catch (UnsatisfiedLinkError e2) {
            logger.debug("cudart library not found.");
            logger.trace("", e2);
            return null;
        } catch (LinkageError e3) {
            logger.warn("You have a conflict version of JNA in the classpath.");
            logger.debug("", e3);
            return null;
        }
    }

    private static String[] execute(int i) {
        try {
            String property = System.getProperty("java.home");
            String property2 = System.getProperty("java.class.path");
            String property3 = System.getProperty("os.name");
            ArrayList arrayList = new ArrayList(4);
            if (property3.startsWith("Win")) {
                arrayList.add(property + "\\bin\\java.exe");
            } else {
                arrayList.add(property + "/bin/java");
            }
            arrayList.add("-cp");
            arrayList.add(property2);
            arrayList.add("ai.djl.util.cuda.CudaUtils");
            if (i >= 0) {
                arrayList.add(String.valueOf(i));
            }
            InputStream inputStream = new ProcessBuilder(arrayList).redirectErrorStream(true).start().getInputStream();
            try {
                String[] split = Utils.toString(inputStream).trim().split(",");
                if (inputStream != null) {
                    inputStream.close();
                }
                return split;
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalArgumentException("Failed get GPU information", e);
        }
    }

    private static void checkCall(int i) {
        if (LIB == null) {
            throw new IllegalStateException("No cuda library is loaded.");
        }
        if (i != 0) {
            throw new EngineException("CUDA API call failed: " + LIB.cudaGetErrorString(i) + " (" + i + ')');
        }
    }
}
