Skip to content

Commit cad2cf5

Browse files
authored
redesign readme (rsxdalv#399)
* improve whisper extension * write changelog for August, Semptember and October * improve readme formatting and organization * cleanup old files * add banner * switch to new samples * reorder * update screenshots
1 parent 03ed05b commit cad2cf5

32 files changed

+212
-170
lines changed

README.md

+162-92
Large diffs are not rendered by default.

check_cuda.py

-23
This file was deleted.

extensions/builtin/extension_whisper/main.py

+50-39
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import gradio as gr
2-
import gc
32
import torch
43
import os
54

65
from typing import TYPE_CHECKING
76

7+
from tts_webui.utils.manage_model_state import manage_model_state
8+
from tts_webui.utils.list_dir_models import unload_model_button
9+
810
if TYPE_CHECKING:
911
from transformers import Pipeline
1012

@@ -14,7 +16,7 @@ def extension__tts_generation_webui():
1416
return {
1517
"package_name": "extension_whisper",
1618
"name": "Whisper",
17-
"version": "0.0.1",
19+
"version": "0.0.2",
1820
"requirements": "git+https://github.com/rsxdalv/extension_whisper@main",
1921
"description": "Whisper allows transcribing audio files.",
2022
"extension_type": "interface",
@@ -28,40 +30,59 @@ def extension__tts_generation_webui():
2830
}
2931

3032

31-
local_dir = os.path.join("data", "models", "whisper")
32-
local_cache_dir = os.path.join(local_dir, "cache")
33+
@manage_model_state("whisper")
34+
def get_model(
35+
model_name="openai/whisper-large-v3",
36+
torch_dtype=torch.float16,
37+
device="cuda:0",
38+
compile=False,
39+
):
40+
from transformers import AutoModelForSpeechSeq2Seq
41+
from transformers import AutoProcessor
42+
43+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
44+
model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True
45+
).to(device)
46+
if compile:
47+
model.generation_config.cache_implementation = "static"
48+
model.generation_config.max_new_tokens = 256
49+
model.forward = torch.compile(
50+
model.forward, mode="reduce-overhead", fullgraph=True
51+
)
3352

34-
pipe = None
35-
last_model_name = None
53+
processor = AutoProcessor.from_pretrained(model_name)
3654

55+
return model, processor
3756

38-
def unload_models():
39-
global pipe, last_model_name
40-
pipe = None
41-
last_model_name = None
42-
gc.collect()
43-
if torch.cuda.is_available():
44-
torch.cuda.empty_cache()
45-
return "Unloaded"
4657

58+
local_dir = os.path.join("data", "models", "whisper")
59+
local_cache_dir = os.path.join(local_dir, "cache")
4760

61+
62+
@manage_model_state("whisper-pipe")
4863
def get_pipe(model_name, device="cuda:0") -> "Pipeline":
4964
from transformers import pipeline
5065

51-
global pipe, last_model_name
52-
if pipe is not None:
53-
if model_name == last_model_name:
54-
return pipe
55-
unload_models()
56-
pipe = pipeline(
57-
"automatic-speech-recognition",
66+
torch_dtype = torch.float16
67+
68+
model, processor = get_model(
69+
# model_name, torch_dtype=torch.float16, device=device, compile=False
5870
model_name,
71+
torch_dtype=torch_dtype,
72+
device=device,
73+
compile=False,
74+
)
75+
return pipeline(
76+
"automatic-speech-recognition",
77+
model=model,
78+
tokenizer=processor.tokenizer,
79+
feature_extractor=processor.feature_extractor,
80+
# chunk_length_s=30,
81+
# batch_size=16, # batch size for inference - set based on your device
5982
torch_dtype=torch.float16,
6083
model_kwargs={"cache_dir": local_cache_dir},
6184
device=device,
6285
)
63-
last_model_name = model_name
64-
return pipe
6586

6687

6788
def transcribe(inputs, model_name="openai/whisper-large-v3"):
@@ -72,13 +93,11 @@ def transcribe(inputs, model_name="openai/whisper-large-v3"):
7293

7394
pipe = get_pipe(model_name)
7495

75-
generate_kwargs = (
76-
{"task": "transcribe"} if model_name == "openai/whisper-large-v3" else {}
77-
)
78-
7996
result = pipe(
8097
inputs,
81-
generate_kwargs=generate_kwargs,
98+
generate_kwargs=(
99+
{"task": "transcribe"} if model_name == "openai/whisper-large-v3" else {}
100+
),
82101
return_timestamps=True,
83102
)
84103
return result["text"]
@@ -108,7 +127,8 @@ def transcribe_ui():
108127
text = gr.Textbox(label="Transcription", interactive=False)
109128

110129
with gr.Row():
111-
unload_models_button = gr.Button("Unload models")
130+
unload_model_button("whisper-pipe")
131+
unload_model_button("whisper")
112132

113133
transcribe_button = gr.Button("Transcribe", variant="primary")
114134

@@ -117,21 +137,12 @@ def transcribe_ui():
117137
inputs=[audio, model_dropdown],
118138
outputs=[text],
119139
api_name="whisper_transcribe",
120-
).then(
121-
fn=lambda: gr.Button(value="Unload models"),
122-
outputs=[unload_models_button],
123-
)
124-
125-
unload_models_button.click(
126-
fn=unload_models,
127-
outputs=[unload_models_button],
128-
api_name="whisper_unload_models",
129140
)
130141

131142

132143
if __name__ == "__main__":
133144
if "demo" in locals():
134-
demo.close()
145+
locals()["demo"].close()
135146

136147
with gr.Blocks() as demo:
137148
with gr.Tab("Whisper"):
File renamed without changes.

samples/Bark Japanese.mp4

40.6 KB
Binary file not shown.

samples/Bark Narration.mp4

187 KB
Binary file not shown.

samples/MusicGen.mp4

135 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

screenshots/banner.png

76.6 KB
Loading

screenshots/gradio (1).png

447 KB
Loading

screenshots/gradio (2).png

450 KB
Loading

screenshots/gradio (3).png

432 KB
Loading

screenshots/react_ui (1).png

589 KB
Loading

screenshots/react_ui (2).png

570 KB
Loading

screenshots/react_ui (3).png

521 KB
Loading

screenshots/screenshot (1).png

-236 KB
Binary file not shown.

screenshots/screenshot (2).png

-95.8 KB
Binary file not shown.

screenshots/screenshot (3).png

-221 KB
Binary file not shown.

screenshots/screenshot (4).png

-122 KB
Binary file not shown.

screenshots/screenshot (5).png

-112 KB
Binary file not shown.

screenshots/screenshot (6).png

-168 KB
Binary file not shown.

screenshots/screenshot (7).png

-178 KB
Binary file not shown.

screenshots/screenshot (8).png

-85.2 KB
Binary file not shown.

screenshots/v2/cloning.png

-84.5 KB
Binary file not shown.

screenshots/v2/generation.jpg

-178 KB
Binary file not shown.

screenshots/v2/history.jpg

-177 KB
Binary file not shown.

screenshots/v2/musicgen.png

-192 KB
Binary file not shown.

screenshots/v2/react.png

-150 KB
Binary file not shown.

screenshots/v2/rvc.png

-142 KB
Binary file not shown.

update.py

-16
This file was deleted.

0 commit comments

Comments
 (0)