@@ -89,9 +89,7 @@ def import_module(module_path):
8989 Import a module from a given file path, in a way that works with both
9090 MicroPython and Pyodide.
9191 """
92- dotted_path = (
93- str (module_path ).replace ("/" , "." ).replace (".py" , "" )
94- )
92+ dotted_path = str (module_path ).replace ("/" , "." ).replace (".py" , "" )
9593 dotted_path = dotted_path .lstrip ("." )
9694 module = __import__ (dotted_path )
9795 for part in dotted_path .split ("." )[1 :]:
@@ -125,29 +123,28 @@ class TestCase:
125123 Represents an individual test to run.
126124 """
127125
128- def __init__ (self , test_function , module_name , test_name ):
126+ def __init__ (self , test_function , module_name , test_name , function_id ):
129127 """
130128 A TestCase is instantiated with a callable test_function, the name of
131- the module containing the test, and the name of the test within the
132- module .
129+ the module containing the test, the name of the test within the module
130+ and the unique Python id of the test function .
133131 """
134132 self .test_function = test_function
135133 self .module_name = module_name
136134 self .test_name = test_name
135+ self .function_id = function_id
137136 self .status = PENDING # the initial state of the test.
138137 self .traceback = None # to contain details of any failure.
139- self .reason = (
140- None # to contain the reason for skipping the test.
141- )
138+ self .reason = None # to contain the reason for skipping the test.
142139
143140 async def run (self ):
144141 """
145142 Run the test function and set the status and traceback attributes, as
146143 required.
147144 """
148- if id ( self .test_function ) in _SKIPPED_TESTS :
145+ if self .function_id in _SKIPPED_TESTS :
149146 self .status = SKIPPED
150- self .reason = _SKIPPED_TESTS .get (id ( self .test_function ) )
147+ self .reason = _SKIPPED_TESTS .get (self .function_id )
151148 if not self .reason :
152149 self .reason = "No reason given."
153150 return
@@ -186,12 +183,28 @@ def __init__(self, path, module, setup=None, teardown=None):
186183 for name , item in self .module .__dict__ .items ():
187184 if callable (item ) or is_awaitable (item ):
188185 if name .startswith ("test" ):
189- t = TestCase (item , self .path , name )
186+ # A simple test function.
187+ t = TestCase (item , self .path , name , id (item ))
190188 self ._tests .append (t )
189+ elif inspect .isclass (item ) and name .startswith ("Test" ):
190+ # A test class, so check for test methods.
191+ instance = item ()
192+ for method_name , method in item .__dict__ .items ():
193+ if callable (method ) or is_awaitable (method ):
194+ if method_name .startswith ("test" ):
195+ t = TestCase (
196+ getattr (instance , method_name ),
197+ self .path ,
198+ f"{ name } .{ method_name } " ,
199+ id (method ),
200+ )
201+ self ._tests .append (t )
191202 elif name == "setup" :
203+ # A local setup function.
192204 self ._setup = item
193205 local_setup_teardown = True
194206 elif name == "teardown" :
207+ # A local teardown function.
195208 self ._teardown = item
196209 local_setup_teardown = True
197210 if local_setup_teardown :
@@ -223,10 +236,14 @@ def teardown(self):
223236
224237 def limit_tests_to (self , test_names ):
225238 """
226- Limit the tests run to the provided test_names list of names.
239+ Limit the tests run to the provided test_names list of names of test
240+ functions or test classes.
227241 """
228242 self ._tests = [
229- t for t in self ._tests if t .test_name in test_names
243+ t
244+ for t in self ._tests
245+ if (t .test_name in test_names )
246+ or (t .test_name .split ("." )[0 ] in test_names )
230247 ]
231248
232249 async def run (self ):
@@ -270,11 +287,7 @@ def gather_conftest_functions(conftest_path, target):
270287 )
271288 conftest = import_module (conftest_path )
272289 setup = conftest .setup if hasattr (conftest , "setup" ) else None
273- teardown = (
274- conftest .teardown
275- if hasattr (conftest , "teardown" )
276- else None
277- )
290+ teardown = conftest .teardown if hasattr (conftest , "teardown" ) else None
278291 return setup , teardown
279292 return None , None
280293
@@ -302,24 +315,16 @@ def discover(targets, pattern, setup=None, teardown=None):
302315 result = []
303316 for target in targets :
304317 if "::" in target :
305- conftest_path = (
306- Path (target .split ("::" )[0 ]).parent / "conftest.py"
307- )
308- setup , teardown = gather_conftest_functions (
309- conftest_path , target
310- )
318+ conftest_path = Path (target .split ("::" )[0 ]).parent / "conftest.py"
319+ setup , teardown = gather_conftest_functions (conftest_path , target )
311320 module_path , test_names = target .split ("::" )
312321 module_instance = import_module (module_path )
313- module = TestModule (
314- module_path , module_instance , setup , teardown
315- )
322+ module = TestModule (module_path , module_instance , setup , teardown )
316323 module .limit_tests_to (test_names .split ("," ))
317324 result .append (module )
318325 elif os .path .isdir (target ):
319326 conftest_path = Path (target ) / "conftest.py"
320- setup , teardown = gather_conftest_functions (
321- conftest_path , target
322- )
327+ setup , teardown = gather_conftest_functions (conftest_path , target )
323328 for module_path in Path (target ).rglob (pattern ):
324329 module_instance = import_module (module_path )
325330 module = TestModule (
@@ -328,13 +333,9 @@ def discover(targets, pattern, setup=None, teardown=None):
328333 result .append (module )
329334 else :
330335 conftest_path = Path (target ).parent / "conftest.py"
331- setup , teardown = gather_conftest_functions (
332- conftest_path , target
333- )
336+ setup , teardown = gather_conftest_functions (conftest_path , target )
334337 module_instance = import_module (target )
335- module = TestModule (
336- target , module_instance , setup , teardown
337- )
338+ module = TestModule (target , module_instance , setup , teardown )
338339 result .append (module )
339340 return result
340341
0 commit comments