@@ -20,7 +20,7 @@ def __init__(self) -> None:
2020        self .project_root  =  None 
2121        self .benchmark_timings  =  []
2222
23-     def  setup (self , trace_path :str , project_root :str ) ->  None :
23+     def  setup (self , trace_path :  str , project_root :  str ) ->  None :
2424        try :
2525            # Open connection 
2626            self .project_root  =  project_root 
@@ -35,7 +35,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
3535                "benchmark_time_ns INTEGER)" 
3636            )
3737            self ._connection .commit ()
38-             self .close () # Reopen only at the end of pytest session 
38+             self .close ()   # Reopen only at the end of pytest session 
3939        except  Exception  as  e :
4040            print (f"Database setup error: { e }  " )
4141            if  self ._connection :
@@ -55,14 +55,15 @@ def write_benchmark_timings(self) -> None:
5555            # Insert data into the benchmark_timings table 
5656            cur .executemany (
5757                "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)" ,
58-                 self .benchmark_timings 
58+                 self .benchmark_timings , 
5959            )
6060            self ._connection .commit ()
61-             self .benchmark_timings  =  [] # Clear the benchmark timings list 
61+             self .benchmark_timings  =  []   # Clear the benchmark timings list 
6262        except  Exception  as  e :
6363            print (f"Error writing to benchmark timings database: { e }  " )
6464            self ._connection .rollback ()
6565            raise 
66+ 
6667    def  close (self ) ->  None :
6768        if  self ._connection :
6869            self ._connection .close ()
@@ -196,12 +197,7 @@ def pytest_sessionfinish(self, session, exitstatus):
196197
197198    @staticmethod  
198199    def  pytest_addoption (parser ):
199-         parser .addoption (
200-             "--codeflash-trace" ,
201-             action = "store_true" ,
202-             default = False ,
203-             help = "Enable CodeFlash tracing" 
204-         )
200+         parser .addoption ("--codeflash-trace" , action = "store_true" , default = False , help = "Enable CodeFlash tracing" )
205201
206202    @staticmethod  
207203    def  pytest_plugin_registered (plugin , manager ):
@@ -213,9 +209,9 @@ def pytest_plugin_registered(plugin, manager):
213209    def  pytest_configure (config ):
214210        """Register the benchmark marker.""" 
215211        config .addinivalue_line (
216-             "markers" ,
217-             "benchmark: mark test as a benchmark that should be run with codeflash tracing" 
212+             "markers" , "benchmark: mark test as a benchmark that should be run with codeflash tracing" 
218213        )
214+ 
219215    @staticmethod  
220216    def  pytest_collection_modifyitems (config , items ):
221217        # Skip tests that don't have the benchmark fixture 
@@ -224,19 +220,18 @@ def pytest_collection_modifyitems(config, items):
224220
225221        skip_no_benchmark  =  pytest .mark .skip (reason = "Test requires benchmark fixture" )
226222        for  item  in  items :
227-             # Check for direct benchmark fixture usage 
228-             has_fixture  =  hasattr (item , "fixturenames" ) and  "benchmark"  in  item .fixturenames 
223+             if  hasattr (item , "fixturenames" ):
224+                 if  "benchmark"  in  item .fixturenames :
225+                     continue   # Skip current iteration if benchmark fixture is present 
229226
230227            # Check for @pytest.mark.benchmark marker 
231-             has_marker  =  False 
232228            if  hasattr (item , "get_closest_marker" ):
233229                marker  =  item .get_closest_marker ("benchmark" )
234230                if  marker  is  not   None :
235-                     has_marker   =   True 
231+                     continue    # Skip current iteration if benchmark marker is present 
236232
237233            # Skip if neither fixture nor marker is present 
238-             if  not  (has_fixture  or  has_marker ):
239-                 item .add_marker (skip_no_benchmark )
234+             item .add_marker (skip_no_benchmark )
240235
241236    # Benchmark fixture 
242237    class  Benchmark :
@@ -248,16 +243,19 @@ def __call__(self, func, *args, **kwargs):
248243            if  args  or  kwargs :
249244                # Used as benchmark(func, *args, **kwargs) 
250245                return  self ._run_benchmark (func , * args , ** kwargs )
246+ 
251247            # Used as @benchmark decorator 
252248            def  wrapped_func (* args , ** kwargs ):
253249                return  func (* args , ** kwargs )
250+ 
254251            result  =  self ._run_benchmark (func )
255252            return  wrapped_func 
256253
257254        def  _run_benchmark (self , func , * args , ** kwargs ):
258255            """Actual benchmark implementation.""" 
259-             benchmark_module_path  =  module_name_from_file_path (Path (str (self .request .node .fspath )),
260-                                                                Path (codeflash_benchmark_plugin .project_root ))
256+             benchmark_module_path  =  module_name_from_file_path (
257+                 Path (str (self .request .node .fspath )), Path (codeflash_benchmark_plugin .project_root )
258+             )
261259            benchmark_function_name  =  self .request .node .name 
262260            line_number  =  int (str (sys ._getframe (2 ).f_lineno ))  # 2 frames up in the call stack 
263261            # Set env vars 
@@ -278,7 +276,8 @@ def _run_benchmark(self, func, *args, **kwargs):
278276            codeflash_trace .function_call_count  =  0 
279277            # Add to the benchmark timings buffer 
280278            codeflash_benchmark_plugin .benchmark_timings .append (
281-                 (benchmark_module_path , benchmark_function_name , line_number , end  -  start ))
279+                 (benchmark_module_path , benchmark_function_name , line_number , end  -  start )
280+             )
282281
283282            return  result 
284283
@@ -290,4 +289,5 @@ def benchmark(request):
290289
291290        return  CodeFlashBenchmarkPlugin .Benchmark (request )
292291
292+ 
293293codeflash_benchmark_plugin  =  CodeFlashBenchmarkPlugin ()
0 commit comments