diff --git a/README.md b/README.md index 3d861b3..bbfec86 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # A-UNet -A toolbox that provides hackable building blocks for (1D/2D/3D) UNets, in PyTorch. +A toolbox that provides hackable building blocks for generic 1D/2D/3D UNets, in PyTorch. ## Install ```bash diff --git a/a_unet/blocks.py b/a_unet/blocks.py index ef54f4f..35806f0 100644 --- a/a_unet/blocks.py +++ b/a_unet/blocks.py @@ -1,7 +1,9 @@ +from math import pi from typing import Any, Callable, Optional, Sequence, Type, TypeVar, Union import torch -from einops import pack, rearrange, repeat, unpack +import torch.nn.functional as F +from einops import pack, rearrange, reduce, repeat, unpack from torch import Tensor, einsum, nn from typing_extensions import TypeGuard @@ -351,13 +353,79 @@ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) -def CFG( +""" +Embedders +""" + + +class NumberEmbedder(nn.Module): + def __init__(self, features: int, dim: int = 256): + super().__init__() + assert dim % 2 == 0, f"dim must be divisible by 2, found {dim}" + self.features = features + self.weights = nn.Parameter(torch.randn(dim // 2)) + self.to_out = nn.Linear(in_features=dim + 1, out_features=features) + + def to_embedding(self, x: Tensor) -> Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return self.to_out(fouriered) + + def forward(self, x: Union[Sequence[float], Tensor]) -> Tensor: + if not torch.is_tensor(x): + x = torch.tensor(x, device=self.weights.device) + assert isinstance(x, Tensor) + shape = x.shape + x = rearrange(x, "... -> (...)") + return self.to_embedding(x).view(*shape, self.features) # type: ignore + + +class T5Embedder(nn.Module): + def __init__(self, model: str = "t5-base", max_length: int = 64): + super().__init__() + from transformers import AutoTokenizer, T5EncoderModel + + self.tokenizer = AutoTokenizer.from_pretrained(model) + self.transformer = T5EncoderModel.from_pretrained(model) + self.max_length = max_length + + @torch.no_grad() + def forward(self, texts: Sequence[str]) -> Tensor: + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + device = next(self.transformer.parameters()).device + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device) + + self.transformer.eval() + + embedding = self.transformer( + input_ids=input_ids, attention_mask=attention_mask + )["last_hidden_state"] + + return embedding + + +""" +Plugins +""" + + +def ClassifierFreeGuidancePlugin( net_t: Type[nn.Module], embedding_max_length: int, ) -> Callable[..., nn.Module]: """Classifier-Free Guidance -> CFG(UNet, embedding_max_length=512)(...)""" - def CFGNet(embedding_features: int, **kwargs) -> nn.Module: + def Net(embedding_features: int, **kwargs) -> nn.Module: fixed_embedding = FixedEmbedding( max_length=embedding_max_length, features=embedding_features, @@ -371,7 +439,8 @@ def forward( embedding_mask_proba: float = 0.0, **kwargs, ): - assert exists(embedding), "embedding required when using CFG" + msg = "ClassiferFreeGuidancePlugin requires embedding" + assert exists(embedding), msg b, device = embedding.shape[0], embedding.device embedding_mask = fixed_embedding(embedding) @@ -393,4 +462,46 @@ def forward( return Module([fixed_embedding, net], forward) - return CFGNet + return Net + + +def TimeConditioningPlugin( + net_t: Type[nn.Module], + num_layers: int = 2, +) -> Callable[..., nn.Module]: + """Adds time conditioning (e.g. for diffusion)""" + + def Net(modulation_features: Optional[int] = None, **kwargs) -> nn.Module: + msg = "TimeConditioningPlugin requires modulation_features" + assert exists(modulation_features), msg + + embedder = NumberEmbedder(features=modulation_features) + mlp = Repeat( + nn.Sequential( + nn.Linear(modulation_features, modulation_features), nn.GELU() + ), + times=num_layers, + ) + net = net_t(modulation_features=modulation_features, **kwargs) # type: ignore + + def forward( + x: Tensor, + time: Optional[Tensor] = None, + features: Optional[Tensor] = None, + **kwargs, + ): + msg = "TimeConditioningPlugin requires time in forward" + assert exists(time), msg + # Process time to time_features + time_features = F.gelu(embedder(time)) + time_features = mlp(time_features) + # Overlap features if more than one per batch + if time_features.ndim == 3: + time_features = reduce(time_features, "b n d -> b d", "sum") + # Merge time features with features if provided + features = features + time_features if exists(features) else time_features + return net(x, features=features, **kwargs) + + return Module([embedder, mlp, net], forward) + + return Net diff --git a/a_unet/unet/apex.py b/a_unet/unet/apex.py index d4e94e9..4cb6afb 100644 --- a/a_unet/unet/apex.py +++ b/a_unet/unet/apex.py @@ -42,7 +42,7 @@ def DownsampleItem( factor: Optional[int] = None, in_channels: Optional[int] = None, channels: Optional[int] = None, - **kwargs + **kwargs, ) -> nn.Module: msg = "DownsampleItem requires dim, factor, in_channels, channels" assert ( @@ -59,7 +59,7 @@ def UpsampleItem( factor: Optional[int] = None, channels: Optional[int] = None, out_channels: Optional[int] = None, - **kwargs + **kwargs, ) -> nn.Module: msg = "UpsampleItem requires dim, factor, channels, out_channels" assert ( @@ -78,7 +78,7 @@ def ResnetItem( dim: Optional[int] = None, channels: Optional[int] = None, resnet_groups: Optional[int] = None, - **kwargs + **kwargs, ) -> nn.Module: msg = "ResnetItem requires dim, channels, and resnet_groups" assert exists(dim) and exists(channels) and exists(resnet_groups), msg @@ -93,7 +93,7 @@ def AttentionItem( channels: Optional[int] = None, attention_features: Optional[int] = None, attention_heads: Optional[int] = None, - **kwargs + **kwargs, ) -> nn.Module: msg = "AttentionItem requires channels, attention_features, attention_heads" assert ( @@ -114,7 +114,7 @@ def CrossAttentionItem( attention_features: Optional[int] = None, attention_heads: Optional[int] = None, embedding_features: Optional[int] = None, - **kwargs + **kwargs, ) -> nn.Module: msg = "CrossAttentionItem requires channels, embedding_features, attention_*" assert ( @@ -149,7 +149,7 @@ def LinearAttentionItem( channels: Optional[int] = None, attention_features: Optional[int] = None, attention_heads: Optional[int] = None, - **kwargs + **kwargs, ) -> nn.Module: msg = "LinearAttentionItem requires attention_features and attention_heads" assert ( @@ -170,7 +170,7 @@ def LinearCrossAttentionItem( attention_features: Optional[int] = None, attention_heads: Optional[int] = None, embedding_features: Optional[int] = None, - **kwargs + **kwargs, ) -> nn.Module: msg = "LinearCrossAttentionItem requires channels, embedding_features, attention_*" assert ( @@ -208,7 +208,7 @@ def SkipAdapterItem( dim: Optional[int] = None, in_channels: Optional[int] = None, out_channels: Optional[int] = None, - **kwargs + **kwargs, ): msg = "SkipAdapterItem requires dim, in_channels, out_channels" assert exists(dim) and exists(in_channels) and exists(out_channels), msg @@ -244,7 +244,7 @@ def SkipModulateItem( dim: Optional[int] = None, out_channels: Optional[int] = None, modulation_features: Optional[int] = None, - **kwargs + **kwargs, ) -> nn.Module: msg = "SkipModulateItem requires dim, out_channels, modulation_features" assert exists(dim) and exists(out_channels) and exists(modulation_features), msg @@ -268,7 +268,7 @@ def __init__( items_up: Optional[Sequence[Callable]] = None, out_channels: Optional[int] = None, inner_block: Optional[nn.Module] = None, - **kwargs + **kwargs, ): super().__init__() out_channels = default(out_channels, in_channels) @@ -316,7 +316,7 @@ def __init__( in_channels: int, blocks: Sequence, out_channels: Optional[int] = None, - **kwargs + **kwargs, ): super().__init__() num_layers = len(blocks) @@ -330,7 +330,11 @@ def Net(i: int) -> Optional[nn.Module]: out_ch = out_channels if i == 0 else in_ch return block_t( - in_channels=in_ch, out_channels=out_ch, inner_block=Net(i + 1), **kwargs + in_channels=in_ch, + out_channels=out_ch, + depth=i, + inner_block=Net(i + 1), + **kwargs, ) self.net = Net(0) @@ -338,6 +342,7 @@ def Net(i: int) -> Optional[nn.Module]: def forward( self, x: Tensor, + *, features: Optional[Tensor] = None, embedding: Optional[Tensor] = None, channels: Optional[Sequence[Tensor]] = None, diff --git a/setup.py b/setup.py index 123f345..4988306 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="a-unet", packages=find_packages(exclude=[]), - version="0.0.6", + version="0.0.7", license="MIT", description="A-UNet", long_description_content_type="text/markdown",