Skip to content

Commit 7df4d05

Browse files
committed
partially port #117
1 parent 9cbd8ff commit 7df4d05

File tree

1 file changed

+42
-16
lines changed

1 file changed

+42
-16
lines changed

light_the_torch/_patch/utils.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,31 @@
66

77
from unittest import mock
88

9+
from pip._vendor.packaging.requirements import Requirement
910

10-
class InternalError(RuntimeError):
11-
def __init__(self) -> None:
12-
# TODO: check against pip version
13-
# TODO: fix wording
14-
msg = (
15-
"Unexpected internal pytorch-pip-shim error. If you ever encounter this "
16-
"message during normal operation, please submit a bug report at "
17-
"https://github.com/pmeier/pytorch-pip-shim/issues"
11+
from light_the_torch._compat import importlib_metadata
12+
13+
14+
class UnexpectedInternalError(Exception):
15+
def __init__(self, msg) -> None:
16+
actual_pip_version = Requirement(f"pip=={importlib_metadata.version('pip')}")
17+
required_pip_version = next(
18+
requirement
19+
for requirement in (
20+
Requirement(requirement_string)
21+
for requirement_string in importlib_metadata.requires("light_the_torch")
22+
)
23+
if requirement.name == "pip"
24+
)
25+
super().__init__(
26+
f"{msg}\n\n"
27+
f"This can happen when the actual pip version (`{actual_pip_version}`) "
28+
f"and the one required by light-the-torch (`{required_pip_version}`) "
29+
f"are out of sync.\n"
30+
f"If that is the case, please reinstall light-the-torch. "
31+
f"Otherwise, please submit a bug report at "
32+
f"https://github.com/pmeier/light-the-torch/issues"
1833
)
19-
super().__init__(msg)
2034

2135

2236
class Input(dict):
@@ -77,7 +91,7 @@ def apply_fn_patch(
7791
postprocessing=lambda input, output: output,
7892
):
7993
target = ".".join(parts)
80-
fn = import_fn(target)
94+
fn = import_obj(target)
8195

8296
@functools.wraps(fn)
8397
def new(*args, **kwargs):
@@ -93,21 +107,33 @@ def new(*args, **kwargs):
93107
yield
94108

95109

96-
def import_fn(target: str):
110+
def import_obj(target: str):
97111
attrs = []
98112
name = target
99113
while name:
100114
try:
101115
module = importlib.import_module(name)
102116
break
103117
except ImportError:
104-
name, attr = name.rsplit(".", 1)
105-
attrs.append(attr)
118+
try:
119+
name, attr = name.rsplit(".", 1)
120+
except ValueError:
121+
attr = name
122+
name = ""
123+
attrs.insert(0, attr)
106124
else:
107-
raise InternalError
125+
raise UnexpectedInternalError(
126+
f"Tried to import `{target}`, "
127+
f"but the top-level namespace `{attrs[0]}` doesn't seem to be a module."
128+
)
108129

109130
obj = module
110-
for attr in attrs[::-1]:
111-
obj = getattr(obj, attr)
131+
for attr in attrs:
132+
try:
133+
obj = getattr(obj, attr)
134+
except AttributeError:
135+
raise UnexpectedInternalError(
136+
f"Failed to access `{attr}` from `{obj.__name__}`"
137+
) from None
112138

113139
return obj

0 commit comments

Comments
 (0)