Skip to content

Commit 914c2a2

Browse files
Implement wav2vec2 as an audio encoder model. (Comfy-Org#9549)
This is useless on its own but there are multiple models that use it.
1 parent e633a47 commit 914c2a2

File tree

5 files changed

+302
-0
lines changed

5 files changed

+302
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from .wav2vec2 import Wav2Vec2Model
2+
import comfy.model_management
3+
import comfy.ops
4+
import comfy.utils
5+
import logging
6+
import torchaudio
7+
8+
9+
class AudioEncoderModel():
10+
def __init__(self, config):
11+
self.load_device = comfy.model_management.text_encoder_device()
12+
offload_device = comfy.model_management.text_encoder_offload_device()
13+
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
14+
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
15+
self.model.eval()
16+
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
17+
self.model_sample_rate = 16000
18+
19+
def load_sd(self, sd):
20+
return self.model.load_state_dict(sd, strict=False)
21+
22+
def get_sd(self):
23+
return self.model.state_dict()
24+
25+
def encode_audio(self, audio, sample_rate):
26+
comfy.model_management.load_model_gpu(self.patcher)
27+
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
28+
out, all_layers = self.model(audio.to(self.load_device))
29+
outputs = {}
30+
outputs["encoded_audio"] = out
31+
outputs["encoded_audio_all_layers"] = all_layers
32+
return outputs
33+
34+
35+
def load_audio_encoder_from_sd(sd, prefix=""):
36+
audio_encoder = AudioEncoderModel(None)
37+
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
38+
m, u = audio_encoder.load_sd(sd)
39+
if len(m) > 0:
40+
logging.warning("missing audio encoder: {}".format(m))
41+
42+
return audio_encoder

comfy/audio_encoders/wav2vec2.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import torch
2+
import torch.nn as nn
3+
from comfy.ldm.modules.attention import optimized_attention_masked
4+
5+
6+
class LayerNormConv(nn.Module):
7+
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
8+
super().__init__()
9+
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
10+
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
11+
12+
def forward(self, x):
13+
x = self.conv(x)
14+
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
15+
16+
17+
class ConvFeatureEncoder(nn.Module):
18+
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
19+
super().__init__()
20+
self.conv_layers = nn.ModuleList([
21+
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
22+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
23+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
24+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
25+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
26+
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
27+
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
28+
])
29+
30+
def forward(self, x):
31+
x = x.unsqueeze(1)
32+
33+
for conv in self.conv_layers:
34+
x = conv(x)
35+
36+
return x.transpose(1, 2)
37+
38+
39+
class FeatureProjection(nn.Module):
40+
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
41+
super().__init__()
42+
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
43+
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
44+
45+
def forward(self, x):
46+
x = self.layer_norm(x)
47+
x = self.projection(x)
48+
return x
49+
50+
51+
class PositionalConvEmbedding(nn.Module):
52+
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
53+
super().__init__()
54+
self.conv = nn.Conv1d(
55+
embed_dim,
56+
embed_dim,
57+
kernel_size=kernel_size,
58+
padding=kernel_size // 2,
59+
groups=groups,
60+
)
61+
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
62+
self.activation = nn.GELU()
63+
64+
def forward(self, x):
65+
x = x.transpose(1, 2)
66+
x = self.conv(x)[:, :, :-1]
67+
x = self.activation(x)
68+
x = x.transpose(1, 2)
69+
return x
70+
71+
72+
class TransformerEncoder(nn.Module):
73+
def __init__(
74+
self,
75+
embed_dim=768,
76+
num_heads=12,
77+
num_layers=12,
78+
mlp_ratio=4.0,
79+
dtype=None, device=None, operations=None
80+
):
81+
super().__init__()
82+
83+
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
84+
self.layers = nn.ModuleList([
85+
TransformerEncoderLayer(
86+
embed_dim=embed_dim,
87+
num_heads=num_heads,
88+
mlp_ratio=mlp_ratio,
89+
device=device, dtype=dtype, operations=operations
90+
)
91+
for _ in range(num_layers)
92+
])
93+
94+
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
95+
96+
def forward(self, x, mask=None):
97+
x = x + self.pos_conv_embed(x)
98+
all_x = ()
99+
for layer in self.layers:
100+
all_x += (x,)
101+
x = layer(x, mask)
102+
x = self.layer_norm(x)
103+
all_x += (x,)
104+
return x, all_x
105+
106+
107+
class Attention(nn.Module):
108+
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
109+
super().__init__()
110+
self.embed_dim = embed_dim
111+
self.num_heads = num_heads
112+
self.head_dim = embed_dim // num_heads
113+
114+
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
115+
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
116+
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
117+
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
118+
119+
def forward(self, x, mask=None):
120+
assert (mask is None) # TODO?
121+
q = self.q_proj(x)
122+
k = self.k_proj(x)
123+
v = self.v_proj(x)
124+
125+
out = optimized_attention_masked(q, k, v, self.num_heads)
126+
return self.out_proj(out)
127+
128+
129+
class FeedForward(nn.Module):
130+
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
131+
super().__init__()
132+
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
133+
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
134+
135+
def forward(self, x):
136+
x = self.intermediate_dense(x)
137+
x = torch.nn.functional.gelu(x)
138+
x = self.output_dense(x)
139+
return x
140+
141+
142+
class TransformerEncoderLayer(nn.Module):
143+
def __init__(
144+
self,
145+
embed_dim=768,
146+
num_heads=12,
147+
mlp_ratio=4.0,
148+
dtype=None, device=None, operations=None
149+
):
150+
super().__init__()
151+
152+
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
153+
154+
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
155+
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
156+
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
157+
158+
def forward(self, x, mask=None):
159+
residual = x
160+
x = self.layer_norm(x)
161+
x = self.attention(x, mask=mask)
162+
x = residual + x
163+
164+
x = x + self.feed_forward(self.final_layer_norm(x))
165+
return x
166+
167+
168+
class Wav2Vec2Model(nn.Module):
169+
"""Complete Wav2Vec 2.0 model."""
170+
171+
def __init__(
172+
self,
173+
embed_dim=1024,
174+
final_dim=256,
175+
num_heads=16,
176+
num_layers=24,
177+
dtype=None, device=None, operations=None
178+
):
179+
super().__init__()
180+
181+
conv_dim = 512
182+
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
183+
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
184+
185+
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
186+
187+
self.encoder = TransformerEncoder(
188+
embed_dim=embed_dim,
189+
num_heads=num_heads,
190+
num_layers=num_layers,
191+
device=device, dtype=dtype, operations=operations
192+
)
193+
194+
def forward(self, x, mask_time_indices=None, return_dict=False):
195+
196+
x = torch.mean(x, dim=1)
197+
198+
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
199+
200+
features = self.feature_extractor(x)
201+
features = self.feature_projection(features)
202+
203+
batch_size, seq_len, _ = features.shape
204+
205+
x, all_x = self.encoder(features)
206+
207+
return x, all_x

comfy_api/latest/_io.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,14 @@ class AnyType(ComfyTypeIO):
730730
class MODEL_PATCH(ComfyTypeIO):
731731
Type = Any
732732

733+
@comfytype(io_type="AUDIO_ENCODER")
734+
class AUDIO_ENCODER(ComfyTypeIO):
735+
Type = Any
736+
737+
@comfytype(io_type="AUDIO_ENCODER_OUTPUT")
738+
class AUDIO_ENCODER_OUTPUT(ComfyTypeIO):
739+
Type = Any
740+
733741
@comfytype(io_type="COMFY_MULTITYPED_V3")
734742
class MultiType:
735743
Type = Any
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import folder_paths
2+
import comfy.audio_encoders.audio_encoders
3+
import comfy.utils
4+
5+
6+
class AudioEncoderLoader:
7+
@classmethod
8+
def INPUT_TYPES(s):
9+
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ),
10+
}}
11+
RETURN_TYPES = ("AUDIO_ENCODER",)
12+
FUNCTION = "load_model"
13+
14+
CATEGORY = "loaders"
15+
16+
def load_model(self, audio_encoder_name):
17+
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
18+
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
19+
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
20+
if audio_encoder is None:
21+
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
22+
return (audio_encoder,)
23+
24+
25+
class AudioEncoderEncode:
26+
@classmethod
27+
def INPUT_TYPES(s):
28+
return {"required": { "audio_encoder": ("AUDIO_ENCODER",),
29+
"audio": ("AUDIO",),
30+
}}
31+
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",)
32+
FUNCTION = "encode"
33+
34+
CATEGORY = "conditioning"
35+
36+
def encode(self, audio_encoder, audio):
37+
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
38+
return (output,)
39+
40+
41+
NODE_CLASS_MAPPINGS = {
42+
"AudioEncoderLoader": AudioEncoderLoader,
43+
"AudioEncoderEncode": AudioEncoderEncode,
44+
}

nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,6 +2324,7 @@ async def init_builtin_extra_nodes():
23242324
"nodes_qwen.py",
23252325
"nodes_model_patch.py",
23262326
"nodes_easycache.py",
2327+
"nodes_audio_encoder.py",
23272328
]
23282329

23292330
import_failed = []

0 commit comments

Comments
 (0)