Skip to content

Commit

Permalink
Merge branch 'dev/lyuxiang.lx' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
aluminumbox authored Jan 8, 2025
2 parents 2a0dd54 + 1e52c60 commit 92f1c65
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
22 changes: 21 additions & 1 deletion cosyvoice/flow/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
Expand All @@ -30,6 +31,7 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
self.lock = threading.Lock()

@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
Expand Down Expand Up @@ -120,7 +122,25 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond):
return sol[-1].float()

def forward_estimator(self, x, mask, mu, t, spks, cond):
return self.estimator.forward(x, mask, mu, t, spks, cond)
if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond)
else:
with self.lock:
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
return x

def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ conformer==0.3.2
deepspeed==0.14.2; sys_platform == 'linux'
diffusers==0.27.2
gdown==5.1.0
gradio==4.32.2
gradio==5.4.0
grpcio==1.57.0
grpcio-tools==1.57.0
huggingface-hub==0.25.2
Expand Down

0 comments on commit 92f1c65

Please sign in to comment.