Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimized whisper.generate() hangs forever, including in tests #213

Closed
TheExGenesis opened this issue Dec 14, 2022 · 30 comments
Closed

optimized whisper.generate() hangs forever, including in tests #213

TheExGenesis opened this issue Dec 14, 2022 · 30 comments
Assignees
Labels
question Further information is requested

Comments

@TheExGenesis
Copy link

Running test_whisper_hf("optimized") is hanging forever. I'm on Ubuntu, with cuda enabled, kernl requirements installed. Congrats on the Whisper support today by the way, hopefully this is a short fix.

@TheExGenesis
Copy link
Author

Poked with debugger as far as I could and traced it as far as I could, this was the most relevant section of the error log:

benchmark = <kernl.benchmark.benchmark_fixture.BenchmarkFixture object at 0x7f4aa09a7190>
implementation = 'optimized'

    @setup_dynamo()
    @pytest.mark.parametrize("implementation", ["reference", "optimized"])
    def test_whisper_hf(benchmark, implementation):
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny").to(
            "cuda"
        )
        if implementation == "optimized":
            optimize_model(model.model.encoder)
            optimize_model(model.model.decoder)
    
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
        inputs = torch.load("test/data/whisper_input.pt")
        with torch.inference_mode(), torch.autocast(
            dtype=torch.float16, cache_enabled=True, device_type="cuda"
        ):
>           predicted_ids = benchmark(
                model.generate,
                inputs,
                min_length=25,
                max_length=25,
                num_beams=2,
                do_sample=False,
            )

/root/francisco-long-trial/kernl/test/test_torchdynamo.py:199: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/benchmark/benchmark_fixture.py:53: in __call__
    function_to_benchmark(*args, **kwargs)
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27: in decorate_context
    return func(*args, **kwargs)
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/generation/utils.py:1367: in generate
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/generation/utils.py:601: in _prepare_encoder_decoder_kwargs_for_generation
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1423: in _call_impl
    return forward_call(*input, **kwargs)
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:163: in _fn
    backend_ctx = backend_ctx_ctor()
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:163: in _fn
    backend_ctx = backend_ctx_ctor()
/root/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_frame.py:987: in trace_dispatch
    self.do_wait_suspend(thread, frame, event, arg)
/root/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_frame.py:164: in do_wait_suspend
    self._args[0].do_wait_suspend(*args, **kwargs)
/root/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py:2062: in do_wait_suspend
    keep_suspended = self._do_wait_suspend(thread, frame, event, arg, suspend_type, from_this_thread, frames_tracker)
/root/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/pydevd.py:2112: in _do_wait_suspend
    self.set_trace_for_frame_and_parents(frame)
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:249: in catch_errors
    return callback(frame, cache_size)
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:452: in _convert_frame
    result = inner_convert(frame, cache_size)
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:118: in _fn
    return fn(*args, **kwargs)
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/utils.py:87: in time_wrapper
    r = func(*args, **kwargs)
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:325: in _convert_frame_assert
    return _compile(

@TheExGenesis TheExGenesis changed the title optimized whisper.generate() hangs forever, including in tests optimized whisper.generate() hangs forever, including in tests #bug Dec 14, 2022
@TheExGenesis TheExGenesis changed the title optimized whisper.generate() hangs forever, including in tests #bug optimized whisper.generate() hangs forever, including in tests Dec 14, 2022
@TheExGenesis
Copy link
Author

TheExGenesis commented Dec 14, 2022

I noticed there were warmup steps in the tutorials so I ran Whisper tiny-en, did a long warmup step, and then had a 3x speed-up in subsequent translations.

Then I tried running Whisper large-v2 on a single batch of size [16, 80, 3000] but it ran for 5m and then I ran out of memory

[2022-12-14 01:52:04,574] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT run /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/model_optimization.py line 65 
 67           0 LOAD_DEREF               0 (original_model)
              2 LOAD_ATTR                0 (forward2)
              4 LOAD_FAST                0 (args)
              6 BUILD_MAP                0
              8 LOAD_FAST                1 (kwargs)
             10 DICT_MERGE               1
             12 CALL_FUNCTION_EX         1
             14 RETURN_VALUE

 ========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
  File "/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 403, in clone_input
    result.copy_(x.clone())
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 60.00 MiB (GPU 0; 79.21 GiB total capacity; 68.57 GiB already allocated; 36.31 MiB free; 74.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 380, in _compile
    out_code = transform_code_object(code, transform)
  File "/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
...
    return _compile(
  File "/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 442, in _compile
    raise InternalTorchDynamoError() from e
torch._dynamo.exc.InternalTorchDynamoError

@TheExGenesis
Copy link
Author

I'm curious about what warmup times usually are for models of this size

@gaetansnl
Copy link
Contributor

Hello,
For the warmup time it's currently normal, but maybe we will have ways to improve it in the future (extended caching, etc...).

It's is caused by two things :

  • cuda graph is generating one graph for each shape of the decoder (in the example the decoding algorithm creates one shape at each step so 25 shapes) for each graph each kernel is tuned to provide maximum performance (and it takes time because we need to run the kernel to benchmark the best settings).
  • It takes time also because we need to warmump cudagraph (because of some lazy loaded values for example in the torch model).

So the warmump can take time, but after the warmup you will have maximum performance.

We found ways to improve whisper performances more but we are waiting for fixes by third party project (torchdynamo). We are also waiting for an improvement of cuda graph #207.

For the memory problem it's a known problem probably caused by torchdynamo pytorch/torchdynamo#1955 We hope to fix it soon. It could also be caused by cuda graph memory allocation we are investigating this too.

@pommedeterresautee pommedeterresautee added the question Further information is requested label Dec 14, 2022
@TheExGenesis
Copy link
Author

TheExGenesis commented Dec 18, 2022

I wonder if this resolves it pytorch/torchdynamo#1950

@pommedeterresautee
Copy link
Member

FYI we have pushed a script in experimental folder, we are currently working on making the warmup much faster in #235

@TheExGenesis
Copy link
Author

Thanks for letting me know, I'll test it. To warmup the large model, should be 9m right (20% of 45m) ?

I was also wondering if you @pommedeterresautee think Whisper could be optimized by following the instructions from transformer-deploy, or if you can see any obvious reasons why it wouldn't work.

@TheExGenesis
Copy link
Author

TheExGenesis commented Jan 16, 2023

Running the experimental script as is throws this error:

BackendCompilerFailed: _compiler raised Exception: Please convert all Tensors to FakeTensors first or instantiate 
FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[-1.1314e-02, -8.0719e-03,  1.6571e-02,  ...,  4.3564e-03,
         -7.1289e-02,  1.1063e-03],
        [ 3.3493e-03, -1.7075e-02, -1.3550e-02,  ..., -1.0681e-02,
         -7.5256e-02,  1.9623e-02],
        [ 8.7929e-04,  6.5689e-03, -1.9608e-02,  ...,  2.2621e-03,
         -6.8542e-02,  4.3274e-02],
        ...,
        [-6.1417e-03, -7.8583e-03, -1.1024e-03,  ...,  1.2100e-05,
         -7.0312e-02,  3.1082e-02],
        [-4.5700e-03,  1.3704e-03,  8.8425e-03,  ...,  8.1253e-03,
         -6.9458e-02,  3.2257e-02],
        [ 5.9242e-03,  2.5421e-02, -8.7891e-03,  ...,  1.9779e-03,
         -6.7505e-02,  4.9072e-02]], device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(..., 
device='meta', size=(5, 1), dtype=torch.int64), cuda:0), 50257), **{}) 


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

I also tried running with FakeTensorMode(allow_non_fake_inputs=True): but just got another error DynamicOutputShapeException: aten.repeat_interleave.Tensor

@pommedeterresautee
Copy link
Member

Are you on main + up to daté dependencies?
It seems to me we have fixed those errors. If yes, can you post whole logs?

@TheExGenesis
Copy link
Author

Yup, on main and just ran pip install 'git+https://github.com/ELS-RD/kernl' --extra-index-url https://download.pytorch.org/whl/nightly/cu117 without any issues.

