Skip to content

Commit 8e08592

Browse files
authored
Merge branch 'main' into add-pr-url
2 parents 84686b2 + 3f795db commit 8e08592

File tree

4 files changed

+43
-18
lines changed

4 files changed

+43
-18
lines changed

codeflash/models/models.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ class TestingMode(enum.Enum):
364364
PERFORMANCE = "performance"
365365
LINE_PROFILE = "line_profile"
366366

367-
367+
#TODO this class is duplicated in codeflash_capture
368368
class VerificationType(str, Enum):
369369
FUNCTION_CALL = (
370370
"function_call" # Correctness verification for a test function, checks input values and output values)
@@ -537,22 +537,20 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
537537
return tree
538538

539539
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
540+
# Efficient single traversal, directly accumulating into a dict.
541+
by_id: dict[InvocationId, list[int]] = {}
540542
for result in self.test_results:
541-
if result.did_pass and not result.runtime:
542-
msg = (
543-
f"Ignoring test case that passed but had no runtime -> {result.id}, "
544-
f"Loop # {result.loop_index}, Test Type: {result.test_type}, "
545-
f"Verification Type: {result.verification_type}"
546-
)
547-
logger.debug(msg)
548-
549-
usable_runtimes = [
550-
(result.id, result.runtime) for result in self.test_results if result.did_pass and result.runtime
551-
]
552-
return {
553-
usable_id: [runtime[1] for runtime in usable_runtimes if runtime[0] == usable_id]
554-
for usable_id in {runtime[0] for runtime in usable_runtimes}
555-
}
543+
if result.did_pass:
544+
if result.runtime:
545+
by_id.setdefault(result.id, []).append(result.runtime)
546+
else:
547+
msg = (
548+
f"Ignoring test case that passed but had no runtime -> {result.id}, "
549+
f"Loop # {result.loop_index}, Test Type: {result.test_type}, "
550+
f"Verification Type: {result.verification_type}"
551+
)
552+
logger.debug(msg)
553+
return by_id
556554

557555
def total_passed_runtime(self) -> int:
558556
"""Calculate the sum of runtimes of all test cases that passed.

codeflash/verification/codeflash_capture.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from __future__ import annotations
22

3+
# This file should not have any dependencies on codeflash
34
import functools
45
import gc
56
import inspect
67
import os
78
import sqlite3
89
import time
910
from pathlib import Path
10-
11+
from enum import Enum
1112
import dill as pickle
1213

13-
from codeflash.models.models import VerificationType
14+
class VerificationType(str, Enum):
15+
FUNCTION_CALL = (
16+
"function_call" # Correctness verification for a test function, checks input values and output values)
17+
)
18+
INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init
19+
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
1420

1521

1622
def get_test_info_from_stack(tests_root: str) -> tuple[str, str | None, str, str]:

codeflash/verification/comparator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
8484
frozenset,
8585
enum.Enum,
8686
type,
87+
range
8788
),
8889
):
8990
return orig == new

tests/test_comparator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,26 @@ def test_basic_python_objects() -> None:
125125
assert comparator(a, b)
126126
assert not comparator(a, c)
127127

128+
@pytest.mark.parametrize("r1, r2, expected", [
129+
(range(1, 10), range(1, 10), True), # equal
130+
(range(0, 10), range(1, 10), False), # different start
131+
(range(2, 10), range(1, 10), False),
132+
(range(1, 5), range(1, 10), False), # different stop
133+
(range(1, 20), range(1, 10), False),
134+
(range(1, 10, 1), range(1, 10, 2), False), # different step
135+
(range(1, 10, 3), range(1, 10, 2), False),
136+
(range(-5, 0), range(-5, 0), True), # negative ranges
137+
(range(-10, 0), range(-5, 0), False),
138+
(range(5, 1), range(10, 5), True), # empty ranges
139+
(range(5, 1), range(5, 1), True),
140+
(range(7), range(0, 7), True),
141+
(range(0, 7), range(0, 7, 1), True),
142+
(range(7), range(0, 7, 1), True),
143+
])
144+
145+
def test_ranges(r1, r2, expected):
146+
assert comparator(r1, r2) == expected
147+
128148

129149
def test_standard_python_library_objects() -> None:
130150
a = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore

0 commit comments

Comments
 (0)