Skip to content

cudnnex SDPA checker function doesn't handle FloatProxy when cache="symbolic values" #2120

Open
@kiya00

Description

@kiya00

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

Occured when testing Starcoder Hugging Face model.

import torch
from thunder.dynamo import thunderfx
import thunder

torch.manual_seed(0)
shape = (4, 28, 4096, 128)
q = torch.rand(shape, dtype=torch.bfloat16, device="cuda", requires_grad=True)
k = torch.rand(shape, dtype=torch.bfloat16, device="cuda", requires_grad=True)
v = torch.rand(shape, dtype=torch.bfloat16, device="cuda", requires_grad=True)


def func(q, k, v, dropout_p):
    return torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
from thunder.core.transforms import grad

cf = thunder.jit(func, cache="symbolic values")
cf(q,k,v, dropout_p=0.5)

output:

/wayan/lightning-thunder/thunder/core/options.py:78: UserWarning: The 'symbolic values' cache option is highly experimental and for development only.
  warnings.warn("The 'symbolic values' cache option is highly experimental and for development only.")
Traceback (most recent call last):
  File "/wayan/lightning-thunder/bug/t1.py", line 18, in <module>
    cf(q,k,v, dropout_p=0.5)
  File "/wayan/lightning-thunder/thunder/__init__.py", line 815, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/__init__.py", line 855, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/__init__.py", line 794, in wrapped
    cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/__init__.py", line 236, in cache_info_wrapper
    res = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/__init__.py", line 760, in get_computation_and_inputs
    cache_entry = apply_transforms_and_build_cache_entry(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/__init__.py", line 559, in apply_transforms_and_build_cache_entry
    computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *computation_trc.args)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/executors/torch_autograd.py", line 253, in split_forward_backward
    fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/transforms/autodiff.py", line 478, in forward_and_backward_from_trace
    joint_trace = grad_transform_on_trace(trace)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/transforms/autodiff.py", line 302, in grad_transform_on_trace
    trace, _ = AugmentedForwardProcessor(trace)()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/core/trace_interpreter.py", line 369, in __call__
    self.process_bsym(bsym)
  File "/wayan/lightning-thunder/thunder/transforms/autodiff.py", line 182, in process_bsym
    joint_forward_backward, _ = _get_gradfn_and_executor(bsym)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/core/transforms.py", line 1499, in _get_gradfn_and_executor
    if ex.can_execute_or_fuse(bsym):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/extend/__init__.py", line 87, in can_execute_or_fuse
    return self.can_execute(bsym) or self.can_fuse(bsym)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/extend/__init__.py", line 99, in can_execute
    return impl.checker(*bsym.args, **bsym.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/executors/cudnnex.py", line 403, in _cudnn_sdpa_checker
    _make_cudnn_sdpa_forward_graph(
  File "/wayan/lightning-thunder/thunder/executors/cudnnex.py", line 164, in _make_cudnn_sdpa_forward_graph
    O, softmax_stats = graph.scaled_dot_product_flash_attention(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unable to cast Python instance of type <class 'thunder.core.proxies.FloatProxy'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)

Since we are moving to support the "symbolic values", the checker function should be able to handle NumberProxy input

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions