This repository was archived by the owner on Aug 7, 2024. It is now read-only.
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Use Activation Hooks failing with AotAutograd for dynamic linear #223
Closed
Description
Summary
See this test:
float8_experimental/test/test_compile.py
Line 60 in 9cce2b9
See this PR has landed: pytorch/pytorch#118191, we expected this to solve the current issue, however removing the xfail produces:
test/test_compile.py:44: in _test_compile_base
y_fp8 = m_fp8(x)
../pytorch/torch/nn/modules/module.py:1529: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../pytorch/torch/nn/modules/module.py:1538: in _call_impl
return forward_call(*args, **kwargs)
../pytorch/torch/_dynamo/eval_frame.py:455: in _fn
return fn(*args, **kwargs)
../pytorch/torch/nn/modules/module.py:1529: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../pytorch/torch/nn/modules/module.py:1579: in _call_impl
result = forward_call(*args, **kwargs)
../pytorch/torch/_dynamo/convert_frame.py:912: in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
../pytorch/torch/_dynamo/convert_frame.py:398: in _convert_frame_assert
return _compile(
../../miniconda3/envs/dev/lib/python3.10/contextlib.py:79: in inner
return func(*args, **kwds)
../pytorch/torch/_dynamo/convert_frame.py:669: in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
../pytorch/torch/_dynamo/utils.py:256: in time_wrapper
r = func(*args, **kwargs)
../pytorch/torch/_dynamo/convert_frame.py:542: in compile_inner
out_code = transform_code_object(code, transform)
../pytorch/torch/_dynamo/bytecode_transformation.py:1033: in transform_code_object
transformations(instructions, code_options)
../pytorch/torch/_dynamo/convert_frame.py:163: in _fn
return fn(*args, **kwargs)
../pytorch/torch/_dynamo/convert_frame.py:507: in transform
tracer.run()
../pytorch/torch/_dynamo/symbolic_convert.py:2128: in run
super().run()
../pytorch/torch/_dynamo/symbolic_convert.py:791: in run
and self.step()
../pytorch/torch/_dynamo/symbolic_convert.py:754: in step
getattr(self, inst.opname)(inst)
../pytorch/torch/_dynamo/symbolic_convert.py:2247: in RETURN_VALUE
self.output.compile_subgraph(
../pytorch/torch/_dynamo/output_graph.py:931: in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
../../miniconda3/envs/dev/lib/python3.10/contextlib.py:79: in inner
return func(*args, **kwds)
../pytorch/torch/_dynamo/output_graph.py:1102: in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
../pytorch/torch/_dynamo/utils.py:256: in time_wrapper
r = func(*args, **kwargs)
../pytorch/torch/_dynamo/output_graph.py:1175: in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
../pytorch/torch/_dynamo/output_graph.py:1156: in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
../pytorch/torch/_dynamo/repro/after_dynamo.py:117: in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
../pytorch/torch/_dynamo/repro/after_dynamo.py:117: in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
../pytorch/torch/__init__.py:1769: in __call__
return self.compiler_fn(model_, inputs_, **self.kwargs)
../pytorch/torch/_dynamo/backends/common.py:57: in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
../pytorch/torch/_functorch/aot_autograd.py:879: in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
../pytorch/torch/_dynamo/utils.py:256: in time_wrapper
r = func(*args, **kwargs)
../pytorch/torch/_functorch/aot_autograd.py:604: in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
../pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py:434: in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
../pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py:639: in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
../pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:152: in aot_dispatch_autograd
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph( # type: ignore[misc]
../pytorch/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:169: in aot_dispatch_autograd_graph
fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
../pytorch/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:40: in _create_graph
fx_g = make_fx(
../pytorch/torch/fx/experimental/proxy_tensor.py:1099: in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
../pytorch/torch/_compile.py:24: in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
../pytorch/torch/_dynamo/eval_frame.py:455: in _fn
return fn(*args, **kwargs)
../pytorch/torch/_dynamo/external_utils.py:25: in inner
return fn(*args, **kwargs)
../pytorch/torch/fx/experimental/proxy_tensor.py:550: in dispatch_trace
graph = tracer.trace(root, concrete_args)
../pytorch/torch/_dynamo/eval_frame.py:455: in _fn
return fn(*args, **kwargs)
../pytorch/torch/_dynamo/external_utils.py:25: in inner
return fn(*args, **kwargs)
../pytorch/torch/fx/_symbolic_trace.py:793: in trace
(self.create_arg(fn(*args)),),
../pytorch/torch/fx/_symbolic_trace.py:652: in flatten_fn
tree_out = root_fn(*tree_args)
../pytorch/torch/fx/experimental/proxy_tensor.py:577: in wrapped
out = f(*tensors)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:537: in joint_fn
return inner_fn(flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:519: in inner_fn
wrapped_outs = fn(*all_args)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:466: in joint_helper
return _functionalized_f_helper(primals, tangents)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:357: in _functionalized_f_helper
f_outs = fn(*f_args)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:252: in inner_fn_with_anomaly
return inner_fn(*args)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:237: in inner_fn
backward_out = torch.autograd.grad(
../pytorch/torch/autograd/__init__.py:412: in grad
result = _engine_run_backward(
../pytorch/torch/autograd/graph.py:744: in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
float8_experimental/float8_tensor.py:172: in __torch_dispatch__
return FLOAT8_OPS_TABLE[func](func, args, kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
aten_op = <OpOverload(op='aten.mm', overload='default')>
args = (FunctionalTensor(_to_functional_tensor(FakeTensor(..., device='cuda:0', size=(32, 16), dtype=torch.bfloat16),
..., please use it with a corresponding FunctionalTensorMode()') raised in repr()] Float8Tensor object at 0x7f2f833fed50>)
kwargs = {}
@implements([aten.mm.default])
def float8_mm(aten_op, args, kwargs=None):
> assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
E torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
E AssertionError:
E
E Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E
E
E You can suppress this exception and fall back to eager by setting:
E import torch._dynamo
E torch._dynamo.config.suppress_errors = True
float8_experimental/float8_ops.py:81: BackendCompilerFailed