Skip to content

[compile] Fix graphbreaks in moe split; scale_grad #2771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ def setup(self, cfg: DictConfig) -> None:
self._compile_loss = compile.get("loss", True)
self._compile_optimizer_step = compile.get("optimizer_step", False)
self._compile_scale_grads = compile.get("scale_grads", True)
if self._compile_model:
# Capture scalar outputs is required to compile MoE
torch._dynamo.config.capture_scalar_outputs = True
Comment on lines +345 to +347
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also want to add this to lora_finetune_distributed.py too (I think the logic should be the same there)


# This indirection is needed to apply torch.compile to scale_grads step.
self._grad_scaler = training.scale_grads_
Expand Down Expand Up @@ -941,7 +944,7 @@ def train(self) -> None:

# Manually scale the gradients from unnormalized loss by total # of tokens
self._grad_scaler(
self._model.parameters(),
list(self._model.parameters()),
self.world_size / num_tokens,
False if self.parallel_dims.tp_enabled else None,
)
Expand Down
2 changes: 2 additions & 0 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def setup(self, cfg: DictConfig) -> None:
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()

self._compile = cfg.get("compile", False)
# Capture scalar outputs is required to compile MoE
torch._dynamo.config.capture_scalar_outputs = True

self._model = self._setup_model(
cfg_model=cfg.model,
Expand Down
11 changes: 6 additions & 5 deletions torchtune/modules/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,12 @@ def _attention_call(
# This will use flash attention under the hood with support for custom masks.
# Currently, it is used when sample packing is enabled (see torchtune.datasets.PackedDataset)
if isinstance(mask, BlockMask):
log_once(
_log,
"Using flex attention for attention computation since a BlockMask was passed in.",
level=logging.DEBUG,
)
if not torch.compiler.is_compiling():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: why do we only want to log this when we're not compiling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dynamo graph_breaks on log(), so this is only to avoid the graph break. But it's safe to log in normal non-compiling execution :)

Graph break in user code at /data/users/ivankobzarev/b/torchtune/torchtune/utils/_logging.py:105
Graph Break Reason: Logger not supported for non-export cases. To avoid graph breaks caused by logger in compile-mode, it is recommended to disable logging by adding logging methods to config.ignore_logger_methods
User code traceback:
  File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 1204, in <module>
    sys.exit(recipe_main())
  File "/data/users/ivankobzarev/b/torchtune/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
  File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 1199, in recipe_main
    recipe.train()
  File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 1034, in train
    current_loss = self._loss_step(batch) * current_num_tokens
  File "/data/users/ivankobzarev/b/torchtune/recipes/full_finetune_distributed.py", line 927, in _loss_step
    outputs = self._model(**batch)
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1873, in _call_impl
    return inner()
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1821, in inner
    result = forward_call(*args, **kwargs)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/model_fusion/_early_fusion.py", line 287, in forward
    output = self.decoder(
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/transformer.py", line 661, in forward
    h = layer(
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1873, in _call_impl
    return inner()
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1821, in inner
    result = forward_call(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
    return self.checkpoint_fn(  # type: ignore[misc]
  File "/data/users/ivankobzarev/b/pytorch/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/utils/checkpoint.py", line 495, in checkpoint
    ret = function(*args, **kwargs)
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1765, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
  File "/data/users/ivankobzarev/b/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/transformer.py", line 134, in forward
    attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/attention.py", line 292, in forward
    output = self._attention_call(
  File "/data/users/ivankobzarev/b/torchtune/torchtune/modules/attention_utils.py", line 214, in _attention_call
    log_once(
  File "/data/users/ivankobzarev/b/pytorch/torch/_dynamo/polyfills/__init__.py", line 193, in getattr_and_trace
    return fn(*args[2:], **kwargs)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/utils/_logging.py", line 55, in log_once
    log_rank_zero(logger=logger, msg=msg, level=level)
  File "/data/users/ivankobzarev/b/torchtune/torchtune/utils/_logging.py", line 105, in log_rank_zero
    logger.log(level, msg, stacklevel=2)

log_once(
_log,
"Using flex attention for attention computation since a BlockMask was passed in.",
level=logging.DEBUG,
)
if dropout_p > 0.0:
raise ValueError(
"Flex attention does not support dropout. Please set dropout to 0.0."
Expand Down
2 changes: 0 additions & 2 deletions torchtune/modules/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def reset_parameters(self) -> None:
# TODO: force no inference mode as a hack to get around
# "Cannot set version_counter for inference tensor"
@torch.inference_mode(mode=False)
# TODO: remove once compilation is fixed
@torch._dynamo.disable(recursive=False)
def forward(
self,
x: torch.Tensor,
Expand Down
Loading