Skip to content

Commit

Permalink
[export] Fix graph_break log registration error when importing expo…
Browse files Browse the repository at this point in the history
…rt/_trace.py (pytorch#131523)

Summary:
When importing `_trace.py`, put `torch._dynamo.exc.Unsupported` in the global variable ``_ALLOW_LIST`` can cause import to ``export/_trace.py`` to fail with error:

ValueError: Artifact name: 'graph_breaks' not registered, please call register_artifact('graph_breaks') in torch._logging.registrations.

The error is directly raise on line `graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")` in `_dynamo/exc.py`. I've checked that ``register_artifact('graph_breaks')`` does already exist in torch._logging.registrations.

Explicitly call `import torch._logging` doesn't fix the issue.

(see T196719676)

We move ``_ALLOW_LIST`` to be a local variable.

Test Plan:
buck2 test 'fbcode//mode/opt' fbcode//aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test -- --exact 'aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test - test_serialized_model_for_disagg_acc (aiplatform.modelstore.publish.utils.tests.fc_transform_utils_test.PrepareSerializedModelTest)'

buck2 test 'fbcode//mode/opt' fbcode//aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test -- --exact 'aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test - test_serialized_test_dsnn_module (aiplatform.modelstore.publish.utils.tests.fc_transform_utils_test.PrepareSerializedModelTest)'

Differential Revision: D60136706

Pull Request resolved: pytorch#131523
Approved by: https://github.com/zhxchen17
  • Loading branch information
yushangdi authored and pytorchmergebot committed Jul 24, 2024
1 parent 236e06f commit 29c9f8c
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,14 +991,14 @@ def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]:
_EXPORT_MODULE_HIERARCHY: Optional[Dict[str, str]] = None


_ALLOW_LIST = {
torch._dynamo.exc.Unsupported,
torch._dynamo.exc.UserError,
torch._dynamo.exc.TorchRuntimeError,
}


def _get_class_if_classified_error(e):
from torch._dynamo.exc import TorchRuntimeError, Unsupported, UserError

_ALLOW_LIST = {
Unsupported,
UserError,
TorchRuntimeError,
}
case_name = getattr(e, "case_name", None)
if type(e) in _ALLOW_LIST and case_name is not None:
return case_name
Expand Down

0 comments on commit 29c9f8c

Please sign in to comment.