Open
Description
Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.
🐛 Bug
Occured when testing Starcoder Hugging Face model.
import torch
from thunder.dynamo import thunderfx
import thunder
torch.manual_seed(0)
shape = (4, 28, 4096, 128)
q = torch.rand(shape, dtype=torch.bfloat16, device="cuda", requires_grad=True)
k = torch.rand(shape, dtype=torch.bfloat16, device="cuda", requires_grad=True)
v = torch.rand(shape, dtype=torch.bfloat16, device="cuda", requires_grad=True)
def func(q, k, v, dropout_p):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
from thunder.core.transforms import grad
cf = thunder.jit(func, cache="symbolic values")
cf(q,k,v, dropout_p=0.5)
output:
/wayan/lightning-thunder/thunder/core/options.py:78: UserWarning: The 'symbolic values' cache option is highly experimental and for development only.
warnings.warn("The 'symbolic values' cache option is highly experimental and for development only.")
Traceback (most recent call last):
File "/wayan/lightning-thunder/bug/t1.py", line 18, in <module>
cf(q,k,v, dropout_p=0.5)
File "/wayan/lightning-thunder/thunder/__init__.py", line 815, in wrapped
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/__init__.py", line 855, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/__init__.py", line 794, in wrapped
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
result = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/__init__.py", line 236, in cache_info_wrapper
res = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/__init__.py", line 760, in get_computation_and_inputs
cache_entry = apply_transforms_and_build_cache_entry(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/__init__.py", line 559, in apply_transforms_and_build_cache_entry
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *computation_trc.args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/executors/torch_autograd.py", line 253, in split_forward_backward
fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/transforms/autodiff.py", line 478, in forward_and_backward_from_trace
joint_trace = grad_transform_on_trace(trace)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/transforms/autodiff.py", line 302, in grad_transform_on_trace
trace, _ = AugmentedForwardProcessor(trace)()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/core/trace_interpreter.py", line 369, in __call__
self.process_bsym(bsym)
File "/wayan/lightning-thunder/thunder/transforms/autodiff.py", line 182, in process_bsym
joint_forward_backward, _ = _get_gradfn_and_executor(bsym)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/core/transforms.py", line 1499, in _get_gradfn_and_executor
if ex.can_execute_or_fuse(bsym):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/extend/__init__.py", line 87, in can_execute_or_fuse
return self.can_execute(bsym) or self.can_fuse(bsym)
^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/extend/__init__.py", line 99, in can_execute
return impl.checker(*bsym.args, **bsym.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/executors/cudnnex.py", line 403, in _cudnn_sdpa_checker
_make_cudnn_sdpa_forward_graph(
File "/wayan/lightning-thunder/thunder/executors/cudnnex.py", line 164, in _make_cudnn_sdpa_forward_graph
O, softmax_stats = graph.scaled_dot_product_flash_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unable to cast Python instance of type <class 'thunder.core.proxies.FloatProxy'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)
Since we are moving to support the "symbolic values", the checker function should be able to handle NumberProxy input