Logs:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:680 │
│ in call_user_compiler                                                                            │
│                                                                                                  │
│   677 │   │   │   elif config.DO_NOT_USE_legacy_non_fake_example_inputs:                         │
│   678 │   │   │   │   compiled_fn = compiler_fn(gm, self.example_inputs())                       │
│   679 │   │   │   else:                                                                          │
│ ❱ 680 │   │   │   │   compiled_fn = compiler_fn(gm, self.fake_example_inputs())                  │
│   681 │   │   │   _step_logger()(logging.INFO, f"done compiler function {name}")                 │
│   682 │   │   │   assert callable(compiled_fn), "compiler_fn did not return callable"            │
│   683 │   │   except Exception as e:                                                             │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py:1032 │
│ in debug_wrapper                                                                                 │
│                                                                                                  │
│   1029 │   │   │   │   │   )                                                                     │
│   1030 │   │   │   │   │   raise                                                                 │
│   1031 │   │   else:                                                                             │
│ ❱ 1032 │   │   │   compiled_gm = compiler_fn(gm, example_inputs, **kwargs)                       │
│   1033 │   │                                                                                     │
│   1034 │   │   return compiled_gm                                                                │
│   1035                                                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/model_optimization.py:33 in │
│ _compiler                                                                                        │
│                                                                                                  │
│   30 # https://github.com/pytorch/torchdynamo/issues/1816                                        │31 def _compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):                │
│   32dynamo_backend_ofi(gm)                                                                  │
│ ❱ 33return cuda_graphs_wrapper(gm, example_inputs, pool=_pool)                              │
│   34                                                                                             │
│   35                                                                                             │
│   36 def optimize_model(original_model: PreTrainedModel) -> None:                                │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/implementations/cuda_graph. │
│ py:40 in cuda_graphs_wrapper                                                                     │
│                                                                                                  │
│   37 │   │   # 2 rounds, 1 to build the model (triton kernels, casting, etc.),                   │38 │   │   # and 1 for warmup                                                                  │39 │   │   for _ in range(2):                                                                  │
│ ❱ 40 │   │   │   model(*inputs)                                                                  │
│   41stream.synchronize()                                                                    │
│   42torch.cuda.current_stream().wait_stream(stream)                                         │
│   43torch.cuda.synchronize()                                                                │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:660 in   │
│ call_wrapped                                                                                     │
│                                                                                                  │
│   657 │   │   │   cls._wrapped_call = _WrappedCall(cls, cls_call)  # type: ignore[attr-defined   │658 │   │                                                                                      │
│   659 │   │   def call_wrapped(self, *args, **kwargs):                                           │
│ ❱ 660 │   │   │   return self._wrapped_call(self, *args, **kwargs)                               │
│   661 │   │                                                                                      │
│   662 │   │   cls.__call__ = call_wrapped                                                        │
│   663                                                                                            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:279 in   │
│ __call__                                                                                         │
│                                                                                                  │
│   276 │   │   │   │   │     file=sys.stderr)                                                     │
│   277 │   │   │   │   raise e.with_traceback(None)                                               │
│   278 │   │   │   else:                                                                          │
│ ❱ 279 │   │   │   │   raise e                                                                    │
│   280                                                                                            │
│   281 @compatibility(is_backward_compatible=True)                                                │
│   282 class GraphModule(torch.nn.Module):                                                        │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:269 in   │
│ __call__                                                                                         │
│                                                                                                  │
│   266 │   │   │   if self.cls_call is not None:                                                  │
│   267 │   │   │   │   return self.cls_call(obj, *args, **kwargs)                                 │
│   268 │   │   │   else:                                                                          │
│ ❱ 269 │   │   │   │   return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[mi   │270 │   │   except Exception as e:                                                             │
│   271 │   │   │   assert e.__traceback__                                                         │
│   272 │   │   │   topmost_framesummary: traceback.FrameSummary = \                               │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482   │
│ in _call_impl                                                                                    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│ <eval_with_key>.86:8 in forward                                                                  │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482   │
│ in _call_impl                                                                                    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/sparse.py:162 in │
│ forward                                                                                          │
│                                                                                                  │
│   159 │   │   │   │   self.weight[self.padding_idx].fill_(0)                                     │
│   160 │                                                                                          │
│   161def forward(self, input: Tensor) -> Tensor:                                            │
│ ❱ 162 │   │   return F.embedding(                                                                │
│   163 │   │   │   input, self.weight, self.padding_idx, self.max_norm,                           │
│   164 │   │   │   self.norm_type, self.scale_grad_by_freq, self.sparse)                          │
│   165                                                                                            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/functional.py:2210 in    │
│ embedding                                                                                        │
│                                                                                                  │
│   2207 │   │   #   torch.embedding_renorm_                                                       │2208 │   │   # remove once script supports set_grad_enabled                                    │2209 │   │   _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)                    │
│ ❱ 2210return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)        │
│   2211                                                                                           │
│   2212                                                                                           │
│   2213 def embedding_bag(                                                                        │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 639 in __torch_dispatch__                                                                        │
│                                                                                                  │
│    636 │   │                                                                                     │
│    637 │   │   assert fake_mode is not None                                                      │
│    638 │   │   with fake_mode:  # type: ignore[attr-defined]                                     │
│ ❱  639 │   │   │   return func(*args, **kwargs)                                                  │
│    640 │                                                                                         │
│    641 │   @staticmethod                                                                         │
│    642def _find_common_device(func, args, kwargs) -> Tuple[torch.device, bool]:             │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_ops.py:284 in __call__     │
│                                                                                                  │
│   281 │   │   )                                                                                  │
│   282 │                                                                                          │
│   283def __call__(self, *args, **kwargs):                                                   │
│ ❱ 284 │   │   return self._op(*args, **kwargs or {})                                             │
│   285 │                                                                                          │
│   286def __hash__(self):                                                                    │
│   287 │   │   return hash(self._op)                                                              │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 818 in __torch_dispatch__                                                                        │
│                                                                                                  │
│    815 │   │   │   ), f"{args} {kwargs}"                                                         │
│    816 │   │   │   return converter(self, args[0])                                               │
│    817 │   │                                                                                     │
│ ❱  818 │   │   args, kwargs = self.validate_and_convert_non_fake_tensors(                        │
│    819 │   │   │   func, converter, args, kwargs                                                 │
│    820 │   │   )                                                                                 │
│    821                                                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 966 in validate_and_convert_non_fake_tensors                                                     │
│                                                                                                  │
│    963 │   │   │   │   return converter(self, x)                                                 │
│    964 │   │   │   return x                                                                      │
│    965 │   │                                                                                     │
│ ❱  966 │   │   return tree_map_only(                                                             │
│    967 │   │   │   torch.Tensor,                                                                 │
│    968 │   │   │   validate,                                                                     │
│    969 │   │   │   (args, kwargs),                                                               │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:259 in     │
│ tree_map_only                                                                                    │
│                                                                                                  │
│   256 │   ...                                                                                    │
│   257                                                                                            │
│   258 def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:                  │
│ ❱ 259return tree_map(map_only(ty)(fn), pytree)                                              │
│   260                                                                                            │
│   261 def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool:                         │
│   262flat_args, _ = tree_flatten(pytree)                                                    │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:195 in     │
│ tree_map                                                                                         │
│                                                                                                  │
│   192                                                                                            │
│   193 def tree_map(fn: Any, pytree: PyTree) -> PyTree:                                           │
│   194flat_args, spec = tree_flatten(pytree)                                                 │
│ ❱ 195return tree_unflatten([fn(i) for i in flat_args], spec)                                │
│   196                                                                                            │
│   197 Type2 = Tuple[Type[T], Type[S]]                                                            │
│   198 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]                                          │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:195 in     │
│ <listcomp>                                                                                       │
│                                                                                                  │
│   192                                                                                            │
│   193 def tree_map(fn: Any, pytree: PyTree) -> PyTree:                                           │
│   194flat_args, spec = tree_flatten(pytree)                                                 │
│ ❱ 195return tree_unflatten([fn(i) for i in flat_args], spec)                                │
│   196                                                                                            │
│   197 Type2 = Tuple[Type[T], Type[S]]                                                            │
│   198 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]                                          │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:244 in     │
│ inner                                                                                            │
│                                                                                                  │
│   241 │   │   @functools.wraps(f)                                                                │
│   242 │   │   def inner(x: T) -> Any:                                                            │
│   243 │   │   │   if isinstance(x, ty):                                                          │
│ ❱ 244 │   │   │   │   return f(x)                                                                │
│   245 │   │   │   else:                                                                          │
│   246 │   │   │   │   return x                                                                   │
│   247 │   │   return inner                                                                       │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 958 in validate                                                                                  │
│                                                                                                  │
│    955 │   │   │   │   │   │   f"Can't call metadata mutating ops on non-Fake Tensor inputs. Fo  │
│    956 │   │   │   │   │   )                                                                     │
│    957 │   │   │   │   if not self.allow_non_fake_inputs:                                        │
│ ❱  958 │   │   │   │   │   raise Exception(                                                      │
│    959 │   │   │   │   │   │   f"Please convert all Tensors to FakeTensors first or instantiate  │
│    960 │   │   │   │   │   │   f"with 'allow_non_fake_inputs'. Found in {func}(*{args}, **{kwar  │
│    961 │   │   │   │   │   )                                                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 
'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[-1.1314e-02, -8.0719e-03,  1.6571e-02,  ...,  4.3564e-03,
         -7.1289e-02,  1.1063e-03],
        [ 3.3493e-03, -1.7075e-02, -1.3550e-02,  ..., -1.0681e-02,
         -7.5256e-02,  1.9623e-02],
        [ 8.7929e-04,  6.5689e-03, -1.9608e-02,  ...,  2.2621e-03,
         -6.8542e-02,  4.3274e-02],
        ...,
        [-6.1417e-03, -7.8583e-03, -1.1024e-03,  ...,  1.2100e-05,
         -7.0312e-02,  3.1082e-02],
        [-4.5700e-03,  1.3704e-03,  8.8425e-03,  ...,  8.1253e-03,
         -6.9458e-02,  3.2257e-02],
        [ 5.9242e-03,  2.5421e-02, -8.7891e-03,  ...,  1.9779e-03,
         -6.7505e-02,  4.9072e-02]], device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(..., 
device='meta', size=(5, 1), dtype=torch.int64), cuda:0), 50257), **{}) 

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <ipython-input-3-5a2c5bbc2633>:5 in <module>                                                     │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_contextlib.py:115 in │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/generation/utils.py: │
│ 1608 in generate                                                                                 │
│                                                                                                  │
│   1605 │   │   │   │   **model_kwargs,                                                           │
│   1606 │   │   │   )                                                                             │
│   1607 │   │   │   # 12. run beam search                                                         │
│ ❱ 1608 │   │   │   return self.beam_search(                                                      │
│   1609 │   │   │   │   input_ids,                                                                │
│   1610 │   │   │   │   beam_scorer,                                                              │
│   1611 │   │   │   │   logits_processor=logits_processor,                                        │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/generation/utils.py: │
│ 2799 in beam_search                                                                              │
│                                                                                                  │
│   2796 │   │   │                                                                                 │
│   2797 │   │   │   model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)  │
│   2798 │   │   │                                                                                 │
│ ❱ 2799 │   │   │   outputs = self(                                                               │
│   2800 │   │   │   │   **model_inputs,                                                           │
│   2801 │   │   │   │   return_dict=True,                                                         │
│   2802 │   │   │   │   output_attentions=output_attentions,                                      │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482   │
│ in _call_impl                                                                                    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │
│   1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/model │
│ ing_whisper.py:1337 in forward                                                                   │
│                                                                                                  │
│   1334 │   │   │   │   │   labels, self.config.pad_token_id, self.config.decoder_start_token_id  │
│   1335 │   │   │   │   )                                                                         │
│   1336 │   │                                                                                     │
│ ❱ 1337 │   │   outputs = self.model(                                                             │
│   1338 │   │   │   input_features,                                                               │
│   1339 │   │   │   decoder_input_ids=decoder_input_ids,                                          │
│   1340 │   │   │   encoder_outputs=encoder_outputs,                                              │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482   │
│ in _call_impl                                                                                    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │
│   1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/model │
│ ing_whisper.py:1202 in forward                                                                   │
│                                                                                                  │
│   1199 │   │   │   )                                                                             │
│   1200 │   │                                                                                     │
│   1201 │   │   # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_att  │
│ ❱ 1202 │   │   decoder_outputs = self.decoder(                                                   │
│   1203 │   │   │   input_ids=decoder_input_ids,                                                  │
│   1204 │   │   │   attention_mask=decoder_attention_mask,                                        │
│   1205 │   │   │   encoder_hidden_states=encoder_outputs[0],                                     │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482   │
│ in _call_impl                                                                                    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │
│   1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:212   │
│ in _fn                                                                                           │
│                                                                                                  │
│   209 │   │   │   dynamic_ctx = enable_dynamic(self.dynamic)                                     │
│   210 │   │   │   dynamic_ctx.__enter__()                                                        │
│   211 │   │   │   try:                                                                           │
│ ❱ 212 │   │   │   │   return fn(*args, **kwargs)                                                 │
│   213 │   │   │   finally:                                                                       │
│   214 │   │   │   │   set_eval_frame(prior)                                                      │
│   215 │   │   │   │   dynamic_ctx.__exit__(None, None, None)                                     │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/model_optimization.py:67 in │
│ run                                                                                              │
│                                                                                                  │
│   64 │                                                                                           │
│   65 │   @torchdynamo.optimize(_compiler)                                                        │
│   66 │   def run(*args, **kwargs):                                                               │
│ ❱ 67 │   │   return original_model.forward2(*args, **kwargs)                                     │
│   68 │                                                                                           │
│   69 │   original_model.forward = run                                                            │
│   70                                                                                             │
│ <ipython-input-1-09da0ab26e20>:45 in wrapper_stride                                              │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:333   │
│ in catch_errors                                                                                  │
│                                                                                                  │
│   330 │   │   │   │   │   return hijacked_callback(frame, cache_size, hooks)                     │
│   331 │   │                                                                                      │
│   332 │   │   with compile_lock:                                                                 │
│ ❱ 333 │   │   │   return callback(frame, cache_size, hooks)                                      │
│   334 │                                                                                          │
│   335 │   catch_errors._torchdynamo_orig_callable = callback  # type: ignore[attr-defined]       │
│   336 │   return catch_errors                                                                    │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:48 │
│ 0 in _convert_frame                                                                              │
│                                                                                                  │
│   477 │   def _convert_frame(frame: types.FrameType, cache_size: int, hooks: Hooks):             │
│   478 │   │   counters["frames"]["total"] += 1                                                   │
│   479 │   │   try:                                                                               │
│ ❱ 480 │   │   │   result = inner_convert(frame, cache_size, hooks)                               │
│   481 │   │   │   counters["frames"]["ok"] += 1                                                  │
│   482 │   │   │   return result                                                                  │
│   483 │   │   except (NotImplementedError, Unsupported):                                         │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:10 │
│ 3 in _fn                                                                                         │
│                                                                                                  │
│   100 │   │   prior_fwd_from_src = torch.fx.graph_module._forward_from_src                       │
│   101 │   │   torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result          │
│   102 │   │   try:                                                                               │
│ ❱ 103 │   │   │   return fn(*args, **kwargs)                                                     │
│   104 │   │   finally:                                                                           │
│   105 │   │   │   torch._C._set_grad_enabled(prior_grad_mode)                                    │
│   106 │   │   │   torch.random.set_rng_state(rng_state)                                          │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/utils.py:88 in      │
│ time_wrapper                                                                                     │
│                                                                                                  │
│     85 │   │   if key not in compilation_metrics:                                                │
│     86 │   │   │   compilation_metrics[key] = []                                                 │
│     87 │   │   t0 = time.time()                                                                  │
│ ❱   88 │   │   r = func(*args, **kwargs)                                                         │
│     89 │   │   latency = time.time() - t0                                                        │
│     90 │   │   # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")                    │
│     91 │   │   compilation_metrics[key].append(latency)                                          │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:33 │
│ 9 in _convert_frame_assert                                                                       │
│                                                                                                  │
│   336 │   │   global initial_grad_state                                                          │
│   337 │   │   initial_grad_state = torch.is_grad_enabled()                                       │
│   338 │   │                                                                                      │
│ ❱ 339 │   │   return _compile(                                                                   │
│   340 │   │   │   frame.f_code,                                                                  │
│   341 │   │   │   frame.f_globals,                                                               │
│   342 │   │   │   frame.f_locals,                                                                │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:40 │
│ 0 in _compile                                                                                    │
│                                                                                                  │
│   397 │   try:                                                                                   │
│   398 │   │   for attempt in itertools.count():                                                  │
│   399 │   │   │   try:                                                                           │
│ ❱ 400 │   │   │   │   out_code = transform_code_object(code, transform)                          │
│   401 │   │   │   │   orig_code_map[out_code] = code                                             │
│   402 │   │   │   │   break                                                                      │
│   403 │   │   │   except exc.RestartAnalysis:                                                    │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/bytecode_transforma │
│ tion.py:341 in transform_code_object                                                             │
│                                                                                                  │
│   338 │   instructions = cleaned_instructions(code, safe)                                        │
│   339 │   propagate_line_nums(instructions)                                                      │
│   340 │                                                                                          │
│ ❱ 341 │   transformations(instructions, code_options)                                            │
│   342 │                                                                                          │
│   343 │   fix_vars(instructions, code_options)                                                   │
│   344                                                                                            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:38 │
│ 7 in transform                                                                                   │
│                                                                                                  │
│   384 │   │   │   export,                                                                        │
│   385 │   │   │   mutated_closure_cell_contents,                                                 │
│   386 │   │   )                                                                                  │
│ ❱ 387 │   │   tracer.run()                                                                       │
│   388 │   │   output = tracer.output                                                             │
│   389 │   │   assert output is not None                                                          │
│   390 │   │   assert output.output_instructions                                                  │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :1684 in run                                                                                     │
│                                                                                                  │
│   1681 │                                                                                         │
│   1682 │   def run(self):                                                                        │
│   1683 │   │   _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")  │
│ ❱ 1684 │   │   super().run()                                                                     │
│   1685 │                                                                                         │
│   1686 │   def match_nested_cell(self, name, cell):                                              │
│   1687 │   │   """Match a cell in this method to one in a function we are inlining"""            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :538 in run                                                                                      │
│                                                                                                  │
│    535 │   │   │   while (                                                                       │
│    536 │   │   │   │   self.instruction_pointer is not None                                      │
│    537 │   │   │   │   and not self.output.should_exit                                           │
│ ❱  538 │   │   │   │   and self.step()                                                           │
│    539 │   │   │   ):                                                                            │
│    540 │   │   │   │   pass                                                                      │
│    541 │   │   except BackendCompilerFailed:                                                     │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :501 in step                                                                                     │
│                                                                                                  │
│    498 │   │   try:                                                                              │
│    499 │   │   │   if not hasattr(self, inst.opname):                                            │
│    500 │   │   │   │   unimplemented(f"missing: {inst.opname}")                                  │
│ ❱  501 │   │   │   getattr(self, inst.opname)(inst)                                              │
│    502 │   │   │                                                                                 │
│    503 │   │   │   return inst.opname != "RETURN_VALUE"                                          │
│    504 │   │   except BackendCompilerFailed:                                                     │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :1750 in RETURN_VALUE                                                                            │
│                                                                                                  │
│   1747 │   │   │   f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",             │
│   1748 │   │   )                                                                                 │
│   1749 │   │   log.debug("RETURN_VALUE triggered compile")                                       │
│ ❱ 1750 │   │   self.output.compile_subgraph(self)                                                │
│   1751 │   │   self.output.add_output_instructions([create_instruction("RETURN_VALUE")])         │
│   1752                                                                                           │
│   1753                                                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:557 │
│ in compile_subgraph                                                                              │
│                                                                                                  │
│   554 │   │   │   output = []                                                                    │
│   555 │   │   │   if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:              │
│   556 │   │   │   │   output.extend(                                                             │
│ ❱ 557 │   │   │   │   │   self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)    │
│   558 │   │   │   │   )                                                                          │
│   559 │   │   │   │                                                                              │
│   560 │   │   │   │   if len(pass2.graph_outputs) != 0:                                          │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:604 │
│ in compile_and_call_fx_graph                                                                     │
│                                                                                                  │
│   601 │   │                                                                                      │
│   602 │   │   assert_no_fake_params_or_buffers(gm)                                               │
│   603 │   │   with tracing(self.tracing_context):                                                │
│ ❱ 604 │   │   │   compiled_fn = self.call_user_compiler(gm)                                      │
│   605 │   │   compiled_fn = disable(compiled_fn)                                                 │
│   606 │   │                                                                                      │
│   607 │   │   counters["stats"]["unique_graphs"] += 1                                            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:685 │
│ in call_user_compiler                                                                            │
│                                                                                                  │
│   682 │   │   │   assert callable(compiled_fn), "compiler_fn did not return callable"            │
│   683 │   │   except Exception as e:                                                             │
│   684 │   │   │   compiled_fn = gm.forward                                                       │
│ ❱ 685 │   │   │   raise BackendCompilerFailed(self.compiler_fn, e) from e                        │
│   686 │   │   return compiled_fn                                                                 │
│   687 │                                                                                          │
│   688 │   def fake_example_inputs(self) -> List[torch.Tensor]:                                   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
BackendCompilerFailed: _compiler raised Exception: Please convert all Tensors to FakeTensors first or instantiate 
FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[-1.1314e-02, -8.0719e-03,  1.6571e-02,  ...,  4.3564e-03,
         -7.1289e-02,  1.1063e-03],
        [ 3.3493e-03, -1.7075e-02, -1.3550e-02,  ..., -1.0681e-02,
         -7.5256e-02,  1.9623e-02],
        [ 8.7929e-04,  6.5689e-03, -1.9608e-02,  ...,  2.2621e-03,
         -6.8542e-02,  4.3274e-02],
        ...,
        [-6.1417e-03, -7.8583e-03, -1.1024e-03,  ...,  1.2100e-05,
         -7.0312e-02,  3.1082e-02],
        [-4.5700e-03,  1.3704e-03,  8.8425e-03,  ...,  8.1253e-03,
         -6.9458e-02,  3.2257e-02],
        [ 5.9242e-03,  2.5421e-02, -8.7891e-03,  ...,  1.9779e-03,
         -6.7505e-02,  4.9072e-02]], device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(..., 
device='meta', size=(5, 1), dtype=torch.int64), cuda:0), 50257), **{}) 


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

@TheExGenesis
Copy link
Author

Should I make a new issue for this?

@pommedeterresautee
Copy link
Member

can you try with this branch?
https://github.com/ELS-RD/kernl/tree/feat/whisper_notbook

@TheExGenesis
Copy link
Author

TheExGenesis commented Jan 25, 2023

similar thing I'm afraid, running on an A100, Python 3.9.15

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:680 │
│ in call_user_compiler                                                                            │
│                                                                                                  │
│   677 │   │   │   elif config.DO_NOT_USE_legacy_non_fake_example_inputs:                         │
│   678 │   │   │   │   compiled_fn = compiler_fn(gm, self.example_inputs())                       │
│   679 │   │   │   else:                                                                          │
│ ❱ 680 │   │   │   │   compiled_fn = compiler_fn(gm, self.fake_example_inputs())                  │
│   681 │   │   │   _step_logger()(logging.INFO, f"done compiler function {name}")                 │
│   682 │   │   │   assert callable(compiled_fn), "compiler_fn did not return callable"            │
│   683 │   │   except Exception as e:                                                             │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py:1032 │
│ in debug_wrapper                                                                                 │
│                                                                                                  │
│   1029 │   │   │   │   │   )                                                                     │
│   1030 │   │   │   │   │   raise                                                                 │
│   1031 │   │   else:                                                                             │
│ ❱ 1032 │   │   │   compiled_gm = compiler_fn(gm, example_inputs, **kwargs)                       │
│   1033 │   │                                                                                     │
│   1034 │   │   return compiled_gm                                                                │
│   1035                                                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/model_optimization.py:33 in │
│ _compiler                                                                                        │
│                                                                                                  │
│   30 # https://github.com/pytorch/torchdynamo/issues/1816                                        │31 def _compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):                │
│   32dynamo_backend_ofi(gm)                                                                  │
│ ❱ 33return cuda_graphs_wrapper(gm, example_inputs, pool=_pool)                              │
│   34                                                                                             │
│   35                                                                                             │
│   36 def optimize_model(original_model: PreTrainedModel) -> None:                                │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/implementations/cuda_graph. │
│ py:40 in cuda_graphs_wrapper                                                                     │
│                                                                                                  │
│   37 │   │   # 2 rounds, 1 to build the model (triton kernels, casting, etc.),                   │38 │   │   # and 1 for warmup                                                                  │39 │   │   for _ in range(2):                                                                  │
│ ❱ 40 │   │   │   model(*inputs)                                                                  │
│   41stream.synchronize()                                                                    │
│   42torch.cuda.current_stream().wait_stream(stream)                                         │
│   43torch.cuda.synchronize()                                                                │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:660 in   │
│ call_wrapped                                                                                     │
│                                                                                                  │
│   657 │   │   │   cls._wrapped_call = _WrappedCall(cls, cls_call)  # type: ignore[attr-defined   │658 │   │                                                                                      │
│   659 │   │   def call_wrapped(self, *args, **kwargs):                                           │
│ ❱ 660 │   │   │   return self._wrapped_call(self, *args, **kwargs)                               │
│   661 │   │                                                                                      │
│   662 │   │   cls.__call__ = call_wrapped                                                        │
│   663                                                                                            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:279 in   │
│ __call__                                                                                         │
│                                                                                                  │
│   276 │   │   │   │   │     file=sys.stderr)                                                     │
│   277 │   │   │   │   raise e.with_traceback(None)                                               │
│   278 │   │   │   else:                                                                          │
│ ❱ 279 │   │   │   │   raise e                                                                    │
│   280                                                                                            │
│   281 @compatibility(is_backward_compatible=True)                                                │
│   282 class GraphModule(torch.nn.Module):                                                        │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:269 in   │
│ __call__                                                                                         │
│                                                                                                  │
│   266 │   │   │   if self.cls_call is not None:                                                  │
│   267 │   │   │   │   return self.cls_call(obj, *args, **kwargs)                                 │
│   268 │   │   │   else:                                                                          │
│ ❱ 269 │   │   │   │   return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[mi   │270 │   │   except Exception as e:                                                             │
│   271 │   │   │   assert e.__traceback__                                                         │
│   272 │   │   │   topmost_framesummary: traceback.FrameSummary = \                               │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482   │
│ in _call_impl                                                                                    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│ <eval_with_key>.680:8 in forward                                                                 │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482   │
│ in _call_impl                                                                                    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/sparse.py:162 in │
│ forward                                                                                          │
│                                                                                                  │
│   159 │   │   │   │   self.weight[self.padding_idx].fill_(0)                                     │
│   160 │                                                                                          │
│   161def forward(self, input: Tensor) -> Tensor:                                            │
│ ❱ 162 │   │   return F.embedding(                                                                │
│   163 │   │   │   input, self.weight, self.padding_idx, self.max_norm,                           │
│   164 │   │   │   self.norm_type, self.scale_grad_by_freq, self.sparse)                          │
│   165                                                                                            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/functional.py:2210 in    │
│ embedding                                                                                        │
│                                                                                                  │
│   2207 │   │   #   torch.embedding_renorm_                                                       │2208 │   │   # remove once script supports set_grad_enabled                                    │2209 │   │   _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)                    │
│ ❱ 2210return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)        │
│   2211                                                                                           │
│   2212                                                                                           │
│   2213 def embedding_bag(                                                                        │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 639 in __torch_dispatch__                                                                        │
│                                                                                                  │
│    636 │   │                                                                                     │
│    637 │   │   assert fake_mode is not None                                                      │
│    638 │   │   with fake_mode:  # type: ignore[attr-defined]                                     │
│ ❱  639 │   │   │   return func(*args, **kwargs)                                                  │
│    640 │                                                                                         │
│    641 │   @staticmethod                                                                         │
│    642def _find_common_device(func, args, kwargs) -> Tuple[torch.device, bool]:             │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_ops.py:284 in __call__     │
│                                                                                                  │
│   281 │   │   )                                                                                  │
│   282 │                                                                                          │
│   283def __call__(self, *args, **kwargs):                                                   │
│ ❱ 284 │   │   return self._op(*args, **kwargs or {})                                             │
│   285 │                                                                                          │
│   286def __hash__(self):                                                                    │
│   287 │   │   return hash(self._op)                                                              │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 818 in __torch_dispatch__                                                                        │
│                                                                                                  │
│    815 │   │   │   ), f"{args} {kwargs}"                                                         │
│    816 │   │   │   return converter(self, args[0])                                               │
│    817 │   │                                                                                     │
│ ❱  818 │   │   args, kwargs = self.validate_and_convert_non_fake_tensors(                        │
│    819 │   │   │   func, converter, args, kwargs                                                 │
│    820 │   │   )                                                                                 │
│    821                                                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 966 in validate_and_convert_non_fake_tensors                                                     │
│                                                                                                  │
│    963 │   │   │   │   return converter(self, x)                                                 │
│    964 │   │   │   return x                                                                      │
│    965 │   │                                                                                     │
│ ❱  966 │   │   return tree_map_only(                                                             │
│    967 │   │   │   torch.Tensor,                                                                 │
│    968 │   │   │   validate,                                                                     │
│    969 │   │   │   (args, kwargs),                                                               │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:259 in     │
│ tree_map_only                                                                                    │
│                                                                                                  │
│   256 │   ...                                                                                    │
│   257                                                                                            │
│   258 def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:                  │
│ ❱ 259return tree_map(map_only(ty)(fn), pytree)                                              │
│   260                                                                                            │
│   261 def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool:                         │
│   262flat_args, _ = tree_flatten(pytree)                                                    │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:195 in     │
│ tree_map                                                                                         │
│                                                                                                  │
│   192                                                                                            │
│   193 def tree_map(fn: Any, pytree: PyTree) -> PyTree:                                           │
│   194flat_args, spec = tree_flatten(pytree)                                                 │
│ ❱ 195return tree_unflatten([fn(i) for i in flat_args], spec)                                │
│   196                                                                                            │
│   197 Type2 = Tuple[Type[T], Type[S]]                                                            │
│   198 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]                                          │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:195 in     │
│ <listcomp>                                                                                       │
│                                                                                                  │
│   192                                                                                            │
│   193 def tree_map(fn: Any, pytree: PyTree) -> PyTree:                                           │
│   194flat_args, spec = tree_flatten(pytree)                                                 │
│ ❱ 195return tree_unflatten([fn(i) for i in flat_args], spec)                                │
│   196                                                                                            │
│   197 Type2 = Tuple[Type[T], Type[S]]                                                            │
│   198 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]                                          │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:244 in     │
│ inner                                                                                            │
│                                                                                                  │
│   241 │   │   @functools.wraps(f)                                                                │
│   242 │   │   def inner(x: T) -> Any:                                                            │
│   243 │   │   │   if isinstance(x, ty):                                                          │
│ ❱ 244 │   │   │   │   return f(x)                                                                │
│   245 │   │   │   else:                                                                          │
│   246 │   │   │   │   return x                                                                   │
│   247 │   │   return inner                                                                       │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 958 in validate                                                                                  │
│                                                                                                  │
│    955 │   │   │   │   │   │   f"Can't call metadata mutating ops on non-Fake Tensor inputs. Fo  │
│    956 │   │   │   │   │   )                                                                     │
│    957 │   │   │   │   if not self.allow_non_fake_inputs:                                        │
│ ❱  958 │   │   │   │   │   raise Exception(                                                      │
│    959 │   │   │   │   │   │   f"Please convert all Tensors to FakeTensors first or instantiate  │
│    960 │   │   │   │   │   │   f"with 'allow_non_fake_inputs'. Found in {func}(*{args}, **{kwar  │
│    961 │   │   │   │   │   )                                                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 
'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[ 0.0347, -0.0053, -0.0164,  ...,  0.0153, -0.0056,  0.0097],
        [ 0.0349, -0.0392, -0.0045,  ..., -0.0051, -0.0129,  0.0196],
        [ 0.0395,  0.0132, -0.0138,  ...,  0.0297, -0.0105,  0.0384],
        ...,
        [ 0.0358, -0.0039, -0.0058,  ...,  0.0160, -0.0051,  0.0229],
        [ 0.0351, -0.0034, -0.0079,  ...,  0.0159, -0.0065,  0.0217],
        [ 0.0357, -0.0059, -0.0096,  ...,  0.0106, -0.0128,  0.0225]],
       device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(..., device='meta', size=(5, 1), 
dtype=torch.int64), cuda:0), 50256), **{}) 

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <ipython-input-6-12f3c3cb93e9>:10 in <module>                                                    │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_contextlib.py:115 in │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/generation/utils.py: │
│ 1608 in generate                                                                                 │
│                                                                                                  │
│   1605 │   │   │   │   **model_kwargs,                                                           │
│   1606 │   │   │   )                                                                             │
│   1607 │   │   │   # 12. run beam search                                                         │
│ ❱ 1608 │   │   │   return self.beam_search(                                                      │
│   1609 │   │   │   │   input_ids,                                                                │
│   1610 │   │   │   │   beam_scorer,                                                              │
│   1611 │   │   │   │   logits_processor=logits_processor,                                        │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/generation/utils.py: │
│ 2799 in beam_search                                                                              │
│                                                                                                  │
│   2796 │   │   │                                                                                 │
│   2797 │   │   │   model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)  │
│   2798 │   │   │                                                                                 │
│ ❱ 2799 │   │   │   outputs = self(                                                               │
│   2800 │   │   │   │   **model_inputs,                                                           │
│   2801 │   │   │   │   return_dict=True,                                                         │
│   2802 │   │   │   │   output_attentions=output_attentions,                                      │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482   │
│ in _call_impl                                                                                    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │
│   1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/model │
│ ing_whisper.py:1337 in forward                                                                   │
│                                                                                                  │
│   1334 │   │   │   │   │   labels, self.config.pad_token_id, self.config.decoder_start_token_id  │
│   1335 │   │   │   │   )                                                                         │
│   1336 │   │                                                                                     │
│ ❱ 1337 │   │   outputs = self.model(                                                             │
│   1338 │   │   │   input_features,                                                               │
│   1339 │   │   │   decoder_input_ids=decoder_input_ids,                                          │
│   1340 │   │   │   encoder_outputs=encoder_outputs,                                              │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482   │
│ in _call_impl                                                                                    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │
│   1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/model │
│ ing_whisper.py:1202 in forward                                                                   │
│                                                                                                  │
│   1199 │   │   │   )                                                                             │
│   1200 │   │                                                                                     │
│   1201 │   │   # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_att  │
│ ❱ 1202 │   │   decoder_outputs = self.decoder(                                                   │
│   1203 │   │   │   input_ids=decoder_input_ids,                                                  │
│   1204 │   │   │   attention_mask=decoder_attention_mask,                                        │
│   1205 │   │   │   encoder_hidden_states=encoder_outputs[0],                                     │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482   │
│ in _call_impl                                                                                    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │
│   1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:212   │
│ in _fn                                                                                           │
│                                                                                                  │
│   209 │   │   │   dynamic_ctx = enable_dynamic(self.dynamic)                                     │
│   210 │   │   │   dynamic_ctx.__enter__()                                                        │
│   211 │   │   │   try:                                                                           │
│ ❱ 212 │   │   │   │   return fn(*args, **kwargs)                                                 │
│   213 │   │   │   finally:                                                                       │
│   214 │   │   │   │   set_eval_frame(prior)                                                      │
│   215 │   │   │   │   dynamic_ctx.__exit__(None, None, None)                                     │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/model_optimization.py:67 in │
│ run                                                                                              │
│                                                                                                  │
│   64 │                                                                                           │
│   65 │   @torchdynamo.optimize(_compiler)                                                        │
│   66 │   def run(*args, **kwargs):                                                               │
│ ❱ 67 │   │   return original_model.forward2(*args, **kwargs)                                     │
│   68 │                                                                                           │
│   69 │   original_model.forward = run                                                            │
│   70                                                                                             │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:333   │
│ in catch_errors                                                                                  │
│                                                                                                  │
│   330 │   │   │   │   │   return hijacked_callback(frame, cache_size, hooks)                     │
│   331 │   │                                                                                      │
│   332 │   │   with compile_lock:                                                                 │
│ ❱ 333 │   │   │   return callback(frame, cache_size, hooks)                                      │
│   334 │                                                                                          │
│   335 │   catch_errors._torchdynamo_orig_callable = callback  # type: ignore[attr-defined]       │
│   336 │   return catch_errors                                                                    │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:48 │
│ 0 in _convert_frame                                                                              │
│                                                                                                  │
│   477 │   def _convert_frame(frame: types.FrameType, cache_size: int, hooks: Hooks):             │
│   478 │   │   counters["frames"]["total"] += 1                                                   │
│   479 │   │   try:                                                                               │
│ ❱ 480 │   │   │   result = inner_convert(frame, cache_size, hooks)                               │
│   481 │   │   │   counters["frames"]["ok"] += 1                                                  │
│   482 │   │   │   return result                                                                  │
│   483 │   │   except (NotImplementedError, Unsupported):                                         │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:10 │
│ 3 in _fn                                                                                         │
│                                                                                                  │
│   100 │   │   prior_fwd_from_src = torch.fx.graph_module._forward_from_src                       │
│   101 │   │   torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result          │
│   102 │   │   try:                                                                               │
│ ❱ 103 │   │   │   return fn(*args, **kwargs)                                                     │
│   104 │   │   finally:                                                                           │
│   105 │   │   │   torch._C._set_grad_enabled(prior_grad_mode)                                    │
│   106 │   │   │   torch.random.set_rng_state(rng_state)                                          │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/utils.py:88 in      │
│ time_wrapper                                                                                     │
│                                                                                                  │
│     85 │   │   if key not in compilation_metrics:                                                │
│     86 │   │   │   compilation_metrics[key] = []                                                 │
│     87 │   │   t0 = time.time()                                                                  │
│ ❱   88 │   │   r = func(*args, **kwargs)                                                         │
│     89 │   │   latency = time.time() - t0                                                        │
│     90 │   │   # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")                    │
│     91 │   │   compilation_metrics[key].append(latency)                                          │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:33 │
│ 9 in _convert_frame_assert                                                                       │
│                                                                                                  │
│   336 │   │   global initial_grad_state                                                          │
│   337 │   │   initial_grad_state = torch.is_grad_enabled()                                       │
│   338 │   │                                                                                      │
│ ❱ 339 │   │   return _compile(                                                                   │
│   340 │   │   │   frame.f_code,                                                                  │
│   341 │   │   │   frame.f_globals,                                                               │
│   342 │   │   │   frame.f_locals,                                                                │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:40 │
│ 0 in _compile                                                                                    │
│                                                                                                  │
│   397 │   try:                                                                                   │
│   398 │   │   for attempt in itertools.count():                                                  │
│   399 │   │   │   try:                                                                           │
│ ❱ 400 │   │   │   │   out_code = transform_code_object(code, transform)                          │
│   401 │   │   │   │   orig_code_map[out_code] = code                                             │
│   402 │   │   │   │   break                                                                      │
│   403 │   │   │   except exc.RestartAnalysis:                                                    │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/bytecode_transforma │
│ tion.py:341 in transform_code_object                                                             │
│                                                                                                  │
│   338 │   instructions = cleaned_instructions(code, safe)                                        │
│   339 │   propagate_line_nums(instructions)                                                      │
│   340 │                                                                                          │
│ ❱ 341 │   transformations(instructions, code_options)                                            │
│   342 │                                                                                          │
│   343 │   fix_vars(instructions, code_options)                                                   │
│   344                                                                                            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:38 │
│ 7 in transform                                                                                   │
│                                                                                                  │
│   384 │   │   │   export,                                                                        │
│   385 │   │   │   mutated_closure_cell_contents,                                                 │
│   386 │   │   )                                                                                  │
│ ❱ 387 │   │   tracer.run()                                                                       │
│   388 │   │   output = tracer.output                                                             │
│   389 │   │   assert output is not None                                                          │
│   390 │   │   assert output.output_instructions                                                  │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :1684 in run                                                                                     │
│                                                                                                  │
│   1681 │                                                                                         │
│   1682 │   def run(self):                                                                        │
│   1683 │   │   _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")  │
│ ❱ 1684 │   │   super().run()                                                                     │
│   1685 │                                                                                         │
│   1686 │   def match_nested_cell(self, name, cell):                                              │
│   1687 │   │   """Match a cell in this method to one in a function we are inlining"""            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :538 in run                                                                                      │
│                                                                                                  │
│    535 │   │   │   while (                                                                       │
│    536 │   │   │   │   self.instruction_pointer is not None                                      │
│    537 │   │   │   │   and not self.output.should_exit                                           │
│ ❱  538 │   │   │   │   and self.step()                                                           │
│    539 │   │   │   ):                                                                            │
│    540 │   │   │   │   pass                                                                      │
│    541 │   │   except BackendCompilerFailed:                                                     │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :501 in step                                                                                     │
│                                                                                                  │
│    498 │   │   try:                                                                              │
│    499 │   │   │   if not hasattr(self, inst.opname):                                            │
│    500 │   │   │   │   unimplemented(f"missing: {inst.opname}")                                  │
│ ❱  501 │   │   │   getattr(self, inst.opname)(inst)                                              │
│    502 │   │   │                                                                                 │
│    503 │   │   │   return inst.opname != "RETURN_VALUE"                                          │
│    504 │   │   except BackendCompilerFailed:                                                     │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :1750 in RETURN_VALUE                                                                            │
│                                                                                                  │
│   1747 │   │   │   f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",             │
│   1748 │   │   )                                                                                 │
│   1749 │   │   log.debug("RETURN_VALUE triggered compile")                                       │
│ ❱ 1750 │   │   self.output.compile_subgraph(self)                                                │
│   1751 │   │   self.output.add_output_instructions([create_instruction("RETURN_VALUE")])         │
│   1752                                                                                           │
│   1753                                                                                           │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:557 │
│ in compile_subgraph                                                                              │
│                                                                                                  │
│   554 │   │   │   output = []                                                                    │
│   555 │   │   │   if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:              │
│   556 │   │   │   │   output.extend(                                                             │
│ ❱ 557 │   │   │   │   │   self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)    │
│   558 │   │   │   │   )                                                                          │
│   559 │   │   │   │                                                                              │
│   560 │   │   │   │   if len(pass2.graph_outputs) != 0:                                          │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:604 │
│ in compile_and_call_fx_graph                                                                     │
│                                                                                                  │
│   601 │   │                                                                                      │
│   602 │   │   assert_no_fake_params_or_buffers(gm)                                               │
│   603 │   │   with tracing(self.tracing_context):                                                │
│ ❱ 604 │   │   │   compiled_fn = self.call_user_compiler(gm)                                      │
│   605 │   │   compiled_fn = disable(compiled_fn)                                                 │
│   606 │   │                                                                                      │
│   607 │   │   counters["stats"]["unique_graphs"] += 1                                            │
│                                                                                                  │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:685 │
│ in call_user_compiler                                                                            │
│                                                                                                  │
│   682 │   │   │   assert callable(compiled_fn), "compiler_fn did not return callable"            │
│   683 │   │   except Exception as e:                                                             │
│   684 │   │   │   compiled_fn = gm.forward                                                       │
│ ❱ 685 │   │   │   raise BackendCompilerFailed(self.compiler_fn, e) from e                        │
│   686 │   │   return compiled_fn                                                                 │
│   687 │                                                                                          │
│   688 │   def fake_example_inputs(self) -> List[torch.Tensor]:                                   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
BackendCompilerFailed: _compiler raised Exception: Please convert all Tensors to FakeTensors first or instantiate 
FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[ 0.0347, -0.0053, -0.0164,  ...,  0.0153, -0.0056,  0.0097],
        [ 0.0349, -0.0392, -0.0045,  ..., -0.0051, -0.0129,  0.0196],
        [ 0.0395,  0.0132, -0.0138,  ...,  0.0297, -0.0105,  0.0384],
        ...,
        [ 0.0358, -0.0039, -0.0058,  ...,  0.0160, -0.0051,  0.0229],
        [ 0.0351, -0.0034, -0.0079,  ...,  0.0159, -0.0065,  0.0217],
        [ 0.0357, -0.0059, -0.0096,  ...,  0.0106, -0.0128,  0.0225]],
       device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(..., device='meta', size=(5, 1), 
dtype=torch.int64), cuda:0), 50256), **{}) 


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True```

@pommedeterresautee
Copy link
Member

pommedeterresautee commented Jan 25, 2023

Mmmhhh this is strange we are testing this script on different kind of machines without any issue (or at least not that one :-) )
The dependencies are all the ones from requirements, right? In particular this is the PyTorch nightly declared in this file?

And you just run python experimental/...?

Also, just asking, can you test with pip env instead of anaconda? That's the only obvious diff I see.

@gaetansnl
Copy link
Contributor

Also, you could try building the docker image we have in the repo, and test with docker. It will ensure we have the exact same setup

@TheExGenesis
Copy link
Author

TheExGenesis commented Jan 26, 2023

I tried the docker image and it did run. However, trying it on a batch of 32, the optimized version runs slower than vanilla hf

difference between original and optimized model:
time to warmup: 652.45s
timings
[original] average: 2.637366771697998s / complete: 2.637366771697998s
[optimized] average: 3.123673915863037s / complete: 3.123673915863037s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 36.140648GB
torch.cuda.memory_reserved: 49.056641GB
torch.cuda.max_memory_reserved: 49.056641GB

code to replicate:

# %%
import time

import torch
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor


from kernl.model_optimization import optimize_model

torch._dynamo.config.verbose = True
torch.set_float32_matmul_precision("high")
# torchdynamo.config.cache_size_limit = 512
# torchdynamo.config.dynamic_shapes = True
max_len = 50
num_beams = 1
# model_name = "openai/whisper-tiny.en"
model_name = "openai/whisper-large-v2"
model = (
    WhisperForConditionalGeneration.from_pretrained(model_name).half().to("cuda").eval()
)

audio_dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)

# audio_dataset = load_dataset("librispeech_asr", "clean", split="test")


def get_tokens(item: dict[str, dict]) -> torch.Tensor:
    tensor = processor(
        item["audio"]["array"], return_tensors="pt", sampling_rate=16_000
    ).input_features
    return tensor.cuda()


processor = WhisperProcessor.from_pretrained(model_name)
# %%
# inputs = get_tokens(audio_dataset[0]).half()
batch_size = 32
inputs = torch.cat(
    [get_tokens(audio_dataset[i]) for i in range(batch_size)], dim=0
).half()


# %%
timings_original = list()
transcriptions = list()
with torch.inference_mode():
    #  with torch.inference_mode(), torch.autocast(
    #     dtype=torch.float16, cache_enabled=True, device_type="cuda"
    # ):
    # warmup
    model.generate(
        inputs,
        min_length=max_len,
        max_length=max_len,
        num_beams=num_beams,
        do_sample=False,
    )
    # torch.cuda.synchronize()
    # for audio in audio_dataset:
    # inputs = get_tokens(audio).half()
    torch.cuda.synchronize()
    start = time.time()
    predicted_ids = model.generate(
        inputs,
        min_length=1,
        max_length=max_len,
        num_beams=num_beams,
        do_sample=False,
        logits_processor=[WhisperTimeStampLogitsProcessor()],
    )
    torch.cuda.synchronize()
    timings_original.append(time.time() - start)
    transcription = processor.batch_decode(
        predicted_ids, skip_special_tokens=True, normalize=True
    )
    # )[0]
    transcriptions.extend(transcription)

print(f"timings_original {timings_original}")
# assert len(audio_dataset) == len(transcriptions)


# apply efficiency fix to HuggingFace implementation of whisper to limit memory footprint
@staticmethod
def fix_reorder_cache(past, beam_idx):
    reordered_past = ()
    for layer_past in past:
        reordered_past += (
            tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2])
            + layer_past[2:],
        )
    return reordered_past


WhisperForConditionalGeneration._reorder_cache = fix_reorder_cache
# %%
optimize_model(model.model.decoder)
model.model.decoder.forward_before = model.model.decoder.forward
# %%
nb_diff = 0
timings_optimized = list()
print("difference between original and optimized model:")
with torch.inference_mode():
    #  with torch.inference_mode(), torch.autocast(
    #     dtype=torch.float16, cache_enabled=True, device_type="cuda"
    # ):
    start = time.time()
    model.generate(
        inputs,
        min_length=max_len,
        max_length=max_len,
        num_beams=num_beams,
        do_sample=False,
    )
    # torch.cuda.synchronize()
    print(f"time to warmup: {time.time() - start:.2f}s")
    # for original_modem_transcription, audio in zip(transcriptions, audio_dataset):
    # inputs = get_tokens(audio)
    torch.cuda.synchronize()
    start = time.time()
    predicted_ids = model.generate(
        inputs,
        min_length=1,
        max_length=max_len,
        num_beams=num_beams,
        do_sample=False,
        # logits_processor=[WhisperTimeStampLogitsProcessor()],
    )
    torch.cuda.synchronize()
    timings_optimized.append(time.time() - start)
    optimized_transcription = processor.batch_decode(
        predicted_ids, skip_special_tokens=True, normalize=True
    )
    # nb_diff += original_modem_transcription != optimized_transcription

print("timings")
print(
    f"[original] average: {sum(timings_original) / len(timings_original)}s / complete: {sum(timings_original)}s"
)
print(
    f"[optimized] average: {sum(timings_optimized) / len(timings_optimized)}s / complete: {sum(timings_optimized)}s"
)
print(
    f"output differences: {nb_diff}/{len(audio_dataset)} ({nb_diff / len(audio_dataset) * 100:.2f}%)"
)

print("memory footprint")
print(
    "torch.cuda.memory_allocated: %fGB"
    % (torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024)
)
print(
    "torch.cuda.memory_reserved: %fGB"
    % (torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024)
)
print(
    "torch.cuda.max_memory_reserved: %fGB"
    % (torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024)
)

# before modification without stride fix
# difference between original and optimzed model:
# optimized model modifies 31.83% (# 834 examples) of the dataset
# [original] average: 1.0774035485646196s / complete: 2822.7972972393036s
# [optimized] average: 0.45561475826583747s / complete: 1193.7106666564941s
# torch.cuda.memory_allocated: 10.874530GB
# torch.cuda.memory_reserved: 14.064453GB
# torch.cuda.max_memory_reserved: 14.064453GB
# after modification
# difference between original and optimzed model:
# time to warmup: 694.1382637023926
# optimized model modifies 2.06% (# 54/2620 examples)
# [original] average: 1.0491925114893732s / complete: 2748.8843801021576s
# [optimized] average: 0.4339728889574531s / complete: 1137.0089690685272s
# torch.cuda.memory_allocated: 10.873960GB
# torch.cuda.memory_reserved: 13.365234GB
# torch.cuda.max_memory_reserved: 13.853516GB

@pommedeterresautee
Copy link
Member

pommedeterresautee commented Jan 26, 2023

I am on my phone and may be I am missing something but I only see the batch 32 used during the creation of the input variable used for warmup.

If true the warmup is done with a batch of 32, and then the first audio of evaluation dataset to be transcripted raises a new warmup for a batch of 1 which may explain the timings you are seeing.

In benchmark we are using beam 5 meaning a batch size of 5 plus a reorder operation, and the model being fairly large (so GPU is already quite busy in matmul), I would not expect that by increasing the batch size the speed up to change significantly (but didn't tried as beam 1 and multiple audio in parallel do not match our use case)

@pommedeterresautee
Copy link
Member

also, I just noticed that the inputs variable is overriden

@TheExGenesis
Copy link
Author

TheExGenesis commented Jan 26, 2023

My apologies, I edited the post above. The code I pasted was wrong but the results are still the same. Also, from experience increasing the batch size speeds up inference quite a bit at least on big gpus.

The tiny model shows speedup, but not large-v2.

@pommedeterresautee
Copy link
Member

pommedeterresautee commented Jan 26, 2023

You are right, increasing batch generally increases GPU busyness (better use of silicon) on most models.

I have launched your script with a smaller batch size (10) which fulfil 24Gb the 3090 memory.
With that, GPU is 95-100 % busy and speed-up is still faster, around 1.5x

I wonder if the enabled verbose mode in the code you posted is the cause of the slowdown? Can you try without?

Second point, we have a special kernel for cases where there is little opportunity to optimize attention (not enough parallilzation opportunity, it's specific to Whisper), it uses a kind of split K strategy like in matmul kernels (https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/attention_skinny.py).
We do not check the batch size to enable it (which I realize is an error), with large batch size, for sure it's counter productive and should be disabled in case of large batch size (with 32 batch, there is more than enough parallelization opportunity).

Code is here https://github.com/ELS-RD/kernl/blob/main/src/kernl/optimizer/attention.py#L45

Could you please replace attention_vec_mat_forward by attention_reference in this block and redo the measure? (signatures are the same)
If you report it works we can improve things again.

FWIW we are replacing the script by a notebook so results will be easier to compare

https://github.com/ELS-RD/kernl/blob/b87532bfba45554d2882404813a905df6f4b9c7d/experimental/whisper/speedup.ipynb

Thank you for your help

@TheExGenesis
Copy link
Author

You're welcome, thanks for building kernl! Sadly, with attention_vec_mat_forward, the largest batch_size I can use is 8 - in terms of tokens/s, it's still slower than vanilla hf on batches of 32. I think with Whisper the bottleneck is memory due to the huge length (1500 per 30s audio) of the encoder embeddings in the kv-cache.

difference between original and optimized model:
time to warmup: 603.65s
timings
[original] average: 1.6842162609100342s / complete: 1.6842162609100342s
[optimized] average: 0.9564313888549805s / complete: 0.9564313888549805s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 10.535777GB
torch.cuda.memory_reserved: 13.009766GB
torch.cuda.max_memory_reserved: 13.009766GB

@pommedeterresautee
Copy link
Member

Just to confirm, you are comparing 2.6/32=0.08 and 0.9/8=0.11, right?

Second confirmation, you have written "with attention_vec_mat_forward" but it's without right.

Your code should looks like that:

    if q.size(-2) == 1 and k.size(-2) > 50:
        if (attention_mask is None) and (not is_causal):
            attention_reference(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask)
        else:
            attention_reference(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask)
    else:
        attention_forward(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask)

Of course you can simplify it to:

    if q.size(-2) == 1 and k.size(-2) > 50:
        attention_reference(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask)
    else:
        attention_forward(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask)

If not the case, can you post measures with this code.
If this was already the code, can you comment this line https://github.com/ELS-RD/kernl/blob/main/src/kernl/model_optimization.py#L28

The idea is to check the timings by just using CUDA graph without any custom kernel.

Third point, regarding the memory footprint, the use of a kernel instead of another one does not affect significantly (or not at all, it depends of the kernel) the memory footprint, it's CUDA graphs which has this effect.
In your first timings post (#213 (comment)) it was batch 32 for kernl too, am I right?

What batch size you have tried which didn't worked? (have you tried 16 or 24 for instance?)

Fourth, if you want, you can comment this line https://github.com/ELS-RD/kernl/blob/main/src/kernl/model_optimization.py#L28 to check if the custom kernels makes things slower at batch 32? (only CUDA graph would be applied)

@TheExGenesis
Copy link
Author

TheExGenesis commented Jan 26, 2023

My bad again, it did run without attention_vec_mat_forward at batch_size=32.

difference between original and optimized model:
time to warmup: 792.89s
timings
[original] average: 2.330775022506714s / complete: 2.330775022506714s
[optimized] average: 2.410492420196533s / complete: 2.410492420196533s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 33.665029GB
torch.cuda.memory_reserved: 41.916016GB
torch.cuda.max_memory_reserved: 41.916016GB

Commenting the ofi does seem to make it faster:

difference between original and optimized model:
time to warmup: 322.83s
timings
[original] average: 2.3458523750305176s / complete: 2.3458523750305176s
[optimized] average: 2.314619541168213s / complete: 2.314619541168213s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 33.664189GB
torch.cuda.memory_reserved: 41.882812GB
torch.cuda.max_memory_reserved: 41.882812GB

@pommedeterresautee
Copy link
Member

pommedeterresautee commented Jan 26, 2023

I took an A100 with 40Gb of RAM and ran experiments on Docker image.
I OOM with batch 32.

On batch 24 + attention_forward:

timings_original [1.9437637329101562]
difference between original and optimized model:
time to warmup: 800.87s
timings
[original] average: 1.9437637329101562s / complete: 1.9437637329101562s
[optimized] average: 1.9220001697540283s / complete: 1.9220001697540283s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 30.002132GB
torch.cuda.memory_reserved: 35.810547GB
torch.cuda.max_memory_reserved: 35.810547GB

On batch 16 + replace vec mat by attention_reference I get:

timings_original [1.6224250793457031]
difference between original and optimized model:
time to warmup: 657.44s
timings
[original] average: 1.6224250793457031s / complete: 1.6224250793457031s
[optimized] average: 1.4375851154327393s / complete: 1.4375851154327393s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 18.383415GB
torch.cuda.memory_reserved: 22.869141GB
torch.cuda.max_memory_reserved: 22.869141GB

Same but using attention_forward (Flash attention) instead of attention_reference:

timings_original [1.6304872035980225]
difference between original and optimized model:
time to warmup: 799.15s
timings
[original] average: 1.6304872035980225s / complete: 1.6304872035980225s
[optimized] average: 1.3849904537200928s / complete: 1.3849904537200928s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 18.129509GB
torch.cuda.memory_reserved: 22.130859GB
torch.cuda.max_memory_reserved: 22.130859GB

Still batch 16, and no more triton kernels

timings_original [1.55999755859375]
difference between original and optimized model:
time to warmup: 348.98s
timings
[original] average: 1.55999755859375s / complete: 1.55999755859375s
[optimized] average: 1.4344558715820312s / complete: 1.4344558715820312s
memory footprint
torch.cuda.memory_allocated: 18.382996GB
torch.cuda.memory_reserved: 22.794922GB
torch.cuda.max_memory_reserved: 22.794922GB

And for comparaison, on batch 10 (main branch, no change):

timings_original [1.337660551071167]
difference between original and optimized model:
time to warmup: 688.54s
timings
[original] average: 1.337660551071167s / complete: 1.337660551071167s
[optimized] average: 1.160353660583496s / complete: 1.160353660583496s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 15.609025GB
torch.cuda.memory_reserved: 19.558594GB
torch.cuda.max_memory_reserved: 19.558594GB

At least in both cases, Kernl doesn't make things worst than baseline :-)
There may be small opportunities to better tweak autotune for A100 but will probably not make a big difference with eager on >= 24 batch size. Still not sure there are so many people / companies using A100 on large cloud providers for inference, but it's another story ($$$). Also, let's not forget these number are just indicative of a trend because measures have been performed on a a very small sample of inputs.

Drivers, CUDA version & co
image

@pommedeterresautee
Copy link
Member

Unit tests outputs for Whisper large batch shapes (comparison should be done on CUDA time as we use CUDA graphs on the model)
Triton seems to always be slightly faster than PyTorch.

self-attention

root@e9354d46b52b:/kernl# pytest test/test_attention.py -k "benchmark and masked and non-causal and no-mask and fp16 and x64 and (24x or 32x)" --benchmark-group-by param:shape
================================================================================= test session starts =================================================================================
platform linux -- Python 3.9.16, pytest-7.2.1, pluggy-1.0.0
rootdir: /kernl
collected 1259 items / 1213 deselected / 46 selected                                                                                                                                  

test/test_attention.py ..............................................                                                                                                           [100%]
shape=(1, 20, 32, 64)
Name                                                                                               Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median        Mean          Min            Max
-------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  ------------  ------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=1x20x32x64]   0.0195 (1.0)     0.0193 (1.0)   0.0184 (1.0)   0.0215 (1.0)   0.091 (1.31)  0.093 (1.3)   0.0883 (1.32)  0.2563 (1.05)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=1x20x32x64]  0.0102 (1.9)     0.0099 (1.94)  0.0082 (2.25)  0.0113 (1.91)  0.1191 (1.0)  0.1209 (1.0)  0.1162 (1.0)   0.2701 (1.0)

shape=(24, 20, 128, 64)
Name                                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)    Max (CUDA)    Median         Mean           Min            Max
---------------------------------------------------------------------------------------------------  ---------------  -------------  ------------  ------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x128x64]   0.1126 (1.0)     0.113 (1.0)    0.1106 (1.0)  0.1157 (1.0)  0.1419 (1.0)   0.1428 (1.0)   0.1396 (1.0)   0.2731 (1.01)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x128x64]  0.0461 (2.44)    0.0463 (2.44)  0.044 (2.51)  0.0481 (2.4)  0.1355 (1.05)  0.1369 (1.04)  0.1325 (1.05)  0.2757 (1.0)

shape=(24, 20, 16, 64)
Name                                                                                                Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min           Max
--------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  ------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x16x64]   0.0246 (1.0)     0.0249 (1.0)   0.0236 (1.0)   0.0266 (1.0)   0.0914 (1.29)  0.0934 (1.28)  0.0885 (1.3)  0.2695 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x16x64]  0.0133 (1.85)    0.013 (1.91)   0.0113 (2.09)  0.0143 (1.86)  0.1178 (1.0)   0.1196 (1.0)   0.1148 (1.0)  0.2432 (1.11)

shape=(24, 20, 256, 64)
Name                                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
---------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x256x64]   0.3758 (1.0)     0.3763 (1.0)   0.3727 (1.0)   0.3809 (1.0)   0.4059 (1.0)   0.4068 (1.0)   0.4017 (1.0)   0.5145 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x256x64]  0.0993 (3.78)    0.0997 (3.77)  0.0973 (3.83)  0.1024 (3.72)  0.1891 (2.15)  0.1904 (2.14)  0.1856 (2.16)  0.3078 (1.67)

shape=(24, 20, 257, 64)
Name                                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
---------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x257x64]   0.6349 (1.0)     0.6346 (1.0)   0.6318 (1.0)   0.639 (1.0)    0.6637 (1.0)   0.6651 (1.0)   0.6612 (1.0)   0.7592 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x257x64]  0.2304 (2.76)    0.2307 (2.75)  0.2284 (2.77)  0.2324 (2.75)  0.3197 (2.08)  0.3209 (2.07)  0.3169 (2.09)  0.4299 (1.77)

shape=(24, 20, 32, 64)
Name                                                                                                Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median        Mean          Min           Max
--------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  ------------  ------------  ------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x32x64]   0.0297 (1.0)     0.03 (1.0)     0.0287 (1.0)   0.0328 (1.0)   0.092 (1.28)  0.094 (1.27)  0.09 (1.27)   0.4296 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x32x64]  0.0164 (1.81)    0.0166 (1.81)  0.0154 (1.87)  0.0184 (1.78)  0.1178 (1.0)  0.1197 (1.0)  0.1146 (1.0)  0.2575 (1.67)

shape=(24, 20, 33, 64)
Name                                                                                                Median (CUDA)    Mean (CUDA)    Min (CUDA)    Max (CUDA)    Median         Mean           Min            Max
--------------------------------------------------------------------------------------------------  ---------------  -------------  ------------  ------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x33x64]   0.0451 (1.0)     0.045 (1.0)    0.043 (1.0)   0.0471 (1.0)  0.0935 (1.27)  0.0954 (1.26)  0.0911 (1.27)  0.4218 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x33x64]  0.0215 (2.1)     0.0216 (2.08)  0.0205 (2.1)  0.0236 (2.0)  0.1188 (1.0)   0.1206 (1.0)   0.1159 (1.0)   0.2362 (1.79)

shape=(24, 20, 384, 64)
Name                                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)    Median         Mean           Min            Max
---------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  ------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x384x64]   0.7629 (1.0)     0.7632 (1.0)   0.7578 (1.0)   0.769 (1.0)   0.7936 (1.0)   0.7948 (1.0)   0.7883 (1.0)   0.8939 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x384x64]  0.1802 (4.23)    0.1799 (4.24)  0.1772 (4.28)  0.1833 (4.2)  0.2699 (2.94)  0.2717 (2.93)  0.2655 (2.97)  0.4168 (2.14)

shape=(24, 20, 512, 64)
Name                                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
---------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x512x64]   1.2626 (1.0)     1.2627 (1.0)   1.2513 (1.0)   1.2739 (1.0)   1.2912 (1.0)   1.3122 (1.0)   1.2827 (1.0)   2.72 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x512x64]  0.2857 (4.42)    0.2862 (4.41)  0.2836 (4.41)  0.2888 (4.41)  0.3769 (3.43)  0.3786 (3.47)  0.3735 (3.43)  0.5031 (5.41)

shape=(24, 20, 64, 64)
Name                                                                                                Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)    Median        Mean          Min            Max
--------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  ------------  ------------  ------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x64x64]   0.0481 (1.0)     0.0485 (1.0)   0.0471 (1.0)   0.0512 (1.0)  0.092 (1.28)  0.094 (1.28)  0.0895 (1.29)  0.2728 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x64x64]  0.0236 (2.04)    0.024 (2.02)   0.0225 (2.09)  0.0256 (2.0)  0.1181 (1.0)  0.1201 (1.0)  0.1152 (1.0)   0.2637 (1.03)

shape=(24, 20, 8, 64)
Name                                                                                               Median (CUDA)    Mean (CUDA)    Min (CUDA)    Max (CUDA)     Median         Mean           Min           Max
-------------------------------------------------------------------------------------------------  ---------------  -------------  ------------  -------------  -------------  -------------  ------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x8x64]   0.0225 (1.0)     0.0226 (1.0)   0.0215 (1.0)  0.0256 (1.0)   0.0917 (1.29)  0.0935 (1.28)  0.089 (1.3)   0.263 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=24x20x8x64]  0.0113 (2.0)     0.011 (2.05)   0.0102 (2.1)  0.0123 (2.08)  0.1181 (1.0)   0.1199 (1.0)   0.1154 (1.0)  0.2554 (1.03)

shape=(32, 20, 128, 64)
Name                                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
---------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x128x64]   0.1556 (1.0)     0.1561 (1.0)   0.1526 (1.0)   0.1608 (1.0)   0.185 (1.0)    0.186 (1.0)    0.1817 (1.0)   0.2868 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x128x64]  0.0563 (2.76)    0.0568 (2.75)  0.0553 (2.76)  0.0594 (2.71)  0.1465 (1.26)  0.1479 (1.26)  0.1439 (1.26)  0.2565 (1.12)

shape=(32, 20, 16, 64)
Name                                                                                                Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean          Min            Max
--------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  ------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x16x64]   0.0266 (1.0)     0.0264 (1.0)   0.0256 (1.0)   0.0287 (1.0)   0.0924 (1.28)  0.094 (1.28)  0.0896 (1.29)  0.2496 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x16x64]  0.0143 (1.86)    0.0138 (1.91)  0.0123 (2.08)  0.0154 (1.87)  0.1182 (1.0)   0.1201 (1.0)  0.1152 (1.0)   0.2461 (1.01)

shape=(32, 20, 256, 64)
Name                                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
---------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x256x64]   0.4915 (1.0)     0.4912 (1.0)   0.4874 (1.0)   0.4956 (1.0)   0.5207 (1.0)   0.5238 (1.0)   0.5167 (1.0)   1.0259 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x256x64]  0.1229 (4.0)     0.1225 (4.01)  0.1198 (4.07)  0.1249 (3.97)  0.2115 (2.46)  0.2127 (2.46)  0.2079 (2.49)  0.3236 (3.17)

shape=(32, 20, 257, 64)
Name                                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min           Max
---------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  ------------  ------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x257x64]   0.8315 (1.0)     0.8315 (1.0)   0.8284 (1.0)   0.8346 (1.0)   0.8613 (1.0)   0.8641 (1.0)   0.8578 (1.0)  1.1375 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x257x64]  0.2939 (2.83)    0.2935 (2.83)  0.2918 (2.84)  0.2949 (2.83)  0.3839 (2.24)  0.3855 (2.24)  0.381 (2.25)  0.482 (2.36)

shape=(32, 20, 32, 64)
Name                                                                                                Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min           Max
--------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  ------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x32x64]   0.0338 (1.0)     0.0338 (1.0)   0.0328 (1.0)   0.0358 (1.0)   0.0917 (1.28)  0.0934 (1.28)  0.089 (1.29)  0.2314 (1.39)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x32x64]  0.0174 (1.94)    0.0179 (1.89)  0.0174 (1.88)  0.0195 (1.84)  0.1173 (1.0)   0.1191 (1.0)   0.1145 (1.0)  0.3228 (1.0)

shape=(32, 20, 33, 64)
Name                                                                                                Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean          Min            Max
--------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  ------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x33x64]   0.0512 (1.0)     0.0513 (1.0)   0.0492 (1.0)   0.0543 (1.0)   0.0943 (1.25)  0.096 (1.25)  0.0918 (1.26)  0.278 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x33x64]  0.0236 (2.17)    0.0238 (2.16)  0.0225 (2.18)  0.0256 (2.12)  0.1182 (1.0)   0.1201 (1.0)  0.1152 (1.0)   0.2425 (1.15)

shape=(32, 20, 384, 64)
Name                                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
---------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x384x64]   0.9994 (1.0)     0.9997 (1.0)   0.9933 (1.0)   1.0066 (1.0)   1.0302 (1.0)   1.0357 (1.0)   1.0232 (1.0)   1.4487 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x384x64]  0.2284 (4.38)    0.2285 (4.38)  0.2263 (4.39)  0.2304 (4.37)  0.3183 (3.24)  0.3197 (3.24)  0.3153 (3.25)  0.4173 (3.47)

shape=(32, 20, 512, 64)
Name                                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
---------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x512x64]   1.6579 (1.0)     1.6577 (1.0)   1.6476 (1.0)   1.666 (1.0)    1.6898 (1.0)   1.7175 (1.0)   1.6794 (1.0)   3.0504 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x512x64]  0.5284 (3.14)    0.5281 (3.14)  0.5243 (3.14)  0.5315 (3.13)  0.6204 (2.72)  0.6223 (2.76)  0.6159 (2.73)  0.7189 (4.24)

shape=(32, 20, 64, 64)
Name                                                                                                Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)    Median         Mean           Min            Max
--------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  ------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x64x64]   0.0573 (1.0)     0.0576 (1.0)   0.0563 (1.0)   0.0604 (1.0)  0.0931 (1.28)  0.0954 (1.27)  0.0907 (1.29)  0.2718 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x64x64]  0.0297 (1.93)    0.0295 (1.95)  0.0287 (1.96)  0.0317 (1.9)  0.1196 (1.0)   0.1213 (1.0)   0.1168 (1.0)   0.2517 (1.08)

shape=(32, 20, 8, 64)
Name                                                                                               Median (CUDA)    Mean (CUDA)    Min (CUDA)    Max (CUDA)    Median         Mean          Min            Max
-------------------------------------------------------------------------------------------------  ---------------  -------------  ------------  ------------  -------------  ------------  -------------  ------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x8x64]   0.0236 (1.0)     0.0238 (1.0)   0.0225 (1.0)  0.0266 (1.0)  0.0921 (1.28)  0.094 (1.27)  0.0896 (1.28)  0.275 (1.07)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=32x20x8x64]  0.0123 (1.92)    0.0122 (1.96)  0.0102 (2.2)  0.0133 (2.0)  0.1174 (1.0)   0.1196 (1.0)  0.1146 (1.0)   0.2956 (1.0)

shape=(64, 20, 32, 64)
Name                                                                                                Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
--------------------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=64x20x32x64]   0.0481 (1.0)     0.0482 (1.0)   0.0461 (1.0)   0.0512 (1.0)   0.0921 (1.28)  0.0938 (1.27)  0.0895 (1.28)  0.265 (1.0)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=64x20x32x64]  0.0276 (1.74)    0.0276 (1.74)  0.0266 (1.73)  0.0287 (1.79)  0.1177 (1.0)   0.1195 (1.0)   0.1146 (1.0)   0.2547 (1.04)

shape=(8, 20, 32, 64)
Name                                                                                               Median (CUDA)    Mean (CUDA)    Min (CUDA)    Max (CUDA)     Median        Mean           Min            Max
-------------------------------------------------------------------------------------------------  ---------------  -------------  ------------  -------------  ------------  -------------  -------------  -------------
test_benchmark_masked[torch-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=8x20x32x64]   0.0225 (1.0)     0.0226 (1.0)   0.0215 (1.0)  0.0246 (1.0)   0.0911 (1.3)  0.0931 (1.29)  0.0886 (1.31)  0.2549 (1.12)
test_benchmark_masked[triton-no-mask-non-causal-fp16-shape(batch,heads,seq_len,dhead)=8x20x32x64]  0.0123 (1.83)    0.0119 (1.9)   0.0102 (2.1)  0.0133 (1.85)  0.1186 (1.0)  0.1206 (1.0)   0.1157 (1.0)   0.2852 (1.0)


=================================================================== 46 passed, 1213 deselected in 83.17s (0:01:23) ====================================================================
root@e9354d46b52b:/kernl# 

cross-attention

root@e9354d46b52b:/kernl# pytest test/test_attention.py -k "benchmark and not masked" --benchmark-group-by param:shape
================================================================================= test session starts =================================================================================
platform linux -- Python 3.9.16, pytest-7.2.1, pluggy-1.0.0
rootdir: /kernl
collected 1259 items / 1227 deselected / 32 selected                                                                                                                                  

test/test_attention.py ................................                                                                                                                         [100%]
shape=(1, 16, 1500, 64)
Name                                                            Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
--------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_skinny_cross_attention[flash-attention-shape2]   0.0829 (1.0)     0.0827 (1.0)   0.0819 (1.0)   0.084 (1.0)    0.123 (1.0)    0.1243 (1.0)   0.1204 (1.0)   0.2375 (1.0)
test_benchmark_skinny_cross_attention[split-k-parallel-shape2]  0.0461 (1.8)     0.0458 (1.81)  0.0451 (1.82)  0.0471 (1.78)  0.0852 (1.44)  0.0865 (1.44)  0.0828 (1.45)  0.2184 (1.09)
test_benchmark_skinny_cross_attention[torch-shape2]             0.0512 (1.62)    0.0507 (1.63)  0.0492 (1.67)  0.0522 (1.61)  0.09 (1.37)    0.0912 (1.36)  0.087 (1.38)   0.205 (1.16)
test_benchmark_skinny_cross_attention[vec-mat-mul-shape2]       0.0389 (2.13)    0.0389 (2.13)  0.0369 (2.22)  0.041 (2.05)   0.0816 (1.51)  0.0831 (1.5)   0.0786 (1.53)  0.2197 (1.08)

shape=(1, 20, 1500, 64)
Name                                                            Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
--------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_skinny_cross_attention[flash-attention-shape4]   0.0922 (1.0)     0.0924 (1.0)   0.0911 (1.0)   0.0942 (1.0)   0.1258 (1.0)   0.1273 (1.0)   0.1232 (1.0)   0.2422 (1.0)
test_benchmark_skinny_cross_attention[split-k-parallel-shape4]  0.0563 (1.64)    0.0564 (1.64)  0.0553 (1.65)  0.0584 (1.61)  0.0887 (1.42)  0.0901 (1.41)  0.0863 (1.43)  0.2193 (1.1)
test_benchmark_skinny_cross_attention[torch-shape4]             0.0604 (1.53)    0.0606 (1.52)  0.0594 (1.53)  0.0625 (1.51)  0.0937 (1.34)  0.0951 (1.34)  0.0911 (1.35)  0.2119 (1.14)
test_benchmark_skinny_cross_attention[vec-mat-mul-shape4]       0.0481 (1.91)    0.0483 (1.91)  0.0461 (1.98)  0.0492 (1.92)  0.0847 (1.49)  0.0861 (1.48)  0.0819 (1.5)   0.2197 (1.1)

shape=(1, 6, 1500, 64)
Name                                                            Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
--------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_skinny_cross_attention[flash-attention-shape0]   0.0829 (1.0)     0.0828 (1.0)   0.0819 (1.0)   0.084 (1.0)    0.1257 (1.0)   0.1271 (1.0)   0.1236 (1.0)   0.2472 (1.0)
test_benchmark_skinny_cross_attention[split-k-parallel-shape0]  0.0451 (1.84)    0.0452 (1.83)  0.044 (1.86)   0.0471 (1.78)  0.0883 (1.42)  0.0897 (1.42)  0.0857 (1.44)  0.2156 (1.15)
test_benchmark_skinny_cross_attention[torch-shape0]             0.0492 (1.69)    0.0519 (1.6)   0.0471 (1.74)  0.0604 (1.39)  0.0908 (1.38)  0.0921 (1.38)  0.0884 (1.4)   0.212 (1.17)
test_benchmark_skinny_cross_attention[vec-mat-mul-shape0]       0.0369 (2.25)    0.0367 (2.26)  0.0348 (2.35)  0.0379 (2.22)  0.0846 (1.49)  0.0861 (1.48)  0.0815 (1.52)  0.229 (1.08)

shape=(24, 20, 1500, 64)
Name                                                            Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
--------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_skinny_cross_attention[flash-attention-shape7]   0.4895 (1.0)     0.4891 (1.0)   0.4833 (1.0)   0.4946 (1.0)   0.5214 (1.0)   0.5222 (1.0)   0.5149 (1.0)   0.6177 (1.0)
test_benchmark_skinny_cross_attention[split-k-parallel-shape7]  0.4598 (1.06)    0.4598 (1.06)  0.4577 (1.06)  0.4628 (1.07)  0.4911 (1.06)  0.4922 (1.06)  0.4872 (1.06)  0.564 (1.1)
test_benchmark_skinny_cross_attention[torch-shape7]             0.4854 (1.01)    0.4849 (1.01)  0.4833 (1.0)   0.4874 (1.01)  0.5164 (1.01)  0.5176 (1.01)  0.5131 (1.0)   0.5909 (1.05)
test_benchmark_skinny_cross_attention[vec-mat-mul-shape7]       0.4618 (1.06)    0.4618 (1.06)  0.4588 (1.05)  0.4649 (1.06)  0.4938 (1.06)  0.4948 (1.06)  0.4903 (1.05)  0.5614 (1.1)

shape=(32, 20, 1500, 64)
Name                                                            Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
--------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_skinny_cross_attention[flash-attention-shape6]   0.6001 (1.04)    0.6007 (1.04)  0.596 (1.04)   0.6042 (1.04)  0.6333 (1.04)  0.6345 (1.04)  0.6285 (1.04)  0.7306 (1.0)
test_benchmark_skinny_cross_attention[split-k-parallel-shape6]  0.5939 (1.05)    0.594 (1.05)   0.5908 (1.05)  0.599 (1.05)   0.6251 (1.05)  0.6263 (1.05)  0.621 (1.05)   0.7 (1.05)
test_benchmark_skinny_cross_attention[torch-shape6]             0.6246 (1.0)     0.6246 (1.0)   0.6226 (1.0)   0.6267 (1.0)   0.6561 (1.0)   0.6575 (1.0)   0.6531 (1.0)   0.7331 (1.0)
test_benchmark_skinny_cross_attention[vec-mat-mul-shape6]       0.599 (1.04)     0.5989 (1.04)  0.596 (1.04)   0.6021 (1.04)  0.6313 (1.04)  0.6323 (1.04)  0.6268 (1.04)  0.7084 (1.03)

shape=(5, 16, 1500, 64)
Name                                                            Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
--------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_skinny_cross_attention[flash-attention-shape3]   0.1352 (1.0)     0.1349 (1.0)   0.1331 (1.0)   0.1382 (1.0)   0.1668 (1.0)   0.1674 (1.0)   0.1632 (1.0)   0.2501 (1.0)
test_benchmark_skinny_cross_attention[split-k-parallel-shape3]  0.1096 (1.23)    0.1097 (1.23)  0.1065 (1.25)  0.1137 (1.22)  0.1403 (1.19)  0.1413 (1.19)  0.137 (1.19)   0.2323 (1.08)
test_benchmark_skinny_cross_attention[torch-shape3]             0.1208 (1.12)    0.1205 (1.12)  0.1178 (1.13)  0.1239 (1.12)  0.1511 (1.1)   0.1517 (1.1)   0.1481 (1.1)   0.2339 (1.07)
test_benchmark_skinny_cross_attention[vec-mat-mul-shape3]       0.1096 (1.23)    0.1097 (1.23)  0.1075 (1.24)  0.1126 (1.23)  0.141 (1.18)   0.1417 (1.18)  0.1385 (1.18)  0.2383 (1.05)

shape=(5, 20, 1500, 64)
Name                                                            Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
--------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_skinny_cross_attention[flash-attention-shape5]   0.1546 (1.0)     0.1546 (1.0)   0.1526 (1.0)   0.1567 (1.0)   0.1854 (1.0)   0.1863 (1.0)   0.1829 (1.0)   0.2587 (1.0)
test_benchmark_skinny_cross_attention[split-k-parallel-shape5]  0.1362 (1.14)    0.1364 (1.13)  0.1341 (1.14)  0.1393 (1.12)  0.1679 (1.1)   0.1684 (1.11)  0.1642 (1.11)  0.2481 (1.05)
test_benchmark_skinny_cross_attention[torch-shape5]             0.1495 (1.03)    0.1493 (1.04)  0.1475 (1.03)  0.1516 (1.03)  0.1799 (1.03)  0.1808 (1.03)  0.1772 (1.03)  0.2597 (1.0)
test_benchmark_skinny_cross_attention[vec-mat-mul-shape5]       0.1372 (1.13)    0.1372 (1.13)  0.1352 (1.13)  0.1393 (1.12)  0.1688 (1.1)   0.1697 (1.1)   0.1667 (1.1)   0.2443 (1.06)

shape=(5, 6, 1500, 64)
Name                                                            Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
--------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark_skinny_cross_attention[flash-attention-shape1]   0.0983 (1.0)     0.0979 (1.0)   0.0963 (1.0)   0.0993 (1.0)   0.1288 (1.0)   0.1298 (1.0)   0.1267 (1.0)   0.2451 (1.0)
test_benchmark_skinny_cross_attention[split-k-parallel-shape1]  0.0614 (1.6)     0.0614 (1.59)  0.0594 (1.62)  0.0635 (1.56)  0.0921 (1.4)   0.0931 (1.39)  0.0898 (1.41)  0.2054 (1.19)
test_benchmark_skinny_cross_attention[torch-shape1]             0.0696 (1.41)    0.0696 (1.41)  0.0676 (1.42)  0.0717 (1.39)  0.1001 (1.29)  0.1011 (1.28)  0.0979 (1.29)  0.2215 (1.11)
test_benchmark_skinny_cross_attention[vec-mat-mul-shape1]       0.0553 (1.78)    0.0556 (1.76)  0.0532 (1.81)  0.0573 (1.73)  0.0868 (1.48)  0.0879 (1.48)  0.0837 (1.51)  0.221 (1.11)


=================================================================== 32 passed, 1227 deselected in 66.27s (0:01:06) ====================================================================

@pommedeterresautee
Copy link
Member

pommedeterresautee commented Jan 27, 2023

linear layer
Triton is slightly faster on short seq when tensors are contiguous

root@e9354d46b52b:/kernl# pytest test/test_linear_layer.py -k "relu and contig and no_bias and fp16 and not no_cuda" --benchmark-group-by fullfunc,param:shape
================================================================================= test session starts =================================================================================
platform linux -- Python 3.9.16, pytest-7.2.1, pluggy-1.0.0
rootdir: /kernl
collected 1408 items / 1364 deselected / 44 selected                                                                                                                                  

test/test_linear_layer.py ............................................                                                                                                          [100%]
test/test_linear_layer.py::test_benchmark shape=(1, 8, 8, 8)
Name                                                                          Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
----------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-1x8x8x8-relu-no_bias-contiguous]      0.0174 (1.12)    0.0177 (1.08)  0.0164 (1.12)  0.0184 (1.11)  0.0533 (1.06)  0.0548 (1.09)  0.0505 (1.07)  0.1753 (1.0)
test_benchmark[pytorch-cuda_graphs-fp16-1x8x8x8-relu-no_bias-non-contiguous]  0.0195 (1.0)     0.0192 (1.0)   0.0184 (1.0)   0.0205 (1.0)   0.0539 (1.05)  0.0551 (1.08)  0.0513 (1.05)  0.1737 (1.01)
test_benchmark[triton-cuda_graphs-fp16-1x8x8x8-relu-no_bias-contiguous]       0.0164 (1.19)    0.0166 (1.16)  0.0154 (1.2)   0.0174 (1.18)  0.0565 (1.0)   0.0596 (1.0)   0.054 (1.0)    0.1693 (1.04)
test_benchmark[triton-cuda_graphs-fp16-1x8x8x8-relu-no_bias-non-contiguous]   0.0184 (1.06)    0.018 (1.07)   0.0164 (1.12)  0.0195 (1.05)  0.0562 (1.01)  0.0577 (1.03)  0.0538 (1.0)   0.1746 (1.0)

test/test_linear_layer.py::test_benchmark shape=(24, 128, 1280, 1280)
Name                                                                                   Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
-------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-24x128x1280x1280-relu-no_bias-contiguous]      0.0932 (1.36)    0.0935 (1.36)  0.0922 (1.36)  0.0952 (1.37)  0.1142 (1.29)  0.1153 (1.29)  0.1124 (1.3)   0.2161 (1.2)
test_benchmark[pytorch-cuda_graphs-fp16-24x128x1280x1280-relu-no_bias-non-contiguous]  0.127 (1.0)      0.1268 (1.0)   0.1239 (1.01)  0.13 (1.0)     0.1479 (1.0)   0.1487 (1.0)   0.1449 (1.0)   0.2482 (1.04)
test_benchmark[triton-cuda_graphs-fp16-24x128x1280x1280-relu-no_bias-contiguous]       0.084 (1.51)     0.0839 (1.51)  0.0829 (1.51)  0.085 (1.53)   0.1037 (1.43)  0.1046 (1.42)  0.1022 (1.42)  0.1979 (1.31)
test_benchmark[triton-cuda_graphs-fp16-24x128x1280x1280-relu-no_bias-non-contiguous]   0.126 (1.01)     0.1261 (1.01)  0.1249 (1.0)   0.128 (1.02)   0.1477 (1.0)   0.1485 (1.0)   0.1456 (1.0)   0.2585 (1.0)

test/test_linear_layer.py::test_benchmark shape=(24, 16, 1280, 1280)
Name                                                                                  Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-24x16x1280x1280-relu-no_bias-contiguous]      0.0307 (1.87)    0.0307 (1.88)  0.0297 (1.9)   0.0317 (1.87)  0.0603 (1.44)  0.061 (1.44)   0.0575 (1.47)  0.1722 (1.12)
test_benchmark[pytorch-cuda_graphs-fp16-24x16x1280x1280-relu-no_bias-non-contiguous]  0.0573 (1.0)     0.0576 (1.0)   0.0563 (1.0)   0.0594 (1.0)   0.0866 (1.0)   0.0876 (1.0)   0.0845 (1.0)   0.1927 (1.0)
test_benchmark[triton-cuda_graphs-fp16-24x16x1280x1280-relu-no_bias-contiguous]       0.0287 (2.0)     0.0291 (1.98)  0.0276 (2.04)  0.0307 (1.93)  0.0578 (1.5)   0.0593 (1.48)  0.0558 (1.52)  0.1806 (1.07)
test_benchmark[triton-cuda_graphs-fp16-24x16x1280x1280-relu-no_bias-non-contiguous]   0.0348 (1.65)    0.035 (1.65)   0.0338 (1.67)  0.0369 (1.61)  0.0633 (1.37)  0.0642 (1.36)  0.0613 (1.38)  0.1734 (1.11)

test/test_linear_layer.py::test_benchmark shape=(24, 256, 1280, 1280)
Name                                                                                   Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
-------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-24x256x1280x1280-relu-no_bias-contiguous]      0.1536 (1.4)     0.1539 (1.4)   0.1526 (1.4)   0.1567 (1.39)  0.1747 (1.35)  0.1757 (1.35)  0.1728 (1.36)  0.2656 (1.21)
test_benchmark[pytorch-cuda_graphs-fp16-24x256x1280x1280-relu-no_bias-non-contiguous]  0.1772 (1.21)    0.177 (1.22)   0.1741 (1.23)  0.1802 (1.21)  0.1984 (1.19)  0.1993 (1.19)  0.1964 (1.19)  0.292 (1.1)
test_benchmark[triton-cuda_graphs-fp16-24x256x1280x1280-relu-no_bias-contiguous]       0.1311 (1.64)    0.1315 (1.64)  0.13 (1.65)    0.1341 (1.63)  0.1519 (1.56)  0.1527 (1.55)  0.1497 (1.57)  0.2371 (1.35)
test_benchmark[triton-cuda_graphs-fp16-24x256x1280x1280-relu-no_bias-non-contiguous]   0.215 (1.0)      0.2154 (1.0)   0.214 (1.0)    0.2181 (1.0)   0.2366 (1.0)   0.2374 (1.0)   0.2344 (1.0)   0.3212 (1.0)

test/test_linear_layer.py::test_benchmark shape=(24, 512, 1280, 1280)
Name                                                                                   Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
-------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-24x512x1280x1280-relu-no_bias-contiguous]      0.3052 (1.3)     0.3056 (1.3)   0.3041 (1.3)   0.3082 (1.3)   0.3269 (1.28)  0.3278 (1.28)  0.3242 (1.29)  0.3957 (1.23)
test_benchmark[pytorch-cuda_graphs-fp16-24x512x1280x1280-relu-no_bias-non-contiguous]  0.3359 (1.18)    0.3359 (1.18)  0.3338 (1.18)  0.342 (1.17)   0.3573 (1.17)  0.3584 (1.17)  0.3539 (1.18)  0.4292 (1.13)
test_benchmark[triton-cuda_graphs-fp16-24x512x1280x1280-relu-no_bias-contiguous]       0.2273 (1.75)    0.2278 (1.74)  0.2263 (1.75)  0.2294 (1.74)  0.2486 (1.68)  0.2494 (1.68)  0.2468 (1.69)  0.3136 (1.55)
test_benchmark[triton-cuda_graphs-fp16-24x512x1280x1280-relu-no_bias-non-contiguous]   0.3973 (1.0)     0.3974 (1.0)   0.3953 (1.0)   0.3994 (1.0)   0.4188 (1.0)   0.4197 (1.0)   0.4168 (1.0)   0.4866 (1.0)

test/test_linear_layer.py::test_benchmark shape=(24, 8, 1280, 1280)
Name                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
-----------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-24x8x1280x1280-relu-no_bias-contiguous]      0.0276 (2.19)    0.0275 (2.18)  0.0266 (2.15)  0.0287 (2.14)  0.0567 (1.58)  0.058 (1.56)   0.0543 (1.6)   0.1762 (1.12)
test_benchmark[pytorch-cuda_graphs-fp16-24x8x1280x1280-relu-no_bias-non-contiguous]  0.0604 (1.0)     0.0599 (1.0)   0.0573 (1.0)   0.0614 (1.0)   0.0894 (1.0)   0.0904 (1.0)   0.0869 (1.0)   0.1972 (1.0)
test_benchmark[triton-cuda_graphs-fp16-24x8x1280x1280-relu-no_bias-contiguous]       0.0276 (2.19)    0.028 (2.14)   0.0266 (2.15)  0.0287 (2.14)  0.0601 (1.49)  0.061 (1.48)   0.0576 (1.51)  0.1811 (1.09)
test_benchmark[triton-cuda_graphs-fp16-24x8x1280x1280-relu-no_bias-non-contiguous]   0.0328 (1.84)    0.0328 (1.83)  0.0307 (1.87)  0.0338 (1.82)  0.0646 (1.38)  0.0657 (1.38)  0.0623 (1.4)   0.1764 (1.12)

test/test_linear_layer.py::test_benchmark shape=(32, 128, 1280, 1280)
Name                                                                                   Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
-------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-32x128x1280x1280-relu-no_bias-contiguous]      0.0983 (1.46)    0.0987 (1.46)  0.0973 (1.46)  0.1004 (1.45)  0.119 (1.38)   0.1199 (1.38)  0.1172 (1.39)  0.2193 (1.18)
test_benchmark[pytorch-cuda_graphs-fp16-32x128x1280x1280-relu-no_bias-non-contiguous]  0.1004 (1.43)    0.1002 (1.43)  0.0983 (1.45)  0.1014 (1.43)  0.1209 (1.36)  0.1218 (1.36)  0.119 (1.37)   0.221 (1.17)
test_benchmark[triton-cuda_graphs-fp16-32x128x1280x1280-relu-no_bias-contiguous]       0.0881 (1.63)    0.0885 (1.62)  0.087 (1.64)   0.0901 (1.61)  0.1083 (1.52)  0.109 (1.52)   0.1065 (1.53)  0.2045 (1.27)
test_benchmark[triton-cuda_graphs-fp16-32x128x1280x1280-relu-no_bias-non-contiguous]   0.1434 (1.0)     0.1436 (1.0)   0.1423 (1.0)   0.1454 (1.0)   0.1646 (1.0)   0.1653 (1.0)   0.1626 (1.0)   0.2593 (1.0)

test/test_linear_layer.py::test_benchmark shape=(32, 16, 1280, 1280)
Name                                                                                  Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min           Max
------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  ------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-32x16x1280x1280-relu-no_bias-contiguous]      0.0348 (1.65)    0.0344 (1.67)  0.0317 (1.77)  0.0358 (1.63)  0.061 (1.39)   0.0624 (1.38)  0.0592 (1.4)  0.1741 (1.17)
test_benchmark[pytorch-cuda_graphs-fp16-32x16x1280x1280-relu-no_bias-non-contiguous]  0.0573 (1.0)     0.0573 (1.0)   0.0563 (1.0)   0.0584 (1.0)   0.0848 (1.0)   0.0862 (1.0)   0.0827 (1.0)  0.2034 (1.0)
test_benchmark[triton-cuda_graphs-fp16-32x16x1280x1280-relu-no_bias-contiguous]       0.0317 (1.81)    0.0316 (1.81)  0.0297 (1.9)   0.0328 (1.78)  0.0581 (1.46)  0.0593 (1.45)  0.056 (1.48)  0.1714 (1.19)
test_benchmark[triton-cuda_graphs-fp16-32x16x1280x1280-relu-no_bias-non-contiguous]   0.0369 (1.56)    0.0373 (1.54)  0.0358 (1.57)  0.0389 (1.5)   0.065 (1.3)    0.066 (1.31)   0.063 (1.31)  0.1743 (1.17)

test/test_linear_layer.py::test_benchmark shape=(32, 256, 1280, 1280)
Name                                                                                   Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
-------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-32x256x1280x1280-relu-no_bias-contiguous]      0.1884 (1.44)    0.1879 (1.44)  0.1864 (1.45)  0.1894 (1.44)  0.2083 (1.4)   0.2092 (1.4)   0.2064 (1.4)   0.2872 (1.27)
test_benchmark[pytorch-cuda_graphs-fp16-32x256x1280x1280-relu-no_bias-non-contiguous]  0.1874 (1.45)    0.1874 (1.45)  0.1864 (1.45)  0.1894 (1.44)  0.208 (1.41)   0.2087 (1.4)   0.206 (1.41)   0.288 (1.27)
test_benchmark[triton-cuda_graphs-fp16-32x256x1280x1280-relu-no_bias-contiguous]       0.1587 (1.71)    0.159 (1.71)   0.1577 (1.71)  0.1608 (1.7)   0.1791 (1.63)  0.1799 (1.63)  0.1774 (1.63)  0.2645 (1.38)
test_benchmark[triton-cuda_graphs-fp16-32x256x1280x1280-relu-no_bias-non-contiguous]   0.2714 (1.0)     0.2713 (1.0)   0.2693 (1.0)   0.2734 (1.0)   0.2925 (1.0)   0.2932 (1.0)   0.2899 (1.0)   0.3649 (1.0)

test/test_linear_layer.py::test_benchmark shape=(32, 512, 1280, 1280)
Name                                                                                   Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
-------------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-32x512x1280x1280-relu-no_bias-contiguous]      0.3523 (1.48)    0.3526 (1.48)  0.3512 (1.48)  0.3543 (1.48)  0.3741 (1.45)  0.3751 (1.45)  0.3719 (1.45)  0.4489 (1.36)
test_benchmark[pytorch-cuda_graphs-fp16-32x512x1280x1280-relu-no_bias-non-contiguous]  0.3502 (1.49)    0.3506 (1.49)  0.3482 (1.49)  0.3533 (1.48)  0.3726 (1.46)  0.3735 (1.46)  0.3704 (1.46)  0.4434 (1.38)
test_benchmark[triton-cuda_graphs-fp16-32x512x1280x1280-relu-no_bias-contiguous]       0.299 (1.74)     0.2995 (1.74)  0.2959 (1.75)  0.3041 (1.72)  0.3192 (1.7)   0.3202 (1.7)   0.3172 (1.7)   0.3838 (1.59)
test_benchmark[triton-cuda_graphs-fp16-32x512x1280x1280-relu-no_bias-non-contiguous]   0.5212 (1.0)     0.5218 (1.0)   0.5192 (1.0)   0.5243 (1.0)   0.5435 (1.0)   0.5446 (1.0)   0.5405 (1.0)   0.6117 (1.0)

test/test_linear_layer.py::test_benchmark shape=(32, 8, 1280, 1280)
Name                                                                                 Median (CUDA)    Mean (CUDA)    Min (CUDA)     Max (CUDA)     Median         Mean           Min            Max
-----------------------------------------------------------------------------------  ---------------  -------------  -------------  -------------  -------------  -------------  -------------  -------------
test_benchmark[pytorch-cuda_graphs-fp16-32x8x1280x1280-relu-no_bias-contiguous]      0.0276 (2.07)    0.0279 (2.04)  0.0266 (2.12)  0.0287 (2.04)  0.0574 (1.52)  0.0591 (1.5)   0.0553 (1.53)  0.1953 (1.01)
test_benchmark[pytorch-cuda_graphs-fp16-32x8x1280x1280-relu-no_bias-non-contiguous]  0.0573 (1.0)     0.057 (1.0)    0.0563 (1.0)   0.0584 (1.0)   0.0871 (1.0)   0.0886 (1.0)   0.0849 (1.0)   0.1968 (1.0)
test_benchmark[triton-cuda_graphs-fp16-32x8x1280x1280-relu-no_bias-contiguous]       0.0266 (2.15)    0.0267 (2.14)  0.0246 (2.29)  0.0276 (2.11)  0.0574 (1.52)  0.0589 (1.5)   0.0553 (1.53)  0.173 (1.14)
test_benchmark[triton-cuda_graphs-fp16-32x8x1280x1280-relu-no_bias-non-contiguous]   0.0307 (1.87)    0.0305 (1.87)  0.0287 (1.96)  0.0317 (1.84)  0.0612 (1.42)  0.0621 (1.43)  0.0583 (1.46)  0.1774 (1.11)


=================================================================== 44 passed, 1364 deselected in 154.47s (0:02:34) ===================================================================

@pommedeterresautee
Copy link
Member

for memory, script with batch 1 and beam 5 (which is indeed a batch 5)

timings_original [0.7823569774627686]
difference between original and optimized model:
time to warmup: 630.16s
timings
[original] average: 0.7823569774627686s / complete: 0.7823569774627686s
[optimized] average: 0.37795209884643555s / complete: 0.37795209884643555s
memory footprint
torch.cuda.memory_allocated: 9.229591GB
torch.cuda.memory_reserved: 10.128906GB
torch.cuda.max_memory_reserved: 10.128906GB
import time

import torch
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor

from kernl.model_optimization import optimize_model


torch.set_float32_matmul_precision("high")
max_len = 50
num_beams = 5
# model_name = "openai/whisper-tiny.en"
model_name = "openai/whisper-large-v2"
model = WhisperForConditionalGeneration.from_pretrained(model_name).half().to("cuda").eval()

audio_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

# audio_dataset = load_dataset("librispeech_asr", "clean", split="test")


def get_tokens(item: dict[str, dict]) -> torch.Tensor:
    tensor = processor(item["audio"]["array"], return_tensors="pt", sampling_rate=16_000).input_features
    return tensor.cuda()


processor = WhisperProcessor.from_pretrained(model_name)
batch_size = 1
inputs = torch.cat([get_tokens(audio_dataset[i]) for i in range(batch_size)], dim=0).half()


timings_original = list()
transcriptions = list()
with torch.inference_mode():
    model.generate(
        inputs,
        min_length=max_len,
        max_length=max_len,
        num_beams=num_beams,
        do_sample=False,
    )
    torch.cuda.synchronize()
    start = time.time()
    predicted_ids = model.generate(
        inputs,
        min_length=1,
        max_length=max_len,
        num_beams=num_beams,
        do_sample=False,
    )
    torch.cuda.synchronize()
    timings_original.append(time.time() - start)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)
    transcriptions.extend(transcription)

print(f"timings_original {timings_original}")


@staticmethod
def fix_reorder_cache(past, beam_idx):
    reordered_past = ()
    for layer_past in past:
        reordered_past += (
            tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
        )
    return reordered_past


WhisperForConditionalGeneration._reorder_cache = fix_reorder_cache

optimize_model(model.model.decoder)

nb_diff = 0
timings_optimized = list()
print("difference between original and optimized model:")
with torch.inference_mode():
    start = time.time()
    model.generate(
        inputs,
        min_length=max_len,
        max_length=max_len,
        num_beams=num_beams,
        do_sample=False,
    )
    torch.cuda.synchronize()
    print(f"time to warmup: {time.time() - start:.2f}s")

    torch.cuda.synchronize()
    start = time.time()
    predicted_ids = model.generate(
        inputs,
        min_length=1,
        max_length=max_len,
        num_beams=num_beams,
        do_sample=False,
    )
    torch.cuda.synchronize()
    timings_optimized.append(time.time() - start)
    optimized_transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)

print("timings")
print(f"[original] average: {sum(timings_original) / len(timings_original)}s / complete: {sum(timings_original)}s")
print(f"[optimized] average: {sum(timings_optimized) / len(timings_optimized)}s / complete: {sum(timings_optimized)}s")


print("memory footprint")
print("torch.cuda.memory_allocated: %fGB" % (torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024))
print("torch.cuda.memory_reserved: %fGB" % (torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024))
print("torch.cuda.max_memory_reserved: %fGB" % (torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024))

@pommedeterresautee
Copy link
Member

FWIW, the notebook has been run on A100, the speedup is similar to 3090RTX (2.3x)

@pommedeterresautee
Copy link
Member

@TheExGenesis may I close this issue?
btw #257 is merged (contains few fixes related to some points expressed here)

@TheExGenesis
Copy link
Author

Yes, thank you very much for all your attention!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants