Skip to content

Commit 3cc1095

Browse files
authored
GPU Info in React (rsxdalv#305)
* add GPU Info to React UI * improve display
1 parent 1858fc6 commit 3cc1095

File tree

5 files changed

+177
-3
lines changed

5 files changed

+177
-3
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ https://github.com/rsxdalv/tts-generation-webui/discussions/186#discussioncommen
4646
## Changelog
4747
Apr 28:
4848
* Add Maha TTS to React UI.
49+
* Add GPU Info to React UI.
4950

5051
Apr 6:
5152
* Add Vall-E-X generation demo tab.

react-ui/src/components/Header.tsx

+4
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ const routes: Route[] = [
112112
href: "/maha-tts",
113113
text: "Maha TTS",
114114
},
115+
{
116+
href: "/gpu_info",
117+
text: "GPU Info",
118+
},
115119
{
116120
href: "https://echo.ps.ai/?utm_source=bark_speaker_directory",
117121
text: <span>More Voices ↗</span>,

react-ui/src/pages/api/gradio/[name].tsx

+12
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,16 @@ async function maha_tts_refresh_voices() {
716716
return result?.data[0].choices.map((x) => x[0]);
717717
}
718718

719+
async function get_gpu_info() {
720+
const app = await getClient();
721+
722+
const result = (await app.predict("/get_gpu_info")) as {
723+
data: [Object];
724+
};
725+
726+
return result?.data[0];
727+
}
728+
719729
const endpoints = {
720730
maha,
721731
maha_tts_refresh_voices,
@@ -753,4 +763,6 @@ const endpoints = {
753763
save_environment_variables_bark,
754764
save_config_bark,
755765
get_config_bark,
766+
767+
get_gpu_info,
756768
};

react-ui/src/pages/gpu_info.tsx

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import React from "react";
2+
import { Template } from "../components/Template";
3+
import Head from "next/head";
4+
5+
type GPUInfo = {
6+
vram: number;
7+
name: string;
8+
cuda_capabilities: number[];
9+
used_vram: number;
10+
used_vram_total: number;
11+
cached_vram: number;
12+
torch_version: string;
13+
};
14+
15+
const REFRESH_RATE = 500;
16+
17+
const ProgressBar = ({
18+
label,
19+
value,
20+
total,
21+
}: {
22+
label: string;
23+
value: number;
24+
total: number;
25+
}) => {
26+
const percentage = (value / total) * 100;
27+
return (
28+
<div className="flex items-center">
29+
<p className="text-sm w-36">
30+
{label}: <br />
31+
[{value.toFixed(0)} MB]
32+
</p>
33+
<div className="flex w-2/3">
34+
<div
35+
style={{
36+
width: `${percentage}%`,
37+
height: "10px",
38+
}}
39+
className="bg-orange-400"
40+
></div>
41+
<div
42+
style={{
43+
width: `${100 - percentage}%`,
44+
height: "10px",
45+
}}
46+
className="bg-slate-300"
47+
></div>
48+
</div>
49+
</div>
50+
);
51+
};
52+
53+
const GPUInfoWidget = ({}) => {
54+
const [gpuData, setGPUData] = React.useState<GPUInfo>({
55+
vram: 0,
56+
name: "",
57+
cuda_capabilities: [],
58+
used_vram: 0,
59+
used_vram_total: 0,
60+
cached_vram: 0,
61+
torch_version: "",
62+
});
63+
const [loading, setLoading] = React.useState<boolean>(false);
64+
65+
const fetchGPUData = async () => {
66+
setLoading(true);
67+
const response = await fetch("/api/gradio/get_gpu_info", {
68+
method: "POST",
69+
});
70+
71+
const result = await response.json();
72+
setGPUData(result);
73+
setLoading(false);
74+
};
75+
76+
React.useEffect(() => {
77+
fetchGPUData();
78+
const interval = setInterval(fetchGPUData, REFRESH_RATE);
79+
return () => clearInterval(interval);
80+
}, []);
81+
82+
return (
83+
<div className="flex flex-col gap-2 w-3/4">
84+
<h2 className="text-lg">
85+
{gpuData.name} [{Math.round(gpuData.vram / 1024)} GB]
86+
</h2>
87+
<h3>Compute Capability: {gpuData.cuda_capabilities.join(".")}</h3>
88+
<h3>PyTorch Version: {gpuData.torch_version}</h3>
89+
<ProgressBar
90+
label="Used VRAM"
91+
value={gpuData.used_vram}
92+
total={gpuData.vram}
93+
/>
94+
<ProgressBar
95+
label="Cached VRAM"
96+
value={gpuData.cached_vram}
97+
total={gpuData.vram}
98+
/>
99+
<ProgressBar
100+
label="Used VRAM System"
101+
value={gpuData.used_vram_total}
102+
total={gpuData.vram}
103+
/>
104+
</div>
105+
);
106+
};
107+
108+
const GPUInfoPage = () => {
109+
return (
110+
<Template>
111+
<Head>
112+
<title>GPU Info - TTS Generation Webui</title>
113+
</Head>
114+
<div className="gap-y-4 p-4 flex w-full flex-col items-center">
115+
<GPUInfoWidget />
116+
</div>
117+
</Template>
118+
);
119+
};
120+
121+
export default GPUInfoPage;

src/utils/gpu_info_tab.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ def gpu_info_tab():
1010
fn=refresh_gpu_info, outputs=gpu_info, api_name="refresh_gpu_info"
1111
)
1212

13+
gpu_info_json = gr.JSON(get_gpu_info(), visible=False)
14+
1315
gr.Button("API_GET_GPU_INFO", visible=False).click(
14-
fn=get_gpu_info, api_name="get_gpu_info"
16+
fn=get_gpu_info, outputs=[gpu_info_json], api_name="get_gpu_info"
1517
)
1618

1719

@@ -24,23 +26,57 @@ def get_gpu_info():
2426
used_vram_total = (
2527
torch.cuda.mem_get_info(0)[1] - torch.cuda.mem_get_info(0)[0]
2628
) / 1024**2
29+
cached_vram = torch.cuda.memory_reserved(0) / 1024**2
30+
torch_version = torch.__version__
2731
return {
2832
"vram": vram,
2933
"name": name,
3034
"cuda_capabilities": cuda_capabilities,
3135
"used_vram": used_vram,
3236
"used_vram_total": used_vram_total,
37+
"cached_vram": cached_vram,
38+
"torch_version": torch_version,
3339
}
3440
else:
35-
return "No GPU with CUDA support detected by PyTorch"
41+
# return "No GPU with CUDA support detected by PyTorch"
42+
return {
43+
"vram": 0,
44+
"name": "No GPU with CUDA support detected by PyTorch",
45+
"cuda_capabilities": 0,
46+
"used_vram": 0,
47+
"used_vram_total": 0,
48+
"cached_vram": 0,
49+
"torch_version": 0,
50+
}
3651

3752

3853
def render_gpu_info(gpu_info):
3954
if isinstance(gpu_info, dict):
40-
return f"VRAM: {gpu_info['vram']} MB\n\nUsed VRAM: {gpu_info['used_vram']} MB\n\nTotal Used VRAM: {gpu_info['used_vram_total']} MB\n\nName: {gpu_info['name']}\n\nCUDA Capabilities: {gpu_info['cuda_capabilities']}"
55+
return f"""VRAM: {gpu_info['vram']} MB
56+
57+
Used VRAM: {gpu_info['used_vram']} MB
58+
59+
Total Used VRAM: {gpu_info['used_vram_total']} MB
60+
61+
Name: {gpu_info['name']}
62+
63+
CUDA Capabilities: {gpu_info['cuda_capabilities']}
64+
65+
Cached VRAM: {gpu_info['cached_vram']} MB
66+
67+
Torch Version: {gpu_info['torch_version']}"""
4168
else:
4269
return gpu_info
4370

4471

4572
def refresh_gpu_info():
4673
return render_gpu_info(get_gpu_info())
74+
75+
76+
if __name__ == "__main__":
77+
if "demo" in locals():
78+
demo.close() # type: ignore
79+
with gr.Blocks() as demo:
80+
gpu_info_tab()
81+
82+
demo.launch()

0 commit comments

Comments
 (0)