From 432557a4b582ce6ff0260aa695c90d98b97ee009 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Thu, 15 Dec 2022 00:00:10 +0100 Subject: [PATCH 01/23] fix: remove print statement --- audio_diffusion_pytorch/model.py | 1 - setup.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index a5ac7af..890bb48 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -120,7 +120,6 @@ def forward( # type: ignore self, x: Tensor, with_info: bool = False, **kwargs ) -> Union[Tensor, Tuple[Tensor, Any]]: latent, info = self.encode(x, with_info=True) - print(latent.shape) loss = super().forward(x, channels_list=[latent], **kwargs) return (loss, info) if with_info else loss diff --git a/setup.py b/setup.py index 07b1641..9f9a047 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.96", + version="0.0.97", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", From ef685fe230a0b157fc94cac2bb43b545670c7af9 Mon Sep 17 00:00:00 2001 From: zaptrem Date: Mon, 19 Dec 2022 17:53:56 -0500 Subject: [PATCH 02/23] fix: ae default encoder & torch.compile error (#36) Co-authored-by: Ryan Tremblay --- audio_diffusion_pytorch/model.py | 4 ++-- audio_diffusion_pytorch/modules.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index 890bb48..fedc82a 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -131,7 +131,7 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor: length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) # Compute noise by inferring shape from latent length noise = torch.randn(b, self.in_channels, length, device=latent.device) - # Compute context form latent + # Compute context from latent default_kwargs = dict(channels_list=[latent]) # Decode by sampling while conditioning on latent channels return super().sample(noise, **{**default_kwargs, **kwargs}) @@ -351,7 +351,7 @@ def __init__(self, in_channels: int, *args, **kwargs): in_channels=in_channels, patch_size=16, channels=16, - multipliers=[1, 2, 4, 4, 4, 4, 4], + multipliers=[2, 2, 4, 4, 4, 4, 4], factors=[4, 4, 4, 2, 2, 2], num_blocks=[2, 2, 2, 2, 2, 2], out_channels=64, diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index 2324ad5..a0af226 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -5,7 +5,6 @@ import torch.nn as nn from einops import rearrange, reduce, repeat from einops.layers.torch import Rearrange -from einops_exts import rearrange_many from torch import Tensor, einsum from .utils import closest_power_2, default, exists, groupby @@ -351,7 +350,9 @@ def __init__( def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # Split heads - q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) + q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) + k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads) + v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) # Compute similarity matrix sim = einsum("... n d, ... m d -> ... n m", q, k) sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim @@ -1361,13 +1362,14 @@ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: magnitude, phase = torch.abs(stft), torch.angle(stft) stft_a, stft_b = magnitude, phase - return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) + return rearrange(stft_a, "(b c) f l -> b c f l", b=b), rearrange(stft_b, "(b c) f l -> b c f l", b=b) def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: b, l = stft_a.shape[0], stft_a.shape[-1] # noqa length = closest_power_2(l * self.hop_length) - stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") + stft_a = rearrange(stft_a, "b c f l -> (b c) f l") + stft_b = rearrange(stft_b, "b c f l -> (b c) f l") if self.use_complex: real, imag = stft_a, stft_b @@ -1393,11 +1395,13 @@ def encode1d( self, wave: Tensor, stacked: bool = True ) -> Union[Tensor, Tuple[Tensor, Tensor]]: stft_a, stft_b = self.encode(wave) - stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") + stft_a = rearrange(stft_a, "b c f l -> b (c f) l") + stft_b = rearrange(stft_b, "b c f l -> b (c f) l") return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) def decode1d(self, stft_pair: Tensor) -> Tensor: f = self.num_fft // 2 + 1 stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) - stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) + stft_a = rearrange(stft_a, "b (c f) l -> b c f l", f=f) + stft_b = rearrange(stft_b, "b (c f) l -> b c f l", f=f) return self.decode(stft_a, stft_b) From abc8493528488b0a8c6a965e297ee7a6cbfcb636 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Mon, 2 Jan 2023 23:44:00 +0100 Subject: [PATCH 03/23] feat: add a-unet model, ae model --- README.md | 328 +----- audio_diffusion_pytorch/__init__.py | 29 +- audio_diffusion_pytorch/diffusion.py | 654 +----------- audio_diffusion_pytorch/model.py | 432 -------- audio_diffusion_pytorch/models.py | 78 ++ audio_diffusion_pytorch/modules.py | 1407 -------------------------- audio_diffusion_pytorch/unets.py | 155 +++ setup.py | 1 + 8 files changed, 294 insertions(+), 2790 deletions(-) delete mode 100644 audio_diffusion_pytorch/model.py create mode 100644 audio_diffusion_pytorch/models.py delete mode 100644 audio_diffusion_pytorch/modules.py create mode 100644 audio_diffusion_pytorch/unets.py diff --git a/README.md b/README.md index 9502381..52d06ac 100644 --- a/README.md +++ b/README.md @@ -1,333 +1,9 @@ -Unconditional audio generation using diffusion models, in PyTorch. The goal of this repository is to explore different architectures and diffusion models to generate audio (speech and music) directly from/to the waveform. -Progress will be documented in the [experiments](#experiments) section. You can use the [`audio-diffusion-pytorch-trainer`](https://github.com/archinetai/audio-diffusion-pytorch-trainer) to run your own experiments – please share your findings in the [discussions](https://github.com/archinetai/audio-diffusion-pytorch/discussions) page! Pretrained models can be found at [`archisound`](https://github.com/archinetai/archisound). +Nightly branch. ## Install ```bash -pip install audio-diffusion-pytorch -``` - -[![PyPI - Python Version](https://img.shields.io/pypi/v/audio-diffusion-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/audio-diffusion-pytorch/) -[![Downloads](https://static.pepy.tech/personalized-badge/audio-diffusion-pytorch?period=total&units=international_system&left_color=black&right_color=black&left_text=Downloads)](https://pepy.tech/project/audio-diffusion-pytorch) -[![HuggingFace](https://img.shields.io/badge/Trained%20Models-%F0%9F%A4%97-yellow?style=flat&colorA=black&colorB=black)](https://huggingface.co/archinetai/audio-diffusion-pytorch/tree/main) -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)]( -https://colab.research.google.com/gist/flavioschneider/d1f67b07ffcbf6fd09fdd27515ba3701/audio-diffusion-pytorch-v0-2.ipynb) - - - -## Usage - -### Generation -```py -from audio_diffusion_pytorch import AudioDiffusionModel - -model = AudioDiffusionModel(in_channels=1) - -# Train model with audio sources -x = torch.randn(2, 1, 2 ** 18) # [batch, in_channels, samples], 2**18 ≈ 12s of audio at a frequency of 22050 -loss = model(x) -loss.backward() # Do this many times - -# Sample 2 sources given start noise -noise = torch.randn(2, 1, 2 ** 18) -sampled = model.sample( - noise=noise, - num_steps=5 # Suggested range: 2-50 -) # [2, 1, 2 ** 18] -``` - -### Upsampling -```py -from audio_diffusion_pytorch import AudioDiffusionUpsampler - -upsampler = AudioDiffusionUpsampler( - in_channels=1, - factor=8, -) - -# Train on high frequency data -x = torch.randn(2, 1, 2 ** 18) -loss = upsampler(x) -loss.backward() - -# Given start undersampled source, samples upsampled source -undersampled = torch.randn(1, 1, 2 ** 15) -upsampled = upsampler.sample( - undersampled, - num_steps=5 -) # [1, 1, 2 ** 18] -``` - -### Autoencoding -```py -from audio_diffusion_pytorch import AudioDiffusionAE - -autoencoder = AudioDiffusionAE(in_channels=1) - -# Train on audio samples -x = torch.randn(2, 1, 2 ** 18) -loss = autoencoder(x) -loss.backward() - -# Encode audio source into latent -x = torch.randn(2, 1, 2 ** 18) -latent = autoencoder.encode(x) # [2, 32, 128] - -# Decode latent by diffusion sampling -decoded = autoencoder.decode( - latent, - num_steps=5 -) # [2, 32, 2**18] -``` - - -### Conditional Generation -```py -from audio_diffusion_pytorch import AudioDiffusionConditional - -model = AudioDiffusionConditional( - in_channels=1, - embedding_max_length=64, - embedding_features=768, - embedding_mask_proba=0.1 # Conditional dropout of batch elements -) - -# Train on pairs of audio and embedding data (e.g. from a transformer output) -x = torch.randn(2, 1, 2 ** 18) -embedding = torch.randn(2, 64, 768) -loss = model(x, embedding=embedding) -loss.backward() - -# Given start embedding and noise sample new source -embedding = torch.randn(2, 64, 768) -noise = torch.randn(2, 1, 2 ** 18) -sampled = model.sample( - noise, - embedding=embedding, - embedding_scale=5.0, # Classifier-free guidance scale - num_steps=5 -) # [2, 1, 2 ** 18] -``` - -#### Text Conditional Generation -You can generate embeddings from text by using a pretrained frozen T5 transformer with `T5Embedder`, as follows (note that this requires `pip install transformers`): - -```py -from audio_diffusion_pytorch import T5Embedder - -embedder = T5Embedder(model='t5-base', max_length=64) -embedding = embedder(["First batch item text...", "Second batch item text..."]) # [2, 64, 768] - -loss = model(x, embedding=embedding) -# ... -sampled = model.sample( - noise, - embedding=embedding, - embedding_scale=5.0, # Classifier-free guidance scale - num_steps=5 -) -``` - -#### Number Conditional Generation - -```py -from audio_diffusion_pytorch import NumberEmbedder - -embedder = NumberEmbedder(features=768) -embedding = embedder([0.1, 0.2]) # [2, 768] -``` - - -## Usage with Components - -### UNet1d -```py -from audio_diffusion_pytorch import UNet1d - -# UNet used to denoise our 1D (audio) data -unet = UNet1d( - in_channels=1, - channels=128, - patch_size=16, - multipliers=[1, 2, 4, 4, 4, 4, 4], - factors=[4, 4, 4, 2, 2, 2], - attentions=[0, 0, 0, 1, 1, 1, 1], - num_blocks=[2, 2, 2, 2, 2, 2], - attention_heads=8, - attention_features=64, - attention_multiplier=2, - resnet_groups=8, - kernel_multiplier_downsample=2, - use_nearest_upsample=False, - use_skip_scale=True, - use_context_time=True -) - -x = torch.randn(3, 1, 2 ** 16) -t = torch.tensor([0.2, 0.8, 0.3]) - -y = unet(x, t) # [3, 1, 32768], compute 3 samples of ~1.5 seconds at 22050Hz with the given noise levels t -``` - -### Diffusion - -#### Training -```python -from audio_diffusion_pytorch import KDiffusion, LogNormalDistribution -from audio_diffusion_pytorch import VDiffusion, UniformDistribution - -# Either use KDiffusion -diffusion = KDiffusion( - net=unet, - sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0), - sigma_data=0.1, - dynamic_threshold=0.0 -) - -# Or use VDiffusion -diffusion = VDiffusion( - net=unet, - sigma_distribution=UniformDistribution() -) - -x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples -loss = diffusion(x) -loss.backward() # Do this many times -``` - -#### Sampling -```python -from audio_diffusion_pytorch import DiffusionSampler, KarrasSchedule - -sampler = DiffusionSampler( - diffusion, - num_steps=5, # Suggested range 2-100, higher better quality but takes longer - sampler=ADPM2Sampler(rho=1), - sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0) -) -# Generate a sample starting from the provided noise -y = sampler(noise = torch.randn(1,1,2 ** 18)) -``` - -#### Inpainting - -```py -from audio_diffusion_pytorch import DiffusionInpainter, KarrasSchedule, ADPM2Sampler - -inpainter = DiffusionInpainter( - diffusion, - num_steps=5, # Suggested range 2-100, higher for better quality - num_resamples=1, # Suggested range 1-10, higher for better quality - sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), - sampler=ADPM2Sampler(rho=1.0), -) - -inpaint = torch.randn(1,1,2 ** 18) # Start track, e.g. one sampled with DiffusionSampler -inpaint_mask = torch.randint(0,2, (1,1,2 ** 18), dtype=torch.bool) # Set to `True` the parts you want to keep -y = inpainter(inpaint = inpaint, inpaint_mask = inpaint_mask) -``` - -#### Infinite Generation -```python -from audio_diffusion_pytorch import SpanBySpanComposer - -composer = SpanBySpanComposer( - inpainter, - num_spans=4 # Number of spans to inpaint after provided input -) -y_long = composer(y, keep_start=True) # [1, 1, 98304] -``` - - -## Experiments - - -| Report | Snapshot | Description | -| --- | --- | --- | -| [Alpha](https://wandb.ai/schneider/audio/reports/Audio-Diffusion-UNet-Alpha---VmlldzoyMjk3MzIz?accessToken=y0l3igdvnm4ogn4d3ph3b0i8twwcf7meufbviwt15f0qtasyn1i14hg340bkk1te) | [6bd9279f19](https://github.com/archinetai/audio-diffusion-pytorch/tree/6bd9279f192fc0c11eb8a21cd919d9c41181bf35) | Initial tests on LJSpeech dataset with new architecture and basic DDPM diffusion model. | -| [Bravo](https://wandb.ai/schneider/audio/reports/Audio-Diffusion-Bravo---VmlldzoyMzE4NjIx?accessToken=qt2w1jeqch9l5v3ffjns99p69jsmexk849dszyiennfbivgg396378u6ken2fm2d) | [a05f30aa94](https://github.com/archinetai/audio-diffusion-pytorch/tree/a05f30aa94e07600038d36cfb96f8492ef735a99) | Elucidated diffusion, improved architecture with patching, longer duration, initial good (unsupervised) results on LJSpeech. -| [Charlie](https://wandb.ai/schneider/audio/reports/Audio-Diffusion-Charlie---VmlldzoyMzYyNDA1?accessToken=71gmurcwndv5e2abqrjnlh3n74j5555j3tycpd7h40tnv8fvb17k5pjkb57j9xxa) | [50ecc30d70](https://github.com/archinetai/audio-diffusion-pytorch/tree/50ecc30d70a211b92cb9c38d4b0250d7cc30533f) | Train on music with [YoutubeDataset](https://github.com/archinetai/audio-data-pytorch), larger patch tests for longer tracks, inpainting tests, initial test with infinite generation using SpanBySpanComposer. | -| [Delta](https://wandb.ai/schneider/audio/reports/Audio-Diffusion-Delta---VmlldzoyNDYyMzk1?accessToken=n1d34n35qserpx7nhskkfdm1q12hlcxx1qcmfw5ypz53kjkzoh0ge2uvhshiseqx) | [672876bf13](https://github.com/archinetai/audio-diffusion-pytorch/tree/672876bf1373b1f10afd3adc8b3b984495bca91a) | Test model with the faster `ADPM2` sampler and dynamic thresholding. | -| [Echo](https://wandb.ai/schneider/audio/reports/Audio-Diffusion-Echo---VmlldzoyNTU2NTcw?accessToken=sthdn25n8is30gjo2x0w4fs9hwbua23rlbg7o4bv8h17y47xdtruiiyb33aoc5h4) | (current) | Test `AudioDiffusionUpsampler`. - -## TODO - -- [x] Add elucidated diffusion. -- [x] Add ancestral DPM2 sampler. -- [x] Add dynamic thresholding. -- [x] Add (variational) autoencoder option to compress audio before diffusion (removed). -- [x] Fix inpainting and make it work with ADPM2 sampler. -- [x] Add trainer with experiments. -- [x] Add diffusion upsampler. -- [x] Add ancestral euler sampler `AEulerSampler`. -- [x] Add diffusion autoencoder. -- [x] Add diffusion upsampler. -- [x] Add autoencoder bottleneck option for quantization. -- [x] Add option to provide context tokens (cross attention). -- [x] Add conditional model with classifier-free guidance. -- [x] Add option to provide context features mapping. -- [x] Add option to change number of (cross) attention blocks. -- [x] Add `VDiffusionn` option. -- [ ] Add flash attention. - - -## Appreciation - -* [StabilityAI](https://stability.ai/) for the compute, [Zach](https://github.com/zqevans) and everyone else from [HarmonAI](https://www.harmonai.org/) for the interesting research discussions. -* [ETH Zurich](https://inf.ethz.ch/) for the resources, [Zhijing Jin](https://zhijing-jin.com/), [Mrinmaya Sachan](http://www.mrinmaya.io/), and [Bernhard Schoelkopf](https://is.mpg.de/~bs) for supervising this Thesis. -* [Phil Wang](https://github.com/lucidrains) for the beautiful open source contributions on [diffusion](https://github.com/lucidrains/denoising-diffusion-pytorch) and [Imagen](https://github.com/lucidrains/imagen-pytorch). -* [Katherine Crowson](https://github.com/crowsonkb) for the experiments with [k-diffusion](https://github.com/crowsonkb/k-diffusion) and the insane collection of samplers. - -## Citations - -DDPM -```bibtex -@misc{2006.11239, -Author = {Jonathan Ho and Ajay Jain and Pieter Abbeel}, -Title = {Denoising Diffusion Probabilistic Models}, -Year = {2020}, -Eprint = {arXiv:2006.11239}, -} -``` - -Diffusion inpainting -```bibtex -@misc{2201.09865, -Author = {Andreas Lugmayr and Martin Danelljan and Andres Romero and Fisher Yu and Radu Timofte and Luc Van Gool}, -Title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models}, -Year = {2022}, -Eprint = {arXiv:2201.09865}, -} -``` - -Diffusion weighted loss -```bibtex -@misc{2204.00227, -Author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo Kim and Sungroh Yoon}, -Title = {Perception Prioritized Training of Diffusion Models}, -Year = {2022}, -Eprint = {arXiv:2204.00227}, -} -``` - -Improved UNet architecture -```bibtex -@misc{2205.11487, -Author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and S. Sara Mahdavi and Rapha Gontijo Lopes and Tim Salimans and Jonathan Ho and David J Fleet and Mohammad Norouzi}, -Title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding}, -Year = {2022}, -Eprint = {arXiv:2205.11487}, -} -``` - -Elucidated diffusion -```bibtex -@misc{2206.00364, -Author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine}, -Title = {Elucidating the Design Space of Diffusion-Based Generative Models}, -Year = {2022}, -Eprint = {arXiv:2206.00364}, -} +pip install -U git+https://github.com/archinetai/audio-diffusion-pytorch.git@nightly ``` diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index 769074c..f2bfe57 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -1,39 +1,14 @@ from audio_encoders_pytorch import Encoder1d, ME1d from .diffusion import ( - ADPM2Sampler, - AEulerSampler, Diffusion, - DiffusionInpainter, - DiffusionSampler, Distribution, - KarrasSampler, - KarrasSchedule, - KDiffusion, LinearSchedule, - LogNormalDistribution, Sampler, Schedule, - SpanBySpanComposer, UniformDistribution, VDiffusion, - VKDiffusion, - VKDistribution, VSampler, - XDiffusion, ) -from .model import ( - AudioDiffusionAE, - AudioDiffusionConditional, - AudioDiffusionModel, - AudioDiffusionUpphaser, - AudioDiffusionUpsampler, - AudioDiffusionVocoder, - DiffusionAE1d, - DiffusionAR1d, - DiffusionUpphaser1d, - DiffusionUpsampler1d, - DiffusionVocoder1d, - Model1d, -) -from .modules import NumberEmbedder, T5Embedder, UNet1d, XUNet1d +from .models import DiffusionAE, DiffusionModel +from .unets import XUNet diff --git a/audio_diffusion_pytorch/diffusion.py b/audio_diffusion_pytorch/diffusion.py index 8d68347..b11ddbe 100644 --- a/audio_diffusion_pytorch/diffusion.py +++ b/audio_diffusion_pytorch/diffusion.py @@ -1,65 +1,29 @@ -from math import atan, cos, pi, sin, sqrt -from typing import Any, Callable, List, Optional, Tuple, Type +from math import pi +from typing import Any, List, Tuple, Type import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, reduce +from einops import rearrange, repeat from torch import Tensor - -from .utils import default, exists - -""" -Diffusion Training -""" +from tqdm import tqdm """ Distributions """ class Distribution: + """Interface used by different distributions""" + def __call__(self, num_samples: int, device: torch.device): raise NotImplementedError() -class LogNormalDistribution(Distribution): - def __init__(self, mean: float, std: float): - self.mean = mean - self.std = std - - def __call__( - self, num_samples: int, device: torch.device = torch.device("cpu") - ) -> Tensor: - normal = self.mean + self.std * torch.randn((num_samples,), device=device) - return normal.exp() - - class UniformDistribution(Distribution): def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")): return torch.rand(num_samples, device=device) -class VKDistribution(Distribution): - def __init__( - self, - min_value: float = 0.0, - max_value: float = float("inf"), - sigma_data: float = 1.0, - ): - self.min_value = min_value - self.max_value = max_value - self.sigma_data = sigma_data - - def __call__( - self, num_samples: int, device: torch.device = torch.device("cpu") - ) -> Tensor: - sigma_data = self.sigma_data - min_cdf = atan(self.min_value / sigma_data) * 2 / pi - max_cdf = atan(self.max_value / sigma_data) * 2 / pi - u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf - return torch.tan(u * pi / 2) * sigma_data - - -""" Diffusion Classes """ +""" Diffusion Methods """ def pad_dims(x: Tensor, ndim: int) -> Tensor: @@ -83,222 +47,41 @@ def clip(x: Tensor, dynamic_threshold: float = 0.0): return x -def to_batch( - batch_size: int, - device: torch.device, - x: Optional[float] = None, - xs: Optional[Tensor] = None, -) -> Tensor: - assert exists(x) ^ exists(xs), "Either x or xs must be provided" - # If x provided use the same for all batch items - if exists(x): - xs = torch.full(size=(batch_size,), fill_value=x).to(device) - assert exists(xs) - return xs - - class Diffusion(nn.Module): + """Interface used by different diffusion methods""" - alias: str = "" - - """Base diffusion class""" - - def denoise_fn( - self, - x_noisy: Tensor, - sigmas: Optional[Tensor] = None, - sigma: Optional[float] = None, - **kwargs, - ) -> Tensor: - raise NotImplementedError("Diffusion class missing denoise_fn") - - def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: + def forward(self, *args, **kwargs) -> Tensor: raise NotImplementedError("Diffusion class missing forward function") class VDiffusion(Diffusion): - - alias = "v" - - def __init__(self, net: nn.Module, *, sigma_distribution: Distribution): + def __init__( + self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution() + ): super().__init__() self.net = net self.sigma_distribution = sigma_distribution def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: angle = sigmas * pi / 2 - alpha = torch.cos(angle) - beta = torch.sin(angle) + alpha, beta = torch.cos(angle), torch.sin(angle) return alpha, beta - def denoise_fn( - self, - x_noisy: Tensor, - sigmas: Optional[Tensor] = None, - sigma: Optional[float] = None, - **kwargs, - ) -> Tensor: - batch_size, device = x_noisy.shape[0], x_noisy.device - sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) - return self.net(x_noisy, sigmas, **kwargs) - - def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: + def forward(self, x: Tensor, **kwargs) -> Tensor: # type: ignore batch_size, device = x.shape[0], x.device - # Sample amount of noise to add for each batch element sigmas = self.sigma_distribution(num_samples=batch_size, device=device) - sigmas_padded = rearrange(sigmas, "b -> b 1 1") - + sigmas_batch = rearrange(sigmas, "b -> b 1 1") # Get noise - noise = default(noise, lambda: torch.randn_like(x)) - + noise = torch.randn_like(x) # Combine input and noise weighted by half-circle - alpha, beta = self.get_alpha_beta(sigmas_padded) - x_noisy = x * alpha + noise * beta - x_target = noise * alpha - x * beta - - # Denoise and return loss - x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs) - return F.mse_loss(x_denoised, x_target) - + alphas, betas = self.get_alpha_beta(sigmas_batch) + x_noisy = alphas * x + betas * noise + v_target = alphas * noise - betas * x + # Predict velocity and return loss + v_pred = self.net(x_noisy, sigmas, **kwargs) + return F.mse_loss(v_pred, v_target) -class KDiffusion(Diffusion): - """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364""" - - alias = "k" - - def __init__( - self, - net: nn.Module, - *, - sigma_distribution: Distribution, - sigma_data: float, # data distribution standard deviation - dynamic_threshold: float = 0.0, - ): - super().__init__() - self.net = net - self.sigma_data = sigma_data - self.sigma_distribution = sigma_distribution - self.dynamic_threshold = dynamic_threshold - - def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: - sigma_data = self.sigma_data - c_noise = torch.log(sigmas) * 0.25 - sigmas = rearrange(sigmas, "b -> b 1 1") - c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2) - c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5 - c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5 - return c_skip, c_out, c_in, c_noise - - def denoise_fn( - self, - x_noisy: Tensor, - sigmas: Optional[Tensor] = None, - sigma: Optional[float] = None, - **kwargs, - ) -> Tensor: - batch_size, device = x_noisy.shape[0], x_noisy.device - sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) - - # Predict network output and add skip connection - c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas) - x_pred = self.net(c_in * x_noisy, c_noise, **kwargs) - x_denoised = c_skip * x_noisy + c_out * x_pred - - # Clips in [-1,1] range, with dynamic thresholding if provided - return clip(x_denoised, dynamic_threshold=self.dynamic_threshold) - - def loss_weight(self, sigmas: Tensor) -> Tensor: - # Computes weight depending on data distribution - return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2 - - def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: - batch_size, device = x.shape[0], x.device - - # Sample amount of noise to add for each batch element - sigmas = self.sigma_distribution(num_samples=batch_size, device=device) - sigmas_padded = rearrange(sigmas, "b -> b 1 1") - - # Add noise to input - noise = default(noise, lambda: torch.randn_like(x)) - x_noisy = x + sigmas_padded * noise - - # Compute denoised values - x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs) - - # Compute weighted loss - losses = F.mse_loss(x_denoised, x, reduction="none") - losses = reduce(losses, "b ... -> b", "mean") - losses = losses * self.loss_weight(sigmas) - loss = losses.mean() - return loss - - -class VKDiffusion(Diffusion): - - alias = "vk" - - def __init__(self, net: nn.Module, *, sigma_distribution: Distribution): - super().__init__() - self.net = net - self.sigma_distribution = sigma_distribution - - def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: - sigma_data = 1.0 - sigmas = rearrange(sigmas, "b -> b 1 1") - c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2) - c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5 - c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5 - return c_skip, c_out, c_in - - def sigma_to_t(self, sigmas: Tensor) -> Tensor: - return sigmas.atan() / pi * 2 - - def t_to_sigma(self, t: Tensor) -> Tensor: - return (t * pi / 2).tan() - - def denoise_fn( - self, - x_noisy: Tensor, - sigmas: Optional[Tensor] = None, - sigma: Optional[float] = None, - **kwargs, - ) -> Tensor: - batch_size, device = x_noisy.shape[0], x_noisy.device - sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) - - # Predict network output and add skip connection - c_skip, c_out, c_in = self.get_scale_weights(sigmas) - x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs) - x_denoised = c_skip * x_noisy + c_out * x_pred - return x_denoised - - def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: - batch_size, device = x.shape[0], x.device - - # Sample amount of noise to add for each batch element - sigmas = self.sigma_distribution(num_samples=batch_size, device=device) - sigmas_padded = rearrange(sigmas, "b -> b 1 1") - - # Add noise to input - noise = default(noise, lambda: torch.randn_like(x)) - x_noisy = x + sigmas_padded * noise - - # Compute model output - c_skip, c_out, c_in = self.get_scale_weights(sigmas) - x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs) - - # Compute v-objective target - v_target = (x - c_skip * x_noisy) / (c_out + 1e-7) - - # Compute loss - loss = F.mse_loss(x_pred, v_target) - return loss - - -""" -Diffusion Sampling -""" """ Schedules """ @@ -312,381 +95,56 @@ def forward(self, num_steps: int, device: torch.device) -> Tensor: class LinearSchedule(Schedule): def forward(self, num_steps: int, device: Any) -> Tensor: - sigmas = torch.linspace(1, 0, num_steps + 1)[:-1] - return sigmas - - -class KarrasSchedule(Schedule): - """https://arxiv.org/abs/2206.00364 equation 5""" - - def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0): - super().__init__() - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.rho = rho - - def forward(self, num_steps: int, device: Any) -> Tensor: - rho_inv = 1.0 / self.rho - steps = torch.arange(num_steps, device=device, dtype=torch.float32) - sigmas = ( - self.sigma_max ** rho_inv - + (steps / (num_steps - 1)) - * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv) - ) ** self.rho - sigmas = F.pad(sigmas, pad=(0, 1), value=0.0) - return sigmas + return torch.linspace(1.0, 0.0, num_steps, device=device) """ Samplers """ class Sampler(nn.Module): + """Interface used by different samplers""" - diffusion_types: List[Type[Diffusion]] = [] + diffusion_types: List[Type] = [] - def forward( - self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int - ) -> Tensor: + def forward(*args, **kwargs) -> Tensor: raise NotImplementedError() - def inpaint( - self, - source: Tensor, - mask: Tensor, - fn: Callable, - sigmas: Tensor, - num_steps: int, - num_resamples: int, - ) -> Tensor: - raise NotImplementedError("Inpainting not available with current sampler") - class VSampler(Sampler): diffusion_types = [VDiffusion] - def get_alpha_beta(self, sigma: float) -> Tuple[float, float]: - angle = sigma * pi / 2 - alpha = cos(angle) - beta = sin(angle) - return alpha, beta - - def forward( - self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int - ) -> Tensor: - x = sigmas[0] * noise - alpha, beta = self.get_alpha_beta(sigmas[0].item()) - - for i in range(num_steps - 1): - is_last = i == num_steps - 1 - - x_denoised = fn(x, sigma=sigmas[i]) - x_pred = x * alpha - x_denoised * beta - x_eps = x * beta + x_denoised * alpha - - if not is_last: - alpha, beta = self.get_alpha_beta(sigmas[i + 1].item()) - x = x_pred * alpha + x_eps * beta - - return x_pred - - -class KarrasSampler(Sampler): - """https://arxiv.org/abs/2206.00364 algorithm 1""" - - diffusion_types = [KDiffusion, VKDiffusion] - - def __init__( - self, - s_tmin: float = 0, - s_tmax: float = float("inf"), - s_churn: float = 0.0, - s_noise: float = 1.0, - ): - super().__init__() - self.s_tmin = s_tmin - self.s_tmax = s_tmax - self.s_noise = s_noise - self.s_churn = s_churn - - def step( - self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float - ) -> Tensor: - """Algorithm 2 (step)""" - # Select temporarily increased noise level - sigma_hat = sigma + gamma * sigma - # Add noise to move from sigma to sigma_hat - epsilon = self.s_noise * torch.randn_like(x) - x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon - # Evaluate ∂x/∂sigma at sigma_hat - d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat - # Take euler step from sigma_hat to sigma_next - x_next = x_hat + (sigma_next - sigma_hat) * d - # Second order correction - if sigma_next != 0: - model_out_next = fn(x_next, sigma=sigma_next) - d_prime = (x_next - model_out_next) / sigma_next - x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime) - return x_next - - def forward( - self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int - ) -> Tensor: - x = sigmas[0] * noise - # Compute gammas - gammas = torch.where( - (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax), - min(self.s_churn / num_steps, sqrt(2) - 1), - 0.0, - ) - # Denoise to sample - for i in range(num_steps - 1): - x = self.step( - x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa - ) - - return x - - -class AEulerSampler(Sampler): - - diffusion_types = [KDiffusion, VKDiffusion] - - def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]: - sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) - sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2) - return sigma_up, sigma_down - - def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: - # Sigma steps - sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next) - # Derivative at sigma (∂x/∂sigma) - d = (x - fn(x, sigma=sigma)) / sigma - # Euler method - x_next = x + d * (sigma_down - sigma) - # Add randomness - x_next = x_next + torch.randn_like(x) * sigma_up - return x_next - - def forward( - self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int - ) -> Tensor: - x = sigmas[0] * noise - # Denoise to sample - for i in range(num_steps - 1): - x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa - return x - - -class ADPM2Sampler(Sampler): - """https://www.desmos.com/calculator/jbxjlqd9mb""" - - diffusion_types = [KDiffusion, VKDiffusion] - - def __init__(self, rho: float = 1.0): - super().__init__() - self.rho = rho - - def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]: - r = self.rho - sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) - sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2) - sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r - return sigma_up, sigma_down, sigma_mid - - def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: - # Sigma steps - sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next) - # Derivative at sigma (∂x/∂sigma) - d = (x - fn(x, sigma=sigma)) / sigma - # Denoise to midpoint - x_mid = x + d * (sigma_mid - sigma) - # Derivative at sigma_mid (∂x_mid/∂sigma_mid) - d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid - # Denoise to next - x = x + d_mid * (sigma_down - sigma) - # Add randomness - x_next = x + torch.randn_like(x) * sigma_up - return x_next - - def forward( - self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int - ) -> Tensor: - x = sigmas[0] * noise - # Denoise to sample - for i in range(num_steps - 1): - x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa - return x - - def inpaint( - self, - source: Tensor, - mask: Tensor, - fn: Callable, - sigmas: Tensor, - num_steps: int, - num_resamples: int, - ) -> Tensor: - x = sigmas[0] * torch.randn_like(source) - - for i in range(num_steps - 1): - # Noise source to current noise level - source_noisy = source + sigmas[i] * torch.randn_like(source) - for r in range(num_resamples): - # Merge noisy source and current then denoise - x = source_noisy * mask + x * ~mask - x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa - # Renoise if not last resample step - if r < num_resamples - 1: - sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2) - x = x + sigma * torch.randn_like(x) - - return source * mask + x * ~mask - - -""" Main Classes """ - - -class DiffusionSampler(nn.Module): - def __init__( - self, - diffusion: Diffusion, - *, - sampler: Sampler, - sigma_schedule: Schedule, - num_steps: Optional[int] = None, - clamp: bool = True, - ): - super().__init__() - self.denoise_fn = diffusion.denoise_fn - self.sampler = sampler - self.sigma_schedule = sigma_schedule - self.num_steps = num_steps - self.clamp = clamp - - # Check sampler is compatible with diffusion type - sampler_class = sampler.__class__.__name__ - diffusion_class = diffusion.__class__.__name__ - message = f"{sampler_class} incompatible with {diffusion_class}" - assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message - - @torch.no_grad() - def forward( - self, noise: Tensor, num_steps: Optional[int] = None, **kwargs - ) -> Tensor: - device = noise.device - num_steps = default(num_steps, self.num_steps) # type: ignore - assert exists(num_steps), "Parameter `num_steps` must be provided" - # Compute sigmas using schedule - sigmas = self.sigma_schedule(num_steps, device) - # Append additional kwargs to denoise function (used e.g. for conditional unet) - fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa - # Sample using sampler - x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps) - x = x.clamp(-1.0, 1.0) if self.clamp else x - return x - - -class DiffusionInpainter(nn.Module): - def __init__( - self, - diffusion: Diffusion, - *, - num_steps: int, - num_resamples: int, - sampler: Sampler, - sigma_schedule: Schedule, - ): + def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()): super().__init__() - self.denoise_fn = diffusion.denoise_fn - self.num_steps = num_steps - self.num_resamples = num_resamples - self.inpaint_fn = sampler.inpaint - self.sigma_schedule = sigma_schedule - - @torch.no_grad() - def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor: - x = self.inpaint_fn( - source=inpaint, - mask=inpaint_mask, - fn=self.denoise_fn, - sigmas=self.sigma_schedule(self.num_steps, inpaint.device), - num_steps=self.num_steps, - num_resamples=self.num_resamples, - ) - return x - - -def sequential_mask(like: Tensor, start: int) -> Tensor: - length, device = like.shape[2], like.device - mask = torch.ones_like(like, dtype=torch.bool) - mask[:, :, start:] = torch.zeros((length - start,), device=device) - return mask - - -class SpanBySpanComposer(nn.Module): - def __init__( - self, - inpainter: DiffusionInpainter, - *, - num_spans: int, - ): - super().__init__() - self.inpainter = inpainter - self.num_spans = num_spans - - def forward(self, start: Tensor, keep_start: bool = False) -> Tensor: - half_length = start.shape[2] // 2 - - spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else [] - # Inpaint second half from first half - inpaint = torch.zeros_like(start) - inpaint[:, :, :half_length] = start[:, :, half_length:] - inpaint_mask = sequential_mask(like=start, start=half_length) - - for i in range(self.num_spans): - # Inpaint second half - span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask) - # Replace first half with generated second half - second_half = span[:, :, half_length:] - inpaint[:, :, :half_length] = second_half - # Save generated span - spans.append(second_half) - - return torch.cat(spans, dim=2) - - -class XDiffusion(nn.Module): - def __init__(self, type: str, net: nn.Module, **kwargs): - super().__init__() - - diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion] - aliases = [t.alias for t in diffusion_classes] # type: ignore - message = f"type='{type}' must be one of {*aliases,}" - assert type in aliases, message self.net = net + self.schedule = schedule - for XDiffusion in diffusion_classes: - if XDiffusion.alias == type: # type: ignore - self.diffusion = XDiffusion(net=net, **kwargs) + @property + def device(self): + return next(self.net.parameters()).device - def forward(self, *args, **kwargs) -> Tensor: - return self.diffusion(*args, **kwargs) - - def sample( - self, - noise: Tensor, - num_steps: int, - sigma_schedule: Schedule, - sampler: Sampler, - clamp: bool, - **kwargs, - ) -> Tensor: - diffusion_sampler = DiffusionSampler( - diffusion=self.diffusion, - sampler=sampler, - sigma_schedule=sigma_schedule, - num_steps=num_steps, - clamp=clamp, - ) - return diffusion_sampler(noise, **kwargs) + def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: + angle = sigmas * pi / 2 + alpha = torch.cos(angle) + beta = torch.sin(angle) + return alpha, beta + + def forward( # type: ignore + self, noise: Tensor, num_steps: int, show_progress: bool = False, **kwargs + ) -> Tensor: + b = noise.shape[0] + sigmas = self.schedule(num_steps + 1, device=self.device) + sigmas = repeat(sigmas, "i -> i b", b=b) + sigmas_batch = rearrange(sigmas, "i b -> i b 1 1") + alphas, betas = self.get_alpha_beta(sigmas_batch) + x_noisy = noise * sigmas_batch[0] + progress_bar = tqdm(range(num_steps), disable=not show_progress) + + for i in progress_bar: + v_pred = self.net(x_noisy, sigmas[i], **kwargs) + x_pred = alphas[i] * x_noisy - betas[i] * v_pred + noise_pred = betas[i] * x_noisy + alphas[i] * v_pred + x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred + progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})") + + return x_noisy diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py deleted file mode 100644 index fedc82a..0000000 --- a/audio_diffusion_pytorch/model.py +++ /dev/null @@ -1,432 +0,0 @@ -from math import pi -from random import randint -from typing import Any, Optional, Sequence, Tuple, Union - -import torch -from audio_encoders_pytorch import Encoder1d -from einops import rearrange -from torch import Tensor, nn -from tqdm import tqdm - -from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion -from .modules import STFT, SinusoidalEmbedding, XUNet1d, rand_bool -from .utils import ( - closest_power_2, - default, - downsample, - exists, - groupby, - to_list, - upsample, -) - -""" -Diffusion Classes (generic for 1d data) -""" - - -class Model1d(nn.Module): - def __init__(self, unet_type: str = "base", **kwargs): - super().__init__() - diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) - self.unet = XUNet1d(type=unet_type, **kwargs) - self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs) - - def forward(self, x: Tensor, **kwargs) -> Tensor: - return self.diffusion(x, **kwargs) - - def sample(self, *args, **kwargs) -> Tensor: - return self.diffusion.sample(*args, **kwargs) - - -class DiffusionUpsampler1d(Model1d): - def __init__( - self, - in_channels: int, - factor: Union[int, Sequence[int]], - factor_features: Optional[int] = None, - *args, - **kwargs, - ): - self.factors = to_list(factor) - self.use_conditioning = exists(factor_features) - - default_kwargs = dict( - in_channels=in_channels, - context_channels=[in_channels], - context_features=factor_features if self.use_conditioning else None, - ) - super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore - - if self.use_conditioning: - assert exists(factor_features) - self.to_features = SinusoidalEmbedding(dim=factor_features) - - def random_reupsample(self, x: Tensor) -> Tuple[Tensor, Tensor]: - batch_size, device, factors = x.shape[0], x.device, self.factors - # Pick random factor for each batch element - random_factors = torch.randint(0, len(factors), (batch_size,), device=device) - x = x.clone() - - for i, factor in enumerate(factors): - # Pick random items with current factor, skip if 0 - n = torch.count_nonzero(random_factors == i) - if n > 0: - waveforms = x[random_factors == i] - # Downsample and reupsample items - downsampled = downsample(waveforms, factor=factor) - reupsampled = upsample(downsampled, factor=factor) - # Save reupsampled version in place - x[random_factors == i] = reupsampled - return x, random_factors - - def forward(self, x: Tensor, **kwargs) -> Tensor: - channels, factors = self.random_reupsample(x) - features = self.to_features(factors) if self.use_conditioning else None - return self.diffusion(x, channels_list=[channels], features=features, **kwargs) - - def sample( # type: ignore - self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs - ): - # Either user provides factor or we pick the first - batch_size, device = undersampled.shape[0], undersampled.device - factor = default(factor, self.factors[0]) - # Upsample channels by interpolation - channels = upsample(undersampled, factor=factor) - # Compute features if conditioning on factor - factors = torch.tensor([factor] * batch_size, device=device) - features = self.to_features(factors) if self.use_conditioning else None - # Diffuse upsampled - noise = torch.randn_like(channels) - default_kwargs = dict(channels_list=[channels], features=features) - return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore - - -class DiffusionAE1d(Model1d): - """Diffusion Auto Encoder""" - - def __init__( - self, in_channels: int, encoder: Encoder1d, encoder_inject_depth: int, **kwargs - ): - super().__init__( - in_channels=in_channels, - context_channels=[0] * encoder_inject_depth + [encoder.out_channels], - **kwargs, - ) - self.in_channels = in_channels - self.encoder = encoder - - def forward( # type: ignore - self, x: Tensor, with_info: bool = False, **kwargs - ) -> Union[Tensor, Tuple[Tensor, Any]]: - latent, info = self.encode(x, with_info=True) - loss = super().forward(x, channels_list=[latent], **kwargs) - return (loss, info) if with_info else loss - - def encode(self, *args, **kwargs): - return self.encoder(*args, **kwargs) - - def decode(self, latent: Tensor, **kwargs) -> Tensor: - b = latent.shape[0] - length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) - # Compute noise by inferring shape from latent length - noise = torch.randn(b, self.in_channels, length, device=latent.device) - # Compute context from latent - default_kwargs = dict(channels_list=[latent]) - # Decode by sampling while conditioning on latent channels - return super().sample(noise, **{**default_kwargs, **kwargs}) - - -class DiffusionVocoder1d(Model1d): - def __init__( - self, - in_channels: int, - stft_num_fft: int, - **kwargs, - ): - self.frequency_channels = stft_num_fft // 2 + 1 - spectrogram_channels = in_channels * self.frequency_channels - - stft_kwargs, kwargs = groupby("stft_", kwargs) - default_kwargs = dict( - in_channels=spectrogram_channels, context_channels=[spectrogram_channels] - ) - - super().__init__(**{**default_kwargs, **kwargs}) # type: ignore - self.stft = STFT(num_fft=stft_num_fft, **stft_kwargs) - - def forward_wave(self, x: Tensor, **kwargs) -> Tensor: - # Get magnitude and phase of true wave - magnitude, phase = self.stft.encode(x) - return self(magnitude, phase, **kwargs) - - def forward(self, magnitude: Tensor, phase: Tensor, **kwargs) -> Tensor: # type: ignore # noqa - magnitude = rearrange(magnitude, "b c f t -> b (c f) t") - phase = rearrange(phase, "b c f t -> b (c f) t") - # Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range) - return self.diffusion(phase / pi, channels_list=[magnitude], **kwargs) - - def sample(self, magnitude: Tensor, **kwargs): # type: ignore - b, c, f, t, device = *magnitude.shape, magnitude.device - magnitude_flat = rearrange(magnitude, "b c f t -> b (c f) t") - noise = torch.randn((b, c * f, t), device=device) - default_kwargs = dict(channels_list=[magnitude_flat]) - phase_flat = super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa - phase = rearrange(phase_flat, "b (c f) t -> b c f t", c=c) - wave = self.stft.decode(magnitude, phase * pi) - return wave - - -class DiffusionUpphaser1d(DiffusionUpsampler1d): - def __init__(self, **kwargs): - stft_kwargs, kwargs = groupby("stft_", kwargs) - super().__init__(**kwargs) - self.stft = STFT(**stft_kwargs) - - def random_rephase(self, x: Tensor) -> Tensor: - magnitude, phase = self.stft.encode(x) - phase_random = (torch.rand_like(phase) - 0.5) * 2 * pi - wave = self.stft.decode(magnitude, phase_random) - return wave - - def forward(self, x: Tensor, **kwargs) -> Tensor: - rephased = self.random_rephase(x) - resampled, factors = self.random_reupsample(rephased) - features = self.to_features(factors) if self.use_conditioning else None - return self.diffusion(x, channels_list=[resampled], features=features, **kwargs) - - -class DiffusionAR1d(Model1d): - def __init__( - self, - in_channels: int, - chunk_length: int, - upsample: int = 0, - dropout: float = 0.05, - verbose: int = 0, - **kwargs, - ): - self.in_channels = in_channels - self.chunk_length = chunk_length - self.dropout = dropout - self.upsample = upsample - self.verbose = verbose - super().__init__( - in_channels=in_channels, - context_channels=[in_channels * (2 if upsample > 0 else 1)], - **kwargs, - ) - - def reupsample(self, x: Tensor) -> Tensor: - x = x.clone() - x = downsample(x, factor=self.upsample) - x = upsample(x, factor=self.upsample) - return x - - def forward(self, x: Tensor, **kwargs) -> Tensor: - b, _, t, device = *x.shape, x.device - cl, num_chunks = self.chunk_length, t // self.chunk_length - assert num_chunks >= 2, "Input tensor length must be >= chunk_length * 2" - - # Get prev and current target chunks - chunk_index = randint(0, num_chunks - 2) - chunk_pos = cl * (chunk_index + 1) - chunk_prev = x[:, :, cl * chunk_index : chunk_pos] - chunk_curr = x[:, :, chunk_pos : cl * (chunk_index + 2)] - - # Randomly dropout source chunks to allow for zero AR start - if self.dropout > 0: - batch_mask = rand_bool(shape=(b, 1, 1), proba=self.dropout, device=device) - chunk_zeros = torch.zeros_like(chunk_prev) - chunk_prev = torch.where(batch_mask, chunk_zeros, chunk_prev) - - # Condition on previous chunk and reupsampled current if required - if self.upsample > 0: - chunk_reupsampled = self.reupsample(chunk_curr) - channels_list = [torch.cat([chunk_prev, chunk_reupsampled], dim=1)] - else: - channels_list = [chunk_prev] - - # Diffuse current current chunk - return self.diffusion(chunk_curr, channels_list=channels_list, **kwargs) - - def sample(self, x: Tensor, start: Optional[Tensor] = None, **kwargs) -> Tensor: # type: ignore # noqa - noise = x - - if self.upsample > 0: - # In this case we assume that x is the downsampled audio instead of noise - upsampled = upsample(x, factor=self.upsample) - noise = torch.randn_like(upsampled) - - b, c, t, device = *noise.shape, noise.device - cl, num_chunks = self.chunk_length, t // self.chunk_length - assert c == self.in_channels - assert t % cl == 0, "noise must be divisible by chunk_length" - - # Initialize previous chunk - if exists(start): - chunk_prev = start[:, :, -cl:] - else: - chunk_prev = torch.zeros(b, c, cl).to(device) - - # Computed chunks - chunks = [] - - for i in tqdm(range(num_chunks), disable=(self.verbose == 0)): - # Chunk noise - chunk_start, chunk_end = cl * i, cl * (i + 1) - noise_curr = noise[:, :, chunk_start:chunk_end] - - # Condition on previous chunk and artifically upsampled current if required - if self.upsample > 0: - chunk_upsampled = upsampled[:, :, chunk_start:chunk_end] - channels_list = [torch.cat([chunk_prev, chunk_upsampled], dim=1)] - else: - channels_list = [chunk_prev] - default_kwargs = dict(channels_list=channels_list) - - # Sample current chunk - chunk_curr = super().sample(noise_curr, **{**default_kwargs, **kwargs}) - - # Save chunk and use current as prev - chunks += [chunk_curr] - chunk_prev = chunk_curr - - return rearrange(chunks, "l b c t -> b c (l t)") - - -""" -Audio Diffusion Classes (specific for 1d audio data) -""" - - -def get_default_model_kwargs(): - return dict( - channels=128, - patch_size=16, - multipliers=[1, 2, 4, 4, 4, 4, 4], - factors=[4, 4, 4, 2, 2, 2], - num_blocks=[2, 2, 2, 2, 2, 2], - attentions=[0, 0, 0, 1, 1, 1, 1], - attention_heads=8, - attention_features=64, - attention_multiplier=2, - attention_use_rel_pos=False, - diffusion_type="v", - diffusion_sigma_distribution=UniformDistribution(), - ) - - -def get_default_sampling_kwargs(): - return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True) - - -class AudioDiffusionModel(Model1d): - def __init__(self, **kwargs): - super().__init__(**{**get_default_model_kwargs(), **kwargs}) - - def sample(self, *args, **kwargs): - return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs}) - - -class AudioDiffusionUpsampler(DiffusionUpsampler1d): - def __init__(self, in_channels: int, **kwargs): - default_kwargs = dict( - **get_default_model_kwargs(), - in_channels=in_channels, - context_channels=[in_channels], - ) - super().__init__(**{**default_kwargs, **kwargs}) # type: ignore - - def sample(self, *args, **kwargs): - return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs}) - - -class AudioDiffusionAE(DiffusionAE1d): - def __init__(self, in_channels: int, *args, **kwargs): - default_kwargs = dict( - **get_default_model_kwargs(), - in_channels=in_channels, - encoder=Encoder1d( - in_channels=in_channels, - patch_size=16, - channels=16, - multipliers=[2, 2, 4, 4, 4, 4, 4], - factors=[4, 4, 4, 2, 2, 2], - num_blocks=[2, 2, 2, 2, 2, 2], - out_channels=64, - ), - encoder_inject_depth=6, - ) - super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore - - def decode(self, *args, **kwargs): - return super().decode(*args, **{**get_default_sampling_kwargs(), **kwargs}) - - -class AudioDiffusionConditional(Model1d): - def __init__( - self, - embedding_features: int, - embedding_max_length: int, - embedding_mask_proba: float = 0.1, - **kwargs, - ): - self.embedding_mask_proba = embedding_mask_proba - default_kwargs = dict( - **get_default_model_kwargs(), - unet_type="cfg", - context_embedding_features=embedding_features, - context_embedding_max_length=embedding_max_length, - ) - super().__init__(**{**default_kwargs, **kwargs}) - - def forward(self, *args, **kwargs): - default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba) - return super().forward(*args, **{**default_kwargs, **kwargs}) - - def sample(self, *args, **kwargs): - default_kwargs = dict( - **get_default_sampling_kwargs(), - embedding_scale=5.0, - ) - return super().sample(*args, **{**default_kwargs, **kwargs}) - - -class AudioDiffusionVocoder(DiffusionVocoder1d): - def __init__(self, in_channels: int, **kwargs): - default_kwargs = dict( - in_channels=in_channels, - stft_num_fft=1023, - stft_hop_length=256, - channels=512, - multipliers=[3, 2, 1, 1, 1, 1, 1, 1], - factors=[1, 2, 2, 2, 2, 2, 2], - num_blocks=[1, 1, 1, 1, 1, 1, 1], - attentions=[0, 0, 0, 0, 1, 1, 1], - attention_heads=8, - attention_features=64, - attention_multiplier=2, - attention_use_rel_pos=False, - diffusion_type="v", - diffusion_sigma_distribution=UniformDistribution(), - ) - super().__init__(**{**default_kwargs, **kwargs}) # type: ignore - - def sample(self, *args, **kwargs): - default_kwargs = dict(**get_default_sampling_kwargs()) - return super().sample(*args, **{**default_kwargs, **kwargs}) - - -class AudioDiffusionUpphaser(DiffusionUpphaser1d): - def __init__(self, in_channels: int, **kwargs): - default_kwargs = dict( - **get_default_model_kwargs(), - in_channels=in_channels, - context_channels=[in_channels], - factor=1, - ) - super().__init__(**{**default_kwargs, **kwargs}) # type: ignore - - def sample(self, *args, **kwargs): - return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs}) diff --git a/audio_diffusion_pytorch/models.py b/audio_diffusion_pytorch/models.py new file mode 100644 index 0000000..ead9db7 --- /dev/null +++ b/audio_diffusion_pytorch/models.py @@ -0,0 +1,78 @@ +from typing import Any, Callable, Sequence, Tuple, Union + +import torch +from audio_encoders_pytorch import Encoder1d +from torch import Tensor, nn + +from .diffusion import VDiffusion, VSampler +from .unets import UNetV0 +from .utils import closest_power_2, groupby + + +class DiffusionModel(nn.Module): + def __init__( + self, + net_t: Callable = UNetV0, + diffusion_t: Callable = VDiffusion, + sampler_t: Callable = VSampler, + **kwargs + ): + super().__init__() + diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) + sampler_kwargs, kwargs = groupby("sampler_", kwargs) + + self.net = net_t(**kwargs) + self.diffusion = diffusion_t(net=self.net, **diffusion_kwargs) + self.sampler = sampler_t(net=self.net, **sampler_kwargs) + + def forward(self, *args, **kwargs) -> Tensor: + return self.diffusion(*args, **kwargs) + + def sample(self, *args, **kwargs) -> Tensor: + return self.sampler(*args, **kwargs) + + +class DiffusionAE(DiffusionModel): + """Diffusion Auto Encoder""" + + def __init__( + self, + in_channels: int, + channels: Sequence[int], + encoder: Encoder1d, + inject_depth: int, + **kwargs + ): + context_channels = [0] * len(channels) + context_channels[inject_depth] = encoder.out_channels + super().__init__( + in_channels=in_channels, + channels=channels, + context_channels=context_channels, + **kwargs, + ) + self.in_channels = in_channels + self.encoder = encoder + self.inject_depth = inject_depth + + def forward( # type: ignore + self, x: Tensor, with_info: bool = False, **kwargs + ) -> Union[Tensor, Tuple[Tensor, Any]]: + latent, info = self.encode(x, with_info=True) + channels = [None] * self.inject_depth + [latent] + loss = super().forward(x, channels=channels, **kwargs) + return (loss, info) if with_info else loss + + def encode(self, *args, **kwargs): + return self.encoder(*args, **kwargs) + + def decode(self, latent: Tensor, **kwargs) -> Tensor: + b = latent.shape[0] + length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) + # Compute noise by inferring shape from latent length + noise = torch.randn(b, self.in_channels, length, device=latent.device) + # Compute context from latent + channels = [None] * self.inject_depth + [latent] # type: ignore + default_kwargs = dict(channels=channels) + # Decode by sampling while conditioning on latent channels + return super().sample(noise, **{**default_kwargs, **kwargs}) diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py deleted file mode 100644 index a0af226..0000000 --- a/audio_diffusion_pytorch/modules.py +++ /dev/null @@ -1,1407 +0,0 @@ -from math import floor, log, pi -from typing import Any, List, Optional, Sequence, Tuple, Union - -import torch -import torch.nn as nn -from einops import rearrange, reduce, repeat -from einops.layers.torch import Rearrange -from torch import Tensor, einsum - -from .utils import closest_power_2, default, exists, groupby - -""" -Utils -""" - - -class ConditionedSequential(nn.Module): - def __init__(self, *modules): - super().__init__() - self.module_list = nn.ModuleList(*modules) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None): - for module in self.module_list: - x = module(x, mapping) - return x - - -""" -Convolutional Blocks -""" - - -def Conv1d(*args, **kwargs) -> nn.Module: - return nn.Conv1d(*args, **kwargs) - - -def ConvTranspose1d(*args, **kwargs) -> nn.Module: - return nn.ConvTranspose1d(*args, **kwargs) - - -def Downsample1d( - in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 -) -> nn.Module: - assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" - - return Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=factor * kernel_multiplier + 1, - stride=factor, - padding=factor * (kernel_multiplier // 2), - ) - - -def Upsample1d( - in_channels: int, out_channels: int, factor: int, use_nearest: bool = False -) -> nn.Module: - - if factor == 1: - return Conv1d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 - ) - - if use_nearest: - return nn.Sequential( - nn.Upsample(scale_factor=factor, mode="nearest"), - Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - ), - ) - else: - return ConvTranspose1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=factor * 2, - stride=factor, - padding=factor // 2 + factor % 2, - output_padding=factor % 2, - ) - - -class ConvBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - kernel_size: int = 3, - stride: int = 1, - padding: int = 1, - dilation: int = 1, - num_groups: int = 8, - use_norm: bool = True, - ) -> None: - super().__init__() - - self.groupnorm = ( - nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) - if use_norm - else nn.Identity() - ) - self.activation = nn.SiLU() - self.project = Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ) - - def forward( - self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None - ) -> Tensor: - x = self.groupnorm(x) - if exists(scale_shift): - scale, shift = scale_shift - x = x * (scale + 1) + shift - x = self.activation(x) - return self.project(x) - - -class MappingToScaleShift(nn.Module): - def __init__( - self, - features: int, - channels: int, - ): - super().__init__() - - self.to_scale_shift = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=features, out_features=channels * 2), - ) - - def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: - scale_shift = self.to_scale_shift(mapping) - scale_shift = rearrange(scale_shift, "b c -> b c 1") - scale, shift = scale_shift.chunk(2, dim=1) - return scale, shift - - -class ResnetBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - kernel_size: int = 3, - stride: int = 1, - padding: int = 1, - dilation: int = 1, - use_norm: bool = True, - num_groups: int = 8, - context_mapping_features: Optional[int] = None, - ) -> None: - super().__init__() - - self.use_mapping = exists(context_mapping_features) - - self.block1 = ConvBlock1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - use_norm=use_norm, - num_groups=num_groups, - ) - - if self.use_mapping: - assert exists(context_mapping_features) - self.to_scale_shift = MappingToScaleShift( - features=context_mapping_features, channels=out_channels - ) - - self.block2 = ConvBlock1d( - in_channels=out_channels, - out_channels=out_channels, - use_norm=use_norm, - num_groups=num_groups, - ) - - self.to_out = ( - Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) - if in_channels != out_channels - else nn.Identity() - ) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor: - assert_message = "context mapping required if context_mapping_features > 0" - assert not (self.use_mapping ^ exists(mapping)), assert_message - - h = self.block1(x) - - scale_shift = None - if self.use_mapping: - scale_shift = self.to_scale_shift(mapping) - - h = self.block2(h, scale_shift=scale_shift) - - return h + self.to_out(x) - - -class Patcher(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - patch_size: int, - context_mapping_features: Optional[int] = None, - ): - super().__init__() - assert_message = f"out_channels must be divisible by patch_size ({patch_size})" - assert out_channels % patch_size == 0, assert_message - self.patch_size = patch_size - - self.block = ResnetBlock1d( - in_channels=in_channels, - out_channels=out_channels // patch_size, - num_groups=1, - context_mapping_features=context_mapping_features, - ) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor: - x = self.block(x, mapping) - x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) - return x - - -class Unpatcher(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - patch_size: int, - context_mapping_features: Optional[int] = None, - ): - super().__init__() - assert_message = f"in_channels must be divisible by patch_size ({patch_size})" - assert in_channels % patch_size == 0, assert_message - self.patch_size = patch_size - - self.block = ResnetBlock1d( - in_channels=in_channels // patch_size, - out_channels=out_channels, - num_groups=1, - context_mapping_features=context_mapping_features, - ) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor: - x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) - x = self.block(x, mapping) - return x - - -""" -Attention Components -""" - - -class RelativePositionBias(nn.Module): - def __init__(self, num_buckets: int, max_distance: int, num_heads: int): - super().__init__() - self.num_buckets = num_buckets - self.max_distance = max_distance - self.num_heads = num_heads - self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) - - @staticmethod - def _relative_position_bucket( - relative_position: Tensor, num_buckets: int, max_distance: int - ): - num_buckets //= 2 - ret = (relative_position >= 0).to(torch.long) * num_buckets - n = torch.abs(relative_position) - - max_exact = num_buckets // 2 - is_small = n < max_exact - - val_if_large = ( - max_exact - + ( - torch.log(n.float() / max_exact) - / log(max_distance / max_exact) - * (num_buckets - max_exact) - ).long() - ) - val_if_large = torch.min( - val_if_large, torch.full_like(val_if_large, num_buckets - 1) - ) - - ret += torch.where(is_small, n, val_if_large) - return ret - - def forward(self, num_queries: int, num_keys: int) -> Tensor: - i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device - q_pos = torch.arange(j - i, j, dtype=torch.long, device=device) - k_pos = torch.arange(j, dtype=torch.long, device=device) - rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1") - - relative_position_bucket = self._relative_position_bucket( - rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance - ) - - bias = self.relative_attention_bias(relative_position_bucket) - bias = rearrange(bias, "m n h -> 1 h m n") - return bias - - -def FeedForward(features: int, multiplier: int) -> nn.Module: - mid_features = features * multiplier - return nn.Sequential( - nn.Linear(in_features=features, out_features=mid_features), - nn.GELU(), - nn.Linear(in_features=mid_features, out_features=features), - ) - - -class AttentionBase(nn.Module): - def __init__( - self, - features: int, - *, - head_features: int, - num_heads: int, - use_rel_pos: bool, - rel_pos_num_buckets: Optional[int] = None, - rel_pos_max_distance: Optional[int] = None, - ): - super().__init__() - self.scale = head_features ** -0.5 - self.num_heads = num_heads - self.use_rel_pos = use_rel_pos - mid_features = head_features * num_heads - - if use_rel_pos: - assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance) - self.rel_pos = RelativePositionBias( - num_buckets=rel_pos_num_buckets, - max_distance=rel_pos_max_distance, - num_heads=num_heads, - ) - - self.to_out = nn.Linear(in_features=mid_features, out_features=features) - - def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: - # Split heads - q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) - k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads) - v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) - # Compute similarity matrix - sim = einsum("... n d, ... m d -> ... n m", q, k) - sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim - sim = sim * self.scale - # Get attention matrix with softmax - attn = sim.softmax(dim=-1) - # Compute values - out = einsum("... n m, ... m d -> ... n d", attn, v) - out = rearrange(out, "b h n d -> b n (h d)") - return self.to_out(out) - - -class Attention(nn.Module): - def __init__( - self, - features: int, - *, - head_features: int, - num_heads: int, - context_features: Optional[int] = None, - use_rel_pos: bool, - rel_pos_num_buckets: Optional[int] = None, - rel_pos_max_distance: Optional[int] = None, - ): - super().__init__() - self.context_features = context_features - mid_features = head_features * num_heads - context_features = default(context_features, features) - - self.norm = nn.LayerNorm(features) - self.norm_context = nn.LayerNorm(context_features) - self.to_q = nn.Linear( - in_features=features, out_features=mid_features, bias=False - ) - self.to_kv = nn.Linear( - in_features=context_features, out_features=mid_features * 2, bias=False - ) - self.attention = AttentionBase( - features, - num_heads=num_heads, - head_features=head_features, - use_rel_pos=use_rel_pos, - rel_pos_num_buckets=rel_pos_num_buckets, - rel_pos_max_distance=rel_pos_max_distance, - ) - - def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor: - assert_message = "You must provide a context when using context_features" - assert not self.context_features or exists(context), assert_message - # Use context if provided - context = default(context, x) - # Normalize then compute q from input and k,v from context - x, context = self.norm(x), self.norm_context(context) - q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) - # Compute and return attention - return self.attention(q, k, v) - - -""" -Transformer Blocks -""" - - -class TransformerBlock(nn.Module): - def __init__( - self, - features: int, - num_heads: int, - head_features: int, - multiplier: int, - use_rel_pos: bool, - rel_pos_num_buckets: Optional[int] = None, - rel_pos_max_distance: Optional[int] = None, - context_features: Optional[int] = None, - ): - super().__init__() - - self.use_cross_attention = exists(context_features) and context_features > 0 - - self.attention = Attention( - features=features, - num_heads=num_heads, - head_features=head_features, - use_rel_pos=use_rel_pos, - rel_pos_num_buckets=rel_pos_num_buckets, - rel_pos_max_distance=rel_pos_max_distance, - ) - - if self.use_cross_attention: - self.cross_attention = Attention( - features=features, - num_heads=num_heads, - head_features=head_features, - context_features=context_features, - use_rel_pos=use_rel_pos, - rel_pos_num_buckets=rel_pos_num_buckets, - rel_pos_max_distance=rel_pos_max_distance, - ) - - self.feed_forward = FeedForward(features=features, multiplier=multiplier) - - def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor: - x = self.attention(x) + x - if self.use_cross_attention: - x = self.cross_attention(x, context=context) + x - x = self.feed_forward(x) + x - return x - - -""" -Transformers -""" - - -class Transformer1d(nn.Module): - def __init__( - self, - num_layers: int, - channels: int, - num_heads: int, - head_features: int, - multiplier: int, - use_rel_pos: bool = False, - rel_pos_num_buckets: Optional[int] = None, - rel_pos_max_distance: Optional[int] = None, - context_features: Optional[int] = None, - ): - super().__init__() - - self.to_in = nn.Sequential( - nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), - Conv1d( - in_channels=channels, - out_channels=channels, - kernel_size=1, - ), - Rearrange("b c t -> b t c"), - ) - - self.blocks = nn.ModuleList( - [ - TransformerBlock( - features=channels, - head_features=head_features, - num_heads=num_heads, - multiplier=multiplier, - context_features=context_features, - use_rel_pos=use_rel_pos, - rel_pos_num_buckets=rel_pos_num_buckets, - rel_pos_max_distance=rel_pos_max_distance, - ) - for i in range(num_layers) - ] - ) - - self.to_out = nn.Sequential( - Rearrange("b t c -> b c t"), - Conv1d( - in_channels=channels, - out_channels=channels, - kernel_size=1, - ), - ) - - def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor: - x = self.to_in(x) - for block in self.blocks: - x = block(x, context=context) - x = self.to_out(x) - return x - - -""" -Time Embeddings -""" - - -class SinusoidalEmbedding(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.dim = dim - - def forward(self, x: Tensor) -> Tensor: - device, half_dim = x.device, self.dim // 2 - emb = torch.tensor(log(10000) / (half_dim - 1), device=device) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") - return torch.cat((emb.sin(), emb.cos()), dim=-1) - - -class LearnedPositionalEmbedding(nn.Module): - """Used for continuous time""" - - def __init__(self, dim: int): - super().__init__() - assert (dim % 2) == 0 - half_dim = dim // 2 - self.weights = nn.Parameter(torch.randn(half_dim)) - - def forward(self, x: Tensor) -> Tensor: - x = rearrange(x, "b -> b 1") - freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi - fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) - fouriered = torch.cat((x, fouriered), dim=-1) - return fouriered - - -def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: - return nn.Sequential( - LearnedPositionalEmbedding(dim), - nn.Linear(in_features=dim + 1, out_features=out_features), - ) - - -""" -Encoder/Decoder Components -""" - - -class DownsampleBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - factor: int, - num_groups: int, - num_layers: int, - kernel_multiplier: int = 2, - use_pre_downsample: bool = True, - use_skip: bool = False, - extract_channels: int = 0, - context_channels: int = 0, - num_transformer_blocks: int = 0, - attention_heads: Optional[int] = None, - attention_features: Optional[int] = None, - attention_multiplier: Optional[int] = None, - attention_use_rel_pos: Optional[bool] = None, - attention_rel_pos_max_distance: Optional[int] = None, - attention_rel_pos_num_buckets: Optional[int] = None, - context_mapping_features: Optional[int] = None, - context_embedding_features: Optional[int] = None, - ): - super().__init__() - self.use_pre_downsample = use_pre_downsample - self.use_skip = use_skip - self.use_transformer = num_transformer_blocks > 0 - self.use_extract = extract_channels > 0 - self.use_context = context_channels > 0 - - channels = out_channels if use_pre_downsample else in_channels - - self.downsample = Downsample1d( - in_channels=in_channels, - out_channels=out_channels, - factor=factor, - kernel_multiplier=kernel_multiplier, - ) - - self.blocks = nn.ModuleList( - [ - ResnetBlock1d( - in_channels=channels + context_channels if i == 0 else channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - ) - for i in range(num_layers) - ] - ) - - if self.use_transformer: - assert ( - exists(attention_heads) - and exists(attention_features) - and exists(attention_multiplier) - and exists(attention_use_rel_pos) - ) - self.transformer = Transformer1d( - num_layers=num_transformer_blocks, - channels=channels, - num_heads=attention_heads, - head_features=attention_features, - multiplier=attention_multiplier, - context_features=context_embedding_features, - use_rel_pos=attention_use_rel_pos, - rel_pos_num_buckets=attention_rel_pos_num_buckets, - rel_pos_max_distance=attention_rel_pos_max_distance, - ) - - if self.use_extract: - num_extract_groups = min(num_groups, extract_channels) - self.to_extracted = ResnetBlock1d( - in_channels=out_channels, - out_channels=extract_channels, - num_groups=num_extract_groups, - ) - - def forward( - self, - x: Tensor, - *, - mapping: Optional[Tensor] = None, - channels: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: - - if self.use_pre_downsample: - x = self.downsample(x) - - if self.use_context and exists(channels): - x = torch.cat([x, channels], dim=1) - - skips = [] - for block in self.blocks: - x = block(x, mapping=mapping) - skips += [x] if self.use_skip else [] - - if self.use_transformer: - x = self.transformer(x, context=embedding) - skips += [x] if self.use_skip else [] - - if not self.use_pre_downsample: - x = self.downsample(x) - - if self.use_extract: - extracted = self.to_extracted(x) - return x, extracted - - return (x, skips) if self.use_skip else x - - -class UpsampleBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - factor: int, - num_layers: int, - num_groups: int, - use_nearest: bool = False, - use_pre_upsample: bool = False, - use_skip: bool = False, - skip_channels: int = 0, - use_skip_scale: bool = False, - extract_channels: int = 0, - num_transformer_blocks: int = 0, - attention_heads: Optional[int] = None, - attention_features: Optional[int] = None, - attention_multiplier: Optional[int] = None, - attention_use_rel_pos: Optional[bool] = None, - attention_rel_pos_max_distance: Optional[int] = None, - attention_rel_pos_num_buckets: Optional[int] = None, - context_mapping_features: Optional[int] = None, - context_embedding_features: Optional[int] = None, - ): - super().__init__() - - self.use_extract = extract_channels > 0 - self.use_pre_upsample = use_pre_upsample - self.use_transformer = num_transformer_blocks > 0 - self.use_skip = use_skip - self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 - - channels = out_channels if use_pre_upsample else in_channels - - self.blocks = nn.ModuleList( - [ - ResnetBlock1d( - in_channels=channels + skip_channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - ) - for _ in range(num_layers) - ] - ) - - if self.use_transformer: - assert ( - exists(attention_heads) - and exists(attention_features) - and exists(attention_multiplier) - and exists(attention_use_rel_pos) - ) - self.transformer = Transformer1d( - num_layers=num_transformer_blocks, - channels=channels, - num_heads=attention_heads, - head_features=attention_features, - multiplier=attention_multiplier, - context_features=context_embedding_features, - use_rel_pos=attention_use_rel_pos, - rel_pos_num_buckets=attention_rel_pos_num_buckets, - rel_pos_max_distance=attention_rel_pos_max_distance, - ) - - self.upsample = Upsample1d( - in_channels=in_channels, - out_channels=out_channels, - factor=factor, - use_nearest=use_nearest, - ) - - if self.use_extract: - num_extract_groups = min(num_groups, extract_channels) - self.to_extracted = ResnetBlock1d( - in_channels=out_channels, - out_channels=extract_channels, - num_groups=num_extract_groups, - ) - - def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: - return torch.cat([x, skip * self.skip_scale], dim=1) - - def forward( - self, - x: Tensor, - *, - skips: Optional[List[Tensor]] = None, - mapping: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - ) -> Union[Tuple[Tensor, Tensor], Tensor]: - - if self.use_pre_upsample: - x = self.upsample(x) - - for block in self.blocks: - x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x - x = block(x, mapping=mapping) - - if self.use_transformer: - x = self.transformer(x, context=embedding) - - if not self.use_pre_upsample: - x = self.upsample(x) - - if self.use_extract: - extracted = self.to_extracted(x) - return x, extracted - - return x - - -class BottleneckBlock1d(nn.Module): - def __init__( - self, - channels: int, - *, - num_groups: int, - num_transformer_blocks: int = 0, - attention_heads: Optional[int] = None, - attention_features: Optional[int] = None, - attention_multiplier: Optional[int] = None, - attention_use_rel_pos: Optional[bool] = None, - attention_rel_pos_max_distance: Optional[int] = None, - attention_rel_pos_num_buckets: Optional[int] = None, - context_mapping_features: Optional[int] = None, - context_embedding_features: Optional[int] = None, - ): - super().__init__() - self.use_transformer = num_transformer_blocks > 0 - - self.pre_block = ResnetBlock1d( - in_channels=channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - ) - - if self.use_transformer: - assert ( - exists(attention_heads) - and exists(attention_features) - and exists(attention_multiplier) - and exists(attention_use_rel_pos) - ) - self.transformer = Transformer1d( - num_layers=num_transformer_blocks, - channels=channels, - num_heads=attention_heads, - head_features=attention_features, - multiplier=attention_multiplier, - context_features=context_embedding_features, - use_rel_pos=attention_use_rel_pos, - rel_pos_num_buckets=attention_rel_pos_num_buckets, - rel_pos_max_distance=attention_rel_pos_max_distance, - ) - - self.post_block = ResnetBlock1d( - in_channels=channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - ) - - def forward( - self, - x: Tensor, - *, - mapping: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - ) -> Tensor: - x = self.pre_block(x, mapping=mapping) - if self.use_transformer: - x = self.transformer(x, context=embedding) - x = self.post_block(x, mapping=mapping) - return x - - -""" -UNet -""" - - -class UNet1d(nn.Module): - def __init__( - self, - in_channels: int, - channels: int, - multipliers: Sequence[int], - factors: Sequence[int], - num_blocks: Sequence[int], - attentions: Sequence[int], - patch_size: int = 1, - resnet_groups: int = 8, - use_context_time: bool = True, - kernel_multiplier_downsample: int = 2, - use_nearest_upsample: bool = False, - use_skip_scale: bool = True, - use_stft: bool = False, - use_stft_context: bool = False, - out_channels: Optional[int] = None, - context_features: Optional[int] = None, - context_features_multiplier: int = 4, - context_channels: Optional[Sequence[int]] = None, - context_embedding_features: Optional[int] = None, - **kwargs, - ): - super().__init__() - out_channels = default(out_channels, in_channels) - context_channels = list(default(context_channels, [])) - num_layers = len(multipliers) - 1 - use_context_features = exists(context_features) - use_context_channels = len(context_channels) > 0 - context_mapping_features = None - - attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) - - self.num_layers = num_layers - self.use_context_time = use_context_time - self.use_context_features = use_context_features - self.use_context_channels = use_context_channels - self.use_stft = use_stft - self.use_stft_context = use_stft_context - - self.context_features = context_features - context_channels_pad_length = num_layers + 1 - len(context_channels) - context_channels = context_channels + [0] * context_channels_pad_length - self.context_channels = context_channels - self.context_embedding_features = context_embedding_features - - if use_context_channels: - has_context = [c > 0 for c in context_channels] - self.has_context = has_context - self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] - - assert ( - len(factors) == num_layers - and len(attentions) >= num_layers - and len(num_blocks) == num_layers - ) - - if use_context_time or use_context_features: - context_mapping_features = channels * context_features_multiplier - - self.to_mapping = nn.Sequential( - nn.Linear(context_mapping_features, context_mapping_features), - nn.GELU(), - nn.Linear(context_mapping_features, context_mapping_features), - nn.GELU(), - ) - - if use_context_time: - assert exists(context_mapping_features) - self.to_time = nn.Sequential( - TimePositionalEmbedding( - dim=channels, out_features=context_mapping_features - ), - nn.GELU(), - ) - - if use_context_features: - assert exists(context_features) and exists(context_mapping_features) - self.to_features = nn.Sequential( - nn.Linear( - in_features=context_features, out_features=context_mapping_features - ), - nn.GELU(), - ) - - if use_stft: - stft_kwargs, kwargs = groupby("stft_", kwargs) - assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" - stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 - in_channels *= stft_channels - out_channels *= stft_channels - context_channels[0] *= stft_channels if use_stft_context else 1 - assert exists(in_channels) and exists(out_channels) - self.stft = STFT(**stft_kwargs) - - assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" - - self.to_in = Patcher( - in_channels=in_channels + context_channels[0], - out_channels=channels * multipliers[0], - patch_size=patch_size, - context_mapping_features=context_mapping_features, - ) - - self.downsamples = nn.ModuleList( - [ - DownsampleBlock1d( - in_channels=channels * multipliers[i], - out_channels=channels * multipliers[i + 1], - context_mapping_features=context_mapping_features, - context_channels=context_channels[i + 1], - context_embedding_features=context_embedding_features, - num_layers=num_blocks[i], - factor=factors[i], - kernel_multiplier=kernel_multiplier_downsample, - num_groups=resnet_groups, - use_pre_downsample=True, - use_skip=True, - num_transformer_blocks=attentions[i], - **attention_kwargs, - ) - for i in range(num_layers) - ] - ) - - self.bottleneck = BottleneckBlock1d( - channels=channels * multipliers[-1], - context_mapping_features=context_mapping_features, - context_embedding_features=context_embedding_features, - num_groups=resnet_groups, - num_transformer_blocks=attentions[-1], - **attention_kwargs, - ) - - self.upsamples = nn.ModuleList( - [ - UpsampleBlock1d( - in_channels=channels * multipliers[i + 1], - out_channels=channels * multipliers[i], - context_mapping_features=context_mapping_features, - context_embedding_features=context_embedding_features, - num_layers=num_blocks[i] + (1 if attentions[i] else 0), - factor=factors[i], - use_nearest=use_nearest_upsample, - num_groups=resnet_groups, - use_skip_scale=use_skip_scale, - use_pre_upsample=False, - use_skip=True, - skip_channels=channels * multipliers[i + 1], - num_transformer_blocks=attentions[i], - **attention_kwargs, - ) - for i in reversed(range(num_layers)) - ] - ) - - self.to_out = Unpatcher( - in_channels=channels * multipliers[0], - out_channels=out_channels, - patch_size=patch_size, - context_mapping_features=context_mapping_features, - ) - - def get_channels( - self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 - ) -> Optional[Tensor]: - """Gets context channels at `layer` and checks that shape is correct""" - use_context_channels = self.use_context_channels and self.has_context[layer] - if not use_context_channels: - return None - assert exists(channels_list), "Missing context" - # Get channels index (skipping zero channel contexts) - channels_id = self.channels_ids[layer] - # Get channels - channels = channels_list[channels_id] - message = f"Missing context for layer {layer} at index {channels_id}" - assert exists(channels), message - # Check channels - num_channels = self.context_channels[layer] - message = f"Expected context with {num_channels} channels at idx {channels_id}" - assert channels.shape[1] == num_channels, message - # STFT channels if requested - channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa - return channels - - def get_mapping( - self, time: Optional[Tensor] = None, features: Optional[Tensor] = None - ) -> Optional[Tensor]: - """Combines context time features and features into mapping""" - items, mapping = [], None - # Compute time features - if self.use_context_time: - assert_message = "use_context_time=True but no time features provided" - assert exists(time), assert_message - items += [self.to_time(time)] - # Compute features - if self.use_context_features: - assert_message = "context_features exists but no features provided" - assert exists(features), assert_message - items += [self.to_features(features)] - # Compute joint mapping - if self.use_context_time or self.use_context_features: - mapping = reduce(torch.stack(items), "n b m -> b m", "sum") - mapping = self.to_mapping(mapping) - return mapping - - def forward( - self, - x: Tensor, - time: Optional[Tensor] = None, - *, - features: Optional[Tensor] = None, - channels_list: Optional[Sequence[Tensor]] = None, - embedding: Optional[Tensor] = None, - ) -> Tensor: - channels = self.get_channels(channels_list, layer=0) - # Apply stft if required - x = self.stft.encode1d(x) if self.use_stft else x # type: ignore - # Concat context channels at layer 0 if provided - x = torch.cat([x, channels], dim=1) if exists(channels) else x - # Compute mapping from time and features - mapping = self.get_mapping(time, features) - x = self.to_in(x, mapping) - skips_list = [x] - - for i, downsample in enumerate(self.downsamples): - channels = self.get_channels(channels_list, layer=i + 1) - x, skips = downsample( - x, mapping=mapping, channels=channels, embedding=embedding - ) - skips_list += [skips] - - x = self.bottleneck(x, mapping=mapping, embedding=embedding) - - for i, upsample in enumerate(self.upsamples): - skips = skips_list.pop() - x = upsample(x, skips=skips, mapping=mapping, embedding=embedding) - - x += skips_list.pop() - x = self.to_out(x, mapping) - x = self.stft.decode1d(x) if self.use_stft else x - - return x - - -""" Conditioning Modules """ - - -class FixedEmbedding(nn.Module): - def __init__(self, max_length: int, features: int): - super().__init__() - self.max_length = max_length - self.embedding = nn.Embedding(max_length, features) - - def forward(self, x: Tensor) -> Tensor: - batch_size, length, device = *x.shape[0:2], x.device - assert_message = "Input sequence length must be <= max_length" - assert length <= self.max_length, assert_message - position = torch.arange(length, device=device) - fixed_embedding = self.embedding(position) - fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) - return fixed_embedding - - -def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: - if proba == 1: - return torch.ones(shape, device=device, dtype=torch.bool) - elif proba == 0: - return torch.zeros(shape, device=device, dtype=torch.bool) - else: - return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) - - -class UNetCFG1d(UNet1d): - - """UNet1d with Classifier-Free Guidance""" - - def __init__( - self, - context_embedding_max_length: int, - context_embedding_features: int, - **kwargs, - ): - super().__init__( - context_embedding_features=context_embedding_features, **kwargs - ) - self.fixed_embedding = FixedEmbedding( - max_length=context_embedding_max_length, features=context_embedding_features - ) - - def forward( # type: ignore - self, - x: Tensor, - time: Tensor, - *, - embedding: Tensor, - embedding_scale: float = 1.0, - embedding_mask_proba: float = 0.0, - **kwargs, - ) -> Tensor: - b, device = embedding.shape[0], embedding.device - fixed_embedding = self.fixed_embedding(embedding) - - if embedding_mask_proba > 0.0: - # Randomly mask embedding - batch_mask = rand_bool( - shape=(b, 1, 1), proba=embedding_mask_proba, device=device - ) - embedding = torch.where(batch_mask, fixed_embedding, embedding) - - if embedding_scale != 1.0: - # Compute both normal and fixed embedding outputs - out = super().forward(x, time, embedding=embedding, **kwargs) - out_masked = super().forward(x, time, embedding=fixed_embedding, **kwargs) - # Scale conditional output using classifier-free guidance - return out_masked + (out - out_masked) * embedding_scale - else: - return super().forward(x, time, embedding=embedding, **kwargs) - - -class UNetNCCA1d(UNet1d): - - """UNet1d with Noise Channel Conditioning Augmentation""" - - def __init__(self, context_features: int, **kwargs): - super().__init__(context_features=context_features, **kwargs) - self.embedder = NumberEmbedder(features=context_features) - - def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: - x = x if torch.is_tensor(x) else torch.tensor(x) - return x.expand(shape) - - def forward( # type: ignore - self, - x: Tensor, - time: Tensor, - *, - channels_list: Sequence[Tensor], - channels_augmentation: Union[ - bool, Sequence[bool], Sequence[Sequence[bool]], Tensor - ] = False, - channels_scale: Union[ - float, Sequence[float], Sequence[Sequence[float]], Tensor - ] = 0, - **kwargs, - ) -> Tensor: - b, n = x.shape[0], len(channels_list) - channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) - channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) - - # Augmentation (for each channel list item) - for i in range(n): - scale = channels_scale[:, i] * channels_augmentation[:, i] - scale = rearrange(scale, "b -> b 1 1") - item = channels_list[i] - channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa - - # Scale embedding (sum reduction if more than one channel list item) - channels_scale_emb = self.embedder(channels_scale) - channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") - - return super().forward( - x=x, - time=time, - channels_list=channels_list, - features=channels_scale_emb, - **kwargs, - ) - - -class UNetAll1d(UNetCFG1d, UNetNCCA1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, *args, **kwargs): # type: ignore - return UNetCFG1d.forward(self, *args, **kwargs) - - -def XUNet1d(type: str = "base", **kwargs) -> UNet1d: - if type == "base": - return UNet1d(**kwargs) - elif type == "all": - return UNetAll1d(**kwargs) - elif type == "cfg": - return UNetCFG1d(**kwargs) - elif type == "ncca": - return UNetNCCA1d(**kwargs) - else: - raise ValueError(f"Unknown XUNet1d type: {type}") - - -class T5Embedder(nn.Module): - def __init__(self, model: str = "t5-base", max_length: int = 64): - super().__init__() - from transformers import AutoTokenizer, T5EncoderModel - - self.tokenizer = AutoTokenizer.from_pretrained(model) - self.transformer = T5EncoderModel.from_pretrained(model) - self.max_length = max_length - - @torch.no_grad() - def forward(self, texts: List[str]) -> Tensor: - - encoded = self.tokenizer( - texts, - truncation=True, - max_length=self.max_length, - padding="max_length", - return_tensors="pt", - ) - - device = next(self.transformer.parameters()).device - input_ids = encoded["input_ids"].to(device) - attention_mask = encoded["attention_mask"].to(device) - - self.transformer.eval() - - embedding = self.transformer( - input_ids=input_ids, attention_mask=attention_mask - )["last_hidden_state"] - - return embedding - - -class NumberEmbedder(nn.Module): - def __init__( - self, - features: int, - dim: int = 256, - ): - super().__init__() - self.features = features - self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) - - def forward(self, x: Union[List[float], Tensor]) -> Tensor: - if not torch.is_tensor(x): - device = next(self.embedding.parameters()).device - x = torch.tensor(x, device=device) - assert isinstance(x, Tensor) - shape = x.shape - x = rearrange(x, "... -> (...)") - embedding = self.embedding(x) - x = embedding.view(*shape, self.features) - return x # type: ignore - - -""" -Audio Transforms -""" - - -class STFT(nn.Module): - """Helper for torch stft and istft""" - - def __init__( - self, - num_fft: int = 1023, - hop_length: int = 256, - window_length: Optional[int] = None, - length: Optional[int] = None, - use_complex: bool = False, - ): - super().__init__() - self.num_fft = num_fft - self.hop_length = default(hop_length, floor(num_fft // 4)) - self.window_length = default(window_length, num_fft) - self.length = length - self.register_buffer("window", torch.hann_window(self.window_length)) - self.use_complex = use_complex - - def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: - b = wave.shape[0] - wave = rearrange(wave, "b c t -> (b c) t") - - stft = torch.stft( - wave, - n_fft=self.num_fft, - hop_length=self.hop_length, - win_length=self.window_length, - window=self.window, # type: ignore - return_complex=True, - normalized=True, - ) - - if self.use_complex: - # Returns real and imaginary - stft_a, stft_b = stft.real, stft.imag - else: - # Returns magnitude and phase matrices - magnitude, phase = torch.abs(stft), torch.angle(stft) - stft_a, stft_b = magnitude, phase - - return rearrange(stft_a, "(b c) f l -> b c f l", b=b), rearrange(stft_b, "(b c) f l -> b c f l", b=b) - - def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: - b, l = stft_a.shape[0], stft_a.shape[-1] # noqa - length = closest_power_2(l * self.hop_length) - - stft_a = rearrange(stft_a, "b c f l -> (b c) f l") - stft_b = rearrange(stft_b, "b c f l -> (b c) f l") - - if self.use_complex: - real, imag = stft_a, stft_b - else: - magnitude, phase = stft_a, stft_b - real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) - - stft = torch.stack([real, imag], dim=-1) - - wave = torch.istft( - stft, - n_fft=self.num_fft, - hop_length=self.hop_length, - win_length=self.window_length, - window=self.window, # type: ignore - length=default(self.length, length), - normalized=True, - ) - - return rearrange(wave, "(b c) t -> b c t", b=b) - - def encode1d( - self, wave: Tensor, stacked: bool = True - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - stft_a, stft_b = self.encode(wave) - stft_a = rearrange(stft_a, "b c f l -> b (c f) l") - stft_b = rearrange(stft_b, "b c f l -> b (c f) l") - return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) - - def decode1d(self, stft_pair: Tensor) -> Tensor: - f = self.num_fft // 2 + 1 - stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) - stft_a = rearrange(stft_a, "b (c f) l -> b c f l", f=f) - stft_b = rearrange(stft_b, "b (c f) l -> b c f l", f=f) - return self.decode(stft_a, stft_b) diff --git a/audio_diffusion_pytorch/unets.py b/audio_diffusion_pytorch/unets.py new file mode 100644 index 0000000..37680c1 --- /dev/null +++ b/audio_diffusion_pytorch/unets.py @@ -0,0 +1,155 @@ +from typing import Callable, Optional, Sequence + +from a_unet import ( + ClassifierFreeGuidancePlugin, + Conv, + Module, + TextConditioningPlugin, + TimeConditioningPlugin, + default, + exists, +) +from a_unet.apex import ( + AttentionItem, + CrossAttentionItem, + InjectChannelsItem, + ModulationItem, + ResnetItem, + SkipCat, + SkipModulate, + XBlock, + XUNet, +) +from torch import Tensor, nn + +""" +UNets (built with a-unet: https://github.com/archinetai/a-unet) +""" + + +def UNetV0( + dim: int, + in_channels: int, + channels: Sequence[int], + factors: Sequence[int], + items: Sequence[int], + attentions: Optional[Sequence[int]] = None, + cross_attentions: Optional[Sequence[int]] = None, + context_channels: Optional[Sequence[int]] = None, + attention_features: Optional[int] = None, + attention_heads: Optional[int] = None, + embedding_features: Optional[int] = None, + resnet_groups: int = 8, + use_modulation: bool = True, + modulation_features: int = 1024, + embedding_max_length: Optional[int] = None, + use_time_conditioning: bool = True, + use_embedding_cfg: bool = False, + use_text_conditioning: bool = False, + out_channels: Optional[int] = None, +): + # Set defaults and check lengths + num_layers = len(channels) + attentions = default(attentions, [0] * num_layers) + cross_attentions = default(cross_attentions, [0] * num_layers) + context_channels = default(context_channels, [0] * num_layers) + xs = (channels, factors, items, attentions, cross_attentions, context_channels) + assert all(len(x) == num_layers for x in xs) # type: ignore + + # Define UNet type + UNetV0 = XUNet + + if use_embedding_cfg: + msg = "use_embedding_cfg requires embedding_max_length" + assert exists(embedding_max_length), msg + UNetV0 = ClassifierFreeGuidancePlugin(UNetV0, embedding_max_length) + + if use_text_conditioning: + UNetV0 = TextConditioningPlugin(UNetV0) + + if use_time_conditioning: + assert use_modulation, "use_time_conditioning requires use_modulation=True" + UNetV0 = TimeConditioningPlugin(UNetV0) + + # Build + return UNetV0( + dim=dim, + in_channels=in_channels, + out_channels=out_channels, + blocks=[ + XBlock( + channels=channels, + factor=factor, + context_channels=ctx_channels, + items=( + [ResnetItem] + + [ModulationItem] * use_modulation + + [InjectChannelsItem] * (ctx_channels > 0) + + [AttentionItem] * att + + [CrossAttentionItem] * cross + ) + * items, + ) + for channels, factor, items, att, cross, ctx_channels in zip(*xs) # type: ignore # noqa + ], + skip_t=SkipModulate if use_modulation else SkipCat, + attention_features=attention_features, + attention_heads=attention_heads, + embedding_features=embedding_features, + modulation_features=modulation_features, + resnet_groups=resnet_groups, + ) + + +""" +Plugins +""" + + +def LTPlugin( + net_t: Callable, num_filters: int, window_length: int, stride: int +) -> Callable[..., nn.Module]: + """Learned Transform Plugin""" + + def Net( + dim: int, in_channels: int, out_channels: Optional[int] = None, **kwargs + ) -> nn.Module: + out_channels = default(out_channels, in_channels) + in_channel_transform = in_channels * num_filters + out_channel_transform = out_channels * num_filters # type: ignore + + padding = window_length // 2 - stride // 2 + encode = Conv( + dim=dim, + in_channels=in_channels, + out_channels=in_channel_transform, + kernel_size=window_length, + stride=stride, + padding=padding, + padding_mode="reflect", + bias=False, + ) + decode = nn.ConvTranspose1d( + in_channels=out_channel_transform, + out_channels=out_channels, # type: ignore + kernel_size=window_length, + stride=stride, + padding=padding, + bias=False, + ) + net = net_t( # type: ignore + dim=dim, + in_channels=in_channel_transform, + out_channels=out_channel_transform, + **kwargs + ) + + def forward(x: Tensor, *args, **kwargs): + x = encode(x) + x = net(x, *args, **kwargs) + x = decode(x) + return x + + return Module([encode, decode, net], forward) + + return Net diff --git a/setup.py b/setup.py index 9f9a047..b808b57 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ "einops>=0.4", "einops-exts>=0.0.3", "audio-encoders-pytorch", + "a-unet", ], classifiers=[ "Development Status :: 4 - Beta", From b44f90bf7c3006fa564a4c03b9a7ca6cb430d98a Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Tue, 3 Jan 2023 01:02:53 +0100 Subject: [PATCH 04/23] fix: version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b808b57..50b0e52 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.97", + version="0.0.1+nightly", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", From 0539e965fa3c3f1f748301d764b88157146c7c14 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Thu, 5 Jan 2023 11:37:31 +0100 Subject: [PATCH 05/23] feat: add ar diffusion --- audio_diffusion_pytorch/diffusion.py | 169 ++++++++++++++++++++++++--- audio_diffusion_pytorch/models.py | 32 ++++- setup.py | 2 +- 3 files changed, 181 insertions(+), 22 deletions(-) diff --git a/audio_diffusion_pytorch/diffusion.py b/audio_diffusion_pytorch/diffusion.py index b11ddbe..ea9154f 100644 --- a/audio_diffusion_pytorch/diffusion.py +++ b/audio_diffusion_pytorch/diffusion.py @@ -1,5 +1,5 @@ from math import pi -from typing import Any, List, Tuple, Type +from typing import Any, Optional, Tuple import torch import torch.nn as nn @@ -47,11 +47,15 @@ def clip(x: Tensor, dynamic_threshold: float = 0.0): return x +def extend_dim(x: Tensor, dim: int): + # e.g. if dim = 4: shape [b] => [b, 1, 1, 1], + return x.view(*x.shape + (1,) * (dim - x.ndim)) + + class Diffusion(nn.Module): """Interface used by different diffusion methods""" - def forward(self, *args, **kwargs) -> Tensor: - raise NotImplementedError("Diffusion class missing forward function") + pass class VDiffusion(Diffusion): @@ -71,7 +75,7 @@ def forward(self, x: Tensor, **kwargs) -> Tensor: # type: ignore batch_size, device = x.shape[0], x.device # Sample amount of noise to add for each batch element sigmas = self.sigma_distribution(num_samples=batch_size, device=device) - sigmas_batch = rearrange(sigmas, "b -> b 1 1") + sigmas_batch = extend_dim(sigmas, dim=x.ndim) # Get noise noise = torch.randn_like(x) # Combine input and noise weighted by half-circle @@ -83,6 +87,40 @@ def forward(self, x: Tensor, **kwargs) -> Tensor: # type: ignore return F.mse_loss(v_pred, v_target) +class ARVDiffusion(Diffusion): + def __init__(self, net: nn.Module, length: int, num_splits: int): + super().__init__() + assert length % num_splits == 0, "length must be divisible by num_splits" + self.net = net + self.length = length + self.num_splits = num_splits + self.split_length = length // num_splits + + def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: + angle = sigmas * pi / 2 + alpha, beta = torch.cos(angle), torch.sin(angle) + return alpha, beta + + def forward(self, x: Tensor, **kwargs) -> Tensor: + """Returns diffusion loss of v-objective with different noises per split""" + b, _, t, device, dtype = *x.shape, x.device, x.dtype + assert t == self.length, "input length must match length" + # Sample amount of noise to add for each split + sigmas = torch.rand((b, 1, self.num_splits), device=device, dtype=dtype) + sigmas = repeat(sigmas, "b 1 n -> b 1 (n l)", l=self.split_length) + # Get noise + noise = torch.randn_like(x) + # Combine input and noise weighted by half-circle + alphas, betas = self.get_alpha_beta(sigmas) + x_noisy = alphas * x + betas * noise + v_target = alphas * noise - betas * x + # Sigmas will be provided as additional channel + channels = torch.cat([x_noisy, sigmas], dim=1) + # Predict velocity and return loss + v_pred = self.net(channels, **kwargs) + return F.mse_loss(v_pred, v_target) + + """ Schedules """ @@ -102,12 +140,7 @@ def forward(self, num_steps: int, device: Any) -> Tensor: class Sampler(nn.Module): - """Interface used by different samplers""" - - diffusion_types: List[Type] = [] - - def forward(*args, **kwargs) -> Tensor: - raise NotImplementedError() + pass class VSampler(Sampler): @@ -119,23 +152,18 @@ def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()): self.net = net self.schedule = schedule - @property - def device(self): - return next(self.net.parameters()).device - def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: angle = sigmas * pi / 2 - alpha = torch.cos(angle) - beta = torch.sin(angle) + alpha, beta = torch.cos(angle), torch.sin(angle) return alpha, beta def forward( # type: ignore self, noise: Tensor, num_steps: int, show_progress: bool = False, **kwargs ) -> Tensor: b = noise.shape[0] - sigmas = self.schedule(num_steps + 1, device=self.device) + sigmas = self.schedule(num_steps + 1, device=noise.device) sigmas = repeat(sigmas, "i -> i b", b=b) - sigmas_batch = rearrange(sigmas, "i b -> i b 1 1") + sigmas_batch = extend_dim(sigmas, dim=noise.ndim + 1) alphas, betas = self.get_alpha_beta(sigmas_batch) x_noisy = noise * sigmas_batch[0] progress_bar = tqdm(range(num_steps), disable=not show_progress) @@ -148,3 +176,108 @@ def forward( # type: ignore progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})") return x_noisy + + +class ARVSampler(Sampler): + def __init__(self, net: nn.Module, in_channels: int, length: int, num_splits: int): + super().__init__() + assert length % num_splits == 0, "length must be divisible by num_splits" + self.length = length + self.in_channels = in_channels + self.num_splits = num_splits + self.split_length = length // num_splits + self.net = net + + @property + def device(self): + return next(self.net.parameters()).device + + def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: + angle = sigmas * pi / 2 + alpha = torch.cos(angle) + beta = torch.sin(angle) + return alpha, beta + + def get_sigmas_ladder(self, num_items: int, num_steps_per_split: int) -> Tensor: + b, n, l, i = num_items, self.num_splits, self.split_length, num_steps_per_split + n_half = n // 2 # Only half ladder, rest is zero, to leave some context + sigmas = torch.linspace(1, 0, i * n_half, device=self.device) + sigmas = repeat(sigmas, "(n i) -> i b 1 (n l)", b=b, l=l, n=n_half) + sigmas = torch.flip(sigmas, dims=[-1]) # Lowest noise level first + sigmas = F.pad(sigmas, pad=[0, 0, 0, 0, 0, 0, 0, 1]) # Add index i+1 + sigmas[-1, :, :, l:] = sigmas[0, :, :, :-l] # Loop back at index i+1 + return torch.cat([torch.zeros_like(sigmas), sigmas], dim=-1) + + def sample_loop( + self, current: Tensor, sigmas: Tensor, show_progress: bool = False, **kwargs + ) -> Tensor: + num_steps = sigmas.shape[0] - 1 + alphas, betas = self.get_alpha_beta(sigmas) + progress_bar = tqdm(range(num_steps), disable=not show_progress) + + for i in progress_bar: + channels = torch.cat([current, sigmas[i]], dim=1) + v_pred = self.net(channels, **kwargs) + x_pred = alphas[i] * current - betas[i] * v_pred + noise_pred = betas[i] * current + alphas[i] * v_pred + current = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred + progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0,0,0]:.2f})") + + return current + + def sample_start(self, num_items: int, num_steps: int, **kwargs) -> Tensor: + b, c, t = num_items, self.in_channels, self.length + # Same sigma schedule over all chunks + sigmas = torch.linspace(1, 0, num_steps + 1, device=self.device) + sigmas = repeat(sigmas, "i -> i b 1 t", b=b, t=t) + noise = torch.randn((b, c, t), device=self.device) * sigmas[0] + # Sample start + return self.sample_loop(current=noise, sigmas=sigmas, **kwargs) + + def forward( + self, + num_items: int, + num_chunks: int, + num_steps: int, + start: Optional[Tensor] = None, + show_progress: bool = False, + **kwargs, + ) -> Tensor: + assert_message = f"required at least {self.num_splits} chunks" + assert num_chunks >= self.num_splits, assert_message + + # Sample initial chunks + start = self.sample_start(num_items=num_items, num_steps=num_steps, **kwargs) + # Return start if only num_splits chunks + if num_chunks == self.num_splits: + return start + + # Get sigmas for autoregressive ladder + b, n = num_items, self.num_splits + assert num_steps >= n, "num_steps must be greater than num_splits" + sigmas = self.get_sigmas_ladder( + num_items=b, + num_steps_per_split=num_steps // self.num_splits, + ) + alphas, betas = self.get_alpha_beta(sigmas) + + # Noise start to match ladder and set starting chunks + start_noise = alphas[0] * start + betas[0] * torch.randn_like(start) + chunks = list(start_noise.chunk(chunks=n, dim=-1)) + + # Loop over ladder shifts + num_shifts = num_chunks # - self.num_splits + progress_bar = tqdm(range(num_shifts), disable=not show_progress) + + for j in progress_bar: + # Decrease ladder noise of last n chunks + updated = self.sample_loop( + current=torch.cat(chunks[-n:], dim=-1), sigmas=sigmas, **kwargs + ) + # Update chunks + chunks[-n:] = list(updated.chunk(chunks=n, dim=-1)) + # Add fresh noise chunk + shape = (b, self.in_channels, self.split_length) + chunks += [torch.randn(shape, device=self.device)] + + return torch.cat(chunks[:num_chunks], dim=-1) diff --git a/audio_diffusion_pytorch/models.py b/audio_diffusion_pytorch/models.py index ead9db7..60a2c6e 100644 --- a/audio_diffusion_pytorch/models.py +++ b/audio_diffusion_pytorch/models.py @@ -4,7 +4,7 @@ from audio_encoders_pytorch import Encoder1d from torch import Tensor, nn -from .diffusion import VDiffusion, VSampler +from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler from .unets import UNetV0 from .utils import closest_power_2, groupby @@ -15,7 +15,7 @@ def __init__( net_t: Callable = UNetV0, diffusion_t: Callable = VDiffusion, sampler_t: Callable = VSampler, - **kwargs + **kwargs, ): super().__init__() diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) @@ -41,7 +41,7 @@ def __init__( channels: Sequence[int], encoder: Encoder1d, inject_depth: int, - **kwargs + **kwargs, ): context_channels = [0] * len(channels) context_channels[inject_depth] = encoder.out_channels @@ -76,3 +76,29 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor: default_kwargs = dict(channels=channels) # Decode by sampling while conditioning on latent channels return super().sample(noise, **{**default_kwargs, **kwargs}) + + +class DiffusionAR(DiffusionModel): + def __init__( + self, + in_channels: int, + length: int, + num_splits: int, + diffusion_t: Callable = ARVDiffusion, + sampler_t: Callable = ARVSampler, + **kwargs, + ): + super().__init__( + in_channels=in_channels + 1, + out_channels=in_channels, + diffusion_t=diffusion_t, + diffusion_length=length, + diffusion_num_splits=num_splits, + sampler_t=sampler_t, + sampler_in_channels=in_channels, + sampler_length=length, + sampler_num_splits=num_splits, + use_time_conditioning=False, + use_modulation=False, + **kwargs, + ) diff --git a/setup.py b/setup.py index 50b0e52..2827a22 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.1+nightly", + version="0.0.2+nightly", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", From 663b57d43555307941353d58c50916780e6d8e5a Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Tue, 10 Jan 2023 23:30:24 +0100 Subject: [PATCH 06/23] feat: add diffusion upsampler --- audio_diffusion_pytorch/__init__.py | 4 ++-- audio_diffusion_pytorch/models.py | 35 +++++++++++++++++++++++++++-- audio_diffusion_pytorch/unets.py | 22 ++++++++++++++++++ setup.py | 2 +- 4 files changed, 58 insertions(+), 5 deletions(-) diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index f2bfe57..a4e897b 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -10,5 +10,5 @@ VDiffusion, VSampler, ) -from .models import DiffusionAE, DiffusionModel -from .unets import XUNet +from .models import DiffusionAE, DiffusionAR, DiffusionModel, DiffusionUpsampler +from .unets import LTPlugin, UNetV0, XUNet diff --git a/audio_diffusion_pytorch/models.py b/audio_diffusion_pytorch/models.py index 60a2c6e..5959b10 100644 --- a/audio_diffusion_pytorch/models.py +++ b/audio_diffusion_pytorch/models.py @@ -5,8 +5,8 @@ from torch import Tensor, nn from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler -from .unets import UNetV0 -from .utils import closest_power_2, groupby +from .unets import AppendChannelsPlugin, UNetV0 +from .utils import closest_power_2, downsample, groupby, upsample class DiffusionModel(nn.Module): @@ -78,6 +78,37 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor: return super().sample(noise, **{**default_kwargs, **kwargs}) +class DiffusionUpsampler(DiffusionModel): + def __init__( + self, + in_channels: int, + upsample_factor: int, + net_t: Callable = UNetV0, + **kwargs, + ): + self.upsample_factor = upsample_factor + super().__init__( + net_t=AppendChannelsPlugin(net_t, channels=in_channels), + in_channels=in_channels, + **kwargs, + ) + + def reupsample(self, x: Tensor) -> Tensor: + x = x.clone() + x = downsample(x, factor=self.upsample_factor) + x = upsample(x, factor=self.upsample_factor) + return x + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # type: ignore + reupsampled = self.reupsample(x) + return super().forward(x, *args, append_channels=reupsampled, **kwargs) + + def sample(self, downsampled: Tensor, *args, **kwargs) -> Tensor: # type: ignore + reupsampled = upsample(downsampled, factor=self.upsample_factor) + noise = torch.randn_like(reupsampled) + return super().sample(noise, *args, append_channels=reupsampled, **kwargs) + + class DiffusionAR(DiffusionModel): def __init__( self, diff --git a/audio_diffusion_pytorch/unets.py b/audio_diffusion_pytorch/unets.py index 37680c1..2307c85 100644 --- a/audio_diffusion_pytorch/unets.py +++ b/audio_diffusion_pytorch/unets.py @@ -1,5 +1,6 @@ from typing import Callable, Optional, Sequence +import torch from a_unet import ( ClassifierFreeGuidancePlugin, Conv, @@ -153,3 +154,24 @@ def forward(x: Tensor, *args, **kwargs): return Module([encode, decode, net], forward) return Net + + +def AppendChannelsPlugin( + net_t: Callable, + channels: int, +): + def Net( + in_channels: int, out_channels: Optional[int] = None, **kwargs + ) -> nn.Module: + out_channels = default(out_channels, in_channels) + net = net_t( # type: ignore + in_channels=in_channels + channels, out_channels=out_channels, **kwargs + ) + + def forward(x: Tensor, *args, append_channels: Tensor, **kwargs): + x = torch.cat([x, append_channels], dim=1) + return net(x, *args, **kwargs) + + return Module([net], forward) + + return Net diff --git a/setup.py b/setup.py index 2827a22..ccda867 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.2+nightly", + version="0.0.3+nightly", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", From 637f22f3efc453a0daf81bdcf16cf9dc773871ed Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Wed, 11 Jan 2023 15:05:18 +0100 Subject: [PATCH 07/23] feat: add generators, mel, refactor unets->components --- audio_diffusion_pytorch/__init__.py | 2 +- .../{unets.py => components.py} | 59 +++++++++++++++++++ audio_diffusion_pytorch/models.py | 30 ++++++---- audio_diffusion_pytorch/utils.py | 10 +++- setup.py | 5 +- 5 files changed, 92 insertions(+), 14 deletions(-) rename audio_diffusion_pytorch/{unets.py => components.py} (74%) diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index a4e897b..272377b 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -1,5 +1,6 @@ from audio_encoders_pytorch import Encoder1d, ME1d +from .components import LTPlugin, MelSpectrogram, UNetV0, XUNet from .diffusion import ( Diffusion, Distribution, @@ -11,4 +12,3 @@ VSampler, ) from .models import DiffusionAE, DiffusionAR, DiffusionModel, DiffusionUpsampler -from .unets import LTPlugin, UNetV0, XUNet diff --git a/audio_diffusion_pytorch/unets.py b/audio_diffusion_pytorch/components.py similarity index 74% rename from audio_diffusion_pytorch/unets.py rename to audio_diffusion_pytorch/components.py index 2307c85..bd9dc40 100644 --- a/audio_diffusion_pytorch/unets.py +++ b/audio_diffusion_pytorch/components.py @@ -1,6 +1,7 @@ from typing import Callable, Optional, Sequence import torch +import torch.nn.functional as F from a_unet import ( ClassifierFreeGuidancePlugin, Conv, @@ -21,7 +22,9 @@ XBlock, XUNet, ) +from einops import pack, unpack from torch import Tensor, nn +from torchaudio import transforms """ UNets (built with a-unet: https://github.com/archinetai/a-unet) @@ -175,3 +178,59 @@ def forward(x: Tensor, *args, append_channels: Tensor, **kwargs): return Module([net], forward) return Net + + +""" +Other +""" + + +class MelSpectrogram(nn.Module): + def __init__( + self, + n_fft: int, + hop_length: int, + win_length: int, + sample_rate: int, + n_mel_channels: int, + center: bool = False, + normalize: bool = False, + normalize_log: bool = False, + ): + super().__init__() + self.padding = (n_fft - hop_length) // 2 + self.normalize = normalize + self.normalize_log = normalize_log + self.hop_length = hop_length + + self.to_spectrogram = transforms.Spectrogram( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + center=center, + power=None, + ) + + self.to_mel_scale = transforms.MelScale( + n_mels=n_mel_channels, n_stft=n_fft // 2 + 1, sample_rate=sample_rate + ) + + def forward(self, waveform: Tensor) -> Tensor: + # Pack non-time dimension + waveform, ps = pack([waveform], "* t") + # Pad waveform + waveform = F.pad(waveform, [self.padding] * 2, mode="reflect") + # Compute STFT + spectrogram = self.to_spectrogram(waveform) + # Compute magnitude + spectrogram = torch.abs(spectrogram) + # Convert to mel scale + mel_spectrogram = self.to_mel_scale(spectrogram) + # Normalize + if self.normalize: + mel_spectrogram = mel_spectrogram / torch.max(mel_spectrogram) + mel_spectrogram = 2 * torch.pow(mel_spectrogram, 0.25) - 1 + if self.normalize_log: + mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5)) + # Unpack non-spectrogram dimension + return unpack(mel_spectrogram, ps, "* f l")[0] diff --git a/audio_diffusion_pytorch/models.py b/audio_diffusion_pytorch/models.py index 5959b10..f4417af 100644 --- a/audio_diffusion_pytorch/models.py +++ b/audio_diffusion_pytorch/models.py @@ -1,17 +1,18 @@ -from typing import Any, Callable, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union import torch from audio_encoders_pytorch import Encoder1d -from torch import Tensor, nn +from torch import Generator, Tensor, nn +from .components import AppendChannelsPlugin, UNetV0 from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler -from .unets import AppendChannelsPlugin, UNetV0 -from .utils import closest_power_2, downsample, groupby, upsample +from .utils import closest_power_2, downsample, groupby, randn_like, upsample class DiffusionModel(nn.Module): def __init__( self, + dim: int = 1, net_t: Callable = UNetV0, diffusion_t: Callable = VDiffusion, sampler_t: Callable = VSampler, @@ -21,7 +22,7 @@ def __init__( diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) sampler_kwargs, kwargs = groupby("sampler_", kwargs) - self.net = net_t(**kwargs) + self.net = net_t(dim=dim, **kwargs) self.diffusion = diffusion_t(net=self.net, **diffusion_kwargs) self.sampler = sampler_t(net=self.net, **sampler_kwargs) @@ -66,11 +67,18 @@ def forward( # type: ignore def encode(self, *args, **kwargs): return self.encoder(*args, **kwargs) - def decode(self, latent: Tensor, **kwargs) -> Tensor: + def decode( + self, latent: Tensor, generator: Optional[Generator] = None, **kwargs + ) -> Tensor: b = latent.shape[0] length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) # Compute noise by inferring shape from latent length - noise = torch.randn(b, self.in_channels, length, device=latent.device) + noise = torch.randn( + (b, self.in_channels, length), + device=latent.device, + dtype=latent.dtype, + generator=generator, + ) # Compute context from latent channels = [None] * self.inject_depth + [latent] # type: ignore default_kwargs = dict(channels=channels) @@ -103,10 +111,12 @@ def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # type: ignore reupsampled = self.reupsample(x) return super().forward(x, *args, append_channels=reupsampled, **kwargs) - def sample(self, downsampled: Tensor, *args, **kwargs) -> Tensor: # type: ignore + def sample( # type: ignore + self, downsampled: Tensor, generator: Optional[Generator] = None, **kwargs + ) -> Tensor: reupsampled = upsample(downsampled, factor=self.upsample_factor) - noise = torch.randn_like(reupsampled) - return super().sample(noise, *args, append_channels=reupsampled, **kwargs) + noise = randn_like(reupsampled, generator=generator) + return super().sample(noise, append_channels=reupsampled, **kwargs) class DiffusionAR(DiffusionModel): diff --git a/audio_diffusion_pytorch/utils.py b/audio_diffusion_pytorch/utils.py index 749015d..62aaa9b 100644 --- a/audio_diffusion_pytorch/utils.py +++ b/audio_diffusion_pytorch/utils.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F from einops import rearrange -from torch import Tensor +from torch import Generator, Tensor from typing_extensions import TypeGuard T = TypeVar("T") @@ -115,3 +115,11 @@ def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor: def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor: return resample(waveforms, factor_in=1, factor_out=factor, **kwargs) + + +""" Torch Utils """ + + +def randn_like(tensor: Tensor, *args, generator: Optional[Generator] = None, **kwargs): + """randn_like that supports generator""" + return torch.randn(tensor.shape, *args, generator=generator, **kwargs).to(tensor) diff --git a/setup.py b/setup.py index ccda867..8f7c2ad 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.3+nightly", + version="0.0.4+nightly", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", @@ -14,8 +14,9 @@ install_requires=[ "tqdm", "torch>=1.6", + "torchaudio", "data-science-types>=0.2", - "einops>=0.4", + "einops>=0.6", "einops-exts>=0.0.3", "audio-encoders-pytorch", "a-unet", From da43e54348cd6281cfcbfeda6f01efab9c1701f2 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 13 Jan 2023 12:33:40 +0100 Subject: [PATCH 08/23] feat: add diffusion vocoder, add readme examples --- README.md | 76 +++++++++++++++++++++++++++++ audio_diffusion_pytorch/__init__.py | 8 ++- audio_diffusion_pytorch/models.py | 68 +++++++++++++++++++++++++- setup.py | 2 +- 4 files changed, 150 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 52d06ac..723a5ae 100644 --- a/README.md +++ b/README.md @@ -7,3 +7,79 @@ Nightly branch. ```bash pip install -U git+https://github.com/archinetai/audio-diffusion-pytorch.git@nightly ``` + +[![PyPI - Python Version](https://img.shields.io/pypi/v/audio-diffusion-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/audio-diffusion-pytorch/) +[![Downloads](https://static.pepy.tech/personalized-badge/audio-diffusion-pytorch?period=total&units=international_system&left_color=black&right_color=black&left_text=Downloads)](https://pepy.tech/project/audio-diffusion-pytorch) + + +## Usage + +### Unconditional Generation + +```py +from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler + +model = DiffusionModel( + net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case) + in_channels=2, # U-Net: number of input/output (audio) channels + channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer + factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer + items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer + attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer + attention_heads=8, # U-Net: number of attention heads per attention item + attention_features=64, # U-Net: number of attention features per attention item + diffusion_t=VDiffusion, # The diffusion method used + sampler_t=VSampler, # The diffusion sampler used +) + +# Train model with audio waveforms +audio = torch.randn(1, 2, 2**18) # [batch_size, in_channels, length] +loss = model(audio) +loss.backward() + +# Turn noise into new audio sample with diffusion +noise = torch.randn(1, 2, 2**18) # [batch_size, in_channels, length] +sample = model.sample(noise, num_steps=10) # Suggested num_steps 10-50 +``` + +### Text-Conditional Generation + +```py +from audio_diffusion_pytorch.models import DiffusionModel, UNetV0, VDiffusion, VSampler + +model = DiffusionModel( + net_t=UNetV0, # The model type used for diffusion + in_channels=2, # U-Net: number of input/output (audio) channels + channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer + factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer + items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer + attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer + attention_heads=8, # U-Net: number of attention heads per attention block + attention_features=64, # U-Net: number of attention features per attention block, + diffusion_t=VDiffusion, # The diffusion method used + sampler_t=VSampler, # The diffusion sampler used + # Additional for text conditioning + use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base) + use_embedding_cfg=True # U-Net: enables classifier free guidance + embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base) + embedding_features=768, # U-Net: text mbedding features (default for T5-base) + cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer +) + +# Train model with audio +audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length] +loss = model( + audio_wave, + text=['The audio description'], # Text conditioning, one element per batch + embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask) +) +loss.backward() + +noise = torch.randn(1, 2, 2**18) +sample = model.sample( + noise, + text=['The audio description'], + embedding_scale=5.0, # Higher for more text importance, suggested range: 1-15 (Classifier-Free Guidance Scale) + num_steps=2 # Higher for better quality, suggested num_steps: 10-50 +) +``` diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index 272377b..a4a380d 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -11,4 +11,10 @@ VDiffusion, VSampler, ) -from .models import DiffusionAE, DiffusionAR, DiffusionModel, DiffusionUpsampler +from .models import ( + DiffusionAE, + DiffusionAR, + DiffusionModel, + DiffusionUpsampler, + DiffusionVocoder, +) diff --git a/audio_diffusion_pytorch/models.py b/audio_diffusion_pytorch/models.py index f4417af..77ce66a 100644 --- a/audio_diffusion_pytorch/models.py +++ b/audio_diffusion_pytorch/models.py @@ -1,12 +1,14 @@ +from math import floor from typing import Any, Callable, Optional, Sequence, Tuple, Union import torch from audio_encoders_pytorch import Encoder1d +from einops import pack, rearrange, unpack from torch import Generator, Tensor, nn -from .components import AppendChannelsPlugin, UNetV0 +from .components import AppendChannelsPlugin, MelSpectrogram, UNetV0 from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler -from .utils import closest_power_2, downsample, groupby, randn_like, upsample +from .utils import closest_power_2, default, downsample, groupby, randn_like, upsample class DiffusionModel(nn.Module): @@ -29,6 +31,7 @@ def __init__( def forward(self, *args, **kwargs) -> Tensor: return self.diffusion(*args, **kwargs) + @torch.no_grad() def sample(self, *args, **kwargs) -> Tensor: return self.sampler(*args, **kwargs) @@ -67,6 +70,7 @@ def forward( # type: ignore def encode(self, *args, **kwargs): return self.encoder(*args, **kwargs) + @torch.no_grad() def decode( self, latent: Tensor, generator: Optional[Generator] = None, **kwargs ) -> Tensor: @@ -111,6 +115,7 @@ def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # type: ignore reupsampled = self.reupsample(x) return super().forward(x, *args, append_channels=reupsampled, **kwargs) + @torch.no_grad() def sample( # type: ignore self, downsampled: Tensor, generator: Optional[Generator] = None, **kwargs ) -> Tensor: @@ -119,6 +124,65 @@ def sample( # type: ignore return super().sample(noise, append_channels=reupsampled, **kwargs) +class DiffusionVocoder(DiffusionModel): + def __init__( + self, + mel_channels: int, + mel_n_fft: int, + mel_hop_length: Optional[int] = None, + mel_win_length: Optional[int] = None, + in_channels: int = 1, # Ignored: channels are automatically batched. + net_t: Callable = UNetV0, + **kwargs, + ): + mel_hop_length = default(mel_hop_length, floor(mel_n_fft) // 4) + mel_win_length = default(mel_win_length, mel_n_fft) + mel_kwargs, kwargs = groupby("mel_", kwargs) + super().__init__( + net_t=AppendChannelsPlugin(net_t, channels=1), + in_channels=1, + **kwargs, + ) + self.to_spectrogram = MelSpectrogram( + n_fft=mel_n_fft, + hop_length=mel_hop_length, + win_length=mel_win_length, + n_mel_channels=mel_channels, + **mel_kwargs, + ) + self.to_flat = nn.ConvTranspose1d( + in_channels=mel_channels, + out_channels=1, + kernel_size=mel_win_length, + stride=mel_hop_length, + padding=(mel_win_length - mel_hop_length) // 2, + bias=False, + ) + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # type: ignore + # Get spectrogram, pack channels and flatten + spectrogram = rearrange(self.to_spectrogram(x), "b c f l -> (b c) f l") + spectrogram_flat = self.to_flat(spectrogram) + # Pack wave channels + x = rearrange(x, "b c t -> (b c) 1 t") + return super().forward(x, *args, append_channels=spectrogram_flat, **kwargs) + + @torch.no_grad() + def sample( # type: ignore + self, spectrogram: Tensor, generator: Optional[Generator] = None, **kwargs + ) -> Tensor: # type: ignore + # Pack channels and flatten spectrogram + spectrogram, ps = pack([spectrogram], "* f l") + spectrogram_flat = self.to_flat(spectrogram) + # Get start noise and sample + noise = randn_like(spectrogram_flat, generator=generator) + waveform = super().sample(noise, append_channels=spectrogram_flat, **kwargs) + # Unpack wave channels + waveform = rearrange(waveform, "... 1 t -> ... t") + waveform = unpack(waveform, ps, "* t")[0] + return waveform + + class DiffusionAR(DiffusionModel): def __init__( self, diff --git a/setup.py b/setup.py index 8f7c2ad..77dc7c2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.4+nightly", + version="0.0.5+nightly", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", From 0e1e237f3db88c8bd307ebe4b10d4e62de4331e3 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 13 Jan 2023 12:35:58 +0100 Subject: [PATCH 09/23] feat: shorten example --- README.md | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/README.md b/README.md index 723a5ae..4e35b23 100644 --- a/README.md +++ b/README.md @@ -48,17 +48,7 @@ sample = model.sample(noise, num_steps=10) # Suggested num_steps 10-50 from audio_diffusion_pytorch.models import DiffusionModel, UNetV0, VDiffusion, VSampler model = DiffusionModel( - net_t=UNetV0, # The model type used for diffusion - in_channels=2, # U-Net: number of input/output (audio) channels - channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer - factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer - items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer - attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer - attention_heads=8, # U-Net: number of attention heads per attention block - attention_features=64, # U-Net: number of attention features per attention block, - diffusion_t=VDiffusion, # The diffusion method used - sampler_t=VSampler, # The diffusion sampler used - # Additional for text conditioning + # ... same as unconditional model use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base) use_embedding_cfg=True # U-Net: enables classifier free guidance embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base) From 51362fe0bca0cabb742366e7fd7528b5a687affb Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 13 Jan 2023 13:04:44 +0100 Subject: [PATCH 10/23] feat: add readme intro, vocoder example --- README.md | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 4e35b23..07af7d4 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -Nightly branch. +A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models work on waveforms, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable. ## Install @@ -15,7 +15,6 @@ pip install -U git+https://github.com/archinetai/audio-diffusion-pytorch.git@nig ## Usage ### Unconditional Generation - ```py from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler @@ -43,7 +42,6 @@ sample = model.sample(noise, num_steps=10) # Suggested num_steps 10-50 ``` ### Text-Conditional Generation - ```py from audio_diffusion_pytorch.models import DiffusionModel, UNetV0, VDiffusion, VSampler @@ -56,7 +54,7 @@ model = DiffusionModel( cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer ) -# Train model with audio +# Train model with audio waveforms audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length] loss = model( audio_wave, @@ -65,6 +63,7 @@ loss = model( ) loss.backward() +# Turn noise into new audio sample with diffusion noise = torch.randn(1, 2, 2**18) sample = model.sample( noise, @@ -73,3 +72,55 @@ sample = model.sample( num_steps=2 # Higher for better quality, suggested num_steps: 10-50 ) ``` + +### Upsampling +```py +from audio_diffusion_pytorch.models import DiffusionUpsampler, UNetV0, VDiffusion, VSampler + +upsampler = DiffusionUpsampler( + net_t=UNetV0, # The model type used for diffusion + upsample_factor=16, # The upsample factor (e.g. 16 can be used for 3kHz to 48kHz) + in_channels=2, # U-Net: number of input/output (audio) channels + channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer + factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer + items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer + diffusion_t=VDiffusion, # The diffusion method used + sampler_t=VSampler, # The diffusion sampler used +) + +# Train model with high sample rate audio waveforms +audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length] +loss = upsampler(audio) +loss.backward() + +# Turn low sample rate audio into high sample rate +downsampled_audio = torch.randn(1, 2, 2**14) # [batch, in_channels, length] +sample = upsampler.sample(downsampled_audio, num_steps=10) # Output has shape: [1, 2, 2**18] +``` + +### Vocoding +```py +from audio_diffusion_pytorch.models import DiffusionVocoder, UNetV0, VDiffusion, VSampler + +vocoder = DiffusionVocoder( + mel_n_fft=1024, # Mel-spectrogram n_fft + mel_channels=80, # Mel-spectrogram channels + mel_sample_rate=48000, # Mel-spectrogram sample rate + mel_normalize_log=True, # Mel-spectrogram log normalization (alternative is mel_normalize=True for [-1,1] power normalization) + net_t=UNetV0, # The model type used for diffusion vocoding + channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer + factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer + items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer + diffusion_t=VDiffusion, # The diffusion method used + sampler_t=VSampler, # The diffusion sampler used +) + +# Train model on waveforms (automatically converted to mel internally) +audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length] +loss = vocoder(audio) +loss.backward() + +# Turn mel spectrogram into waveform +mel_spectrogram = torch.randn(1, 2, 80, 1024) # [batch, in_channels, mel_channels, mel_length] +sample = vocoder.sample(mel_spectrogram, num_steps=10) # Output has shape: [1, 2, 2**18] +``` From eb7711e7371cf41319ea63a3003717a34f7194ef Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 13 Jan 2023 15:00:20 +0100 Subject: [PATCH 11/23] feat: remove encoder from requirement, add readme diffae, improve readme --- README.md | 113 +++++++++++++++++++++++++--- audio_diffusion_pytorch/__init__.py | 2 - audio_diffusion_pytorch/models.py | 20 +++-- setup.py | 3 +- 4 files changed, 113 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 07af7d4..64573c7 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models work on waveforms, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable. +A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models are waveform-based, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable to work on other formats. ## Install @@ -14,7 +14,8 @@ pip install -U git+https://github.com/archinetai/audio-diffusion-pytorch.git@nig ## Usage -### Unconditional Generation +### Unconditional Generator + ```py from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler @@ -38,12 +39,13 @@ loss.backward() # Turn noise into new audio sample with diffusion noise = torch.randn(1, 2, 2**18) # [batch_size, in_channels, length] -sample = model.sample(noise, num_steps=10) # Suggested num_steps 10-50 +sample = model.sample(noise, num_steps=10) # Suggested num_steps 10-100 ``` -### Text-Conditional Generation +### Text-Conditional Generator +A text-to-audio diffusion model that conditions the generation with `t5-base` text embeddings, requires `pip install transformers`. ```py -from audio_diffusion_pytorch.models import DiffusionModel, UNetV0, VDiffusion, VSampler +from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler model = DiffusionModel( # ... same as unconditional model @@ -69,13 +71,14 @@ sample = model.sample( noise, text=['The audio description'], embedding_scale=5.0, # Higher for more text importance, suggested range: 1-15 (Classifier-Free Guidance Scale) - num_steps=2 # Higher for better quality, suggested num_steps: 10-50 + num_steps=2 # Higher for better quality, suggested num_steps: 10-100 ) ``` -### Upsampling +### Diffusion Upsampler +Upsample audio from a lower sample rate to higher sample rate using diffusion, e.g. 3kHz to 48kHz: ```py -from audio_diffusion_pytorch.models import DiffusionUpsampler, UNetV0, VDiffusion, VSampler +from audio_diffusion_pytorch import DiffusionUpsampler, UNetV0, VDiffusion, VSampler upsampler = DiffusionUpsampler( net_t=UNetV0, # The model type used for diffusion @@ -98,9 +101,10 @@ downsampled_audio = torch.randn(1, 2, 2**14) # [batch, in_channels, length] sample = upsampler.sample(downsampled_audio, num_steps=10) # Output has shape: [1, 2, 2**18] ``` -### Vocoding +### Diffusion Vocoder +Convert a mel-spectrogram to wavefrom using diffusion: ```py -from audio_diffusion_pytorch.models import DiffusionVocoder, UNetV0, VDiffusion, VSampler +from audio_diffusion_pytorch import DiffusionVocoder, UNetV0, VDiffusion, VSampler vocoder = DiffusionVocoder( mel_n_fft=1024, # Mel-spectrogram n_fft @@ -124,3 +128,92 @@ loss.backward() mel_spectrogram = torch.randn(1, 2, 80, 1024) # [batch, in_channels, mel_channels, mel_length] sample = vocoder.sample(mel_spectrogram, num_steps=10) # Output has shape: [1, 2, 2**18] ``` + +## Diffusion Autoencoder +Autoencode audio into a compressed latent using diffusion. Any encoder can be provided as long as it has `out_channels` and `downsample_factor` attributes that can be used to infer the original audio length from the latent. +```py +from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler +from audio_encoders_pytorch import MelE1d, TanhBottleneck + +autoencoder = DiffusionAE( + encoder=MelE1d( # The encoder used, in this case a mel-spectrogram encoder + in_channels=2, + channels=512, + multipliers=[1, 1], + factors=[2], + num_blocks=[12], + out_channels=32, + mel_channels=80, + mel_sample_rate=48000, + mel_normalize_log=True, + bottleneck=TanhBottleneck(), + ), + inject_depth=6, + net_t=UNetV0, # The model type used for diffusion upsampling + in_channels=2, # U-Net: number of input/output (audio) channels + channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer + factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer + items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer + diffusion_t=VDiffusion, # The diffusion method used + sampler_t=VSampler, # The diffusion sampler used +) + +# Train autoencoder with audio samples +audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length] +loss = autoencoder(audio) +loss.backward() + +# Encode/decode audio +audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length] +latent = autoencoder.encode(audio) # Encode +sample = autoencoder.decode(latent, num_steps=10) # Decode by sampling diffusion model conditioning on latent +``` + +## Appreciation + +* [StabilityAI](https://stability.ai/) for the compute, [Zach Evans](https://github.com/zqevans) and everyone else from [HarmonAI](https://www.harmonai.org/) for the interesting research discussions. +* [ETH Zurich](https://inf.ethz.ch/) for the resources, [Zhijing Jin](https://zhijing-jin.com/), [Bernhard Schoelkopf](https://is.mpg.de/~bs), and [Mrinmaya Sachan](http://www.mrinmaya.io/) for supervising this Thesis. +* [Phil Wang](https://github.com/lucidrains) for the beautiful open source contributions on [diffusion](https://github.com/lucidrains/denoising-diffusion-pytorch) and [Imagen](https://github.com/lucidrains/imagen-pytorch). +* [Katherine Crowson](https://github.com/crowsonkb) for the experiments with [k-diffusion](https://github.com/crowsonkb/k-diffusion) and the insane collection of samplers. + +## Citations + +DDPM Diffusion +```bibtex +@misc{2006.11239, +Author = {Jonathan Ho and Ajay Jain and Pieter Abbeel}, +Title = {Denoising Diffusion Probabilistic Models}, +Year = {2020}, +Eprint = {arXiv:2006.11239}, +} +``` + +DDIM (V-Sampler) +```bibtex +@misc{2010.02502, +Author = {Jiaming Song and Chenlin Meng and Stefano Ermon}, +Title = {Denoising Diffusion Implicit Models}, +Year = {2020}, +Eprint = {arXiv:2010.02502}, +} +``` + +V-Diffusion +```bibtex +@misc{2202.00512, +Author = {Tim Salimans and Jonathan Ho}, +Title = {Progressive Distillation for Fast Sampling of Diffusion Models}, +Year = {2022}, +Eprint = {arXiv:2202.00512}, +} +``` + +Imagen (T5 Text Conditioning) +```bibtex +@misc{2205.11487, +Author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and S. Sara Mahdavi and Rapha Gontijo Lopes and Tim Salimans and Jonathan Ho and David J Fleet and Mohammad Norouzi}, +Title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding}, +Year = {2022}, +Eprint = {arXiv:2205.11487}, +} +``` diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index a4a380d..932a44c 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -1,5 +1,3 @@ -from audio_encoders_pytorch import Encoder1d, ME1d - from .components import LTPlugin, MelSpectrogram, UNetV0, XUNet from .diffusion import ( Diffusion, diff --git a/audio_diffusion_pytorch/models.py b/audio_diffusion_pytorch/models.py index 77ce66a..42d896e 100644 --- a/audio_diffusion_pytorch/models.py +++ b/audio_diffusion_pytorch/models.py @@ -2,11 +2,10 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union import torch -from audio_encoders_pytorch import Encoder1d from einops import pack, rearrange, unpack from torch import Generator, Tensor, nn -from .components import AppendChannelsPlugin, MelSpectrogram, UNetV0 +from .components import AppendChannelsPlugin, MelSpectrogram from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler from .utils import closest_power_2, default, downsample, groupby, randn_like, upsample @@ -14,10 +13,10 @@ class DiffusionModel(nn.Module): def __init__( self, - dim: int = 1, - net_t: Callable = UNetV0, + net_t: Callable, diffusion_t: Callable = VDiffusion, sampler_t: Callable = VSampler, + dim: int = 1, **kwargs, ): super().__init__() @@ -43,12 +42,12 @@ def __init__( self, in_channels: int, channels: Sequence[int], - encoder: Encoder1d, + encoder: nn.Module, inject_depth: int, **kwargs, ): context_channels = [0] * len(channels) - context_channels[inject_depth] = encoder.out_channels + context_channels[inject_depth] = encoder.out_channels # type: ignore super().__init__( in_channels=in_channels, channels=channels, @@ -75,7 +74,7 @@ def decode( self, latent: Tensor, generator: Optional[Generator] = None, **kwargs ) -> Tensor: b = latent.shape[0] - length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) + length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) # type: ignore # noqa # Compute noise by inferring shape from latent length noise = torch.randn( (b, self.in_channels, length), @@ -85,9 +84,8 @@ def decode( ) # Compute context from latent channels = [None] * self.inject_depth + [latent] # type: ignore - default_kwargs = dict(channels=channels) # Decode by sampling while conditioning on latent channels - return super().sample(noise, **{**default_kwargs, **kwargs}) + return super().sample(noise, channels=channels, **kwargs) class DiffusionUpsampler(DiffusionModel): @@ -95,7 +93,7 @@ def __init__( self, in_channels: int, upsample_factor: int, - net_t: Callable = UNetV0, + net_t: Callable, **kwargs, ): self.upsample_factor = upsample_factor @@ -127,12 +125,12 @@ def sample( # type: ignore class DiffusionVocoder(DiffusionModel): def __init__( self, + net_t: Callable, mel_channels: int, mel_n_fft: int, mel_hop_length: Optional[int] = None, mel_win_length: Optional[int] = None, in_channels: int = 1, # Ignored: channels are automatically batched. - net_t: Callable = UNetV0, **kwargs, ): mel_hop_length = default(mel_hop_length, floor(mel_n_fft) // 4) diff --git a/setup.py b/setup.py index 77dc7c2..a932a2c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.5+nightly", + version="0.0.6+nightly", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", @@ -18,7 +18,6 @@ "data-science-types>=0.2", "einops>=0.6", "einops-exts>=0.0.3", - "audio-encoders-pytorch", "a-unet", ], classifiers=[ From 7819a152b9715c482ef56ec9511ad0de2c53426f Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 13 Jan 2023 15:08:53 +0100 Subject: [PATCH 12/23] feat: provide abstract class for encoder --- README.md | 2 +- audio_diffusion_pytorch/models.py | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 64573c7..40d8a87 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ sample = vocoder.sample(mel_spectrogram, num_steps=10) # Output has shape: [1, 2 ``` ## Diffusion Autoencoder -Autoencode audio into a compressed latent using diffusion. Any encoder can be provided as long as it has `out_channels` and `downsample_factor` attributes that can be used to infer the original audio length from the latent. +Autoencode audio into a compressed latent using diffusion. Any encoder can be provided as long as it subclasses the `EncoderBase` class or contains an `out_channels` and `downsample_factor` field. ```py from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler from audio_encoders_pytorch import MelE1d, TanhBottleneck diff --git a/audio_diffusion_pytorch/models.py b/audio_diffusion_pytorch/models.py index 42d896e..6247187 100644 --- a/audio_diffusion_pytorch/models.py +++ b/audio_diffusion_pytorch/models.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from math import floor from typing import Any, Callable, Optional, Sequence, Tuple, Union @@ -35,6 +36,16 @@ def sample(self, *args, **kwargs) -> Tensor: return self.sampler(*args, **kwargs) +class EncoderBase(nn.Module, ABC): + """Abstract class for DiffusionAE encoder""" + + @abstractmethod + def __init__(self): + super().__init__() + self.out_channels = None + self.downsample_factor = None + + class DiffusionAE(DiffusionModel): """Diffusion Auto Encoder""" @@ -42,12 +53,12 @@ def __init__( self, in_channels: int, channels: Sequence[int], - encoder: nn.Module, + encoder: EncoderBase, inject_depth: int, **kwargs, ): context_channels = [0] * len(channels) - context_channels[inject_depth] = encoder.out_channels # type: ignore + context_channels[inject_depth] = encoder.out_channels super().__init__( in_channels=in_channels, channels=channels, @@ -74,7 +85,7 @@ def decode( self, latent: Tensor, generator: Optional[Generator] = None, **kwargs ) -> Tensor: b = latent.shape[0] - length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) # type: ignore # noqa + length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) # Compute noise by inferring shape from latent length noise = torch.randn( (b, self.in_channels, length), From 7f3a73c7314e25f9d902f1da3cfc3cdcfab7db3b Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 13 Jan 2023 15:27:49 +0100 Subject: [PATCH 13/23] fix: setup and readme --- README.md | 4 ++-- setup.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 40d8a87..89631ad 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ sample = model.sample( ``` ### Diffusion Upsampler -Upsample audio from a lower sample rate to higher sample rate using diffusion, e.g. 3kHz to 48kHz: +Upsample audio from a lower sample rate to higher sample rate using diffusion, e.g. 3kHz to 48kHz. ```py from audio_diffusion_pytorch import DiffusionUpsampler, UNetV0, VDiffusion, VSampler @@ -102,7 +102,7 @@ sample = upsampler.sample(downsampled_audio, num_steps=10) # Output has shape: [ ``` ### Diffusion Vocoder -Convert a mel-spectrogram to wavefrom using diffusion: +Convert a mel-spectrogram to wavefrom using diffusion. ```py from audio_diffusion_pytorch import DiffusionVocoder, UNetV0, VDiffusion, VSampler diff --git a/setup.py b/setup.py index a932a2c..93a79fc 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,6 @@ "torchaudio", "data-science-types>=0.2", "einops>=0.6", - "einops-exts>=0.0.3", "a-unet", ], classifiers=[ From 5d398952bb3d6ffb2f1e1ed9401f28699e70e239 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 13 Jan 2023 15:30:06 +0100 Subject: [PATCH 14/23] feat: v0.0.1 --- README.md | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 89631ad..6feac64 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ A fully featured audio diffusion library, for PyTorch. Includes models for uncon ## Install ```bash -pip install -U git+https://github.com/archinetai/audio-diffusion-pytorch.git@nightly +pip install audio-diffusion-pytorch ``` [![PyPI - Python Version](https://img.shields.io/pypi/v/audio-diffusion-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/audio-diffusion-pytorch/) diff --git a/setup.py b/setup.py index 93a79fc..23b7b19 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.6+nightly", + version="0.1.0", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", From e6f5f4bb7cde13e1d0eb167a73abebe6b1ef2bbf Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 13 Jan 2023 15:30:33 +0100 Subject: [PATCH 15/23] feat: v0.1.0 --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 6feac64..8a717ef 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models are waveform-based, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable to work on other formats. + ## Install ```bash From 7517b9f75f6121b99b2429d296981d63d1a6e342 Mon Sep 17 00:00:00 2001 From: Eleiber Date: Wed, 18 Jan 2023 09:03:29 -0400 Subject: [PATCH 16/23] fix: readme SyntaxError due to missing comma in example (#42) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8a717ef..70452b9 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler model = DiffusionModel( # ... same as unconditional model use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base) - use_embedding_cfg=True # U-Net: enables classifier free guidance + use_embedding_cfg=True, # U-Net: enables classifier free guidance embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base) embedding_features=768, # U-Net: text mbedding features (default for T5-base) cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer From a34014f7c7a748879dac4d2541ee1e692dad14ce Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 20 Jan 2023 16:50:42 +0100 Subject: [PATCH 17/23] feat: add parameters linear schedule, uniform distribution --- audio_diffusion_pytorch/diffusion.py | 22 +++++++++++++++------- setup.py | 2 +- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/audio_diffusion_pytorch/diffusion.py b/audio_diffusion_pytorch/diffusion.py index ea9154f..63e3115 100644 --- a/audio_diffusion_pytorch/diffusion.py +++ b/audio_diffusion_pytorch/diffusion.py @@ -19,8 +19,13 @@ def __call__(self, num_samples: int, device: torch.device): class UniformDistribution(Distribution): + def __init__(self, vmin: float = 0.0, vmax: float = 1.0): + super().__init__() + self.vmin, self.vmax = vmin, vmax + def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")): - return torch.rand(num_samples, device=device) + vmax, vmin = self.vmax, self.vmin + return (vmax - vmin) * torch.rand(num_samples, device=device) + vmin """ Diffusion Methods """ @@ -132,8 +137,12 @@ def forward(self, num_steps: int, device: torch.device) -> Tensor: class LinearSchedule(Schedule): + def __init__(self, start: float = 1.0, end: float = 0.0): + super().__init__() + self.start, self.end = start, end + def forward(self, num_steps: int, device: Any) -> Tensor: - return torch.linspace(1.0, 0.0, num_steps, device=device) + return torch.linspace(self.start, self.end, num_steps, device=device) """ Samplers """ @@ -158,14 +167,13 @@ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: return alpha, beta def forward( # type: ignore - self, noise: Tensor, num_steps: int, show_progress: bool = False, **kwargs + self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs ) -> Tensor: - b = noise.shape[0] - sigmas = self.schedule(num_steps + 1, device=noise.device) + b = x_noisy.shape[0] + sigmas = self.schedule(num_steps + 1, device=x_noisy.device) sigmas = repeat(sigmas, "i -> i b", b=b) - sigmas_batch = extend_dim(sigmas, dim=noise.ndim + 1) + sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1) alphas, betas = self.get_alpha_beta(sigmas_batch) - x_noisy = noise * sigmas_batch[0] progress_bar = tqdm(range(num_steps), disable=not show_progress) for i in progress_bar: diff --git a/setup.py b/setup.py index 23b7b19..3c98e62 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.1.0", + version="0.1.1", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", From 514b8f72a4fc868e52b98ebdee7f64ea1b2b7bb0 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 27 Jan 2023 12:17:52 +0100 Subject: [PATCH 18/23] feat: add adapter option to DiffusionAE --- audio_diffusion_pytorch/models.py | 39 +++++++++++++++++++++++++++---- setup.py | 2 +- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/audio_diffusion_pytorch/models.py b/audio_diffusion_pytorch/models.py index 6247187..c3f9cc2 100644 --- a/audio_diffusion_pytorch/models.py +++ b/audio_diffusion_pytorch/models.py @@ -8,7 +8,15 @@ from .components import AppendChannelsPlugin, MelSpectrogram from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler -from .utils import closest_power_2, default, downsample, groupby, randn_like, upsample +from .utils import ( + closest_power_2, + default, + downsample, + exists, + groupby, + randn_like, + upsample, +) class DiffusionModel(nn.Module): @@ -46,6 +54,18 @@ def __init__(self): self.downsample_factor = None +class AdapterBase(nn.Module, ABC): + """Abstract class for DiffusionAE encoder""" + + @abstractmethod + def encode(self, x: Tensor) -> Tensor: + pass + + @abstractmethod + def decode(self, x: Tensor) -> Tensor: + pass + + class DiffusionAE(DiffusionModel): """Diffusion Auto Encoder""" @@ -55,6 +75,8 @@ def __init__( channels: Sequence[int], encoder: EncoderBase, inject_depth: int, + latent_factor: Optional[int] = None, + adapter: Optional[AdapterBase] = None, **kwargs, ): context_channels = [0] * len(channels) @@ -68,12 +90,19 @@ def __init__( self.in_channels = in_channels self.encoder = encoder self.inject_depth = inject_depth + # Optional custom latent factor and adapter + self.latent_factor = default(latent_factor, self.encoder.downsample_factor) + self.adapter = adapter.requires_grad_(False) if exists(adapter) else None def forward( # type: ignore self, x: Tensor, with_info: bool = False, **kwargs ) -> Union[Tensor, Tuple[Tensor, Any]]: + # Encode input to latent channels latent, info = self.encode(x, with_info=True) channels = [None] * self.inject_depth + [latent] + # Adapt input to diffusion if adapter provided + x = self.adapter.encode(x) if exists(self.adapter) else x + # Compute diffusion loss loss = super().forward(x, channels=channels, **kwargs) return (loss, info) if with_info else loss @@ -85,10 +114,10 @@ def decode( self, latent: Tensor, generator: Optional[Generator] = None, **kwargs ) -> Tensor: b = latent.shape[0] - length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) + noise_length = closest_power_2(latent.shape[2] * self.latent_factor) # Compute noise by inferring shape from latent length noise = torch.randn( - (b, self.in_channels, length), + (b, self.in_channels, noise_length), device=latent.device, dtype=latent.dtype, generator=generator, @@ -96,7 +125,9 @@ def decode( # Compute context from latent channels = [None] * self.inject_depth + [latent] # type: ignore # Decode by sampling while conditioning on latent channels - return super().sample(noise, channels=channels, **kwargs) + out = super().sample(noise, channels=channels, **kwargs) + # Decode output with adapter if provided + return self.adapter.decode(out) if exists(self.adapter) else out class DiffusionUpsampler(DiffusionModel): diff --git a/setup.py b/setup.py index 3c98e62..a31a95a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.1.1", + version="0.1.2", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", From a16e83526ed0e1fd9d99d8fbc22ac378d706575b Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 27 Jan 2023 12:32:07 +0100 Subject: [PATCH 19/23] fix: readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 70452b9..ca83576 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ mel_spectrogram = torch.randn(1, 2, 80, 1024) # [batch, in_channels, mel_channel sample = vocoder.sample(mel_spectrogram, num_steps=10) # Output has shape: [1, 2, 2**18] ``` -## Diffusion Autoencoder +### Diffusion Autoencoder Autoencode audio into a compressed latent using diffusion. Any encoder can be provided as long as it subclasses the `EncoderBase` class or contains an `out_channels` and `downsample_factor` field. ```py from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler From eafa972e27d332ec6f53dd616ac9a0cd466fc42f Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Wed, 8 Feb 2023 21:13:24 +0100 Subject: [PATCH 20/23] feat: update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ca83576..82e82c9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models are waveform-based, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable to work on other formats. +A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models are waveform-based, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable to work on other formats. **Note: no pre-trained models are provided here, this library is meant for research purposes.** ## Install From bcbb510b86dc0be97a193e18fc43a2ad04df2d90 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Sun, 26 Feb 2023 21:18:31 +0100 Subject: [PATCH 21/23] feat: add new v-inpainter --- README.md | 31 ++++++++++++++ audio_diffusion_pytorch/__init__.py | 2 + audio_diffusion_pytorch/diffusion.py | 62 ++++++++++++++++++++++++++++ setup.py | 2 +- 4 files changed, 96 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ca83576..61fc4cd 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,37 @@ latent = autoencoder.encode(audio) # Encode sample = autoencoder.decode(latent, num_steps=10) # Decode by sampling diffusion model conditioning on latent ``` +## Other + +### Inpainting +```py +from audio_diffusion_pytorch import UNetV0, VInpainter + +# The diffusion UNetV0 (this is an example, the net must be trained to work) +net = UNetV0( + dim=1, + in_channels=2, # U-Net: number of input/output (audio) channels + channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer + factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer + items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer + attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer + attention_heads=8, # U-Net: number of attention heads per attention block + attention_features=64, # U-Net: number of attention features per attention block, +) + +# Instantiate inpainter with trained net +inpainter = VInpainter(net=net) + +# Inpaint source +y = inpainter( + source=torch.randn(1, 2, 2**18), # Start source + mask=torch.randint(0, 2, (1, 2, 2 ** 18), dtype=torch.bool), # Set to `True` the parts you want to keep + num_steps=10, # Number of inpainting steps + num_resamples=2, # Number of resampling steps + show_progress=True, +) # [1, 2, 2 ** 18] +``` + ## Appreciation * [StabilityAI](https://stability.ai/) for the compute, [Zach Evans](https://github.com/zqevans) and everyone else from [HarmonAI](https://www.harmonai.org/) for the interesting research discussions. diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index 932a44c..3b4cad7 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -7,6 +7,7 @@ Schedule, UniformDistribution, VDiffusion, + VInpainter, VSampler, ) from .models import ( @@ -15,4 +16,5 @@ DiffusionModel, DiffusionUpsampler, DiffusionVocoder, + EncoderBase, ) diff --git a/audio_diffusion_pytorch/diffusion.py b/audio_diffusion_pytorch/diffusion.py index 63e3115..ee6e0ba 100644 --- a/audio_diffusion_pytorch/diffusion.py +++ b/audio_diffusion_pytorch/diffusion.py @@ -8,6 +8,8 @@ from torch import Tensor from tqdm import tqdm +from .utils import default + """ Distributions """ @@ -166,6 +168,7 @@ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: alpha, beta = torch.cos(angle), torch.sin(angle) return alpha, beta + @torch.no_grad() def forward( # type: ignore self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs ) -> Tensor: @@ -242,6 +245,7 @@ def sample_start(self, num_items: int, num_steps: int, **kwargs) -> Tensor: # Sample start return self.sample_loop(current=noise, sigmas=sigmas, **kwargs) + @torch.no_grad() def forward( self, num_items: int, @@ -289,3 +293,61 @@ def forward( chunks += [torch.randn(shape, device=self.device)] return torch.cat(chunks[:num_chunks], dim=-1) + + +""" Inpainters """ + + +class Inpainter(nn.Module): + pass + + +class VInpainter(Inpainter): + + diffusion_types = [VDiffusion] + + def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()): + super().__init__() + self.net = net + self.schedule = schedule + + def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: + angle = sigmas * pi / 2 + alpha, beta = torch.cos(angle), torch.sin(angle) + return alpha, beta + + @torch.no_grad() + def forward( # type: ignore + self, + source: Tensor, + mask: Tensor, + num_steps: int, + num_resamples: int, + show_progress: bool = False, + x_noisy: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + x_noisy = default(x_noisy, lambda: torch.randn_like(source)) + b = x_noisy.shape[0] + sigmas = self.schedule(num_steps + 1, device=x_noisy.device) + sigmas = repeat(sigmas, "i -> i b", b=b) + sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1) + alphas, betas = self.get_alpha_beta(sigmas_batch) + progress_bar = tqdm(range(num_steps), disable=not show_progress) + + for i in progress_bar: + for r in range(num_resamples): + v_pred = self.net(x_noisy, sigmas[i], **kwargs) + x_pred = alphas[i] * x_noisy - betas[i] * v_pred + noise_pred = betas[i] * x_noisy + alphas[i] * v_pred + # Renoise to current noise level if resampling + j = r == num_resamples - 1 + x_noisy = alphas[i + j] * x_pred + betas[i + j] * noise_pred + s_noisy = alphas[i + j] * source + betas[i + j] * torch.randn_like( + source + ) + x_noisy = s_noisy * mask + x_noisy * ~mask + + progress_bar.set_description(f"Inpainting (noise={sigmas[i+1,0]:.2f})") + + return x_noisy diff --git a/setup.py b/setup.py index a31a95a..56a5618 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.1.2", + version="0.1.3", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", From 70a24bba19de8278ca80e101e3da88ebde452624 Mon Sep 17 00:00:00 2001 From: "Dr. Tristan Behrens" <32195399+AI-Guru@users.noreply.github.com> Date: Tue, 25 Apr 2023 11:24:52 +0200 Subject: [PATCH 22/23] feat: add custom loss functionality (#65) --- audio_diffusion_pytorch/diffusion.py | 11 ++++---- audio_diffusion_pytorch/models.py | 3 ++- tests/testcustomloss.py | 39 ++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 6 deletions(-) create mode 100644 tests/testcustomloss.py diff --git a/audio_diffusion_pytorch/diffusion.py b/audio_diffusion_pytorch/diffusion.py index ee6e0ba..bca143a 100644 --- a/audio_diffusion_pytorch/diffusion.py +++ b/audio_diffusion_pytorch/diffusion.py @@ -67,11 +67,12 @@ class Diffusion(nn.Module): class VDiffusion(Diffusion): def __init__( - self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution() + self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution(), loss_fn: Any = F.mse_loss ): super().__init__() self.net = net self.sigma_distribution = sigma_distribution + self.loss_fn = loss_fn def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: angle = sigmas * pi / 2 @@ -91,17 +92,18 @@ def forward(self, x: Tensor, **kwargs) -> Tensor: # type: ignore v_target = alphas * noise - betas * x # Predict velocity and return loss v_pred = self.net(x_noisy, sigmas, **kwargs) - return F.mse_loss(v_pred, v_target) + return self.loss_fn(v_pred, v_target) class ARVDiffusion(Diffusion): - def __init__(self, net: nn.Module, length: int, num_splits: int): + def __init__(self, net: nn.Module, length: int, num_splits: int, loss_fn: Any = F.mse_loss): super().__init__() assert length % num_splits == 0, "length must be divisible by num_splits" self.net = net self.length = length self.num_splits = num_splits self.split_length = length // num_splits + self.loss_fn = loss_fn def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: angle = sigmas * pi / 2 @@ -125,8 +127,7 @@ def forward(self, x: Tensor, **kwargs) -> Tensor: channels = torch.cat([x_noisy, sigmas], dim=1) # Predict velocity and return loss v_pred = self.net(channels, **kwargs) - return F.mse_loss(v_pred, v_target) - + return self.loss_fn(v_pred, v_target) """ Schedules """ diff --git a/audio_diffusion_pytorch/models.py b/audio_diffusion_pytorch/models.py index c3f9cc2..04c3c47 100644 --- a/audio_diffusion_pytorch/models.py +++ b/audio_diffusion_pytorch/models.py @@ -25,6 +25,7 @@ def __init__( net_t: Callable, diffusion_t: Callable = VDiffusion, sampler_t: Callable = VSampler, + loss_fn: Callable = torch.nn.functional.mse_loss, dim: int = 1, **kwargs, ): @@ -33,7 +34,7 @@ def __init__( sampler_kwargs, kwargs = groupby("sampler_", kwargs) self.net = net_t(dim=dim, **kwargs) - self.diffusion = diffusion_t(net=self.net, **diffusion_kwargs) + self.diffusion = diffusion_t(net=self.net, loss_fn=loss_fn, **diffusion_kwargs) self.sampler = sampler_t(net=self.net, **sampler_kwargs) def forward(self, *args, **kwargs) -> Tensor: diff --git a/tests/testcustomloss.py b/tests/testcustomloss.py new file mode 100644 index 0000000..619eb33 --- /dev/null +++ b/tests/testcustomloss.py @@ -0,0 +1,39 @@ +import torch +import torch.nn.functional as F +from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler +from audio_encoders_pytorch import MelE1d, TanhBottleneck +from auraloss.freq import MultiResolutionSTFTLoss + +autoencoder = DiffusionAE( + encoder=MelE1d( # The encoder used, in this case a mel-spectrogram encoder + in_channels=2, + channels=512, + multipliers=[1, 1], + factors=[2], + num_blocks=[12], + out_channels=32, + mel_channels=80, + mel_sample_rate=48000, + mel_normalize_log=True, + bottleneck=TanhBottleneck(), + ), + inject_depth=6, + net_t=UNetV0, # The model type used for diffusion upsampling + in_channels=2, # U-Net: number of input/output (audio) channels + channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer + factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer + items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer + diffusion_t=VDiffusion, # The diffusion method used + sampler_t=VSampler, # The diffusion sampler used + loss_fn=MultiResolutionSTFTLoss(), # The loss function used +) + +# Train autoencoder with audio samples +audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length] +loss = autoencoder(audio) +loss.backward() + +# Encode/decode audio +audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length] +latent = autoencoder.encode(audio) # Encode +sample = autoencoder.decode(latent, num_steps=10) # Decode by sampling diffusion model conditioning on latent \ No newline at end of file From f4052e321c820e467e51e14d005f3d0077997278 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Tue, 25 Apr 2023 11:29:31 +0200 Subject: [PATCH 23/23] feat: update notes --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 432aabc..69530a9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models are waveform-based, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable to work on other formats. **Note: no pre-trained models are provided here, this library is meant for research purposes.** +A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models are waveform-based, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable to work on other formats. **Notes: (1) no pre-trained models are provided here, (2) the configs shown are indicative and untested, see [Moûsai](https://arxiv.org/abs/2301.11757) for the configs used in the paper.** ## Install