Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling gradient checkpointing in eval() mode #9878

Merged
merged 7 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/community/matryoshka.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def forward(
blocks = list(zip(self.resnets, self.attentions))

for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def forward(

hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1191,7 +1191,7 @@ def forward(

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1364,7 +1364,7 @@ def forward(

# Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def forward(

# 2. Blocks
for block_index, block in enumerate(self.transformer.transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:
# rc todo: for training and gradient checkpointing
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -522,7 +522,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -636,7 +636,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -773,7 +773,7 @@ def forward(

hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -939,7 +939,7 @@ def forward(

hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def forward(
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -311,7 +311,7 @@ def forward(
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -392,7 +392,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -529,7 +529,7 @@ def forward(
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 4, 1, 2, 3)

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -646,7 +646,7 @@ def forward(
hidden_states = self.conv_in(hidden_states)

# 1. Mid
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def forward(
sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/models/autoencoders/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:

sample = self.conv_in(sample)

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -291,7 +291,7 @@ def forward(
sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -544,7 +544,7 @@ def forward(
sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -876,7 +876,7 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -962,7 +962,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Clamp.
x = torch.tanh(x / 3) * 3

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def forward(

block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -363,7 +363,7 @@ def custom_forward(*inputs):

single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def forward(
block_res_samples = ()

for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/controlnet_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,7 @@ def custom_forward(*inputs):
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)

# apply base subblock
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_base = torch.utils.checkpoint.checkpoint(
create_custom_forward(b_res),
Expand All @@ -1488,7 +1488,7 @@ def custom_forward(*inputs):

# apply ctrl subblock
if apply_control:
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_ctrl = torch.utils.checkpoint.checkpoint(
create_custom_forward(c_res),
Expand Down Expand Up @@ -1897,7 +1897,7 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def forward(

# MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -497,7 +497,7 @@ def custom_forward(*inputs):
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def forward(

# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh thanks! why do we also removed the torch.is_grad_enabled() check? gradient checkpointing isn't meaningful without gradient being computed, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added it back, thanks for pointing it out.
it does not break anything, but found that it throws an annoying warning when use_reentrant=True,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but found that it throws an annoying warning when use_reentrant=True,

what do you mean by that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_reentrant is an argument passed to torch.utils.checkpoint.checkpoint

if True one of the checks will print this to stderr
warnings.warn(
"None of the inputs have requires_grad=True. Gradients will be None"
)
but diffusers are using use_reentrant=False anyway

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh got thanks, so the warning is specific to when we use gradient checkpointing when gradient is not enabled


def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/dit_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def forward(

# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def forward(
for i, (spatial_block, temp_block) in enumerate(
zip(self.transformer_blocks, self.temporal_transformer_blocks)
):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
Expand Down Expand Up @@ -271,7 +271,7 @@ def forward(
if i == 0 and num_frame > 1:
hidden_states = hidden_states + self.temp_pos_embed

if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/pixart_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def forward(

# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def forward(
attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)

for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def forward(

# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def forward(
hidden_states = hidden_states[:, text_seq_length:]

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def forward(
image_rotary_emb = self.pos_embed(ids)

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -525,7 +525,7 @@ def custom_forward(*inputs):
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def forward(
)

for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def forward(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def forward(

# 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
Expand Down
Loading