@@ -10,8 +10,10 @@ def gpu_info_tab():
10
10
fn = refresh_gpu_info , outputs = gpu_info , api_name = "refresh_gpu_info"
11
11
)
12
12
13
+ gpu_info_json = gr .JSON (get_gpu_info (), visible = False )
14
+
13
15
gr .Button ("API_GET_GPU_INFO" , visible = False ).click (
14
- fn = get_gpu_info , api_name = "get_gpu_info"
16
+ fn = get_gpu_info , outputs = [ gpu_info_json ], api_name = "get_gpu_info"
15
17
)
16
18
17
19
@@ -24,23 +26,57 @@ def get_gpu_info():
24
26
used_vram_total = (
25
27
torch .cuda .mem_get_info (0 )[1 ] - torch .cuda .mem_get_info (0 )[0 ]
26
28
) / 1024 ** 2
29
+ cached_vram = torch .cuda .memory_reserved (0 ) / 1024 ** 2
30
+ torch_version = torch .__version__
27
31
return {
28
32
"vram" : vram ,
29
33
"name" : name ,
30
34
"cuda_capabilities" : cuda_capabilities ,
31
35
"used_vram" : used_vram ,
32
36
"used_vram_total" : used_vram_total ,
37
+ "cached_vram" : cached_vram ,
38
+ "torch_version" : torch_version ,
33
39
}
34
40
else :
35
- return "No GPU with CUDA support detected by PyTorch"
41
+ # return "No GPU with CUDA support detected by PyTorch"
42
+ return {
43
+ "vram" : 0 ,
44
+ "name" : "No GPU with CUDA support detected by PyTorch" ,
45
+ "cuda_capabilities" : 0 ,
46
+ "used_vram" : 0 ,
47
+ "used_vram_total" : 0 ,
48
+ "cached_vram" : 0 ,
49
+ "torch_version" : 0 ,
50
+ }
36
51
37
52
38
53
def render_gpu_info (gpu_info ):
39
54
if isinstance (gpu_info , dict ):
40
- return f"VRAM: { gpu_info ['vram' ]} MB\n \n Used VRAM: { gpu_info ['used_vram' ]} MB\n \n Total Used VRAM: { gpu_info ['used_vram_total' ]} MB\n \n Name: { gpu_info ['name' ]} \n \n CUDA Capabilities: { gpu_info ['cuda_capabilities' ]} "
55
+ return f"""VRAM: { gpu_info ['vram' ]} MB
56
+
57
+ Used VRAM: { gpu_info ['used_vram' ]} MB
58
+
59
+ Total Used VRAM: { gpu_info ['used_vram_total' ]} MB
60
+
61
+ Name: { gpu_info ['name' ]}
62
+
63
+ CUDA Capabilities: { gpu_info ['cuda_capabilities' ]}
64
+
65
+ Cached VRAM: { gpu_info ['cached_vram' ]} MB
66
+
67
+ Torch Version: { gpu_info ['torch_version' ]} """
41
68
else :
42
69
return gpu_info
43
70
44
71
45
72
def refresh_gpu_info ():
46
73
return render_gpu_info (get_gpu_info ())
74
+
75
+
76
+ if __name__ == "__main__" :
77
+ if "demo" in locals ():
78
+ demo .close () # type: ignore
79
+ with gr .Blocks () as demo :
80
+ gpu_info_tab ()
81
+
82
+ demo .launch ()
0 commit comments