Skip to content

Commit 3fcc946

Browse files
oralubenmuellerzr
andauthored
Do not import transformer_engine on import (#3056)
* Do not import `transformer_engine` on import * fix message * add test * Update test_imports.py * resolve comment 1/2 * resolve comment 1.5/2 * lint * more lint * Update tests/test_imports.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fmt --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
1 parent 939ce40 commit 3fcc946

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

src/accelerate/utils/transformer_engine.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from .operations import GatheredParameters
2121

2222

23-
if is_fp8_available():
24-
import transformer_engine.pytorch as te
23+
# Do not import `transformer_engine` at package level to avoid potential issues
2524

2625

2726
def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
@@ -30,6 +29,8 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
3029
"""
3130
if not is_fp8_available():
3231
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
32+
import transformer_engine.pytorch as te
33+
3334
for name, module in model.named_children():
3435
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
3536
has_bias = module.bias is not None
@@ -87,6 +88,8 @@ def has_transformer_engine_layers(model):
8788
"""
8889
if not is_fp8_available():
8990
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
91+
import transformer_engine.pytorch as te
92+
9093
for m in model.modules():
9194
if isinstance(m, (te.LayerNorm, te.Linear, te.TransformerLayer)):
9295
return True
@@ -98,6 +101,8 @@ def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
98101
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
99102
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
100103
"""
104+
if not is_fp8_available():
105+
raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
101106
from transformer_engine.pytorch import fp8_autocast
102107

103108
def forward(self, *args, **kwargs):
@@ -115,7 +120,8 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
115120
"""
116121
Applies FP8 context manager to the model's forward method
117122
"""
118-
# Import here to keep base imports fast
123+
if not is_fp8_available():
124+
raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
119125
import transformer_engine.common.recipe as te_recipe
120126

121127
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}

tests/test_imports.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import subprocess
15+
import sys
1516

17+
from accelerate.test_utils import require_transformer_engine
1618
from accelerate.test_utils.testing import TempDirTestCase, require_import_timer
1719
from accelerate.utils import is_import_timer_available
1820

@@ -31,7 +33,7 @@ def convert_list_to_string(data):
3133

3234

3335
def run_import_time(command: str):
34-
output = subprocess.run(["python3", "-X", "importtime", "-c", command], capture_output=True, text=True)
36+
output = subprocess.run([sys.executable, "-X", "importtime", "-c", command], capture_output=True, text=True)
3537
return output.stderr
3638

3739

@@ -81,3 +83,18 @@ def test_cli_import(self):
8183
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7)
8284
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
8385
self.assertLess(pct_more, 20, err_msg)
86+
87+
88+
@require_transformer_engine
89+
class LazyImportTester(TempDirTestCase):
90+
"""
91+
Test suite which checks if specific packages are lazy-loaded.
92+
93+
Eager-import will trigger circular import in some case,
94+
e.g. in huggingface/accelerate#3056.
95+
"""
96+
97+
def test_te_import(self):
98+
output = run_import_time("import accelerate, accelerate.utils.transformer_engine")
99+
100+
self.assertFalse(" transformer_engine" in output, "`transformer_engine` should not be imported on import")

0 commit comments

Comments
 (0)