File tree Expand file tree Collapse file tree 1 file changed +2
-1
lines changed
src/transformers/integrations Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Original file line number Diff line number Diff line change 2929from typing import Optional , Tuple , Union
3030
3131import torch
32+ from packaging import version
3233
3334from ..utils import is_torch_flex_attn_available
3435from ..utils .import_utils import _torch_version
@@ -66,7 +67,7 @@ def __init__(self, training):
6667 # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
6768 # see https://github.com/pytorch/pytorch/issues/146260 for training
6869 self .training = training
69- if _torch_version . split ( "+" )[ 0 ] == "2.6.0" and training :
70+ if version . parse ( _torch_version ). base_version == "2.6.0" and training :
7071 self ._compiled_flex_attention = torch .compile (
7172 flex_attention , dynamic = False , mode = "max-autotune-no-cudagraphs"
7273 )
You can’t perform that action at this time.
0 commit comments