diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index e9b8dd8..6a134cf 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -1,6 +1,6 @@ import math from random import random -from beartype.typing import List, Union +from beartype.typing import List, Union, Optional from beartype import beartype from tqdm.auto import tqdm from functools import partial, wraps @@ -2288,7 +2288,7 @@ def p_sample_loop( @beartype def sample( self, - texts: List[str] = None, + texts: Optional[List[str]] = None, text_masks = None, text_embeds = None, video_frames = None, @@ -2637,7 +2637,7 @@ def forward( self, images, # rename to images or video unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None, - texts: List[str] = None, + texts: Optional[List[str]] = None, text_embeds = None, text_masks = None, unet_number = None, diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index 8dd3026..22db8a6 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.26.2' +__version__ = '1.26.3'