Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Nov 28, 2025

What does this PR do?

  • Introduces a dedicated test suite for the Z-Image DiT.
  • Adds is_flaky decorator to test_inference() in the Z-Image pipeline test suite.
  • Adds a return_dict argument to the forward() of Z-Image DiT, following other models in the library.
    • As a consequence of this, I followed the return pattern, i.e., return a Transformer2DModelOutput type output or something like return (out,).

Notes

  • The model accepts the hidden states as a list[torch.Tensor] which differs from other models. Output also follows the same type. This is why I had to modify a couple of tests (where it was reasonably easy) to allow this. Tests, where it was not relatively easy, were skipped (such as test_training, `test_ema_training, etc.).
  • The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice.
  • Some of the group offloading tests were skipped because of states that interfered in between the tests (as also noted here).
  • x_pad_token and cap_pad_token params within the DiT are initialized with torch.empty(), possibly for memory efficiency, but they interfere during test in very weird ways. This is because torch.empty() can render NaNs. To prevent this from creeping into the tests, I tried adding is_flaky() to some of the tests that got affected by this, but that didn't help (see this). @JerryWu-code, would it be safe to get x_pad_token and cap_pad_token initialized deterministically, maybe with something like torch.ones()? Or do you think it would have memory implications?

Minor nits

  • We usually avoid raw assert statements inside the model implementations in favor of properly raising errors. Should we follow something similar here, too?
  • There is a self.scheduler.sigma_min = 0.0 inside the Z-Image pipeline:
    self.scheduler.sigma_min = 0.0
    . Maybe I am missing out on something but that seems like an antipattern to me.
  • The signature of forward() of the DiT has shorthand variable names: x, t, cap_feats, unlike hidden_states, timestep, and encoder_hidden_states.
  • Should _cfg_normalization and _cfg_truncation inside the pipeline be turned into properties like guidance_scale?

Maybe we could consider revisiting them (but not a priority perhaps).

Cc: @JerryWu-code

@sayakpaul sayakpaul requested review from dg845 and yiyixuxu November 28, 2025 13:59
Comment on lines -661 to +636
return x, {}
if not return_dict:
return (x,)

return Transformer2DModelOutput(sample=x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a very safe change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes indeed, is this the only actual change in this file? (others seems just format changes)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

sayakpaul commented Nov 28, 2025

Failures in "Fast tests for PRs / Fast PyTorch Models & Schedulers CPU tests (pull_request)" pass even when run with CUDA_VISIBLE_DEVICES="" pytest tests/models/transformers/test_models_transformer_z_image.py.

Edit: it likely fails when CUDA_VISIBLE_DEVICES="" pytest tests/models/ is run.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@dg845
Copy link
Collaborator

dg845 commented Dec 2, 2025

Not blocking, but I think we should think about how to best support and test the List[torch.Tensor] hidden_states pattern going forward.

My current understanding is that the motivation for using a List[torch.Tensor] rather than a single batched torch.Tensor is to support ragged tensors. This makes it easier to support a batch of hidden_states that would naturally have different shapes (for example, images of different resolutions) without needing extra memory and logic to pad them to the same shape. I'm not sure I fully understand the use cases and advantages/disadvantages for this pattern, so I would greatly appreciate it if someone could shed some more light on it. In particular, what is seen as the main reason(s) to use this pattern for current models?

I could see several support strategies here:

  1. Insist that all model inputs are single batched tensors.
  2. Allow lists of tensors, and try to support them as much as possible within the existing model tests (e.g. ModelTesterMixin).
  3. Create a separate test suite to handle models with List[torch.Tensor] inputs (e.g. a ragged tensor version of ModelTesterMixin).

My current personal preference is for (3), as it allows us to tailor the tests for actual List[torch.Tensor] use cases, (1) seems unnecessarily restrictive, and (2) seems like it might struggle to handle truly ragged inputs, and I'm wary that it might make the test suite more complicated and result in a lot of test skipping for tests which it is hard to support lists of tensors in the current implementation. A downside of (3) is that it likely requires the most work of the three.

It seems like PyTorch will probably eventually support ragged tensors with torch.nested. It's not obvious to me which solution is most "future-proof", but that may be a consideration as well.

super().test_group_offloading_with_disk()


class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
class ZImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):

nit: naming

Comment on lines +2165 to +2166
elif self.model_class.__name__ == "ZImageTransformer2DModel":
recompile_limit = 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not blocking, but I think it makes more sense to refactor this test to have a recompile_limit argument:

    def test_torch_compile_repeated_blocks(self, recompile_limit: int = 1):
        ...

and then override the test as follows:

    def test_torch_compile_repeated_blocks(self):
        super().test_torch_compile_repeated_blocks(recompile_limit=3)

IMO it's more clear this way that the Z-Image model is using special testing logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or perhaps making recompile_limit a class attribute might make sense, since it seems like we could reuse it in test_torch_compile_recompilation_and_graph_break. Unless I'm missing something, it seems like if self.recompilation_limit > 1, test_torch_compile_recompilation_and_graph_break should always fail due to

torch._dynamo.config.patch(error_on_recompile=True),

def prepare_dummy_input(self, height, width):
return ZImageTransformerTests().prepare_dummy_input(height=height, width=width)

@unittest.skip("Fullgraph is broken")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the skip reason here reflect this?

The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice.

Or is there another reason why we would expect fullgraph=True to fail?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants