6565R = TypeVar ("R" )
6666
6767
68- class NonStrictTypeError (Exception ):
68+ class ModelCheckerException (Exception ):
6969 """Dummy exception. Allows us to detect unwanted types during a module import."""
7070
7171
72+ class MissingStrictInConstrainedTypeException (ModelCheckerException ):
73+ factory_name : str
74+
75+ def __init__ (self , factory_name : str ):
76+ self .factory_name = factory_name
77+
78+
79+ class FieldHasUnwantedTypeException (ModelCheckerException ):
80+ message : str
81+
82+ def __init__ (self , message : str ):
83+ self .message = message
84+
85+
7286def make_wrapper (factory : Callable [P , R ]) -> Callable [P , R ]:
73- """We patch `constr` and friends with wrappers that enforce strict=True. """
87+ """We patch `constr` and friends with wrappers that enforce strict=True."""
7488
7589 @functools .wraps (factory )
7690 def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> R :
7791 # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668
7892 if "strict" not in kwargs : # type: ignore[attr-defined]
79- raise NonStrictTypeError ( )
93+ raise MissingStrictInConstrainedTypeException ( factory . __name__ )
8094 if not kwargs ["strict" ]: # type: ignore[index]
81- raise NonStrictTypeError ( )
95+ raise MissingStrictInConstrainedTypeException ( factory . __name__ )
8296 return factory (* args , ** kwargs )
8397
8498 return wrapper
@@ -113,18 +127,25 @@ def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object):
113127 # Note that field.type_ and field.outer_type are computed based on the
114128 # annotation type, see pydantic.fields.ModelField._type_analysis
115129 if field_type_unwanted (field .outer_type_ ):
116- raise NonStrictTypeError ()
130+ # TODO: this only reports the first bad field. Can we find all bad ones
131+ # and report them all?
132+ raise FieldHasUnwantedTypeException (
133+ f"{ cls .__module__ } .{ cls .__qualname__ } has field '{ field .name } ' "
134+ f"with unwanted type `{ field .outer_type_ } `"
135+ )
117136
118137
119138@contextmanager
120139def monkeypatch_pydantic () -> Generator [None , None , None ]:
121140 """Patch pydantic with our snooping versions of BaseModel and the con* functions.
122141
123- Most Synapse code ought to import the patched objects directly from `pydantic`.
124- But we include their containing models `pydantic.main` and `pydantic.types` for
125- completeness.
142+ If the snooping functions see something they don't like, they'll raise a
143+ ModelCheckingException instance.
126144 """
127145 with contextlib .ExitStack () as patches :
146+ # Most Synapse code ought to import the patched objects directly from
147+ # `pydantic`. But we also patch their containing modules `pydantic.main` and
148+ # `pydantic.types` for completeness.
128149 patch_basemodel1 = unittest .mock .patch (
129150 "pydantic.BaseModel" , new = PatchedBaseModel
130151 )
@@ -144,10 +165,20 @@ def monkeypatch_pydantic() -> Generator[None, None, None]:
144165 yield
145166
146167
147- def format_error (e : NonStrictTypeError ) -> str :
168+ def format_model_checker_exception (e : ModelCheckerException ) -> str :
148169 """Work out which line of code caused e. Format the line in a human-friendly way."""
149- frame_summary = traceback .extract_tb (e .__traceback__ )[- 2 ]
150- return traceback .format_list ([frame_summary ])[0 ].lstrip ()
170+ # TODO. FieldHasUnwantedTypeException gives better error messages. Can we ditch the
171+ # patches of constr() etc, and instead inspect fields to look for ConstrainedStr
172+ # with strict=False? There is some difficulty with the inheritance hierarchy
173+ # because StrictStr < ConstrainedStr < str.
174+ if isinstance (e , FieldHasUnwantedTypeException ):
175+ return e .message
176+ elif isinstance (e , MissingStrictInConstrainedTypeException ):
177+ frame_summary = traceback .extract_tb (e .__traceback__ )[- 2 ]
178+ return (
179+ f"Missing `strict=True` from { e .factory_name } () call \n "
180+ + traceback .format_list ([frame_summary ])[0 ].lstrip ()
181+ )
151182
152183
153184def lint () -> int :
@@ -168,26 +199,30 @@ def do_lint() -> Set[str]:
168199
169200 with monkeypatch_pydantic ():
170201 try :
171- synapse = importlib .import_module ("synapse" )
172- except NonStrictTypeError as e :
202+ # TODO: make "synapse" an argument so we can target this script at
203+ # a subpackage
204+ module = importlib .import_module ("synapse" )
205+ except ModelCheckerException as e :
173206 logger .warning ("Bad annotation found when importing synapse" )
174- failures .add (format_error (e ))
207+ failures .add (format_model_checker_exception (e ))
175208 return failures
176209
177210 try :
178- modules = list (pkgutil .walk_packages (synapse .__path__ , "synapse." ))
179- except NonStrictTypeError as e :
211+ modules = list (
212+ pkgutil .walk_packages (module .__path__ , f"{ module .__name__ } ." )
213+ )
214+ except ModelCheckerException as e :
180215 logger .warning ("Bad annotation found when looking for modules to import" )
181- failures .add (format_error (e ))
216+ failures .add (format_model_checker_exception (e ))
182217 return failures
183218
184219 for module in modules :
185220 logger .debug ("Importing %s" , module .name )
186221 try :
187222 importlib .import_module (module .name )
188- except NonStrictTypeError as e :
223+ except ModelCheckerException as e :
189224 logger .warning (f"Bad annotation found when importing { module .name } " )
190- failures .add (format_error (e ))
225+ failures .add (format_model_checker_exception (e ))
191226
192227 return failures
193228
@@ -208,7 +243,7 @@ def run_test_snippet(source: str) -> None:
208243
209244class TestConstrainedTypesPatch (unittest .TestCase ):
210245 def test_expression_without_strict_raises (self ) -> None :
211- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
246+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
212247 run_test_snippet (
213248 """
214249 from pydantic import constr
@@ -217,7 +252,7 @@ def test_expression_without_strict_raises(self) -> None:
217252 )
218253
219254 def test_called_as_module_attribute_raises (self ) -> None :
220- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
255+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
221256 run_test_snippet (
222257 """
223258 import pydantic
@@ -226,7 +261,7 @@ def test_called_as_module_attribute_raises(self) -> None:
226261 )
227262
228263 def test_wildcard_import_raises (self ) -> None :
229- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
264+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
230265 run_test_snippet (
231266 """
232267 from pydantic import *
@@ -235,7 +270,7 @@ def test_wildcard_import_raises(self) -> None:
235270 )
236271
237272 def test_alternative_import_raises (self ) -> None :
238- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
273+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
239274 run_test_snippet (
240275 """
241276 from pydantic.types import constr
@@ -244,7 +279,7 @@ def test_alternative_import_raises(self) -> None:
244279 )
245280
246281 def test_alternative_import_attribute_raises (self ) -> None :
247- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
282+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
248283 run_test_snippet (
249284 """
250285 import pydantic.types
@@ -253,7 +288,7 @@ def test_alternative_import_attribute_raises(self) -> None:
253288 )
254289
255290 def test_kwarg_but_no_strict_raises (self ) -> None :
256- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
291+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
257292 run_test_snippet (
258293 """
259294 from pydantic import constr
@@ -262,7 +297,7 @@ def test_kwarg_but_no_strict_raises(self) -> None:
262297 )
263298
264299 def test_kwarg_strict_False_raises (self ) -> None :
265- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
300+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
266301 run_test_snippet (
267302 """
268303 from pydantic import constr
@@ -280,7 +315,7 @@ def test_kwarg_strict_True_doesnt_raise(self) -> None:
280315 )
281316
282317 def test_annotation_without_strict_raises (self ) -> None :
283- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
318+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
284319 run_test_snippet (
285320 """
286321 from pydantic import constr
@@ -289,7 +324,7 @@ def test_annotation_without_strict_raises(self) -> None:
289324 )
290325
291326 def test_field_annotation_without_strict_raises (self ) -> None :
292- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
327+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
293328 run_test_snippet (
294329 """
295330 from pydantic import BaseModel, conint
@@ -317,7 +352,7 @@ class TestFieldTypeInspection(unittest.TestCase):
317352 ]
318353 )
319354 def test_field_holding_unwanted_type_raises (self , annotation : str ) -> None :
320- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
355+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
321356 run_test_snippet (
322357 f"""
323358 from typing import *
@@ -355,7 +390,7 @@ class C(BaseModel):
355390 )
356391
357392 def test_field_holding_str_raises_with_alternative_import (self ) -> None :
358- with monkeypatch_pydantic (), self .assertRaises (NonStrictTypeError ):
393+ with monkeypatch_pydantic (), self .assertRaises (ModelCheckerException ):
359394 run_test_snippet (
360395 """
361396 from pydantic.main import BaseModel
0 commit comments