1
1
import gradio as gr
2
- import gc
3
2
import torch
4
3
import os
5
4
6
5
from typing import TYPE_CHECKING
7
6
7
+ from tts_webui .utils .manage_model_state import manage_model_state
8
+ from tts_webui .utils .list_dir_models import unload_model_button
9
+
8
10
if TYPE_CHECKING :
9
11
from transformers import Pipeline
10
12
@@ -14,7 +16,7 @@ def extension__tts_generation_webui():
14
16
return {
15
17
"package_name" : "extension_whisper" ,
16
18
"name" : "Whisper" ,
17
- "version" : "0.0.1 " ,
19
+ "version" : "0.0.2 " ,
18
20
"requirements" : "git+https://github.com/rsxdalv/extension_whisper@main" ,
19
21
"description" : "Whisper allows transcribing audio files." ,
20
22
"extension_type" : "interface" ,
@@ -28,40 +30,59 @@ def extension__tts_generation_webui():
28
30
}
29
31
30
32
31
- local_dir = os .path .join ("data" , "models" , "whisper" )
32
- local_cache_dir = os .path .join (local_dir , "cache" )
33
+ @manage_model_state ("whisper" )
34
+ def get_model (
35
+ model_name = "openai/whisper-large-v3" ,
36
+ torch_dtype = torch .float16 ,
37
+ device = "cuda:0" ,
38
+ compile = False ,
39
+ ):
40
+ from transformers import AutoModelForSpeechSeq2Seq
41
+ from transformers import AutoProcessor
42
+
43
+ model = AutoModelForSpeechSeq2Seq .from_pretrained (
44
+ model_name , torch_dtype = torch_dtype , low_cpu_mem_usage = True
45
+ ).to (device )
46
+ if compile :
47
+ model .generation_config .cache_implementation = "static"
48
+ model .generation_config .max_new_tokens = 256
49
+ model .forward = torch .compile (
50
+ model .forward , mode = "reduce-overhead" , fullgraph = True
51
+ )
33
52
34
- pipe = None
35
- last_model_name = None
53
+ processor = AutoProcessor .from_pretrained (model_name )
36
54
55
+ return model , processor
37
56
38
- def unload_models ():
39
- global pipe , last_model_name
40
- pipe = None
41
- last_model_name = None
42
- gc .collect ()
43
- if torch .cuda .is_available ():
44
- torch .cuda .empty_cache ()
45
- return "Unloaded"
46
57
58
+ local_dir = os .path .join ("data" , "models" , "whisper" )
59
+ local_cache_dir = os .path .join (local_dir , "cache" )
47
60
61
+
62
+ @manage_model_state ("whisper-pipe" )
48
63
def get_pipe (model_name , device = "cuda:0" ) -> "Pipeline" :
49
64
from transformers import pipeline
50
65
51
- global pipe , last_model_name
52
- if pipe is not None :
53
- if model_name == last_model_name :
54
- return pipe
55
- unload_models ()
56
- pipe = pipeline (
57
- "automatic-speech-recognition" ,
66
+ torch_dtype = torch .float16
67
+
68
+ model , processor = get_model (
69
+ # model_name, torch_dtype=torch.float16, device=device, compile=False
58
70
model_name ,
71
+ torch_dtype = torch_dtype ,
72
+ device = device ,
73
+ compile = False ,
74
+ )
75
+ return pipeline (
76
+ "automatic-speech-recognition" ,
77
+ model = model ,
78
+ tokenizer = processor .tokenizer ,
79
+ feature_extractor = processor .feature_extractor ,
80
+ # chunk_length_s=30,
81
+ # batch_size=16, # batch size for inference - set based on your device
59
82
torch_dtype = torch .float16 ,
60
83
model_kwargs = {"cache_dir" : local_cache_dir },
61
84
device = device ,
62
85
)
63
- last_model_name = model_name
64
- return pipe
65
86
66
87
67
88
def transcribe (inputs , model_name = "openai/whisper-large-v3" ):
@@ -72,13 +93,11 @@ def transcribe(inputs, model_name="openai/whisper-large-v3"):
72
93
73
94
pipe = get_pipe (model_name )
74
95
75
- generate_kwargs = (
76
- {"task" : "transcribe" } if model_name == "openai/whisper-large-v3" else {}
77
- )
78
-
79
96
result = pipe (
80
97
inputs ,
81
- generate_kwargs = generate_kwargs ,
98
+ generate_kwargs = (
99
+ {"task" : "transcribe" } if model_name == "openai/whisper-large-v3" else {}
100
+ ),
82
101
return_timestamps = True ,
83
102
)
84
103
return result ["text" ]
@@ -108,7 +127,8 @@ def transcribe_ui():
108
127
text = gr .Textbox (label = "Transcription" , interactive = False )
109
128
110
129
with gr .Row ():
111
- unload_models_button = gr .Button ("Unload models" )
130
+ unload_model_button ("whisper-pipe" )
131
+ unload_model_button ("whisper" )
112
132
113
133
transcribe_button = gr .Button ("Transcribe" , variant = "primary" )
114
134
@@ -117,21 +137,12 @@ def transcribe_ui():
117
137
inputs = [audio , model_dropdown ],
118
138
outputs = [text ],
119
139
api_name = "whisper_transcribe" ,
120
- ).then (
121
- fn = lambda : gr .Button (value = "Unload models" ),
122
- outputs = [unload_models_button ],
123
- )
124
-
125
- unload_models_button .click (
126
- fn = unload_models ,
127
- outputs = [unload_models_button ],
128
- api_name = "whisper_unload_models" ,
129
140
)
130
141
131
142
132
143
if __name__ == "__main__" :
133
144
if "demo" in locals ():
134
- demo .close ()
145
+ locals ()[ " demo" ] .close ()
135
146
136
147
with gr .Blocks () as demo :
137
148
with gr .Tab ("Whisper" ):
0 commit comments