@@ -1105,16 +1105,24 @@ def read_orig_module(path) -> cst.Module:
11051105 return cst .parse_module (f .read ())
11061106
11071107
1108- def find_function (module : cst .Module , name : str ) -> Union [cst .FunctionDef , None ]:
1108+ def find_top_level_function_or_method (module : cst .Module , name : str ) -> Union [cst .FunctionDef , None ]:
11091109 name = name .split ('.' )[- 1 ]
1110- return next (iter (m .findall (module , m .FunctionDef (m .Name (name )))), None ) # type: ignore
1110+ for child in module .body :
1111+ if isinstance (child , cst .FunctionDef ) and child .name .value == name :
1112+ return child
1113+ if isinstance (child , cst .ClassDef ) and isinstance (child .body , cst .IndentedBlock ):
1114+ for method in child .body .body :
1115+ if isinstance (method , cst .FunctionDef ) and method .name .value == name :
1116+ return method
1117+
1118+ return None
11111119
11121120
11131121def read_original_function (module : cst .Module , mutant_name : str ):
11141122 orig_function_name , _ = orig_function_and_class_names_from_key (mutant_name )
11151123 orig_name = mangled_name_from_mutant_name (mutant_name ) + '__mutmut_orig'
11161124
1117- result = find_function (module , orig_name )
1125+ result = find_top_level_function_or_method (module , orig_name )
11181126 if not result :
11191127 raise FileNotFoundError (f'Could not find original function "{ orig_function_name } "' )
11201128 return result .with_changes (name = cst .Name (orig_function_name ))
@@ -1123,7 +1131,7 @@ def read_original_function(module: cst.Module, mutant_name: str):
11231131def read_mutant_function (module : cst .Module , mutant_name : str ):
11241132 orig_function_name , _ = orig_function_and_class_names_from_key (mutant_name )
11251133
1126- result = find_function (module , mutant_name )
1134+ result = find_top_level_function_or_method (module , mutant_name )
11271135 if not result :
11281136 raise FileNotFoundError (f'Could not find original function "{ orig_function_name } "' )
11291137 return result .with_changes (name = cst .Name (orig_function_name ))
@@ -1196,7 +1204,7 @@ def apply_mutant(mutant_name):
11961204 mutant_function = read_mutant_function (mutants_module , mutant_name )
11971205 mutant_function = mutant_function .with_changes (name = cst .Name (orig_function_name ))
11981206
1199- original_function = find_function (orig_module , orig_function_name )
1207+ original_function = find_top_level_function_or_method (orig_module , orig_function_name )
12001208 if not original_function :
12011209 raise FileNotFoundError (f'Could not apply mutant { mutant_name } ' )
12021210
@@ -1270,6 +1278,7 @@ def on_mount(self):
12701278 def read_data (self ):
12711279 ensure_config_loaded ()
12721280 self .source_file_mutation_data_and_stat_by_path = {}
1281+ self .path_by_name = {}
12731282
12741283 for p in walk_source_files ():
12751284 if mutmut .config .should_ignore_for_mutation (p ):
@@ -1279,6 +1288,8 @@ def read_data(self):
12791288 stat = collect_stat (source_file_mutation_data )
12801289
12811290 self .source_file_mutation_data_and_stat_by_path [str (p )] = source_file_mutation_data , stat
1291+ for name in source_file_mutation_data .exit_code_by_key :
1292+ self .path_by_name [name ] = p
12821293
12831294 def populate_files_table (self ):
12841295 # noinspection PyTypeChecker
@@ -1317,11 +1328,12 @@ def on_data_table_row_highlighted(self, event):
13171328 else :
13181329 diff_view .update ('<loading...>' )
13191330 self .loading_id = event .row_key .value
1331+ path = self .path_by_name .get (event .row_key .value )
13201332
13211333 def load_thread ():
13221334 ensure_config_loaded ()
13231335 try :
1324- d = get_diff_for_mutant (event .row_key .value )
1336+ d = get_diff_for_mutant (event .row_key .value , path = path )
13251337 if event .row_key .value == self .loading_id :
13261338 diff_view .update (Syntax (d , "diff" ))
13271339 except Exception as e :
0 commit comments