Skip to content

Commit

Permalink
[PPDiffusers]add t2v model (#5529)
Browse files Browse the repository at this point in the history
* add t2v model

* update

* update

* add hf copyright

* update it all

* update
  • Loading branch information
westfish committed Apr 12, 2023
1 parent d7336d9 commit d6e316a
Show file tree
Hide file tree
Showing 19 changed files with 2,119 additions and 11 deletions.
2 changes: 2 additions & 0 deletions ppdiffusers/ppdiffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
VQModel,
)
from .optimization import (
Expand Down Expand Up @@ -157,6 +158,7 @@
StableDiffusionUpscalePipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
TextToVideoSDPipeline,
UnCLIPImageVariationPipeline,
UnCLIPPipeline,
VersatileDiffusionDualGuidedPipeline,
Expand Down
1 change: 1 addition & 0 deletions ppdiffusers/ppdiffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .vq_model import VQModel
11 changes: 8 additions & 3 deletions ppdiffusers/ppdiffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ class BasicTransformerBlock(nn.Layer):
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
Expand All @@ -189,6 +193,7 @@ def __init__(
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
Expand Down Expand Up @@ -220,10 +225,10 @@ def __init__(
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)

# 2. Cross-Attn
if cross_attention_dim is not None:
if cross_attention_dim is not None or double_self_attention:
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
Expand All @@ -245,7 +250,7 @@ def __init__(
else:
self.norm1 = nn.LayerNorm(dim, **norm_kwargs)

if cross_attention_dim is not None:
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
Expand Down
56 changes: 56 additions & 0 deletions ppdiffusers/ppdiffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle.nn as nn
import paddle.nn.functional as F

from ..initializer import zeros_
from .attention import AdaGroupNorm


Expand Down Expand Up @@ -806,3 +807,58 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

return out.reshape([-1, channel, out_h, out_w])


class TemporalConvLayer(nn.Layer):
"""
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
"""

def __init__(self, in_dim, out_dim=None, dropout=0.0):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
self.out_dim = out_dim
self.conv1 = nn.Sequential(
nn.GroupNorm(num_groups=32, num_channels=in_dim),
nn.Silu(),
nn.Conv3D(in_channels=in_dim, out_channels=out_dim, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
)
self.conv2 = nn.Sequential(
nn.GroupNorm(num_groups=32, num_channels=out_dim),
nn.Silu(),
nn.Dropout(p=dropout),
nn.Conv3D(in_channels=out_dim, out_channels=in_dim, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
)
self.conv3 = nn.Sequential(
nn.GroupNorm(num_groups=32, num_channels=out_dim),
nn.Silu(),
nn.Dropout(p=dropout),
nn.Conv3D(in_channels=out_dim, out_channels=in_dim, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
)
self.conv4 = nn.Sequential(
nn.GroupNorm(num_groups=32, num_channels=out_dim),
nn.Silu(),
nn.Dropout(p=dropout),
nn.Conv3D(in_channels=out_dim, out_channels=in_dim, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
)
zeros_(self.conv4[-1].weight)
zeros_(self.conv4[-1].bias)

def forward(self, hidden_states, num_frames=1):
hidden_states = (
hidden_states[(None), :]
.reshape((-1, num_frames) + tuple(hidden_states.shape[1:]))
.transpose(perm=[0, 2, 1, 3, 4])
)
identity = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.conv3(hidden_states)
hidden_states = self.conv4(hidden_states)
hidden_states = identity + hidden_states
hidden_states = hidden_states.transpose(perm=[0, 2, 1, 3, 4]).reshape(
(hidden_states.shape[0] * hidden_states.shape[2], -1) + tuple(hidden_states.shape[3:])
)
return hidden_states
162 changes: 162 additions & 0 deletions ppdiffusers/ppdiffusers/models/transformer_temporal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional

import paddle
import paddle.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .attention import BasicTransformerBlock
from .modeling_utils import ModelMixin


@dataclass
class TransformerTemporalModelOutput(BaseOutput):
"""
Args:
sample (`paddle.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`)
Hidden states conditioned on `encoder_hidden_states` input.
"""

sample: paddle.Tensor


class TransformerTemporalModel(ModelMixin, ConfigMixin):
"""
Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
double_self_attention (`bool`, *optional*):
Configure if each TransformerBlock should contain two self-attention layers
"""

@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
double_self_attention: bool = True,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-06)
self.proj_in = nn.Linear(in_features=in_channels, out_features=inner_dim)
self.transformer_blocks = nn.LayerList(
sublayers=[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(in_features=inner_dim, out_features=in_channels)

def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
class_labels=None,
num_frames=1,
cross_attention_kwargs=None,
return_dict: bool = True,
):
"""
Args:
hidden_states ( When discrete, `paddle.Tensor` of shape `(batch size, num latent pixels)`.
When continous, `paddle.Tensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `paddleTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `paddle.int64`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
class_labels ( `paddle.Tensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`:
[`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
batch_frames, channel, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
residual = hidden_states
hidden_states = hidden_states[(None), :].reshape((batch_size, num_frames, channel, height, width))
hidden_states = hidden_states.transpose(perm=[0, 2, 1, 3, 4])
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.transpose(perm=[0, 3, 4, 2, 1]).reshape(
(batch_size * height * width, num_frames, channel)
)
hidden_states = self.proj_in(hidden_states)
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states[(None), (None), :]
.reshape((batch_size, height, width, channel, num_frames))
.transpose(perm=[0, 3, 4, 1, 2])
)
hidden_states = hidden_states.reshape((batch_frames, channel, height, width))
output = hidden_states + residual
if not return_dict:
return (output,)
return TransformerTemporalModelOutput(sample=output)
Loading

0 comments on commit d6e316a

Please sign in to comment.