Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] CogVideoX memory optimizations in VAE encode #9340

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 101 additions & 4 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,7 @@ def __init__(
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
# number of temporal frames.
self.num_latent_frames_batch_size = 2
self.num_sample_frames_batch_size = 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be configurable? Why change to a different hardcoded value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be configurable to multiples of temporal_compression_ratio which is 4. We can configure this using vae.num_sample_frames_batch_size = 16 for the moment, and I think it's best to do it that way. 8, here, corresponds to 1 second of video because Cog was trained on 8 FPS videos. I would leave the configuration option out of enable_tiling and keep it for power users who have looked through the code and understand its usage tbh


# We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2
Expand Down Expand Up @@ -1081,6 +1082,29 @@ def disable_slicing(self) -> None:
"""
self.use_slicing = False

def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape

if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)

frame_batch_size = self.num_sample_frames_batch_size
enc = []
for i in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
x_intermediate = x[:, :, start_frame:end_frame]
x_intermediate = self.encoder(x_intermediate)
if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate)

self._clear_fake_context_parallel_cache()
enc = torch.cat(enc, dim=2)

return enc

@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
Expand All @@ -1094,13 +1118,17 @@ def encode(
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

Returns:
The latent representations of the encoded images. If `return_dict` is True, a
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)

posterior = DiagonalGaussianDistribution(h)

if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
Expand Down Expand Up @@ -1172,6 +1200,75 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
)
return b

def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.

When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.

Args:
x (`torch.Tensor`): Input batch of videos.

Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
# For a rough memory estimate, take a look at the `tiled_decode` method.
batch_size, num_channels, num_frames, height, width = x.shape

overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_latent_min_height - blend_extent_height
row_limit_width = self.tile_latent_min_width - blend_extent_width
frame_batch_size = self.num_sample_frames_batch_size

# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
time = []
for k in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = x[
:,
:,
start_frame:end_frame,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
if self.quant_conv is not None:
tile = self.quant_conv(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)

result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))

enc = torch.cat(result_rows, dim=3)
return enc

def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Expand Down
Loading