Skip to content

Commit

Permalink
address #279 again
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 8, 2024
1 parent 8fe5984 commit 5d17c09
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 5 additions & 4 deletions audiolm_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def train_step(self):

self.accelerator.wait_for_everyone()

if self.is_main and not (steps % self.save_results_every):
if not (steps % self.save_results_every):
models = [(self.unwrapped_soundstream, str(steps))]
if self.use_ema:
models.append((self.ema_soundstream.ema_model if self.use_ema else self.unwrapped_soundstream, f'{steps}.ema'))
Expand All @@ -682,9 +682,10 @@ def train_step(self):
with torch.inference_mode():
recons = model(wave, return_recons_only = True)

for ind, recon in enumerate(recons.unbind(dim = 0)):
filename = str(self.results_folder / f'sample_{label}.flac')
torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz)
if self.is_main:
for ind, recon in enumerate(recons.unbind(dim = 0)):
filename = str(self.results_folder / f'sample_{label}.flac')
torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz)

self.print(f'{steps}: saving to {str(self.results_folder)}')

Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.2.2'
__version__ = '2.2.3'

0 comments on commit 5d17c09

Please sign in to comment.