diff --git a/ppdiffusers/ppdiffusers/__init__.py b/ppdiffusers/ppdiffusers/__init__.py index 2a07c3c08e5a..c28d1de70d3b 100644 --- a/ppdiffusers/ppdiffusers/__init__.py +++ b/ppdiffusers/ppdiffusers/__init__.py @@ -62,6 +62,7 @@ UNet1DModel, UNet2DConditionModel, UNet2DModel, + UNet3DConditionModel, VQModel, ) from .optimization import ( @@ -157,6 +158,7 @@ StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, + TextToVideoSDPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, VersatileDiffusionDualGuidedPipeline, diff --git a/ppdiffusers/ppdiffusers/models/__init__.py b/ppdiffusers/ppdiffusers/models/__init__.py index e09e7b1f9640..8f177fe293a1 100644 --- a/ppdiffusers/ppdiffusers/models/__init__.py +++ b/ppdiffusers/ppdiffusers/models/__init__.py @@ -28,4 +28,5 @@ from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel + from .unet_3d_condition import UNet3DConditionModel from .vq_model import VQModel diff --git a/ppdiffusers/ppdiffusers/models/attention.py b/ppdiffusers/ppdiffusers/models/attention.py index c51d5c6a2664..fb66a74bd2e6 100644 --- a/ppdiffusers/ppdiffusers/models/attention.py +++ b/ppdiffusers/ppdiffusers/models/attention.py @@ -171,6 +171,10 @@ class BasicTransformerBlock(nn.Layer): attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. @@ -189,6 +193,7 @@ def __init__( num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, + double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", @@ -220,10 +225,10 @@ def __init__( self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) # 2. Cross-Attn - if cross_attention_dim is not None: + if cross_attention_dim is not None or double_self_attention: self.attn2 = CrossAttention( query_dim=dim, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, @@ -245,7 +250,7 @@ def __init__( else: self.norm1 = nn.LayerNorm(dim, **norm_kwargs) - if cross_attention_dim is not None: + if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. diff --git a/ppdiffusers/ppdiffusers/models/resnet.py b/ppdiffusers/ppdiffusers/models/resnet.py index fe85e59ace6f..f52610af17a9 100644 --- a/ppdiffusers/ppdiffusers/models/resnet.py +++ b/ppdiffusers/ppdiffusers/models/resnet.py @@ -20,6 +20,7 @@ import paddle.nn as nn import paddle.nn.functional as F +from ..initializer import zeros_ from .attention import AdaGroupNorm @@ -806,3 +807,58 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.reshape([-1, channel, out_h, out_w]) + + +class TemporalConvLayer(nn.Layer): + """ + Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + """ + + def __init__(self, in_dim, out_dim=None, dropout=0.0): + super().__init__() + out_dim = out_dim or in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.conv1 = nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=in_dim), + nn.Silu(), + nn.Conv3D(in_channels=in_dim, out_channels=out_dim, kernel_size=(3, 1, 1), padding=(1, 0, 0)), + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=out_dim), + nn.Silu(), + nn.Dropout(p=dropout), + nn.Conv3D(in_channels=out_dim, out_channels=in_dim, kernel_size=(3, 1, 1), padding=(1, 0, 0)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=out_dim), + nn.Silu(), + nn.Dropout(p=dropout), + nn.Conv3D(in_channels=out_dim, out_channels=in_dim, kernel_size=(3, 1, 1), padding=(1, 0, 0)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=out_dim), + nn.Silu(), + nn.Dropout(p=dropout), + nn.Conv3D(in_channels=out_dim, out_channels=in_dim, kernel_size=(3, 1, 1), padding=(1, 0, 0)), + ) + zeros_(self.conv4[-1].weight) + zeros_(self.conv4[-1].bias) + + def forward(self, hidden_states, num_frames=1): + hidden_states = ( + hidden_states[(None), :] + .reshape((-1, num_frames) + tuple(hidden_states.shape[1:])) + .transpose(perm=[0, 2, 1, 3, 4]) + ) + identity = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.conv3(hidden_states) + hidden_states = self.conv4(hidden_states) + hidden_states = identity + hidden_states + hidden_states = hidden_states.transpose(perm=[0, 2, 1, 3, 4]).reshape( + (hidden_states.shape[0] * hidden_states.shape[2], -1) + tuple(hidden_states.shape[3:]) + ) + return hidden_states diff --git a/ppdiffusers/ppdiffusers/models/transformer_temporal.py b/ppdiffusers/ppdiffusers/models/transformer_temporal.py new file mode 100644 index 000000000000..483f3fa28ddb --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/transformer_temporal.py @@ -0,0 +1,162 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +import paddle +import paddle.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .attention import BasicTransformerBlock +from .modeling_utils import ModelMixin + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + Args: + sample (`paddle.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`) + Hidden states conditioned on `encoder_hidden_states` input. + """ + + sample: paddle.Tensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each TransformerBlock should contain two self-attention layers + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-06) + self.proj_in = nn.Linear(in_features=in_channels, out_features=inner_dim) + self.transformer_blocks = nn.LayerList( + sublayers=[ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(in_features=inner_dim, out_features=in_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + num_frames=1, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `paddle.Tensor` of shape `(batch size, num latent pixels)`. + When continous, `paddle.Tensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `paddleTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `paddle.int64`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `paddle.Tensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: + [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + residual = hidden_states + hidden_states = hidden_states[(None), :].reshape((batch_size, num_frames, channel, height, width)) + hidden_states = hidden_states.transpose(perm=[0, 2, 1, 3, 4]) + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.transpose(perm=[0, 3, 4, 2, 1]).reshape( + (batch_size * height * width, num_frames, channel) + ) + hidden_states = self.proj_in(hidden_states) + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[(None), (None), :] + .reshape((batch_size, height, width, channel, num_frames)) + .transpose(perm=[0, 3, 4, 1, 2]) + ) + hidden_states = hidden_states.reshape((batch_frames, channel, height, width)) + output = hidden_states + residual + if not return_dict: + return (output,) + return TransformerTemporalModelOutput(sample=output) diff --git a/ppdiffusers/ppdiffusers/models/unet_3d_blocks.py b/ppdiffusers/ppdiffusers/models/unet_3d_blocks.py new file mode 100644 index 000000000000..47bcdae5c7f7 --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/unet_3d_blocks.py @@ -0,0 +1,595 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D +from .transformer_2d import Transformer2DModel +from .transformer_temporal import TransformerTemporalModel + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Layer): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-06, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=True, + upcast_attention=False, + ): + super().__init__() + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [TemporalConvLayer(in_channels, in_channels, dropout=0.1)] + attentions = [] + temp_attentions = [] + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append(TemporalConvLayer(in_channels, in_channels, dropout=0.1)) + self.resnets = nn.LayerList(sublayers=resnets) + self.temp_convs = nn.LayerList(sublayers=temp_convs) + self.attentions = nn.LayerList(sublayers=attentions) + self.temp_attentions = nn.LayerList(sublayers=temp_attentions) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + return hidden_states + + +class CrossAttnDownBlock3D(nn.Layer): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-06, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1)) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.LayerList(sublayers=resnets) + self.temp_convs = nn.LayerList(sublayers=temp_convs) + self.attentions = nn.LayerList(sublayers=attentions) + self.temp_attentions = nn.LayerList(sublayers=temp_attentions) + if add_downsample: + self.downsamplers = nn.LayerList( + sublayers=[ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + output_states = () + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +class DownBlock3D(nn.Layer): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-06, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1)) + self.resnets = nn.LayerList(sublayers=resnets) + self.temp_convs = nn.LayerList(sublayers=temp_convs) + if add_downsample: + self.downsamplers = nn.LayerList( + sublayers=[ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + output_states += (hidden_states,) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states += (hidden_states,) + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Layer): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-06, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1)) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.LayerList(sublayers=resnets) + self.temp_convs = nn.LayerList(sublayers=temp_convs) + self.attentions = nn.LayerList(sublayers=attentions) + self.temp_attentions = nn.LayerList(sublayers=temp_attentions) + if add_upsample: + self.upsamplers = nn.LayerList( + sublayers=[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = paddle.concat(x=[hidden_states, res_hidden_states], axis=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +class UpBlock3D(nn.Layer): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-06, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + temp_convs = [] + for i in range(num_layers): + res_skip_channels = in_channels if i == num_layers - 1 else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1)) + self.resnets = nn.LayerList(sublayers=resnets) + self.temp_convs = nn.LayerList(sublayers=temp_convs) + if add_upsample: + self.upsamplers = nn.LayerList( + sublayers=[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = paddle.concat(x=[hidden_states, res_hidden_states], axis=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states diff --git a/ppdiffusers/ppdiffusers/models/unet_3d_condition.py b/ppdiffusers/ppdiffusers/models/unet_3d_condition.py new file mode 100644 index 000000000000..e42924f70d91 --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/unet_3d_condition.py @@ -0,0 +1,408 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import paddle +import paddle.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) + +logger = logging.get_logger(__name__) + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + Args: + sample (`paddle.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: paddle.Tensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + """ + UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-05, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + ): + super().__init__() + self.sample_size = sample_size + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2D( + in_channels=in_channels, + out_channels=block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + self.down_blocks = nn.LayerList(sublayers=[]) + self.up_blocks = nn.LayerList(sublayers=[]) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + self.num_upsamplers = 0 + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=False, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=block_out_channels[0], + epsilon=norm_eps, + weight_attr=None, + bias_attr=None, + ) + self.conv_act = nn.Silu() + else: + self.conv_norm_out = None + self.conv_act = None + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2D( + in_channels=block_out_channels[0], + out_channels=out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding, + ) + + def set_attention_slice(self, slice_size): + """ + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: nn.Layer): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + num_slicable_layers = len(sliceable_head_dims) + if slice_size == "auto": + slice_size = [(dim // 2) for dim in sliceable_head_dims] + elif slice_size == "max": + slice_size = num_slicable_layers * [1] + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + def fn_recursive_set_attention_slice(module: nn.Layer, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: paddle.Tensor, + timestep: Union[paddle.Tensor, float, int], + encoder_hidden_states: paddle.Tensor, + class_labels: Optional[paddle.Tensor] = None, + timestep_cond: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[paddle.Tensor]] = None, + mid_block_additional_residual: Optional[paddle.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + """ + Args: + sample (`paddle.Tensor`): (batch, num_frames, channel, height, width) noisy inputs tensor + timestep (`paddle.Tensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`paddle.Tensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in ppdiffusers.cross_attention. + + Returns: + [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + default_overall_up_factor = 2**self.num_upsamplers + forward_upsample_size = False + upsample_size = None + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + if attention_mask is not None: + attention_mask = (1 - attention_mask.astype(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(axis=1) + timesteps = timestep + if not paddle.is_tensor(x=timesteps): + if isinstance(timestep, float): + dtype = "float64" + else: + dtype = "int64" + timesteps = paddle.to_tensor([timesteps], dtype=dtype) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + num_frames = sample.shape[2] + timesteps = timesteps.expand( + [ + sample.shape[0], + ] + ) + t_emb = self.time_proj(timesteps) + t_emb = t_emb.astype(dtype=self.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, axis=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, axis=0) + sample = sample.transpose(perm=[0, 2, 1, 3, 4]).reshape( + (sample.shape[0] * num_frames, -1) + tuple(sample.shape[3:]) + ) + sample = self.conv_in(sample) + sample = self.transformer_in(sample, num_frames=num_frames).sample + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + down_block_res_samples += res_samples + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + down_block_res_samples = new_down_block_res_samples + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + sample = sample[(None), :].reshape((-1, num_frames) + tuple(sample.shape[1:])).transpose(perm=[0, 2, 1, 3, 4]) + if not return_dict: + return (sample,) + return UNet3DConditionOutput(sample=sample) diff --git a/ppdiffusers/ppdiffusers/pipelines/__init__.py b/ppdiffusers/ppdiffusers/pipelines/__init__.py index 981035513a2b..c2ec680f6e75 100644 --- a/ppdiffusers/ppdiffusers/pipelines/__init__.py +++ b/ppdiffusers/ppdiffusers/pipelines/__init__.py @@ -84,6 +84,7 @@ StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .text_to_video_synthesis import TextToVideoSDPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, diff --git a/ppdiffusers/ppdiffusers/pipelines/text_to_video_synthesis/__init__.py b/ppdiffusers/ppdiffusers/pipelines/text_to_video_synthesis/__init__.py new file mode 100644 index 000000000000..4a2da9e9a689 --- /dev/null +++ b/ppdiffusers/ppdiffusers/pipelines/text_to_video_synthesis/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import paddle + +from ...utils import ( + BaseOutput, + OptionalDependencyNotAvailable, + is_paddle_available, + is_paddlenlp_available, +) + + +@dataclass +class TextToVideoSDPipelineOutput(BaseOutput): + """ + Output class for text to video pipelines. + + Args: + frames (`List[np.ndarray]` or `paddle.Tensor`) + List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as + a `paddle` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list + denotes the video length i.e., the number of frames. + """ + + frames: Union[List[np.ndarray], paddle.Tensor] + + +try: + if not (is_paddlenlp_available() and is_paddle_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_paddle_and_paddlenlp_objects import * +else: + from .pipeline_text_to_video_synth import TextToVideoSDPipeline diff --git a/ppdiffusers/ppdiffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/ppdiffusers/ppdiffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py new file mode 100644 index 000000000000..195bfd47514c --- /dev/null +++ b/ppdiffusers/ppdiffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -0,0 +1,435 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import paddle + +from paddlenlp.transformers import CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet3DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from . import TextToVideoSDPipelineOutput + +logger = logging.get_logger(__name__) +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import paddle + >>> from ppdiffusers import TextToVideoSDPipeline + >>> from ppdiffusers.utils import export_to_video + + >>> pipe = TextToVideoSDPipeline.from_pretrained( + ... "damo-vilab/text-to-video-ms-1.7b", + ... ) + + >>> prompt = "Spiderman is surfing" + >>> video_frames = pipe(prompt).frames + >>> video_path = export_to_video(video_frames) + >>> video_path + ``` +""" + + +def tensor2vid(video: paddle.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + mean = paddle.to_tensor(mean).reshape((1, -1, 1, 1, 1)) + std = paddle.to_tensor(std).reshape((1, -1, 1, 1, 1)) + video = video.multiply(std) + video = video.add(mean) + + video.clip_(min=0, max=1) + i, c, f, h, w = video.shape + images = video.transpose(perm=[2, 3, 0, 4, 1]).reshape((f, h, i * w, c)) + images = images.unbind(axis=0) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] + return images + + +class TextToVideoSDPipeline(DiffusionPipeline): + """ + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Same as Stable Diffusion 2. + tokenizer (`CLIPTokenizer`): + Tokenizer of class CLIPTokenizer. + unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def _encode_prompt( + self, + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[paddle.Tensor] = None, + negative_prompt_embeds: Optional[paddle.Tensor] = None, + ): + """ + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`paddle.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`paddle.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pd", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because CLIP can only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask + else: + attention_mask = None + prompt_embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.astype(self.text_encoder.dtype) + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.tile(repeat_times=[1, num_images_per_prompt, 1]) + prompt_embeds = prompt_embeds.reshape((bs_embed * num_images_per_prompt, seq_len, -1)) + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} != {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`: {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pd" + ) + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask + else: + attention_mask = None + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids, attention_mask=attention_mask) + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.astype(self.text_encoder.dtype) + negative_prompt_embeds = negative_prompt_embeds.tile(repeat_times=[1, num_images_per_prompt, 1]) + negative_prompt_embeds = negative_prompt_embeds.reshape((batch_size * num_images_per_prompt, seq_len, -1)) + prompt_embeds = paddle.concat(x=[negative_prompt_embeds, prompt_embeds]) + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.transpose(perm=[0, 2, 1, 3, 4]).reshape((batch_size * num_frames, channels, height, width)) + image = self.vae.decode(latents).sample + video = ( + image[(None), :] + .reshape((batch_size, num_frames, -1) + tuple(image.shape[2:])) + .transpose(perm=[0, 2, 1, 3, 4]) + ) + video = video.astype(dtype="float32") + return video + + def prepare_extra_step_kwargs(self, generator, eta): + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if ( + callback_steps is None + or callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type {type(callback_steps)}." + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds` {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + + latents = latents * self.scheduler.init_noise_sigma + return latents + + @paddle.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None, + latents: Optional[paddle.Tensor] = None, + prompt_embeds: Optional[paddle.Tensor] = None, + negative_prompt_embeds: Optional[paddle.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*): + One or a list of [paddle generator(s)] + to make generation deterministic. + latents (`paddle.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`paddle.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`paddle.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `paddle.Tensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in ppdiffusers.cross_attention. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated frames. + """ + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + num_images_per_prompt = 1 + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = paddle.concat(x=[latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(chunks=2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + bsz, channel, frames, width, height = latents.shape + latents = latents.transpose(perm=[0, 2, 1, 3, 4]).reshape((bsz * frames, channel, width, height)) + noise_pred = noise_pred.transpose(perm=[0, 2, 1, 3, 4]).reshape((bsz * frames, channel, width, height)) + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = ( + latents[None, :].reshape((bsz, frames, channel, width, height)).transpose(perm=[0, 2, 1, 3, 4]) + ) + if i == len(timesteps) - 1 or i + 1 > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + video_tensor = self.decode_latents(latents) + if output_type == "pd": + video = video_tensor + else: + video = tensor2vid(video_tensor) + if not return_dict: + return (video,) + return TextToVideoSDPipelineOutput(frames=video) diff --git a/ppdiffusers/ppdiffusers/utils/__init__.py b/ppdiffusers/ppdiffusers/utils/__init__.py index c3f0e6d20058..72bb6d9f976f 100644 --- a/ppdiffusers/ppdiffusers/utils/__init__.py +++ b/ppdiffusers/ppdiffusers/utils/__init__.py @@ -113,6 +113,9 @@ def apply_forward_hook(method): return method +from .testing_utils import export_to_video + + def check_min_version(min_version): if version.parse(__version__) < version.parse(min_version): if "dev" in min_version: diff --git a/ppdiffusers/ppdiffusers/utils/dummy_paddle_and_paddlenlp_objects.py b/ppdiffusers/ppdiffusers/utils/dummy_paddle_and_paddlenlp_objects.py index 7653874d2bc6..8bfc1ea35e33 100644 --- a/ppdiffusers/ppdiffusers/utils/dummy_paddle_and_paddlenlp_objects.py +++ b/ppdiffusers/ppdiffusers/utils/dummy_paddle_and_paddlenlp_objects.py @@ -495,3 +495,18 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["paddle", "paddlenlp"]) + + +class TextToVideoSDPipeline(metaclass=DummyObject): + _backends = ["paddle", "paddlenlp"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["paddle", "paddlenlp"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["paddle", "paddlenlp"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["paddle", "paddlenlp"]) diff --git a/ppdiffusers/ppdiffusers/utils/dummy_paddle_objects.py b/ppdiffusers/ppdiffusers/utils/dummy_paddle_objects.py index f741047e7094..3141acbdd599 100644 --- a/ppdiffusers/ppdiffusers/utils/dummy_paddle_objects.py +++ b/ppdiffusers/ppdiffusers/utils/dummy_paddle_objects.py @@ -137,6 +137,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["paddle"]) +class UNet3DConditionModel(metaclass=DummyObject): + _backends = ["paddle"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["paddle"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["paddle"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["paddle"]) + + class VQModel(metaclass=DummyObject): _backends = ["paddle"] diff --git a/ppdiffusers/ppdiffusers/utils/import_utils.py b/ppdiffusers/ppdiffusers/utils/import_utils.py index 13c206e790a2..0b8b8bcf6175 100644 --- a/ppdiffusers/ppdiffusers/utils/import_utils.py +++ b/ppdiffusers/ppdiffusers/utils/import_utils.py @@ -138,6 +138,15 @@ except importlib_metadata.PackageNotFoundError: _paddlenlp_available = False +# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. +# _opencv_available = importlib.util.find_spec("opencv-python") is not None +try: + _opencv_version = importlib_metadata.version("opencv-python") + _opencv_available = True + logger.debug(f"Successfully imported cv2 version {_opencv_version}") +except importlib_metadata.PackageNotFoundError: + _opencv_available = False + _scipy_available = importlib.util.find_spec("scipy") is not None try: _scipy_version = importlib_metadata.version("scipy") @@ -228,6 +237,10 @@ def is_unidecode_available(): return _unidecode_available +def is_opencv_available(): + return _opencv_available + + def is_scipy_available(): return _scipy_available @@ -300,6 +313,11 @@ def is_tensorboard_available(): installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. """ +# docstyle-ignore +OPENCV_IMPORT_ERROR = """ +{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip +install opencv-python` +""" # docstyle-ignore SCIPY_IMPORT_ERROR = """ @@ -351,6 +369,7 @@ def is_tensorboard_available(): ("paddlenlp", (is_paddlenlp_available, PADDLENLP_IMPORT_ERROR)), ("visualdl", (is_visualdl_available, VISUALDL_IMPORT_ERROR)), ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), diff --git a/ppdiffusers/ppdiffusers/utils/testing_utils.py b/ppdiffusers/ppdiffusers/utils/testing_utils.py index ca99a4a3b948..c09c26130d58 100644 --- a/ppdiffusers/ppdiffusers/utils/testing_utils.py +++ b/ppdiffusers/ppdiffusers/utils/testing_utils.py @@ -18,12 +18,13 @@ import os import random import re +import tempfile import unittest import urllib.parse from distutils.util import strtobool from io import BytesIO, StringIO from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union import numpy as np import PIL.Image @@ -31,7 +32,9 @@ import requests from .import_utils import ( + BACKENDS_MAPPING, is_fastdeploy_available, + is_opencv_available, is_paddle_available, is_torch_available, ) @@ -267,6 +270,23 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: return image +def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: + if is_opencv_available(): + import cv2 + else: + raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, c = video_frames[0].shape + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) + for i in range(len(video_frames)): + img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + return output_video_path + + def load_hf_numpy(path) -> np.ndarray: if not path.startswith("http://") or path.startswith("https://"): path = os.path.join( diff --git a/ppdiffusers/ppdiffusers_test/test_pipelines_common.py b/ppdiffusers/ppdiffusers_test/test_pipelines_common.py index ef2be7c5ccfb..1449a024cd8c 100644 --- a/ppdiffusers/ppdiffusers_test/test_pipelines_common.py +++ b/ppdiffusers/ppdiffusers_test/test_pipelines_common.py @@ -30,6 +30,13 @@ from ppdiffusers.utils.testing_utils import require_paddle +def to_np(tensor): + if isinstance(tensor, paddle.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + @require_paddle class PipelineTesterMixin: """ @@ -124,7 +131,7 @@ def test_save_load_local(self): pipe_loaded.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs() output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(output - output_loaded).max() + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, 0.0001) def test_pipeline_call_signature(self): @@ -269,7 +276,7 @@ def test_dict_tuple_outputs_equivalent(self): output = pipe(**self.get_dummy_inputs())[0] output_tuple = pipe(**self.get_dummy_inputs(), return_dict=False)[0] - max_diff = np.abs(output - output_tuple).max() + max_diff = np.abs(to_np(output) - to_np(output_tuple)).max() self.assertLess(max_diff, 0.0001) def test_components_function(self): @@ -290,7 +297,7 @@ def test_float16_inference(self): pipe_fp16.set_progress_bar_config(disable=None) output = pipe(**self.get_dummy_inputs())[0] output_fp16 = pipe_fp16(**self.get_dummy_inputs())[0] - max_diff = np.abs(output - output_fp16).max() + max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() self.assertLess(max_diff, 0.01, "The outputs of the fp16 and fp32 pipelines are too different.") def test_save_load_float16(self): @@ -318,7 +325,7 @@ def test_save_load_float16(self): # ) # inputs = self.get_dummy_inputs() # output_loaded = pipe_loaded(**inputs)[0] - # max_diff = np.abs(output - output_loaded).max() + # max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() # self.assertLess(max_diff, 5, "The output of the fp16 pipeline changed after saving and loading.") def test_save_load_optional_components(self): @@ -344,7 +351,7 @@ def test_save_load_optional_components(self): ) inputs = self.get_dummy_inputs() output_loaded = pipe_loaded(**inputs)[0] - max_diff = np.abs(output - output_loaded).max() + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, 0.0001) # def test_to_device(self): @@ -361,7 +368,7 @@ def test_save_load_optional_components(self): # model_devices = [str(component.device) for component in components.values() if hasattr(component, "device")] # self.assertTrue(all(device == "Place(gpu:0)" for device in model_devices)) # output_cuda = pipe(**self.get_dummy_inputs())[0] - # self.assertTrue(np.isnan(output_cuda).sum() == 0) + # self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass() @@ -380,7 +387,7 @@ def _test_attention_slicing_forward_pass(self, test_max_difference=True, expecte inputs = self.get_dummy_inputs() output_with_slicing = pipe(**inputs)[0] if test_max_difference: - max_diff = np.abs(output_with_slicing - output_without_slicing).max() + max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max() self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0]) diff --git a/ppdiffusers/tests/models/test_models_unet_3d_condition.py b/ppdiffusers/tests/models/test_models_unet_3d_condition.py new file mode 100644 index 000000000000..df3dd97411d3 --- /dev/null +++ b/ppdiffusers/tests/models/test_models_unet_3d_condition.py @@ -0,0 +1,131 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +from ppdiffusers_test.test_modeling_common import ModelTesterMixin + +from ppdiffusers.models import UNet3DConditionModel +from ppdiffusers.utils import floats_tensor, logging +from ppdiffusers.utils.import_utils import is_ppxformers_available + +logger = logging.get_logger(__name__) + + +class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet3DConditionModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + num_frames = 4 + sizes = 32, 32 + noise = floats_tensor((batch_size, num_channels, num_frames) + sizes) + time_step = paddle.to_tensor([10]) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)) + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return 4, 4, 32, 32 + + @property + def output_shape(self): + return 4, 4, 32, 32 + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 64, 64), + "down_block_types": ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + "cross_attention_dim": 32, + "attention_head_dim": 4, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skipIf( + not is_ppxformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed" + ) + def test_xformers_enable_works(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.enable_xformers_memory_efficient_attention() + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" + + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict["norm_num_groups"] = 32 + init_dict["block_out_channels"] = 32, 64, 64, 64 + model = self.model_class(**init_dict) + model.eval() + with paddle.no_grad(): + output = model(**inputs_dict) + if isinstance(output, dict): + output = output.sample + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_determinism(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.eval() + with paddle.no_grad(): + first = model(**inputs_dict) + if isinstance(first, dict): + first = first.sample + second = model(**inputs_dict) + if isinstance(second, dict): + second = second.sample + out_1 = first.cpu().numpy() + out_2 = second.cpu().numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-05) + + def test_model_attention_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict["attention_head_dim"] = 8 + model = self.model_class(**init_dict) + model.eval() + model.set_attention_slice("auto") + with paddle.no_grad(): + output = model(**inputs_dict) + assert output is not None + model.set_attention_slice("max") + with paddle.no_grad(): + output = model(**inputs_dict) + assert output is not None + model.set_attention_slice(2) + with paddle.no_grad(): + output = model(**inputs_dict) + assert output is not None diff --git a/ppdiffusers/tests/pipelines/text_to_video/__init__.py b/ppdiffusers/tests/pipelines/text_to_video/__init__.py new file mode 100644 index 000000000000..595add0aed9e --- /dev/null +++ b/ppdiffusers/tests/pipelines/text_to_video/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ppdiffusers/tests/pipelines/text_to_video/test_text_to_video.py b/ppdiffusers/tests/pipelines/text_to_video/test_text_to_video.py new file mode 100644 index 000000000000..d88c7297c4ea --- /dev/null +++ b/ppdiffusers/tests/pipelines/text_to_video/test_text_to_video.py @@ -0,0 +1,169 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +from ppdiffusers_test.pipeline_params import ( + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ppdiffusers_test.test_pipelines_common import PipelineTesterMixin + +from paddlenlp.transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +from ppdiffusers import ( + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + TextToVideoSDPipeline, + UNet3DConditionModel, +) +from ppdiffusers.utils import load_numpy, slow + + +class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = TextToVideoSDPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + required_optional_params = frozenset( + ["num_inference_steps", "generator", "latents", "return_dict", "callback", "callback_steps"] + ) + + def get_dummy_components(self): + paddle.seed(0) + unet = UNet3DConditionModel( + block_out_channels=(32, 64, 64, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"), + up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + cross_attention_dim=32, + attention_head_dim=4, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + paddle.seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + paddle.seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=512, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, seed=0): + generator = paddle.Generator().manual_seed(0) + # "output_type": "pd" is problematic + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "pd", + } + return inputs + + def test_text_to_video_default_case(self): + components = self.get_dummy_components() + sd_pipe = TextToVideoSDPipeline(**components) + sd_pipe = sd_pipe + sd_pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs() + inputs["output_type"] = "np" + frames = sd_pipe(**inputs).frames + image_slice = frames[0][-3:, -3:, (-1)] + assert frames[0].shape == (64, 64, 3) + expected_slice = np.array([65, 138, 97, 105, 157, 113, 78, 111, 69]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 0.01 + + # def test_attention_slicing_forward_pass(self): + # self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False) + + @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") + def test_inference_batch_consistent(self): + pass + + @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") + def test_inference_batch_single_identical(self): + pass + + @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.") + def test_num_images_per_prompt(self): + pass + + +@slow +class TextToVideoSDPipelineSlowTests(unittest.TestCase): + def test_full_model(self): + expected_video = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video.npy" + ) + pipe = TextToVideoSDPipeline.from_pretrained( + "damo-vilab/text-to-video-ms-1.7b", from_hf_hub=True, from_diffusers=True + ) + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe = pipe + prompt = "Spiderman is surfing" + generator = paddle.Generator().manual_seed(0) + video_frames = pipe(prompt, generator=generator, num_inference_steps=25, output_type="pd").frames + video = video_frames.cpu().numpy() + assert np.abs(expected_video - video).mean() < 0.8 + + def test_two_step_model(self): + expected_video = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy" + ) + pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b") + pipe = pipe + prompt = "Spiderman is surfing" + generator = paddle.Generator().manual_seed(0) + video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pd").frames + video = video_frames.cpu().numpy() + assert np.abs(expected_video - video).mean() < 0.8