package ai.djl.util.cuda;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import com.sun.jna.Native;
import java.io.File;
import java.lang.management.MemoryUsage;
import java.util.regex.Pattern;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/util/cuda/CudaUtils.class */
public final class CudaUtils {
    private static final Logger logger = LoggerFactory.getLogger(CudaUtils.class);
    private static final CudaLibrary LIB = loadLibrary();

    private CudaUtils() {
    }

    public static boolean hasCuda() {
        return getGpuCount() > 0;
    }

    public static int getGpuCount() {
        if (LIB == null) {
            return 0;
        }
        int[] iArr = new int[1];
        int cudaGetDeviceCount = LIB.cudaGetDeviceCount(iArr);
        switch (cudaGetDeviceCount) {
            case 0:
                return iArr[0];
            case 3:
            case CudaLibrary.INSUFFICIENT_DRIVER /* 35 */:
            case CudaLibrary.ERROR_NOT_PERMITTED /* 800 */:
            default:
                logger.warn("Failed to detect GPU count: {} ({})", LIB.cudaGetErrorString(cudaGetDeviceCount), Integer.valueOf(cudaGetDeviceCount));
                return 0;
            case 100:
                logger.debug("No GPU device found: {} ({})", LIB.cudaGetErrorString(cudaGetDeviceCount), Integer.valueOf(cudaGetDeviceCount));
                return 0;
        }
    }

    public static int getCudaVersion() {
        if (LIB == null) {
            throw new IllegalStateException("No cuda library is loaded.");
        }
        int[] iArr = new int[1];
        checkCall(LIB.cudaRuntimeGetVersion(iArr));
        return iArr[0];
    }

    public static String getCudaVersionString() {
        if (LIB == null) {
            throw new IllegalStateException("No cuda library is loaded.");
        }
        int cudaVersion = getCudaVersion();
        return String.valueOf(cudaVersion / TarArchiveEntry.MILLIS_PER_SECOND) + ((cudaVersion / 10) % 10);
    }

    public static String getComputeCapability(int i) {
        if (LIB == null) {
            throw new IllegalStateException("No cuda library is loaded.");
        }
        int[] iArr = new int[1];
        int[] iArr2 = new int[1];
        checkCall(LIB.cudaDeviceGetAttribute(iArr, 75, i));
        checkCall(LIB.cudaDeviceGetAttribute(iArr2, 76, i));
        return String.valueOf(iArr[0]) + iArr2[0];
    }

    public static MemoryUsage getGpuMemory(Device device) {
        if (!Device.Type.GPU.equals(device.getDeviceType())) {
            throw new IllegalArgumentException("Only GPU device is allowed.");
        }
        if (LIB == null) {
            throw new IllegalStateException("No GPU device detected.");
        }
        int[] iArr = new int[1];
        checkCall(LIB.cudaGetDevice(iArr));
        checkCall(LIB.cudaSetDevice(device.getDeviceId()));
        long[] jArr = new long[1];
        long[] jArr2 = new long[1];
        checkCall(LIB.cudaMemGetInfo(jArr, jArr2));
        checkCall(LIB.cudaSetDevice(iArr[0]));
        long j = jArr2[0] - jArr[0];
        return new MemoryUsage(-1L, j, j, jArr2[0]);
    }

    private static CudaLibrary loadLibrary() {
        try {
            if (!System.getProperty("os.name").startsWith("Win")) {
                return (CudaLibrary) Native.load("cudart", CudaLibrary.class);
            }
            String str = System.getenv("PATH");
            if (str == null) {
                return null;
            }
            Pattern compile = Pattern.compile("cudart64_\\d+\\.dll");
            String str2 = System.getenv("CUDA_PATH");
            for (String str3 : str2 == null ? str.split(";") : (str2 + "\\bin\\;" + str).split(";")) {
                File[] listFiles = new File(str3).listFiles(file -> {
                    return compile.matcher(file.getName()).matches();
                });
                if (listFiles != null && listFiles.length > 0) {
                    String name = listFiles[0].getName();
                    String substring = name.substring(0, name.length() - 4);
                    logger.debug("Found cudart: {}", listFiles[0].getAbsolutePath());
                    return (CudaLibrary) Native.load(substring, CudaLibrary.class);
                }
            }
            logger.debug("No cudart library found in path.");
            return null;
        } catch (UnsatisfiedLinkError e) {
            logger.debug("cudart library not found.");
            logger.trace("", e);
            return null;
        }
    }

    private static void checkCall(int i) {
        if (LIB == null) {
            throw new IllegalStateException("No cuda library is loaded.");
        }
        if (i != 0) {
            throw new EngineException("CUDA API call failed: " + LIB.cudaGetErrorString(i) + " (" + i + ')');
        }
    }
}
