Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Use Activation Hooks failing with AotAutograd for dynamic linear #223

Closed
@drisspg

Description

@drisspg

Summary

See this test:

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")

See this PR has landed: pytorch/pytorch#118191, we expected this to solve the current issue, however removing the xfail produces:

test/test_compile.py:44: in _test_compile_base
    y_fp8 = m_fp8(x)
../pytorch/torch/nn/modules/module.py:1529: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../pytorch/torch/nn/modules/module.py:1538: in _call_impl
    return forward_call(*args, **kwargs)
../pytorch/torch/_dynamo/eval_frame.py:455: in _fn
    return fn(*args, **kwargs)
../pytorch/torch/nn/modules/module.py:1529: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../pytorch/torch/nn/modules/module.py:1579: in _call_impl
    result = forward_call(*args, **kwargs)
../pytorch/torch/_dynamo/convert_frame.py:912: in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
../pytorch/torch/_dynamo/convert_frame.py:398: in _convert_frame_assert
    return _compile(
../../miniconda3/envs/dev/lib/python3.10/contextlib.py:79: in inner
    return func(*args, **kwds)
../pytorch/torch/_dynamo/convert_frame.py:669: in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
../pytorch/torch/_dynamo/utils.py:256: in time_wrapper
    r = func(*args, **kwargs)
../pytorch/torch/_dynamo/convert_frame.py:542: in compile_inner
    out_code = transform_code_object(code, transform)
../pytorch/torch/_dynamo/bytecode_transformation.py:1033: in transform_code_object
    transformations(instructions, code_options)
../pytorch/torch/_dynamo/convert_frame.py:163: in _fn
    return fn(*args, **kwargs)
../pytorch/torch/_dynamo/convert_frame.py:507: in transform
    tracer.run()
../pytorch/torch/_dynamo/symbolic_convert.py:2128: in run
    super().run()
../pytorch/torch/_dynamo/symbolic_convert.py:791: in run
    and self.step()
../pytorch/torch/_dynamo/symbolic_convert.py:754: in step
    getattr(self, inst.opname)(inst)
../pytorch/torch/_dynamo/symbolic_convert.py:2247: in RETURN_VALUE
    self.output.compile_subgraph(
../pytorch/torch/_dynamo/output_graph.py:931: in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
../../miniconda3/envs/dev/lib/python3.10/contextlib.py:79: in inner
    return func(*args, **kwds)
../pytorch/torch/_dynamo/output_graph.py:1102: in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
../pytorch/torch/_dynamo/utils.py:256: in time_wrapper
    r = func(*args, **kwargs)
../pytorch/torch/_dynamo/output_graph.py:1175: in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
../pytorch/torch/_dynamo/output_graph.py:1156: in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
../pytorch/torch/_dynamo/repro/after_dynamo.py:117: in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
../pytorch/torch/_dynamo/repro/after_dynamo.py:117: in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
../pytorch/torch/__init__.py:1769: in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
../pytorch/torch/_dynamo/backends/common.py:57: in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
../pytorch/torch/_functorch/aot_autograd.py:879: in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
../pytorch/torch/_dynamo/utils.py:256: in time_wrapper
    r = func(*args, **kwargs)
../pytorch/torch/_functorch/aot_autograd.py:604: in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
../pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py:434: in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
../pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py:639: in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
../pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:152: in aot_dispatch_autograd
    fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(  # type: ignore[misc]
../pytorch/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:169: in aot_dispatch_autograd_graph
    fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
../pytorch/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:40: in _create_graph
    fx_g = make_fx(
../pytorch/torch/fx/experimental/proxy_tensor.py:1099: in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
../pytorch/torch/_compile.py:24: in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
../pytorch/torch/_dynamo/eval_frame.py:455: in _fn
    return fn(*args, **kwargs)
../pytorch/torch/_dynamo/external_utils.py:25: in inner
    return fn(*args, **kwargs)
../pytorch/torch/fx/experimental/proxy_tensor.py:550: in dispatch_trace
    graph = tracer.trace(root, concrete_args)
../pytorch/torch/_dynamo/eval_frame.py:455: in _fn
    return fn(*args, **kwargs)
../pytorch/torch/_dynamo/external_utils.py:25: in inner
    return fn(*args, **kwargs)
../pytorch/torch/fx/_symbolic_trace.py:793: in trace
    (self.create_arg(fn(*args)),),
../pytorch/torch/fx/_symbolic_trace.py:652: in flatten_fn
    tree_out = root_fn(*tree_args)
../pytorch/torch/fx/experimental/proxy_tensor.py:577: in wrapped
    out = f(*tensors)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:537: in joint_fn
    return inner_fn(flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:519: in inner_fn
    wrapped_outs = fn(*all_args)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:466: in joint_helper
    return _functionalized_f_helper(primals, tangents)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:357: in _functionalized_f_helper
    f_outs = fn(*f_args)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:252: in inner_fn_with_anomaly
    return inner_fn(*args)
../pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py:237: in inner_fn
    backward_out = torch.autograd.grad(
../pytorch/torch/autograd/__init__.py:412: in grad
    result = _engine_run_backward(
../pytorch/torch/autograd/graph.py:744: in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
float8_experimental/float8_tensor.py:172: in __torch_dispatch__
    return FLOAT8_OPS_TABLE[func](func, args, kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

aten_op = <OpOverload(op='aten.mm', overload='default')>
args = (FunctionalTensor(_to_functional_tensor(FakeTensor(..., device='cuda:0', size=(32, 16), dtype=torch.bfloat16),
       ..., please use it with a corresponding FunctionalTensorMode()') raised in repr()] Float8Tensor object at 0x7f2f833fed50>)
kwargs = {}

    @implements([aten.mm.default])
    def float8_mm(aten_op, args, kwargs=None):
>       assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
E       torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
E       AssertionError: 
E       
E       Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E       
E       
E       You can suppress this exception and fall back to eager by setting:
E           import torch._dynamo
E           torch._dynamo.config.suppress_errors = True

float8_experimental/float8_ops.py:81: BackendCompilerFailed

Metadata

Metadata

Assignees

No one assigned

    Labels

    CompileIssues related with subclass compilation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions