-
Notifications
You must be signed in to change notification settings - Fork 320
Open
Labels
bugSomething isn't workingSomething isn't working
Description
If we use an HF transformer model with DDP and torch.compile with backend "eager", applying float8 training to it leads to an error, stack trace points to https://github.com/pytorch/pytorch/blob/5998cd4eaaf50d5a427f0b0ec14f2135e4a46723/torch/_dynamo/backends/distributed.py#L572 .
Note that:
- no error if I just use a toy linear layer instead of HF model
- no error if turn off compile, or turn off float8, or use FSDP instead of DDP
Repro script: https://gist.github.com/vkuzo/9a1154fe08b654abcc9628f8a4834e83
TLParse rank 0: https://gist.github.com/vkuzo/a7f342cdbd549e2f8bf13b18e1012da8
Stack trace excerpt:
V0723 07:24:02.662000 4188664 site-packages/torch/_dynamo/convert_frame.py:1081] {"artifact": {"name": "dynamo_error", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "e2174b3b108a3f1588fe64a49464b362"}
Traceback (most recent call last):
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 797, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1422, in transform_code_object
transformations(instructions, code_options)
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 257, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in transform
tracer.run()
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3498, in run
super().run()
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
while self.step():
^^^^^^^^^^^
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3699, in RETURN_VALUE
self._return(inst)
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3684, in _return
self.output.compile_subgraph(
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1179, in compile_subgraph
self.compile_and_call_fx_graph(
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1437, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1487, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1544, in _call_user_compiler
raise BackendCompilerFailed(
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1519, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/backends/distributed.py", line 548, in compile_fn
submod_compiler.run(*example_inputs)
File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/fx/interpreter.py", line 185, in run
raise RuntimeError(*e.args) from e
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
RuntimeError: val
While executing %submod_8 : [num_users=2] = call_module[target=submod_8](args = (%l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_,), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
def forward(self, L_kwargs_input_ids_: "i64[2, 256][256, 1]", L_self_modules_model_modules_embed_tokens_parameters_weight_: "bf16[128256, 2048][2048, 1]", L_kwargs_attention_mask_: "i64[2, 256][256, 1]", L_self_modules_model_modules_rotary_emb_buffers_inv_freq_: "f32[32][1]", L_self_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_: "bf16[2048][1]", L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_: "bf16[2048, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_: "bf16[512, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_: "bf16[512, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_: "bf16[2048, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_: "bf16[2048][1]", L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_: "bf16[8192, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_: "bf16[8192, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_: "bf16[2048, 8192][8192, 1]", L_self_modules_model_modules_norm_parameters_weight_: "bf16[2048][1]"):
l_kwargs_input_ids_ = L_kwargs_input_ids_
l_self_modules_model_modules_embed_tokens_parameters_weight_ = L_self_modules_model_modules_embed_tokens_parameters_weight_
l_kwargs_attention_mask_ = L_kwargs_attention_mask_
l_self_modules_model_modules_rotary_emb_buffers_inv_freq_ = L_self_modules_model_modules_rotary_emb_buffers_inv_freq_
l_self_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_
l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_
l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_
l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_
l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_
l_self_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_
l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_
l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_
l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_
l_self_modules_model_modules_norm_parameters_weight_ = L_self_modules_model_modules_norm_parameters_weight_
# No stacktrace found for following nodes
submod_8 = self.submod_8(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_)
getitem = submod_8[0]
getitem_1 = submod_8[1]; submod_8 = None
submod_5 = self.submod_5(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_)
getitem_2 = submod_5[0]
getitem_3 = submod_5[1]; submod_5 = None
submod_2 = self.submod_2(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_)
getitem_4 = submod_2[0]
getitem_5 = submod_2[1]; submod_2 = None
submod_0 = self.submod_0(l_kwargs_input_ids_, l_self_modules_model_modules_embed_tokens_parameters_weight_, l_kwargs_attention_mask_, l_self_modules_model_modules_rotary_emb_buffers_inv_freq_, l_self_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_, l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_, l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_, l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_); l_kwargs_attention_mask_ = l_self_modules_model_modules_rotary_emb_buffers_inv_freq_ = l_self_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_ = l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_ = l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_ = l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_ = None
getitem_6 = submod_0[0]
getitem_7 = submod_0[1]
getitem_8 = submod_0[2]
getitem_9 = submod_0[3]; submod_0 = None
submod_1 = self.submod_1(l_self_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_, getitem_6, l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_); l_self_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_ = getitem_6 = None
submod_6 = self.submod_6(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_, getitem_2, getitem_3, submod_1); getitem_2 = getitem_3 = None
submod_3 = self.submod_3(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_, getitem_4, getitem_5, submod_1); l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_ = getitem_4 = getitem_5 = submod_1 = None
submod_4 = self.submod_4(submod_3, l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_); submod_3 = l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_ = None
submod_7 = self.submod_7(submod_4, submod_6, l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_); submod_4 = submod_6 = None
submod_9 = self.submod_9(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_, getitem, getitem_1, submod_7, getitem_7, l_self_modules_model_modules_norm_parameters_weight_); l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_ = getitem = getitem_1 = submod_7 = getitem_7 = l_self_modules_model_modules_norm_parameters_weight_ = None
submod_10 = self.submod_10(submod_9, l_self_modules_model_modules_embed_tokens_parameters_weight_, l_kwargs_input_ids_); submod_9 = l_self_modules_model_modules_embed_tokens_parameters_weight_ = l_kwargs_input_ids_ = None
getitem_10 = submod_10[0]
getitem_11 = submod_10[1]; submod_10 = None
return (getitem_8, getitem_9, getitem_10, getitem_11)
class submod_8(torch.nn.Module):
def forward(self, l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_: "bf16[2048, 8192][8192, 1]"):
# No stacktrace found for following nodes
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
# File: /data/users/vasiliy/ao/torchao/float8/float8_utils.py:65 in tensor_to_amax, code: amax = torch.max(torch.abs(x))
abs_1: "bf16[2048, 8192][8192, 1]" = torch.abs(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_); l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_ = None
max_1: "bf16[][]" = torch.max(abs_1); abs_1 = None
# File: /data/users/vasiliy/ao/torchao/float8/float8_utils.py:45 in amax_to_scale, code: amax = amax.to(torch.float64)
to: "f64[][]" = max_1.to(torch.float64); max_1 = None
# File: /data/users/vasiliy/ao/torchao/float8/float8_utils.py:47 in amax_to_scale, code: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
clamp: "f64[][]" = torch.clamp(to, min = 1e-12); to = None
truediv: "f64[][]" = 448.0 / clamp; clamp = None
# File: /data/users/vasiliy/ao/torchao/float8/float8_utils.py:48 in amax_to_scale, code: res = res.to(torch.float32)
to_1: "f32[][]" = truediv.to(torch.float32); truediv = None
# No stacktrace found for following nodes
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
scaled_mmconfig = torchao_float8_float8_training_tensor_ScaledMMConfig(False, True, False, False)
scaled_mmconfig_1 = torchao_float8_float8_training_tensor_ScaledMMConfig(False, False, False, False)
scaled_mmconfig_2 = torchao_float8_float8_training_tensor_ScaledMMConfig(False, False, False, False)
linear_mmconfig = torchao_float8_float8_training_tensor_LinearMMConfig(scaled_mmconfig, scaled_mmconfig_1, scaled_mmconfig_2); scaled_mmconfig = scaled_mmconfig_1 = scaled_mmconfig_2 = None
return (to_1, linear_mmconfig)
...
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working