Skip to content

Commit

Permalink
feat: add CFG wrapper, fix fixed embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Dec 20, 2022
1 parent feed052 commit 9aa1b1b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pip install a-unet

### Basic UNet

<details> <summary> (Code): A convolutional only UNet generic to any dimension, using A-UNet blocks. </summary>
<details> <summary> (Code): A convolutional only UNet generic to any dimension. </summary>

```py
from typing import List
Expand Down
56 changes: 55 additions & 1 deletion a_unet/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor:
def FixedEmbedding(max_length: int, features: int):
embedding = nn.Embedding(max_length, features)

def forward(self, x: Tensor) -> Tensor:
def forward(x: Tensor) -> Tensor:
batch_size, length, device = *x.shape[0:2], x.device
assert_message = "Input sequence length must be <= max_length"
assert length <= max_length, assert_message
Expand Down Expand Up @@ -280,3 +280,57 @@ def FeedForward(features: int, multiplier: int) -> nn.Module:
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)


def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
if proba == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif proba == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)


def CFG(
net_t: Type[nn.Module],
embedding_max_length: int,
) -> Callable[..., nn.Module]:
"""Classifier-Free Guidance -> CFG(UNet, embedding_max_length=512)(...)"""

def CFGNet(embedding_features: int, **kwargs) -> nn.Module:
fixed_embedding = FixedEmbedding(
max_length=embedding_max_length,
features=embedding_features,
)
net = net_t(embedding_features=embedding_features, **kwargs) # type: ignore

def forward(
x: Tensor,
embedding: Optional[Tensor] = None,
embedding_scale: float = 1.0,
embedding_mask_proba: float = 0.0,
**kwargs,
):
assert exists(embedding), "embedding required when using CFG"
b, device = embedding.shape[0], embedding.device
embedding_mask = fixed_embedding(embedding)

if embedding_mask_proba > 0.0:
# Randomly mask embedding
batch_mask = rand_bool(
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
)
embedding = torch.where(batch_mask, embedding_mask, embedding)

if embedding_scale != 1.0:
# Compute both normal and fixed embedding outputs
out = net(x, embedding=embedding, **kwargs)
out_masked = net(x, embedding=embedding_mask, **kwargs)
# Scale conditional output using classifier-free guidance
return out_masked + (out - out_masked) * embedding_scale
else:
return net(x, embedding=embedding, **kwargs)

return Module([fixed_embedding, net], forward)

return CFGNet
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="a-unet",
packages=find_packages(exclude=[]),
version="0.0.4",
version="0.0.5",
license="MIT",
description="A-UNet",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 9aa1b1b

Please sign in to comment.