Skip to content

Commit 9e86e9d

Browse files
Add encoder part of whisper large v3 as an audio encoder model. (comfyanonymous#9894)
Not useful yet but some models use it.
1 parent 9d4eb9a commit 9e86e9d

File tree

2 files changed

+224
-20
lines changed

2 files changed

+224
-20
lines changed
Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .wav2vec2 import Wav2Vec2Model
2+
from .whisper import WhisperLargeV3
23
import comfy.model_management
34
import comfy.ops
45
import comfy.utils
@@ -11,13 +12,18 @@ def __init__(self, config):
1112
self.load_device = comfy.model_management.text_encoder_device()
1213
offload_device = comfy.model_management.text_encoder_offload_device()
1314
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
15+
model_type = config.pop("model_type")
1416
model_config = dict(config)
1517
model_config.update({
1618
"dtype": self.dtype,
1719
"device": offload_device,
1820
"operations": comfy.ops.manual_cast
1921
})
20-
self.model = Wav2Vec2Model(**model_config)
22+
23+
if model_type == "wav2vec2":
24+
self.model = Wav2Vec2Model(**model_config)
25+
elif model_type == "whisper3":
26+
self.model = WhisperLargeV3(**model_config)
2127
self.model.eval()
2228
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
2329
self.model_sample_rate = 16000
@@ -40,33 +46,45 @@ def encode_audio(self, audio, sample_rate):
4046

4147
def load_audio_encoder_from_sd(sd, prefix=""):
4248
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
43-
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
44-
if embed_dim == 1024:# large
45-
config = {
46-
"embed_dim": 1024,
47-
"num_heads": 16,
48-
"num_layers": 24,
49-
"conv_norm": True,
50-
"conv_bias": True,
51-
"do_normalize": True,
52-
"do_stable_layer_norm": True
49+
if "encoder.layer_norm.bias" in sd: #wav2vec2
50+
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
51+
if embed_dim == 1024:# large
52+
config = {
53+
"model_type": "wav2vec2",
54+
"embed_dim": 1024,
55+
"num_heads": 16,
56+
"num_layers": 24,
57+
"conv_norm": True,
58+
"conv_bias": True,
59+
"do_normalize": True,
60+
"do_stable_layer_norm": True
61+
}
62+
elif embed_dim == 768: # base
63+
config = {
64+
"model_type": "wav2vec2",
65+
"embed_dim": 768,
66+
"num_heads": 12,
67+
"num_layers": 12,
68+
"conv_norm": False,
69+
"conv_bias": False,
70+
"do_normalize": False, # chinese-wav2vec2-base has this False
71+
"do_stable_layer_norm": False
5372
}
54-
elif embed_dim == 768: # base
73+
else:
74+
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
75+
elif "model.encoder.embed_positions.weight" in sd:
76+
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
5577
config = {
56-
"embed_dim": 768,
57-
"num_heads": 12,
58-
"num_layers": 12,
59-
"conv_norm": False,
60-
"conv_bias": False,
61-
"do_normalize": False, # chinese-wav2vec2-base has this False
62-
"do_stable_layer_norm": False
78+
"model_type": "whisper3",
6379
}
6480
else:
65-
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
81+
raise RuntimeError("ERROR: audio encoder not supported.")
6682

6783
audio_encoder = AudioEncoderModel(config)
6884
m, u = audio_encoder.load_sd(sd)
6985
if len(m) > 0:
7086
logging.warning("missing audio encoder: {}".format(m))
87+
if len(u) > 0:
88+
logging.warning("unexpected audio encoder: {}".format(u))
7189

7290
return audio_encoder

comfy/audio_encoders/whisper.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torchaudio
5+
from typing import Optional
6+
from comfy.ldm.modules.attention import optimized_attention_masked
7+
import comfy.ops
8+
9+
class WhisperFeatureExtractor(nn.Module):
10+
def __init__(self, n_mels=128, device=None):
11+
super().__init__()
12+
self.sample_rate = 16000
13+
self.n_fft = 400
14+
self.hop_length = 160
15+
self.n_mels = n_mels
16+
self.chunk_length = 30
17+
self.n_samples = 480000
18+
19+
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
20+
sample_rate=self.sample_rate,
21+
n_fft=self.n_fft,
22+
hop_length=self.hop_length,
23+
n_mels=self.n_mels,
24+
f_min=0,
25+
f_max=8000,
26+
norm="slaney",
27+
mel_scale="slaney",
28+
).to(device)
29+
30+
def __call__(self, audio):
31+
audio = torch.mean(audio, dim=1)
32+
batch_size = audio.shape[0]
33+
processed_audio = []
34+
35+
for i in range(batch_size):
36+
aud = audio[i]
37+
if aud.shape[0] > self.n_samples:
38+
aud = aud[:self.n_samples]
39+
elif aud.shape[0] < self.n_samples:
40+
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
41+
processed_audio.append(aud)
42+
43+
audio = torch.stack(processed_audio)
44+
45+
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
46+
47+
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
48+
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
49+
log_mel_spec = (log_mel_spec + 4.0) / 4.0
50+
51+
return log_mel_spec
52+
53+
54+
class MultiHeadAttention(nn.Module):
55+
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
56+
super().__init__()
57+
assert d_model % n_heads == 0
58+
59+
self.d_model = d_model
60+
self.n_heads = n_heads
61+
self.d_k = d_model // n_heads
62+
63+
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
64+
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
65+
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
66+
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
67+
68+
def forward(
69+
self,
70+
query: torch.Tensor,
71+
key: torch.Tensor,
72+
value: torch.Tensor,
73+
mask: Optional[torch.Tensor] = None,
74+
) -> torch.Tensor:
75+
batch_size, seq_len, _ = query.shape
76+
77+
q = self.q_proj(query)
78+
k = self.k_proj(key)
79+
v = self.v_proj(value)
80+
81+
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
82+
attn_output = self.out_proj(attn_output)
83+
84+
return attn_output
85+
86+
87+
class EncoderLayer(nn.Module):
88+
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
89+
super().__init__()
90+
91+
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
92+
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
93+
94+
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
95+
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
96+
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
97+
98+
def forward(
99+
self,
100+
x: torch.Tensor,
101+
attention_mask: Optional[torch.Tensor] = None
102+
) -> torch.Tensor:
103+
residual = x
104+
x = self.self_attn_layer_norm(x)
105+
x = self.self_attn(x, x, x, attention_mask)
106+
x = residual + x
107+
108+
residual = x
109+
x = self.final_layer_norm(x)
110+
x = self.fc1(x)
111+
x = F.gelu(x)
112+
x = self.fc2(x)
113+
x = residual + x
114+
115+
return x
116+
117+
118+
class AudioEncoder(nn.Module):
119+
def __init__(
120+
self,
121+
n_mels: int = 128,
122+
n_ctx: int = 1500,
123+
n_state: int = 1280,
124+
n_head: int = 20,
125+
n_layer: int = 32,
126+
dtype=None,
127+
device=None,
128+
operations=None
129+
):
130+
super().__init__()
131+
132+
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
133+
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
134+
135+
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
136+
137+
self.layers = nn.ModuleList([
138+
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
139+
for _ in range(n_layer)
140+
])
141+
142+
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
143+
144+
def forward(self, x: torch.Tensor) -> torch.Tensor:
145+
x = F.gelu(self.conv1(x))
146+
x = F.gelu(self.conv2(x))
147+
148+
x = x.transpose(1, 2)
149+
150+
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
151+
152+
all_x = ()
153+
for layer in self.layers:
154+
all_x += (x,)
155+
x = layer(x)
156+
157+
x = self.layer_norm(x)
158+
all_x += (x,)
159+
return x, all_x
160+
161+
162+
class WhisperLargeV3(nn.Module):
163+
def __init__(
164+
self,
165+
n_mels: int = 128,
166+
n_audio_ctx: int = 1500,
167+
n_audio_state: int = 1280,
168+
n_audio_head: int = 20,
169+
n_audio_layer: int = 32,
170+
dtype=None,
171+
device=None,
172+
operations=None
173+
):
174+
super().__init__()
175+
176+
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
177+
178+
self.encoder = AudioEncoder(
179+
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
180+
dtype=dtype, device=device, operations=operations
181+
)
182+
183+
def forward(self, audio):
184+
mel = self.feature_extractor(audio)
185+
x, all_x = self.encoder(mel)
186+
return x, all_x

0 commit comments

Comments
 (0)