Skip to content

Commit

Permalink
[api] Use folk java process to avoid jvm consume GPU memory (#2882)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Apr 26, 2024
1 parent 5f1f0fd commit 5c66606
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 12 deletions.
108 changes: 105 additions & 3 deletions api/src/main/java/ai/djl/util/cuda/CudaUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.management.MemoryUsage;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.regex.Pattern;

Expand All @@ -33,6 +37,8 @@ public final class CudaUtils {

private static final CudaLibrary LIB = loadLibrary();

private static String[] gpuInfo;

private CudaUtils() {}

/**
Expand All @@ -49,7 +55,15 @@ public static boolean hasCuda() {
*
* @return the number of GPUs available in the system
*/
@SuppressWarnings("PMD.NonThreadSafeSingleton")
public static int getGpuCount() {
if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
if (gpuInfo == null) {
gpuInfo = execute(-1); // NOPMD
}
return Integer.parseInt(gpuInfo[0]);
}

if (LIB == null) {
return 0;
}
Expand Down Expand Up @@ -79,7 +93,19 @@ public static int getGpuCount() {
*
* @return the version of CUDA runtime
*/
@SuppressWarnings("PMD.NonThreadSafeSingleton")
public static int getCudaVersion() {
if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
if (gpuInfo == null) {
gpuInfo = execute(-1);
}
int version = Integer.parseInt(gpuInfo[1]);
if (version == -1) {
throw new IllegalArgumentException("No cuda device found.");
}
return version;
}

if (LIB == null) {
throw new IllegalStateException("No cuda library is loaded.");
}
Expand All @@ -95,9 +121,6 @@ public static int getCudaVersion() {
* @return the version string of CUDA runtime
*/
public static String getCudaVersionString() {
if (LIB == null) {
throw new IllegalStateException("No cuda library is loaded.");
}
int version = getCudaVersion();
int major = version / 1000;
int minor = (version / 10) % 10;
Expand All @@ -111,6 +134,14 @@ public static String getCudaVersionString() {
* @return the CUDA compute capability
*/
public static String getComputeCapability(int device) {
if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
String[] ret = execute(device);
if (ret.length != 3) {
throw new IllegalArgumentException(ret[0]);
}
return ret[0];
}

if (LIB == null) {
throw new IllegalStateException("No cuda library is loaded.");
}
Expand All @@ -137,6 +168,16 @@ public static MemoryUsage getGpuMemory(Device device) {
throw new IllegalArgumentException("Only GPU device is allowed.");
}

if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
String[] ret = execute(device.getDeviceId());
if (ret.length != 3) {
throw new IllegalArgumentException(ret[0]);
}
long total = Long.parseLong(ret[1]);
long used = Long.parseLong(ret[2]);
return new MemoryUsage(-1, used, used, total);
}

if (LIB == null) {
throw new IllegalStateException("No GPU device detected.");
}
Expand All @@ -155,8 +196,42 @@ public static MemoryUsage getGpuMemory(Device device) {
return new MemoryUsage(-1, committed, committed, total[0]);
}

/**
* The main entrypoint to get CUDA information with command line.
*
* @param args the command line arguments.
*/
@SuppressWarnings("PMD.SystemPrintln")
public static void main(String[] args) {
int gpuCount = getGpuCount();
if (args.length == 0) {
if (gpuCount <= 0) {
System.out.println("0,-1");
return;
}
int cudaVersion = getCudaVersion();
System.out.println(gpuCount + "," + cudaVersion);
return;
}
try {
int deviceId = Integer.parseInt(args[0]);
if (deviceId < 0 || deviceId >= gpuCount) {
System.out.println("Invalid device: " + deviceId);
return;
}
MemoryUsage mem = getGpuMemory(Device.gpu(deviceId));
String cc = getComputeCapability(deviceId);
System.out.println(cc + ',' + mem.getMax() + ',' + mem.getUsed());
} catch (NumberFormatException e) {
System.out.println("Invalid device: " + args[0]);
}
}

private static CudaLibrary loadLibrary() {
try {
if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
return null;
}
if (System.getProperty("os.name").startsWith("Win")) {
String path = Utils.getenv("PATH");
if (path == null) {
Expand Down Expand Up @@ -199,6 +274,33 @@ private static CudaLibrary loadLibrary() {
}
}

private static String[] execute(int deviceId) {
try {
String javaHome = System.getProperty("java.home");
String classPath = System.getProperty("java.class.path");
String os = System.getProperty("os.name");
List<String> cmd = new ArrayList<>(4);
if (os.startsWith("Win")) {
cmd.add(javaHome + "\\bin\\java.exe");
} else {
cmd.add(javaHome + "/bin/java");
}
cmd.add("-cp");
cmd.add(classPath);
cmd.add("ai.djl.util.cuda.CudaUtils");
if (deviceId >= 0) {
cmd.add(String.valueOf(deviceId));
}
Process ps = new ProcessBuilder(cmd).redirectErrorStream(true).start();
try (InputStream is = ps.getInputStream()) {
String line = Utils.toString(is).trim();
return line.split(",");
}
} catch (IOException e) {
throw new IllegalArgumentException("Failed get GPU information", e);
}
}

private static void checkCall(int ret) {
if (LIB == null) {
throw new IllegalStateException("No cuda library is loaded.");
Expand Down
9 changes: 6 additions & 3 deletions api/src/test/java/ai/djl/util/SecurityManagerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ public void checkPermission(Permission perm) {
}
};
System.setSecurityManager(sm);

Assert.assertFalse(CudaUtils.hasCuda());
Assert.assertEquals(CudaUtils.getGpuCount(), 0);
try {
Assert.assertFalse(CudaUtils.hasCuda());
Assert.assertEquals(CudaUtils.getGpuCount(), 0);
} finally {
System.setSecurityManager(null);
}
}
}
21 changes: 15 additions & 6 deletions api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import org.testng.annotations.Test;

import java.lang.management.MemoryUsage;
import java.util.Arrays;
import java.util.List;

public class CudaUtilsTest {

Expand All @@ -30,23 +28,34 @@ public class CudaUtilsTest {
@Test
public void testCudaUtils() {
if (!CudaUtils.hasCuda()) {
Assert.assertThrows(CudaUtils::getCudaVersionString);
Assert.assertThrows(() -> CudaUtils.getComputeCapability(0));
Assert.assertThrows(() -> CudaUtils.getGpuMemory(Device.gpu()));
return;
}
// Possible to have CUDA and not have a GPU.
if (CudaUtils.getGpuCount() == 0) {
return;
}

int cudaVersion = CudaUtils.getCudaVersion();
String cudaVersion = CudaUtils.getCudaVersionString();
String smVersion = CudaUtils.getComputeCapability(0);
MemoryUsage memoryUsage = CudaUtils.getGpuMemory(Device.gpu());

logger.info("CUDA runtime version: {}, sm: {}", cudaVersion, smVersion);
logger.info("Memory usage: {}", memoryUsage);

Assert.assertTrue(cudaVersion >= 9020, "cuda 9.2+ required.");
Assert.assertNotNull(cudaVersion);
Assert.assertNotNull(smVersion);
}

List<String> supportedSm = Arrays.asList("37", "52", "60", "61", "70", "75");
Assert.assertTrue(supportedSm.contains(smVersion), "Unsupported cuda sm: " + smVersion);
@Test
public void testCudaUtilsWithFolk() {
System.setProperty("ai.djl.util.cuda.folk", "true");
try {
testCudaUtils();
} finally {
System.clearProperty("ai.djl.util.cuda.folk");
}
}
}

0 comments on commit 5c66606

Please sign in to comment.