Skip to content

Commit

Permalink
[Type Hints] VAE models (open-mmlab#344)
Browse files Browse the repository at this point in the history
* [Type Hints] VAE models

* apply suggestions from code review

apply suggestions to also return the return type
  • Loading branch information
daspartho authored Sep 4, 2022
1 parent 878af0e commit 5791f4a
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -293,7 +295,7 @@ def __init__(self, parameters, deterministic=False):
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

def sample(self, generator=None):
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
return x

Expand Down Expand Up @@ -327,16 +329,16 @@ class VQModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=1,
act_fn="silu",
latent_channels=3,
sample_size=32,
num_vq_embeddings=256,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 3,
sample_size: int = 32,
num_vq_embeddings: int = 256,
):
super().__init__()

Expand Down Expand Up @@ -382,7 +384,7 @@ def decode(self, h, force_not_quantize=False):
dec = self.decoder(quant)
return dec

def forward(self, sample):
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
x = sample
h = self.encode(x)
dec = self.decode(h)
Expand All @@ -393,15 +395,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=1,
act_fn="silu",
latent_channels=4,
sample_size=32,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
sample_size: int = 32,
):
super().__init__()

Expand Down Expand Up @@ -440,7 +442,7 @@ def decode(self, z):
dec = self.decoder(z)
return dec

def forward(self, sample, sample_posterior=False):
def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
x = sample
posterior = self.encode(x)
if sample_posterior:
Expand Down

0 comments on commit 5791f4a

Please sign in to comment.