6
6
7
7
from unittest import mock
8
8
9
+ from pip ._vendor .packaging .requirements import Requirement
9
10
10
- class InternalError (RuntimeError ):
11
- def __init__ (self ) -> None :
12
- # TODO: check against pip version
13
- # TODO: fix wording
14
- msg = (
15
- "Unexpected internal pytorch-pip-shim error. If you ever encounter this "
16
- "message during normal operation, please submit a bug report at "
17
- "https://github.com/pmeier/pytorch-pip-shim/issues"
11
+ from light_the_torch ._compat import importlib_metadata
12
+
13
+
14
+ class UnexpectedInternalError (Exception ):
15
+ def __init__ (self , msg ) -> None :
16
+ actual_pip_version = Requirement (f"pip=={ importlib_metadata .version ('pip' )} " )
17
+ required_pip_version = next (
18
+ requirement
19
+ for requirement in (
20
+ Requirement (requirement_string )
21
+ for requirement_string in importlib_metadata .requires ("light_the_torch" )
22
+ )
23
+ if requirement .name == "pip"
24
+ )
25
+ super ().__init__ (
26
+ f"{ msg } \n \n "
27
+ f"This can happen when the actual pip version (`{ actual_pip_version } `) "
28
+ f"and the one required by light-the-torch (`{ required_pip_version } `) "
29
+ f"are out of sync.\n "
30
+ f"If that is the case, please reinstall light-the-torch. "
31
+ f"Otherwise, please submit a bug report at "
32
+ f"https://github.com/pmeier/light-the-torch/issues"
18
33
)
19
- super ().__init__ (msg )
20
34
21
35
22
36
class Input (dict ):
@@ -77,7 +91,7 @@ def apply_fn_patch(
77
91
postprocessing = lambda input , output : output ,
78
92
):
79
93
target = "." .join (parts )
80
- fn = import_fn (target )
94
+ fn = import_obj (target )
81
95
82
96
@functools .wraps (fn )
83
97
def new (* args , ** kwargs ):
@@ -93,21 +107,33 @@ def new(*args, **kwargs):
93
107
yield
94
108
95
109
96
- def import_fn (target : str ):
110
+ def import_obj (target : str ):
97
111
attrs = []
98
112
name = target
99
113
while name :
100
114
try :
101
115
module = importlib .import_module (name )
102
116
break
103
117
except ImportError :
104
- name , attr = name .rsplit ("." , 1 )
105
- attrs .append (attr )
118
+ try :
119
+ name , attr = name .rsplit ("." , 1 )
120
+ except ValueError :
121
+ attr = name
122
+ name = ""
123
+ attrs .insert (0 , attr )
106
124
else :
107
- raise InternalError
125
+ raise UnexpectedInternalError (
126
+ f"Tried to import `{ target } `, "
127
+ f"but the top-level namespace `{ attrs [0 ]} ` doesn't seem to be a module."
128
+ )
108
129
109
130
obj = module
110
- for attr in attrs [::- 1 ]:
111
- obj = getattr (obj , attr )
131
+ for attr in attrs :
132
+ try :
133
+ obj = getattr (obj , attr )
134
+ except AttributeError :
135
+ raise UnexpectedInternalError (
136
+ f"Failed to access `{ attr } ` from `{ obj .__name__ } `"
137
+ ) from None
112
138
113
139
return obj
0 commit comments