20
20
from .operations import GatheredParameters
21
21
22
22
23
- if is_fp8_available ():
24
- import transformer_engine .pytorch as te
23
+ # Do not import `transformer_engine` at package level to avoid potential issues
25
24
26
25
27
26
def convert_model (model , to_transformer_engine = True , _convert_linear = True , _convert_ln = True ):
@@ -30,6 +29,8 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
30
29
"""
31
30
if not is_fp8_available ():
32
31
raise ImportError ("Using `convert_model` requires transformer_engine to be installed." )
32
+ import transformer_engine .pytorch as te
33
+
33
34
for name , module in model .named_children ():
34
35
if isinstance (module , nn .Linear ) and to_transformer_engine and _convert_linear :
35
36
has_bias = module .bias is not None
@@ -87,6 +88,8 @@ def has_transformer_engine_layers(model):
87
88
"""
88
89
if not is_fp8_available ():
89
90
raise ImportError ("Using `has_transformer_engine_layers` requires transformer_engine to be installed." )
91
+ import transformer_engine .pytorch as te
92
+
90
93
for m in model .modules ():
91
94
if isinstance (m , (te .LayerNorm , te .Linear , te .TransformerLayer )):
92
95
return True
@@ -98,6 +101,8 @@ def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
98
101
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
99
102
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
100
103
"""
104
+ if not is_fp8_available ():
105
+ raise ImportError ("Using `contextual_fp8_autocast` requires transformer_engine to be installed." )
101
106
from transformer_engine .pytorch import fp8_autocast
102
107
103
108
def forward (self , * args , ** kwargs ):
@@ -115,7 +120,8 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
115
120
"""
116
121
Applies FP8 context manager to the model's forward method
117
122
"""
118
- # Import here to keep base imports fast
123
+ if not is_fp8_available ():
124
+ raise ImportError ("Using `apply_fp8_autowrap` requires transformer_engine to be installed." )
119
125
import transformer_engine .common .recipe as te_recipe
120
126
121
127
kwargs = fp8_recipe_handler .to_kwargs () if fp8_recipe_handler is not None else {}
0 commit comments