Skip to content

float8 training + HF transformer model + torch.compile + DDP does not work #2586

@vkuzo

Description

@vkuzo

If we use an HF transformer model with DDP and torch.compile with backend "eager", applying float8 training to it leads to an error, stack trace points to https://github.com/pytorch/pytorch/blob/5998cd4eaaf50d5a427f0b0ec14f2135e4a46723/torch/_dynamo/backends/distributed.py#L572 .

Note that:

  • no error if I just use a toy linear layer instead of HF model
  • no error if turn off compile, or turn off float8, or use FSDP instead of DDP

Repro script: https://gist.github.com/vkuzo/9a1154fe08b654abcc9628f8a4834e83
TLParse rank 0: https://gist.github.com/vkuzo/a7f342cdbd549e2f8bf13b18e1012da8

Stack trace excerpt:

V0723 07:24:02.662000 4188664 site-packages/torch/_dynamo/convert_frame.py:1081] {"artifact": {"name": "dynamo_error", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "e2174b3b108a3f1588fe64a49464b362"}
	Traceback (most recent call last):
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
	    guarded_code = compile_inner(code, one_graph, hooks, transform)
	                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
	    return function(*args, **kwargs)
	           ^^^^^^^^^^^^^^^^^^^^^^^^^
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
	    return _compile_inner(code, one_graph, hooks, transform)
	           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 797, in _compile_inner
	    out_code = transform_code_object(code, transform)
	               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1422, in transform_code_object
	    transformations(instructions, code_options)
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 257, in _fn
	    return fn(*args, **kwargs)
	           ^^^^^^^^^^^^^^^^^^^
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in transform
	    tracer.run()
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3498, in run
	    super().run()
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
	    while self.step():
	          ^^^^^^^^^^^
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
	    self.dispatch_table[inst.opcode](self, inst)
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3699, in RETURN_VALUE
	    self._return(inst)
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3684, in _return
	    self.output.compile_subgraph(
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1179, in compile_subgraph
	    self.compile_and_call_fx_graph(
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1437, in compile_and_call_fx_graph
	    compiled_fn = self.call_user_compiler(gm)
	                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1487, in call_user_compiler
	    return self._call_user_compiler(gm)
	           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1544, in _call_user_compiler
	    raise BackendCompilerFailed(
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1519, in _call_user_compiler
	    compiled_fn = compiler_fn(gm, self.example_inputs())
	                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/_dynamo/backends/distributed.py", line 548, in compile_fn
	    submod_compiler.run(*example_inputs)
	  File "/home/vasiliy/.conda/envs/pytorch_nightly/lib/python3.11/site-packages/torch/fx/interpreter.py", line 185, in run
	    raise RuntimeError(*e.args) from e
	torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
	RuntimeError: val
	
	While executing %submod_8 : [num_users=2] = call_module[target=submod_8](args = (%l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_,), kwargs = {})
	GraphModule: class GraphModule(torch.nn.Module):
	    def forward(self, L_kwargs_input_ids_: "i64[2, 256][256, 1]", L_self_modules_model_modules_embed_tokens_parameters_weight_: "bf16[128256, 2048][2048, 1]", L_kwargs_attention_mask_: "i64[2, 256][256, 1]", L_self_modules_model_modules_rotary_emb_buffers_inv_freq_: "f32[32][1]", L_self_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_: "bf16[2048][1]", L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_: "bf16[2048, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_: "bf16[512, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_: "bf16[512, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_: "bf16[2048, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_: "bf16[2048][1]", L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_: "bf16[8192, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_: "bf16[8192, 2048][2048, 1]", L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_: "bf16[2048, 8192][8192, 1]", L_self_modules_model_modules_norm_parameters_weight_: "bf16[2048][1]"):
	        l_kwargs_input_ids_ = L_kwargs_input_ids_
	        l_self_modules_model_modules_embed_tokens_parameters_weight_ = L_self_modules_model_modules_embed_tokens_parameters_weight_
	        l_kwargs_attention_mask_ = L_kwargs_attention_mask_
	        l_self_modules_model_modules_rotary_emb_buffers_inv_freq_ = L_self_modules_model_modules_rotary_emb_buffers_inv_freq_
	        l_self_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_
	        l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_
	        l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_
	        l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_
	        l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_
	        l_self_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_
	        l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_
	        l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_
	        l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_ = L_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_
	        l_self_modules_model_modules_norm_parameters_weight_ = L_self_modules_model_modules_norm_parameters_weight_
	        
	        # No stacktrace found for following nodes
	        submod_8 = self.submod_8(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_)
	        getitem = submod_8[0]
	        getitem_1 = submod_8[1];  submod_8 = None
	        submod_5 = self.submod_5(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_)
	        getitem_2 = submod_5[0]
	        getitem_3 = submod_5[1];  submod_5 = None
	        submod_2 = self.submod_2(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_)
	        getitem_4 = submod_2[0]
	        getitem_5 = submod_2[1];  submod_2 = None
	        submod_0 = self.submod_0(l_kwargs_input_ids_, l_self_modules_model_modules_embed_tokens_parameters_weight_, l_kwargs_attention_mask_, l_self_modules_model_modules_rotary_emb_buffers_inv_freq_, l_self_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_, l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_, l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_, l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_);  l_kwargs_attention_mask_ = l_self_modules_model_modules_rotary_emb_buffers_inv_freq_ = l_self_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_ = l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_ = l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_ = l_self_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_ = None
	        getitem_6 = submod_0[0]
	        getitem_7 = submod_0[1]
	        getitem_8 = submod_0[2]
	        getitem_9 = submod_0[3];  submod_0 = None
	        submod_1 = self.submod_1(l_self_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_, getitem_6, l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_);  l_self_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_ = getitem_6 = None
	        submod_6 = self.submod_6(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_, getitem_2, getitem_3, submod_1);  getitem_2 = getitem_3 = None
	        submod_3 = self.submod_3(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_, getitem_4, getitem_5, submod_1);  l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_ = getitem_4 = getitem_5 = submod_1 = None
	        submod_4 = self.submod_4(submod_3, l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_);  submod_3 = l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_ = None
	        submod_7 = self.submod_7(submod_4, submod_6, l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_);  submod_4 = submod_6 = None
	        submod_9 = self.submod_9(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_, getitem, getitem_1, submod_7, getitem_7, l_self_modules_model_modules_norm_parameters_weight_);  l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_ = getitem = getitem_1 = submod_7 = getitem_7 = l_self_modules_model_modules_norm_parameters_weight_ = None
	        submod_10 = self.submod_10(submod_9, l_self_modules_model_modules_embed_tokens_parameters_weight_, l_kwargs_input_ids_);  submod_9 = l_self_modules_model_modules_embed_tokens_parameters_weight_ = l_kwargs_input_ids_ = None
	        getitem_10 = submod_10[0]
	        getitem_11 = submod_10[1];  submod_10 = None
	        return (getitem_8, getitem_9, getitem_10, getitem_11)
	        
	    class submod_8(torch.nn.Module):
	        def forward(self, l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_: "bf16[2048, 8192][8192, 1]"):
	            # No stacktrace found for following nodes
	            _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
	            
	             # File: /data/users/vasiliy/ao/torchao/float8/float8_utils.py:65 in tensor_to_amax, code: amax = torch.max(torch.abs(x))
	            abs_1: "bf16[2048, 8192][8192, 1]" = torch.abs(l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_);  l_self_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_ = None
	            max_1: "bf16[][]" = torch.max(abs_1);  abs_1 = None
	            
	             # File: /data/users/vasiliy/ao/torchao/float8/float8_utils.py:45 in amax_to_scale, code: amax = amax.to(torch.float64)
	            to: "f64[][]" = max_1.to(torch.float64);  max_1 = None
	            
	             # File: /data/users/vasiliy/ao/torchao/float8/float8_utils.py:47 in amax_to_scale, code: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
	            clamp: "f64[][]" = torch.clamp(to, min = 1e-12);  to = None
	            truediv: "f64[][]" = 448.0 / clamp;  clamp = None
	            
	             # File: /data/users/vasiliy/ao/torchao/float8/float8_utils.py:48 in amax_to_scale, code: res = res.to(torch.float32)
	            to_1: "f32[][]" = truediv.to(torch.float32);  truediv = None
	            
	            # No stacktrace found for following nodes
	            _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
	            scaled_mmconfig = torchao_float8_float8_training_tensor_ScaledMMConfig(False, True, False, False)
	            scaled_mmconfig_1 = torchao_float8_float8_training_tensor_ScaledMMConfig(False, False, False, False)
	            scaled_mmconfig_2 = torchao_float8_float8_training_tensor_ScaledMMConfig(False, False, False, False)
	            linear_mmconfig = torchao_float8_float8_training_tensor_LinearMMConfig(scaled_mmconfig, scaled_mmconfig_1, scaled_mmconfig_2);  scaled_mmconfig = scaled_mmconfig_1 = scaled_mmconfig_2 = None
	            return (to_1, linear_mmconfig)
...

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions