-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimized whisper.generate() hangs forever, including in tests #213
Comments
Poked with debugger as far as I could and traced it as far as I could, this was the most relevant section of the error log:
|
I noticed there were warmup steps in the tutorials so I ran Whisper tiny-en, did a long warmup step, and then had a 3x speed-up in subsequent translations. Then I tried running Whisper large-v2 on a single batch of size [16, 80, 3000] but it ran for 5m and then I ran out of memory
|
I'm curious about what warmup times usually are for models of this size |
Hello, It's is caused by two things :
So the warmump can take time, but after the warmup you will have maximum performance. We found ways to improve whisper performances more but we are waiting for fixes by third party project (torchdynamo). We are also waiting for an improvement of cuda graph #207. For the memory problem it's a known problem probably caused by torchdynamo pytorch/torchdynamo#1955 We hope to fix it soon. It could also be caused by cuda graph memory allocation we are investigating this too. |
I wonder if this resolves it pytorch/torchdynamo#1950 |
FYI we have pushed a script in experimental folder, we are currently working on making the warmup much faster in #235 |
Thanks for letting me know, I'll test it. To warmup the large model, should be 9m right (20% of 45m) ? I was also wondering if you @pommedeterresautee think Whisper could be optimized by following the instructions from transformer-deploy, or if you can see any obvious reasons why it wouldn't work. |
Running the experimental script as is throws this error: BackendCompilerFailed: _compiler raised Exception: Please convert all Tensors to FakeTensors first or instantiate
FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[-1.1314e-02, -8.0719e-03, 1.6571e-02, ..., 4.3564e-03,
-7.1289e-02, 1.1063e-03],
[ 3.3493e-03, -1.7075e-02, -1.3550e-02, ..., -1.0681e-02,
-7.5256e-02, 1.9623e-02],
[ 8.7929e-04, 6.5689e-03, -1.9608e-02, ..., 2.2621e-03,
-6.8542e-02, 4.3274e-02],
...,
[-6.1417e-03, -7.8583e-03, -1.1024e-03, ..., 1.2100e-05,
-7.0312e-02, 3.1082e-02],
[-4.5700e-03, 1.3704e-03, 8.8425e-03, ..., 8.1253e-03,
-6.9458e-02, 3.2257e-02],
[ 5.9242e-03, 2.5421e-02, -8.7891e-03, ..., 1.9779e-03,
-6.7505e-02, 4.9072e-02]], device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(...,
device='meta', size=(5, 1), dtype=torch.int64), cuda:0), 50257), **{})
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
I also tried running |
Are you on main + up to daté dependencies? |
Yup, on main and just ran Logs: ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:680 │
│ in call_user_compiler │
│ │
│ 677 │ │ │ elif config.DO_NOT_USE_legacy_non_fake_example_inputs: │
│ 678 │ │ │ │ compiled_fn = compiler_fn(gm, self.example_inputs()) │
│ 679 │ │ │ else: │
│ ❱ 680 │ │ │ │ compiled_fn = compiler_fn(gm, self.fake_example_inputs()) │
│ 681 │ │ │ _step_logger()(logging.INFO, f"done compiler function {name}") │
│ 682 │ │ │ assert callable(compiled_fn), "compiler_fn did not return callable" │
│ 683 │ │ except Exception as e: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py:1032 │
│ in debug_wrapper │
│ │
│ 1029 │ │ │ │ │ ) │
│ 1030 │ │ │ │ │ raise │
│ 1031 │ │ else: │
│ ❱ 1032 │ │ │ compiled_gm = compiler_fn(gm, example_inputs, **kwargs) │
│ 1033 │ │ │
│ 1034 │ │ return compiled_gm │
│ 1035 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/model_optimization.py:33 in │
│ _compiler │
│ │
│ 30 # https://github.com/pytorch/torchdynamo/issues/1816 │
│ 31 def _compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): │
│ 32 │ dynamo_backend_ofi(gm) │
│ ❱ 33 │ return cuda_graphs_wrapper(gm, example_inputs, pool=_pool) │
│ 34 │
│ 35 │
│ 36 def optimize_model(original_model: PreTrainedModel) -> None: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/implementations/cuda_graph. │
│ py:40 in cuda_graphs_wrapper │
│ │
│ 37 │ │ # 2 rounds, 1 to build the model (triton kernels, casting, etc.), │
│ 38 │ │ # and 1 for warmup │
│ 39 │ │ for _ in range(2): │
│ ❱ 40 │ │ │ model(*inputs) │
│ 41 │ stream.synchronize() │
│ 42 │ torch.cuda.current_stream().wait_stream(stream) │
│ 43 │ torch.cuda.synchronize() │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:660 in │
│ call_wrapped │
│ │
│ 657 │ │ │ cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined │
│ 658 │ │ │
│ 659 │ │ def call_wrapped(self, *args, **kwargs): │
│ ❱ 660 │ │ │ return self._wrapped_call(self, *args, **kwargs) │
│ 661 │ │ │
│ 662 │ │ cls.__call__ = call_wrapped │
│ 663 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:279 in │
│ __call__ │
│ │
│ 276 │ │ │ │ │ file=sys.stderr) │
│ 277 │ │ │ │ raise e.with_traceback(None) │
│ 278 │ │ │ else: │
│ ❱ 279 │ │ │ │ raise e │
│ 280 │
│ 281 @compatibility(is_backward_compatible=True) │
│ 282 class GraphModule(torch.nn.Module): │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:269 in │
│ __call__ │
│ │
│ 266 │ │ │ if self.cls_call is not None: │
│ 267 │ │ │ │ return self.cls_call(obj, *args, **kwargs) │
│ 268 │ │ │ else: │
│ ❱ 269 │ │ │ │ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[mi │
│ 270 │ │ except Exception as e: │
│ 271 │ │ │ assert e.__traceback__ │
│ 272 │ │ │ topmost_framesummary: traceback.FrameSummary = \ │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482 │
│ in _call_impl │
│ │
│ 1479 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1480 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1481 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1482 │ │ │ return forward_call(*args, **kwargs) │
│ 1483 │ │ # Do not call functions when jit is used │
│ 1484 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1485 │ │ backward_pre_hooks = [] │
│ <eval_with_key>.86:8 in forward │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482 │
│ in _call_impl │
│ │
│ 1479 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1480 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1481 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1482 │ │ │ return forward_call(*args, **kwargs) │
│ 1483 │ │ # Do not call functions when jit is used │
│ 1484 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1485 │ │ backward_pre_hooks = [] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/sparse.py:162 in │
│ forward │
│ │
│ 159 │ │ │ │ self.weight[self.padding_idx].fill_(0) │
│ 160 │ │
│ 161 │ def forward(self, input: Tensor) -> Tensor: │
│ ❱ 162 │ │ return F.embedding( │
│ 163 │ │ │ input, self.weight, self.padding_idx, self.max_norm, │
│ 164 │ │ │ self.norm_type, self.scale_grad_by_freq, self.sparse) │
│ 165 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/functional.py:2210 in │
│ embedding │
│ │
│ 2207 │ │ # torch.embedding_renorm_ │
│ 2208 │ │ # remove once script supports set_grad_enabled │
│ 2209 │ │ _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) │
│ ❱ 2210 │ return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) │
│ 2211 │
│ 2212 │
│ 2213 def embedding_bag( │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 639 in __torch_dispatch__ │
│ │
│ 636 │ │ │
│ 637 │ │ assert fake_mode is not None │
│ 638 │ │ with fake_mode: # type: ignore[attr-defined] │
│ ❱ 639 │ │ │ return func(*args, **kwargs) │
│ 640 │ │
│ 641 │ @staticmethod │
│ 642 │ def _find_common_device(func, args, kwargs) -> Tuple[torch.device, bool]: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_ops.py:284 in __call__ │
│ │
│ 281 │ │ ) │
│ 282 │ │
│ 283 │ def __call__(self, *args, **kwargs): │
│ ❱ 284 │ │ return self._op(*args, **kwargs or {}) │
│ 285 │ │
│ 286 │ def __hash__(self): │
│ 287 │ │ return hash(self._op) │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 818 in __torch_dispatch__ │
│ │
│ 815 │ │ │ ), f"{args} {kwargs}" │
│ 816 │ │ │ return converter(self, args[0]) │
│ 817 │ │ │
│ ❱ 818 │ │ args, kwargs = self.validate_and_convert_non_fake_tensors( │
│ 819 │ │ │ func, converter, args, kwargs │
│ 820 │ │ ) │
│ 821 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 966 in validate_and_convert_non_fake_tensors │
│ │
│ 963 │ │ │ │ return converter(self, x) │
│ 964 │ │ │ return x │
│ 965 │ │ │
│ ❱ 966 │ │ return tree_map_only( │
│ 967 │ │ │ torch.Tensor, │
│ 968 │ │ │ validate, │
│ 969 │ │ │ (args, kwargs), │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:259 in │
│ tree_map_only │
│ │
│ 256 │ ... │
│ 257 │
│ 258 def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree: │
│ ❱ 259 │ return tree_map(map_only(ty)(fn), pytree) │
│ 260 │
│ 261 def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool: │
│ 262 │ flat_args, _ = tree_flatten(pytree) │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:195 in │
│ tree_map │
│ │
│ 192 │
│ 193 def tree_map(fn: Any, pytree: PyTree) -> PyTree: │
│ 194 │ flat_args, spec = tree_flatten(pytree) │
│ ❱ 195 │ return tree_unflatten([fn(i) for i in flat_args], spec) │
│ 196 │
│ 197 Type2 = Tuple[Type[T], Type[S]] │
│ 198 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:195 in │
│ <listcomp> │
│ │
│ 192 │
│ 193 def tree_map(fn: Any, pytree: PyTree) -> PyTree: │
│ 194 │ flat_args, spec = tree_flatten(pytree) │
│ ❱ 195 │ return tree_unflatten([fn(i) for i in flat_args], spec) │
│ 196 │
│ 197 Type2 = Tuple[Type[T], Type[S]] │
│ 198 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:244 in │
│ inner │
│ │
│ 241 │ │ @functools.wraps(f) │
│ 242 │ │ def inner(x: T) -> Any: │
│ 243 │ │ │ if isinstance(x, ty): │
│ ❱ 244 │ │ │ │ return f(x) │
│ 245 │ │ │ else: │
│ 246 │ │ │ │ return x │
│ 247 │ │ return inner │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 958 in validate │
│ │
│ 955 │ │ │ │ │ │ f"Can't call metadata mutating ops on non-Fake Tensor inputs. Fo │
│ 956 │ │ │ │ │ ) │
│ 957 │ │ │ │ if not self.allow_non_fake_inputs: │
│ ❱ 958 │ │ │ │ │ raise Exception( │
│ 959 │ │ │ │ │ │ f"Please convert all Tensors to FakeTensors first or instantiate │
│ 960 │ │ │ │ │ │ f"with 'allow_non_fake_inputs'. Found in {func}(*{args}, **{kwar │
│ 961 │ │ │ │ │ ) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with
'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[-1.1314e-02, -8.0719e-03, 1.6571e-02, ..., 4.3564e-03,
-7.1289e-02, 1.1063e-03],
[ 3.3493e-03, -1.7075e-02, -1.3550e-02, ..., -1.0681e-02,
-7.5256e-02, 1.9623e-02],
[ 8.7929e-04, 6.5689e-03, -1.9608e-02, ..., 2.2621e-03,
-6.8542e-02, 4.3274e-02],
...,
[-6.1417e-03, -7.8583e-03, -1.1024e-03, ..., 1.2100e-05,
-7.0312e-02, 3.1082e-02],
[-4.5700e-03, 1.3704e-03, 8.8425e-03, ..., 8.1253e-03,
-6.9458e-02, 3.2257e-02],
[ 5.9242e-03, 2.5421e-02, -8.7891e-03, ..., 1.9779e-03,
-6.7505e-02, 4.9072e-02]], device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(...,
device='meta', size=(5, 1), dtype=torch.int64), cuda:0), 50257), **{})
The above exception was the direct cause of the following exception:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <ipython-input-3-5a2c5bbc2633>:5 in <module> │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_contextlib.py:115 in │
│ decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/generation/utils.py: │
│ 1608 in generate │
│ │
│ 1605 │ │ │ │ **model_kwargs, │
│ 1606 │ │ │ ) │
│ 1607 │ │ │ # 12. run beam search │
│ ❱ 1608 │ │ │ return self.beam_search( │
│ 1609 │ │ │ │ input_ids, │
│ 1610 │ │ │ │ beam_scorer, │
│ 1611 │ │ │ │ logits_processor=logits_processor, │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/generation/utils.py: │
│ 2799 in beam_search │
│ │
│ 2796 │ │ │ │
│ 2797 │ │ │ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) │
│ 2798 │ │ │ │
│ ❱ 2799 │ │ │ outputs = self( │
│ 2800 │ │ │ │ **model_inputs, │
│ 2801 │ │ │ │ return_dict=True, │
│ 2802 │ │ │ │ output_attentions=output_attentions, │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482 │
│ in _call_impl │
│ │
│ 1479 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1480 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1481 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1482 │ │ │ return forward_call(*args, **kwargs) │
│ 1483 │ │ # Do not call functions when jit is used │
│ 1484 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1485 │ │ backward_pre_hooks = [] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/model │
│ ing_whisper.py:1337 in forward │
│ │
│ 1334 │ │ │ │ │ labels, self.config.pad_token_id, self.config.decoder_start_token_id │
│ 1335 │ │ │ │ ) │
│ 1336 │ │ │
│ ❱ 1337 │ │ outputs = self.model( │
│ 1338 │ │ │ input_features, │
│ 1339 │ │ │ decoder_input_ids=decoder_input_ids, │
│ 1340 │ │ │ encoder_outputs=encoder_outputs, │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482 │
│ in _call_impl │
│ │
│ 1479 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1480 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1481 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1482 │ │ │ return forward_call(*args, **kwargs) │
│ 1483 │ │ # Do not call functions when jit is used │
│ 1484 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1485 │ │ backward_pre_hooks = [] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/model │
│ ing_whisper.py:1202 in forward │
│ │
│ 1199 │ │ │ ) │
│ 1200 │ │ │
│ 1201 │ │ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_att │
│ ❱ 1202 │ │ decoder_outputs = self.decoder( │
│ 1203 │ │ │ input_ids=decoder_input_ids, │
│ 1204 │ │ │ attention_mask=decoder_attention_mask, │
│ 1205 │ │ │ encoder_hidden_states=encoder_outputs[0], │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482 │
│ in _call_impl │
│ │
│ 1479 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1480 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1481 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1482 │ │ │ return forward_call(*args, **kwargs) │
│ 1483 │ │ # Do not call functions when jit is used │
│ 1484 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1485 │ │ backward_pre_hooks = [] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:212 │
│ in _fn │
│ │
│ 209 │ │ │ dynamic_ctx = enable_dynamic(self.dynamic) │
│ 210 │ │ │ dynamic_ctx.__enter__() │
│ 211 │ │ │ try: │
│ ❱ 212 │ │ │ │ return fn(*args, **kwargs) │
│ 213 │ │ │ finally: │
│ 214 │ │ │ │ set_eval_frame(prior) │
│ 215 │ │ │ │ dynamic_ctx.__exit__(None, None, None) │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/model_optimization.py:67 in │
│ run │
│ │
│ 64 │ │
│ 65 │ @torchdynamo.optimize(_compiler) │
│ 66 │ def run(*args, **kwargs): │
│ ❱ 67 │ │ return original_model.forward2(*args, **kwargs) │
│ 68 │ │
│ 69 │ original_model.forward = run │
│ 70 │
│ <ipython-input-1-09da0ab26e20>:45 in wrapper_stride │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:333 │
│ in catch_errors │
│ │
│ 330 │ │ │ │ │ return hijacked_callback(frame, cache_size, hooks) │
│ 331 │ │ │
│ 332 │ │ with compile_lock: │
│ ❱ 333 │ │ │ return callback(frame, cache_size, hooks) │
│ 334 │ │
│ 335 │ catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined] │
│ 336 │ return catch_errors │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:48 │
│ 0 in _convert_frame │
│ │
│ 477 │ def _convert_frame(frame: types.FrameType, cache_size: int, hooks: Hooks): │
│ 478 │ │ counters["frames"]["total"] += 1 │
│ 479 │ │ try: │
│ ❱ 480 │ │ │ result = inner_convert(frame, cache_size, hooks) │
│ 481 │ │ │ counters["frames"]["ok"] += 1 │
│ 482 │ │ │ return result │
│ 483 │ │ except (NotImplementedError, Unsupported): │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:10 │
│ 3 in _fn │
│ │
│ 100 │ │ prior_fwd_from_src = torch.fx.graph_module._forward_from_src │
│ 101 │ │ torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result │
│ 102 │ │ try: │
│ ❱ 103 │ │ │ return fn(*args, **kwargs) │
│ 104 │ │ finally: │
│ 105 │ │ │ torch._C._set_grad_enabled(prior_grad_mode) │
│ 106 │ │ │ torch.random.set_rng_state(rng_state) │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/utils.py:88 in │
│ time_wrapper │
│ │
│ 85 │ │ if key not in compilation_metrics: │
│ 86 │ │ │ compilation_metrics[key] = [] │
│ 87 │ │ t0 = time.time() │
│ ❱ 88 │ │ r = func(*args, **kwargs) │
│ 89 │ │ latency = time.time() - t0 │
│ 90 │ │ # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec") │
│ 91 │ │ compilation_metrics[key].append(latency) │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:33 │
│ 9 in _convert_frame_assert │
│ │
│ 336 │ │ global initial_grad_state │
│ 337 │ │ initial_grad_state = torch.is_grad_enabled() │
│ 338 │ │ │
│ ❱ 339 │ │ return _compile( │
│ 340 │ │ │ frame.f_code, │
│ 341 │ │ │ frame.f_globals, │
│ 342 │ │ │ frame.f_locals, │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:40 │
│ 0 in _compile │
│ │
│ 397 │ try: │
│ 398 │ │ for attempt in itertools.count(): │
│ 399 │ │ │ try: │
│ ❱ 400 │ │ │ │ out_code = transform_code_object(code, transform) │
│ 401 │ │ │ │ orig_code_map[out_code] = code │
│ 402 │ │ │ │ break │
│ 403 │ │ │ except exc.RestartAnalysis: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/bytecode_transforma │
│ tion.py:341 in transform_code_object │
│ │
│ 338 │ instructions = cleaned_instructions(code, safe) │
│ 339 │ propagate_line_nums(instructions) │
│ 340 │ │
│ ❱ 341 │ transformations(instructions, code_options) │
│ 342 │ │
│ 343 │ fix_vars(instructions, code_options) │
│ 344 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:38 │
│ 7 in transform │
│ │
│ 384 │ │ │ export, │
│ 385 │ │ │ mutated_closure_cell_contents, │
│ 386 │ │ ) │
│ ❱ 387 │ │ tracer.run() │
│ 388 │ │ output = tracer.output │
│ 389 │ │ assert output is not None │
│ 390 │ │ assert output.output_instructions │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :1684 in run │
│ │
│ 1681 │ │
│ 1682 │ def run(self): │
│ 1683 │ │ _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}") │
│ ❱ 1684 │ │ super().run() │
│ 1685 │ │
│ 1686 │ def match_nested_cell(self, name, cell): │
│ 1687 │ │ """Match a cell in this method to one in a function we are inlining""" │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :538 in run │
│ │
│ 535 │ │ │ while ( │
│ 536 │ │ │ │ self.instruction_pointer is not None │
│ 537 │ │ │ │ and not self.output.should_exit │
│ ❱ 538 │ │ │ │ and self.step() │
│ 539 │ │ │ ): │
│ 540 │ │ │ │ pass │
│ 541 │ │ except BackendCompilerFailed: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :501 in step │
│ │
│ 498 │ │ try: │
│ 499 │ │ │ if not hasattr(self, inst.opname): │
│ 500 │ │ │ │ unimplemented(f"missing: {inst.opname}") │
│ ❱ 501 │ │ │ getattr(self, inst.opname)(inst) │
│ 502 │ │ │ │
│ 503 │ │ │ return inst.opname != "RETURN_VALUE" │
│ 504 │ │ except BackendCompilerFailed: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :1750 in RETURN_VALUE │
│ │
│ 1747 │ │ │ f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)", │
│ 1748 │ │ ) │
│ 1749 │ │ log.debug("RETURN_VALUE triggered compile") │
│ ❱ 1750 │ │ self.output.compile_subgraph(self) │
│ 1751 │ │ self.output.add_output_instructions([create_instruction("RETURN_VALUE")]) │
│ 1752 │
│ 1753 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:557 │
│ in compile_subgraph │
│ │
│ 554 │ │ │ output = [] │
│ 555 │ │ │ if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: │
│ 556 │ │ │ │ output.extend( │
│ ❱ 557 │ │ │ │ │ self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) │
│ 558 │ │ │ │ ) │
│ 559 │ │ │ │ │
│ 560 │ │ │ │ if len(pass2.graph_outputs) != 0: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:604 │
│ in compile_and_call_fx_graph │
│ │
│ 601 │ │ │
│ 602 │ │ assert_no_fake_params_or_buffers(gm) │
│ 603 │ │ with tracing(self.tracing_context): │
│ ❱ 604 │ │ │ compiled_fn = self.call_user_compiler(gm) │
│ 605 │ │ compiled_fn = disable(compiled_fn) │
│ 606 │ │ │
│ 607 │ │ counters["stats"]["unique_graphs"] += 1 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:685 │
│ in call_user_compiler │
│ │
│ 682 │ │ │ assert callable(compiled_fn), "compiler_fn did not return callable" │
│ 683 │ │ except Exception as e: │
│ 684 │ │ │ compiled_fn = gm.forward │
│ ❱ 685 │ │ │ raise BackendCompilerFailed(self.compiler_fn, e) from e │
│ 686 │ │ return compiled_fn │
│ 687 │ │
│ 688 │ def fake_example_inputs(self) -> List[torch.Tensor]: │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
BackendCompilerFailed: _compiler raised Exception: Please convert all Tensors to FakeTensors first or instantiate
FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[-1.1314e-02, -8.0719e-03, 1.6571e-02, ..., 4.3564e-03,
-7.1289e-02, 1.1063e-03],
[ 3.3493e-03, -1.7075e-02, -1.3550e-02, ..., -1.0681e-02,
-7.5256e-02, 1.9623e-02],
[ 8.7929e-04, 6.5689e-03, -1.9608e-02, ..., 2.2621e-03,
-6.8542e-02, 4.3274e-02],
...,
[-6.1417e-03, -7.8583e-03, -1.1024e-03, ..., 1.2100e-05,
-7.0312e-02, 3.1082e-02],
[-4.5700e-03, 1.3704e-03, 8.8425e-03, ..., 8.1253e-03,
-6.9458e-02, 3.2257e-02],
[ 5.9242e-03, 2.5421e-02, -8.7891e-03, ..., 1.9779e-03,
-6.7505e-02, 4.9072e-02]], device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(...,
device='meta', size=(5, 1), dtype=torch.int64), cuda:0), 50257), **{})
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True |
Should I make a new issue for this? |
can you try with this branch? |
similar thing I'm afraid, running on an A100, Python 3.9.15 ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:680 │
│ in call_user_compiler │
│ │
│ 677 │ │ │ elif config.DO_NOT_USE_legacy_non_fake_example_inputs: │
│ 678 │ │ │ │ compiled_fn = compiler_fn(gm, self.example_inputs()) │
│ 679 │ │ │ else: │
│ ❱ 680 │ │ │ │ compiled_fn = compiler_fn(gm, self.fake_example_inputs()) │
│ 681 │ │ │ _step_logger()(logging.INFO, f"done compiler function {name}") │
│ 682 │ │ │ assert callable(compiled_fn), "compiler_fn did not return callable" │
│ 683 │ │ except Exception as e: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py:1032 │
│ in debug_wrapper │
│ │
│ 1029 │ │ │ │ │ ) │
│ 1030 │ │ │ │ │ raise │
│ 1031 │ │ else: │
│ ❱ 1032 │ │ │ compiled_gm = compiler_fn(gm, example_inputs, **kwargs) │
│ 1033 │ │ │
│ 1034 │ │ return compiled_gm │
│ 1035 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/model_optimization.py:33 in │
│ _compiler │
│ │
│ 30 # https://github.com/pytorch/torchdynamo/issues/1816 │
│ 31 def _compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): │
│ 32 │ dynamo_backend_ofi(gm) │
│ ❱ 33 │ return cuda_graphs_wrapper(gm, example_inputs, pool=_pool) │
│ 34 │
│ 35 │
│ 36 def optimize_model(original_model: PreTrainedModel) -> None: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/implementations/cuda_graph. │
│ py:40 in cuda_graphs_wrapper │
│ │
│ 37 │ │ # 2 rounds, 1 to build the model (triton kernels, casting, etc.), │
│ 38 │ │ # and 1 for warmup │
│ 39 │ │ for _ in range(2): │
│ ❱ 40 │ │ │ model(*inputs) │
│ 41 │ stream.synchronize() │
│ 42 │ torch.cuda.current_stream().wait_stream(stream) │
│ 43 │ torch.cuda.synchronize() │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:660 in │
│ call_wrapped │
│ │
│ 657 │ │ │ cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined │
│ 658 │ │ │
│ 659 │ │ def call_wrapped(self, *args, **kwargs): │
│ ❱ 660 │ │ │ return self._wrapped_call(self, *args, **kwargs) │
│ 661 │ │ │
│ 662 │ │ cls.__call__ = call_wrapped │
│ 663 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:279 in │
│ __call__ │
│ │
│ 276 │ │ │ │ │ file=sys.stderr) │
│ 277 │ │ │ │ raise e.with_traceback(None) │
│ 278 │ │ │ else: │
│ ❱ 279 │ │ │ │ raise e │
│ 280 │
│ 281 @compatibility(is_backward_compatible=True) │
│ 282 class GraphModule(torch.nn.Module): │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/fx/graph_module.py:269 in │
│ __call__ │
│ │
│ 266 │ │ │ if self.cls_call is not None: │
│ 267 │ │ │ │ return self.cls_call(obj, *args, **kwargs) │
│ 268 │ │ │ else: │
│ ❱ 269 │ │ │ │ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[mi │
│ 270 │ │ except Exception as e: │
│ 271 │ │ │ assert e.__traceback__ │
│ 272 │ │ │ topmost_framesummary: traceback.FrameSummary = \ │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482 │
│ in _call_impl │
│ │
│ 1479 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1480 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1481 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1482 │ │ │ return forward_call(*args, **kwargs) │
│ 1483 │ │ # Do not call functions when jit is used │
│ 1484 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1485 │ │ backward_pre_hooks = [] │
│ <eval_with_key>.680:8 in forward │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482 │
│ in _call_impl │
│ │
│ 1479 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1480 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1481 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1482 │ │ │ return forward_call(*args, **kwargs) │
│ 1483 │ │ # Do not call functions when jit is used │
│ 1484 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1485 │ │ backward_pre_hooks = [] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/sparse.py:162 in │
│ forward │
│ │
│ 159 │ │ │ │ self.weight[self.padding_idx].fill_(0) │
│ 160 │ │
│ 161 │ def forward(self, input: Tensor) -> Tensor: │
│ ❱ 162 │ │ return F.embedding( │
│ 163 │ │ │ input, self.weight, self.padding_idx, self.max_norm, │
│ 164 │ │ │ self.norm_type, self.scale_grad_by_freq, self.sparse) │
│ 165 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/functional.py:2210 in │
│ embedding │
│ │
│ 2207 │ │ # torch.embedding_renorm_ │
│ 2208 │ │ # remove once script supports set_grad_enabled │
│ 2209 │ │ _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) │
│ ❱ 2210 │ return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) │
│ 2211 │
│ 2212 │
│ 2213 def embedding_bag( │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 639 in __torch_dispatch__ │
│ │
│ 636 │ │ │
│ 637 │ │ assert fake_mode is not None │
│ 638 │ │ with fake_mode: # type: ignore[attr-defined] │
│ ❱ 639 │ │ │ return func(*args, **kwargs) │
│ 640 │ │
│ 641 │ @staticmethod │
│ 642 │ def _find_common_device(func, args, kwargs) -> Tuple[torch.device, bool]: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_ops.py:284 in __call__ │
│ │
│ 281 │ │ ) │
│ 282 │ │
│ 283 │ def __call__(self, *args, **kwargs): │
│ ❱ 284 │ │ return self._op(*args, **kwargs or {}) │
│ 285 │ │
│ 286 │ def __hash__(self): │
│ 287 │ │ return hash(self._op) │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 818 in __torch_dispatch__ │
│ │
│ 815 │ │ │ ), f"{args} {kwargs}" │
│ 816 │ │ │ return converter(self, args[0]) │
│ 817 │ │ │
│ ❱ 818 │ │ args, kwargs = self.validate_and_convert_non_fake_tensors( │
│ 819 │ │ │ func, converter, args, kwargs │
│ 820 │ │ ) │
│ 821 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 966 in validate_and_convert_non_fake_tensors │
│ │
│ 963 │ │ │ │ return converter(self, x) │
│ 964 │ │ │ return x │
│ 965 │ │ │
│ ❱ 966 │ │ return tree_map_only( │
│ 967 │ │ │ torch.Tensor, │
│ 968 │ │ │ validate, │
│ 969 │ │ │ (args, kwargs), │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:259 in │
│ tree_map_only │
│ │
│ 256 │ ... │
│ 257 │
│ 258 def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree: │
│ ❱ 259 │ return tree_map(map_only(ty)(fn), pytree) │
│ 260 │
│ 261 def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool: │
│ 262 │ flat_args, _ = tree_flatten(pytree) │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:195 in │
│ tree_map │
│ │
│ 192 │
│ 193 def tree_map(fn: Any, pytree: PyTree) -> PyTree: │
│ 194 │ flat_args, spec = tree_flatten(pytree) │
│ ❱ 195 │ return tree_unflatten([fn(i) for i in flat_args], spec) │
│ 196 │
│ 197 Type2 = Tuple[Type[T], Type[S]] │
│ 198 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:195 in │
│ <listcomp> │
│ │
│ 192 │
│ 193 def tree_map(fn: Any, pytree: PyTree) -> PyTree: │
│ 194 │ flat_args, spec = tree_flatten(pytree) │
│ ❱ 195 │ return tree_unflatten([fn(i) for i in flat_args], spec) │
│ 196 │
│ 197 Type2 = Tuple[Type[T], Type[S]] │
│ 198 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_pytree.py:244 in │
│ inner │
│ │
│ 241 │ │ @functools.wraps(f) │
│ 242 │ │ def inner(x: T) -> Any: │
│ 243 │ │ │ if isinstance(x, ty): │
│ ❱ 244 │ │ │ │ return f(x) │
│ 245 │ │ │ else: │
│ 246 │ │ │ │ return x │
│ 247 │ │ return inner │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py: │
│ 958 in validate │
│ │
│ 955 │ │ │ │ │ │ f"Can't call metadata mutating ops on non-Fake Tensor inputs. Fo │
│ 956 │ │ │ │ │ ) │
│ 957 │ │ │ │ if not self.allow_non_fake_inputs: │
│ ❱ 958 │ │ │ │ │ raise Exception( │
│ 959 │ │ │ │ │ │ f"Please convert all Tensors to FakeTensors first or instantiate │
│ 960 │ │ │ │ │ │ f"with 'allow_non_fake_inputs'. Found in {func}(*{args}, **{kwar │
│ 961 │ │ │ │ │ ) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with
'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[ 0.0347, -0.0053, -0.0164, ..., 0.0153, -0.0056, 0.0097],
[ 0.0349, -0.0392, -0.0045, ..., -0.0051, -0.0129, 0.0196],
[ 0.0395, 0.0132, -0.0138, ..., 0.0297, -0.0105, 0.0384],
...,
[ 0.0358, -0.0039, -0.0058, ..., 0.0160, -0.0051, 0.0229],
[ 0.0351, -0.0034, -0.0079, ..., 0.0159, -0.0065, 0.0217],
[ 0.0357, -0.0059, -0.0096, ..., 0.0106, -0.0128, 0.0225]],
device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(..., device='meta', size=(5, 1),
dtype=torch.int64), cuda:0), 50256), **{})
The above exception was the direct cause of the following exception:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <ipython-input-6-12f3c3cb93e9>:10 in <module> │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/utils/_contextlib.py:115 in │
│ decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/generation/utils.py: │
│ 1608 in generate │
│ │
│ 1605 │ │ │ │ **model_kwargs, │
│ 1606 │ │ │ ) │
│ 1607 │ │ │ # 12. run beam search │
│ ❱ 1608 │ │ │ return self.beam_search( │
│ 1609 │ │ │ │ input_ids, │
│ 1610 │ │ │ │ beam_scorer, │
│ 1611 │ │ │ │ logits_processor=logits_processor, │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/generation/utils.py: │
│ 2799 in beam_search │
│ │
│ 2796 │ │ │ │
│ 2797 │ │ │ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) │
│ 2798 │ │ │ │
│ ❱ 2799 │ │ │ outputs = self( │
│ 2800 │ │ │ │ **model_inputs, │
│ 2801 │ │ │ │ return_dict=True, │
│ 2802 │ │ │ │ output_attentions=output_attentions, │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482 │
│ in _call_impl │
│ │
│ 1479 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1480 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1481 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1482 │ │ │ return forward_call(*args, **kwargs) │
│ 1483 │ │ # Do not call functions when jit is used │
│ 1484 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1485 │ │ backward_pre_hooks = [] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/model │
│ ing_whisper.py:1337 in forward │
│ │
│ 1334 │ │ │ │ │ labels, self.config.pad_token_id, self.config.decoder_start_token_id │
│ 1335 │ │ │ │ ) │
│ 1336 │ │ │
│ ❱ 1337 │ │ outputs = self.model( │
│ 1338 │ │ │ input_features, │
│ 1339 │ │ │ decoder_input_ids=decoder_input_ids, │
│ 1340 │ │ │ encoder_outputs=encoder_outputs, │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482 │
│ in _call_impl │
│ │
│ 1479 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1480 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1481 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1482 │ │ │ return forward_call(*args, **kwargs) │
│ 1483 │ │ # Do not call functions when jit is used │
│ 1484 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1485 │ │ backward_pre_hooks = [] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/model │
│ ing_whisper.py:1202 in forward │
│ │
│ 1199 │ │ │ ) │
│ 1200 │ │ │
│ 1201 │ │ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_att │
│ ❱ 1202 │ │ decoder_outputs = self.decoder( │
│ 1203 │ │ │ input_ids=decoder_input_ids, │
│ 1204 │ │ │ attention_mask=decoder_attention_mask, │
│ 1205 │ │ │ encoder_hidden_states=encoder_outputs[0], │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/nn/modules/module.py:1482 │
│ in _call_impl │
│ │
│ 1479 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1480 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1481 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1482 │ │ │ return forward_call(*args, **kwargs) │
│ 1483 │ │ # Do not call functions when jit is used │
│ 1484 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1485 │ │ backward_pre_hooks = [] │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:212 │
│ in _fn │
│ │
│ 209 │ │ │ dynamic_ctx = enable_dynamic(self.dynamic) │
│ 210 │ │ │ dynamic_ctx.__enter__() │
│ 211 │ │ │ try: │
│ ❱ 212 │ │ │ │ return fn(*args, **kwargs) │
│ 213 │ │ │ finally: │
│ 214 │ │ │ │ set_eval_frame(prior) │
│ 215 │ │ │ │ dynamic_ctx.__exit__(None, None, None) │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/kernl/model_optimization.py:67 in │
│ run │
│ │
│ 64 │ │
│ 65 │ @torchdynamo.optimize(_compiler) │
│ 66 │ def run(*args, **kwargs): │
│ ❱ 67 │ │ return original_model.forward2(*args, **kwargs) │
│ 68 │ │
│ 69 │ original_model.forward = run │
│ 70 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:333 │
│ in catch_errors │
│ │
│ 330 │ │ │ │ │ return hijacked_callback(frame, cache_size, hooks) │
│ 331 │ │ │
│ 332 │ │ with compile_lock: │
│ ❱ 333 │ │ │ return callback(frame, cache_size, hooks) │
│ 334 │ │
│ 335 │ catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined] │
│ 336 │ return catch_errors │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:48 │
│ 0 in _convert_frame │
│ │
│ 477 │ def _convert_frame(frame: types.FrameType, cache_size: int, hooks: Hooks): │
│ 478 │ │ counters["frames"]["total"] += 1 │
│ 479 │ │ try: │
│ ❱ 480 │ │ │ result = inner_convert(frame, cache_size, hooks) │
│ 481 │ │ │ counters["frames"]["ok"] += 1 │
│ 482 │ │ │ return result │
│ 483 │ │ except (NotImplementedError, Unsupported): │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:10 │
│ 3 in _fn │
│ │
│ 100 │ │ prior_fwd_from_src = torch.fx.graph_module._forward_from_src │
│ 101 │ │ torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result │
│ 102 │ │ try: │
│ ❱ 103 │ │ │ return fn(*args, **kwargs) │
│ 104 │ │ finally: │
│ 105 │ │ │ torch._C._set_grad_enabled(prior_grad_mode) │
│ 106 │ │ │ torch.random.set_rng_state(rng_state) │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/utils.py:88 in │
│ time_wrapper │
│ │
│ 85 │ │ if key not in compilation_metrics: │
│ 86 │ │ │ compilation_metrics[key] = [] │
│ 87 │ │ t0 = time.time() │
│ ❱ 88 │ │ r = func(*args, **kwargs) │
│ 89 │ │ latency = time.time() - t0 │
│ 90 │ │ # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec") │
│ 91 │ │ compilation_metrics[key].append(latency) │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:33 │
│ 9 in _convert_frame_assert │
│ │
│ 336 │ │ global initial_grad_state │
│ 337 │ │ initial_grad_state = torch.is_grad_enabled() │
│ 338 │ │ │
│ ❱ 339 │ │ return _compile( │
│ 340 │ │ │ frame.f_code, │
│ 341 │ │ │ frame.f_globals, │
│ 342 │ │ │ frame.f_locals, │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:40 │
│ 0 in _compile │
│ │
│ 397 │ try: │
│ 398 │ │ for attempt in itertools.count(): │
│ 399 │ │ │ try: │
│ ❱ 400 │ │ │ │ out_code = transform_code_object(code, transform) │
│ 401 │ │ │ │ orig_code_map[out_code] = code │
│ 402 │ │ │ │ break │
│ 403 │ │ │ except exc.RestartAnalysis: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/bytecode_transforma │
│ tion.py:341 in transform_code_object │
│ │
│ 338 │ instructions = cleaned_instructions(code, safe) │
│ 339 │ propagate_line_nums(instructions) │
│ 340 │ │
│ ❱ 341 │ transformations(instructions, code_options) │
│ 342 │ │
│ 343 │ fix_vars(instructions, code_options) │
│ 344 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:38 │
│ 7 in transform │
│ │
│ 384 │ │ │ export, │
│ 385 │ │ │ mutated_closure_cell_contents, │
│ 386 │ │ ) │
│ ❱ 387 │ │ tracer.run() │
│ 388 │ │ output = tracer.output │
│ 389 │ │ assert output is not None │
│ 390 │ │ assert output.output_instructions │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :1684 in run │
│ │
│ 1681 │ │
│ 1682 │ def run(self): │
│ 1683 │ │ _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}") │
│ ❱ 1684 │ │ super().run() │
│ 1685 │ │
│ 1686 │ def match_nested_cell(self, name, cell): │
│ 1687 │ │ """Match a cell in this method to one in a function we are inlining""" │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :538 in run │
│ │
│ 535 │ │ │ while ( │
│ 536 │ │ │ │ self.instruction_pointer is not None │
│ 537 │ │ │ │ and not self.output.should_exit │
│ ❱ 538 │ │ │ │ and self.step() │
│ 539 │ │ │ ): │
│ 540 │ │ │ │ pass │
│ 541 │ │ except BackendCompilerFailed: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :501 in step │
│ │
│ 498 │ │ try: │
│ 499 │ │ │ if not hasattr(self, inst.opname): │
│ 500 │ │ │ │ unimplemented(f"missing: {inst.opname}") │
│ ❱ 501 │ │ │ getattr(self, inst.opname)(inst) │
│ 502 │ │ │ │
│ 503 │ │ │ return inst.opname != "RETURN_VALUE" │
│ 504 │ │ except BackendCompilerFailed: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py │
│ :1750 in RETURN_VALUE │
│ │
│ 1747 │ │ │ f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)", │
│ 1748 │ │ ) │
│ 1749 │ │ log.debug("RETURN_VALUE triggered compile") │
│ ❱ 1750 │ │ self.output.compile_subgraph(self) │
│ 1751 │ │ self.output.add_output_instructions([create_instruction("RETURN_VALUE")]) │
│ 1752 │
│ 1753 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:557 │
│ in compile_subgraph │
│ │
│ 554 │ │ │ output = [] │
│ 555 │ │ │ if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: │
│ 556 │ │ │ │ output.extend( │
│ ❱ 557 │ │ │ │ │ self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) │
│ 558 │ │ │ │ ) │
│ 559 │ │ │ │ │
│ 560 │ │ │ │ if len(pass2.graph_outputs) != 0: │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:604 │
│ in compile_and_call_fx_graph │
│ │
│ 601 │ │ │
│ 602 │ │ assert_no_fake_params_or_buffers(gm) │
│ 603 │ │ with tracing(self.tracing_context): │
│ ❱ 604 │ │ │ compiled_fn = self.call_user_compiler(gm) │
│ 605 │ │ compiled_fn = disable(compiled_fn) │
│ 606 │ │ │
│ 607 │ │ counters["stats"]["unique_graphs"] += 1 │
│ │
│ /root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py:685 │
│ in call_user_compiler │
│ │
│ 682 │ │ │ assert callable(compiled_fn), "compiler_fn did not return callable" │
│ 683 │ │ except Exception as e: │
│ 684 │ │ │ compiled_fn = gm.forward │
│ ❱ 685 │ │ │ raise BackendCompilerFailed(self.compiler_fn, e) from e │
│ 686 │ │ return compiled_fn │
│ 687 │ │
│ 688 │ def fake_example_inputs(self) -> List[torch.Tensor]: │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
BackendCompilerFailed: _compiler raised Exception: Please convert all Tensors to FakeTensors first or instantiate
FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.embedding.default(*(Parameter containing:
tensor([[ 0.0347, -0.0053, -0.0164, ..., 0.0153, -0.0056, 0.0097],
[ 0.0349, -0.0392, -0.0045, ..., -0.0051, -0.0129, 0.0196],
[ 0.0395, 0.0132, -0.0138, ..., 0.0297, -0.0105, 0.0384],
...,
[ 0.0358, -0.0039, -0.0058, ..., 0.0160, -0.0051, 0.0229],
[ 0.0351, -0.0034, -0.0079, ..., 0.0159, -0.0065, 0.0217],
[ 0.0357, -0.0059, -0.0096, ..., 0.0106, -0.0128, 0.0225]],
device='cuda:0', requires_grad=True), FakeTensor(FakeTensor(..., device='meta', size=(5, 1),
dtype=torch.int64), cuda:0), 50256), **{})
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True``` |
Mmmhhh this is strange we are testing this script on different kind of machines without any issue (or at least not that one :-) ) And you just run Also, just asking, can you test with pip env instead of anaconda? That's the only obvious diff I see. |
Also, you could try building the docker image we have in the repo, and test with docker. It will ensure we have the exact same setup |
I tried the docker image and it did run. However, trying it on a batch of 32, the optimized version runs slower than vanilla hf difference between original and optimized model:
time to warmup: 652.45s
timings
[original] average: 2.637366771697998s / complete: 2.637366771697998s
[optimized] average: 3.123673915863037s / complete: 3.123673915863037s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 36.140648GB
torch.cuda.memory_reserved: 49.056641GB
torch.cuda.max_memory_reserved: 49.056641GB code to replicate: # %%
import time
import torch
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from kernl.model_optimization import optimize_model
torch._dynamo.config.verbose = True
torch.set_float32_matmul_precision("high")
# torchdynamo.config.cache_size_limit = 512
# torchdynamo.config.dynamic_shapes = True
max_len = 50
num_beams = 1
# model_name = "openai/whisper-tiny.en"
model_name = "openai/whisper-large-v2"
model = (
WhisperForConditionalGeneration.from_pretrained(model_name).half().to("cuda").eval()
)
audio_dataset = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
# audio_dataset = load_dataset("librispeech_asr", "clean", split="test")
def get_tokens(item: dict[str, dict]) -> torch.Tensor:
tensor = processor(
item["audio"]["array"], return_tensors="pt", sampling_rate=16_000
).input_features
return tensor.cuda()
processor = WhisperProcessor.from_pretrained(model_name)
# %%
# inputs = get_tokens(audio_dataset[0]).half()
batch_size = 32
inputs = torch.cat(
[get_tokens(audio_dataset[i]) for i in range(batch_size)], dim=0
).half()
# %%
timings_original = list()
transcriptions = list()
with torch.inference_mode():
# with torch.inference_mode(), torch.autocast(
# dtype=torch.float16, cache_enabled=True, device_type="cuda"
# ):
# warmup
model.generate(
inputs,
min_length=max_len,
max_length=max_len,
num_beams=num_beams,
do_sample=False,
)
# torch.cuda.synchronize()
# for audio in audio_dataset:
# inputs = get_tokens(audio).half()
torch.cuda.synchronize()
start = time.time()
predicted_ids = model.generate(
inputs,
min_length=1,
max_length=max_len,
num_beams=num_beams,
do_sample=False,
logits_processor=[WhisperTimeStampLogitsProcessor()],
)
torch.cuda.synchronize()
timings_original.append(time.time() - start)
transcription = processor.batch_decode(
predicted_ids, skip_special_tokens=True, normalize=True
)
# )[0]
transcriptions.extend(transcription)
print(f"timings_original {timings_original}")
# assert len(audio_dataset) == len(transcriptions)
# apply efficiency fix to HuggingFace implementation of whisper to limit memory footprint
@staticmethod
def fix_reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
WhisperForConditionalGeneration._reorder_cache = fix_reorder_cache
# %%
optimize_model(model.model.decoder)
model.model.decoder.forward_before = model.model.decoder.forward
# %%
nb_diff = 0
timings_optimized = list()
print("difference between original and optimized model:")
with torch.inference_mode():
# with torch.inference_mode(), torch.autocast(
# dtype=torch.float16, cache_enabled=True, device_type="cuda"
# ):
start = time.time()
model.generate(
inputs,
min_length=max_len,
max_length=max_len,
num_beams=num_beams,
do_sample=False,
)
# torch.cuda.synchronize()
print(f"time to warmup: {time.time() - start:.2f}s")
# for original_modem_transcription, audio in zip(transcriptions, audio_dataset):
# inputs = get_tokens(audio)
torch.cuda.synchronize()
start = time.time()
predicted_ids = model.generate(
inputs,
min_length=1,
max_length=max_len,
num_beams=num_beams,
do_sample=False,
# logits_processor=[WhisperTimeStampLogitsProcessor()],
)
torch.cuda.synchronize()
timings_optimized.append(time.time() - start)
optimized_transcription = processor.batch_decode(
predicted_ids, skip_special_tokens=True, normalize=True
)
# nb_diff += original_modem_transcription != optimized_transcription
print("timings")
print(
f"[original] average: {sum(timings_original) / len(timings_original)}s / complete: {sum(timings_original)}s"
)
print(
f"[optimized] average: {sum(timings_optimized) / len(timings_optimized)}s / complete: {sum(timings_optimized)}s"
)
print(
f"output differences: {nb_diff}/{len(audio_dataset)} ({nb_diff / len(audio_dataset) * 100:.2f}%)"
)
print("memory footprint")
print(
"torch.cuda.memory_allocated: %fGB"
% (torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024)
)
print(
"torch.cuda.memory_reserved: %fGB"
% (torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024)
)
print(
"torch.cuda.max_memory_reserved: %fGB"
% (torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024)
)
# before modification without stride fix
# difference between original and optimzed model:
# optimized model modifies 31.83% (# 834 examples) of the dataset
# [original] average: 1.0774035485646196s / complete: 2822.7972972393036s
# [optimized] average: 0.45561475826583747s / complete: 1193.7106666564941s
# torch.cuda.memory_allocated: 10.874530GB
# torch.cuda.memory_reserved: 14.064453GB
# torch.cuda.max_memory_reserved: 14.064453GB
# after modification
# difference between original and optimzed model:
# time to warmup: 694.1382637023926
# optimized model modifies 2.06% (# 54/2620 examples)
# [original] average: 1.0491925114893732s / complete: 2748.8843801021576s
# [optimized] average: 0.4339728889574531s / complete: 1137.0089690685272s
# torch.cuda.memory_allocated: 10.873960GB
# torch.cuda.memory_reserved: 13.365234GB
# torch.cuda.max_memory_reserved: 13.853516GB |
I am on my phone and may be I am missing something but I only see the batch 32 used during the creation of the input variable used for warmup. If true the warmup is done with a batch of 32, and then the first audio of evaluation dataset to be transcripted raises a new warmup for a batch of 1 which may explain the timings you are seeing. In benchmark we are using beam 5 meaning a batch size of 5 plus a reorder operation, and the model being fairly large (so GPU is already quite busy in matmul), I would not expect that by increasing the batch size the speed up to change significantly (but didn't tried as beam 1 and multiple audio in parallel do not match our use case) |
also, I just noticed that the inputs variable is overriden |
My apologies, I edited the post above. The code I pasted was wrong but the results are still the same. Also, from experience increasing the batch size speeds up inference quite a bit at least on big gpus. The tiny model shows speedup, but not large-v2. |
You are right, increasing batch generally increases GPU busyness (better use of silicon) on most models. I have launched your script with a smaller batch size (10) which fulfil 24Gb the 3090 memory. I wonder if the enabled verbose mode in the code you posted is the cause of the slowdown? Can you try without? Second point, we have a special kernel for cases where there is little opportunity to optimize attention (not enough parallilzation opportunity, it's specific to Whisper), it uses a kind of split K strategy like in matmul kernels (https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/attention_skinny.py). Code is here https://github.com/ELS-RD/kernl/blob/main/src/kernl/optimizer/attention.py#L45 Could you please replace FWIW we are replacing the script by a notebook so results will be easier to compare Thank you for your help |
You're welcome, thanks for building kernl! Sadly, with difference between original and optimized model:
time to warmup: 603.65s
timings
[original] average: 1.6842162609100342s / complete: 1.6842162609100342s
[optimized] average: 0.9564313888549805s / complete: 0.9564313888549805s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 10.535777GB
torch.cuda.memory_reserved: 13.009766GB
torch.cuda.max_memory_reserved: 13.009766GB |
Just to confirm, you are comparing Second confirmation, you have written "with attention_vec_mat_forward" but it's without right. Your code should looks like that: if q.size(-2) == 1 and k.size(-2) > 50:
if (attention_mask is None) and (not is_causal):
attention_reference(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask)
else:
attention_reference(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask)
else:
attention_forward(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask) Of course you can simplify it to: if q.size(-2) == 1 and k.size(-2) > 50:
attention_reference(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask)
else:
attention_forward(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask) If not the case, can you post measures with this code. The idea is to check the timings by just using CUDA graph without any custom kernel. Third point, regarding the memory footprint, the use of a kernel instead of another one does not affect significantly (or not at all, it depends of the kernel) the memory footprint, it's CUDA graphs which has this effect. What batch size you have tried which didn't worked? (have you tried 16 or 24 for instance?) Fourth, if you want, you can comment this line https://github.com/ELS-RD/kernl/blob/main/src/kernl/model_optimization.py#L28 to check if the custom kernels makes things slower at batch 32? (only CUDA graph would be applied) |
My bad again, it did run without difference between original and optimized model:
time to warmup: 792.89s
timings
[original] average: 2.330775022506714s / complete: 2.330775022506714s
[optimized] average: 2.410492420196533s / complete: 2.410492420196533s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 33.665029GB
torch.cuda.memory_reserved: 41.916016GB
torch.cuda.max_memory_reserved: 41.916016GB Commenting the difference between original and optimized model:
time to warmup: 322.83s
timings
[original] average: 2.3458523750305176s / complete: 2.3458523750305176s
[optimized] average: 2.314619541168213s / complete: 2.314619541168213s
output differences: 0/73 (0.00%)
memory footprint
torch.cuda.memory_allocated: 33.664189GB
torch.cuda.memory_reserved: 41.882812GB
torch.cuda.max_memory_reserved: 41.882812GB |
I took an A100 with 40Gb of RAM and ran experiments on Docker image. On batch 24 + attention_forward:
On batch 16 + replace vec mat by attention_reference I get:
Same but using attention_forward (Flash attention) instead of attention_reference:
Still batch 16, and no more triton kernels
And for comparaison, on batch 10 (main branch, no change):
At least in both cases, Kernl doesn't make things worst than baseline :-) |
Unit tests outputs for Whisper large batch shapes (comparison should be done on CUDA time as we use CUDA graphs on the model) self-attention
cross-attention
|
linear layer
|
for memory, script with batch 1 and beam 5 (which is indeed a batch 5)
import time
import torch
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from kernl.model_optimization import optimize_model
torch.set_float32_matmul_precision("high")
max_len = 50
num_beams = 5
# model_name = "openai/whisper-tiny.en"
model_name = "openai/whisper-large-v2"
model = WhisperForConditionalGeneration.from_pretrained(model_name).half().to("cuda").eval()
audio_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# audio_dataset = load_dataset("librispeech_asr", "clean", split="test")
def get_tokens(item: dict[str, dict]) -> torch.Tensor:
tensor = processor(item["audio"]["array"], return_tensors="pt", sampling_rate=16_000).input_features
return tensor.cuda()
processor = WhisperProcessor.from_pretrained(model_name)
batch_size = 1
inputs = torch.cat([get_tokens(audio_dataset[i]) for i in range(batch_size)], dim=0).half()
timings_original = list()
transcriptions = list()
with torch.inference_mode():
model.generate(
inputs,
min_length=max_len,
max_length=max_len,
num_beams=num_beams,
do_sample=False,
)
torch.cuda.synchronize()
start = time.time()
predicted_ids = model.generate(
inputs,
min_length=1,
max_length=max_len,
num_beams=num_beams,
do_sample=False,
)
torch.cuda.synchronize()
timings_original.append(time.time() - start)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)
transcriptions.extend(transcription)
print(f"timings_original {timings_original}")
@staticmethod
def fix_reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
WhisperForConditionalGeneration._reorder_cache = fix_reorder_cache
optimize_model(model.model.decoder)
nb_diff = 0
timings_optimized = list()
print("difference between original and optimized model:")
with torch.inference_mode():
start = time.time()
model.generate(
inputs,
min_length=max_len,
max_length=max_len,
num_beams=num_beams,
do_sample=False,
)
torch.cuda.synchronize()
print(f"time to warmup: {time.time() - start:.2f}s")
torch.cuda.synchronize()
start = time.time()
predicted_ids = model.generate(
inputs,
min_length=1,
max_length=max_len,
num_beams=num_beams,
do_sample=False,
)
torch.cuda.synchronize()
timings_optimized.append(time.time() - start)
optimized_transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)
print("timings")
print(f"[original] average: {sum(timings_original) / len(timings_original)}s / complete: {sum(timings_original)}s")
print(f"[optimized] average: {sum(timings_optimized) / len(timings_optimized)}s / complete: {sum(timings_optimized)}s")
print("memory footprint")
print("torch.cuda.memory_allocated: %fGB" % (torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024))
print("torch.cuda.memory_reserved: %fGB" % (torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024))
print("torch.cuda.max_memory_reserved: %fGB" % (torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024)) |
FWIW, the notebook has been run on A100, the speedup is similar to 3090RTX (2.3x) |
@TheExGenesis may I close this issue? |
Yes, thank you very much for all your attention! |
Running
test_whisper_hf("optimized")
is hanging forever. I'm on Ubuntu, with cuda enabled, kernl requirements installed. Congrats on the Whisper support today by the way, hopefully this is a short fix.The text was updated successfully, but these errors were encountered: