Skip to content

Commit 92a6032

Browse files
committed
add array.array t comparator
1 parent 3dd4459 commit 92a6032

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

codeflash/verification/comparator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import array
12
import ast
23
import datetime
34
import decimal
@@ -170,6 +171,15 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
170171
if HAS_PANDAS and pandas.isna(orig) and pandas.isna(new):
171172
return True
172173

174+
if isinstance(orig, array.array):
175+
if not isinstance(new, array.array):
176+
return False
177+
if orig.typecode != new.typecode:
178+
return False
179+
if len(orig) != len(new):
180+
return False
181+
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
182+
173183
# This should be at the end of all numpy checking
174184
try:
175185
if HAS_NUMPY and np.isnan(orig):

tests/test_comparator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
from enum import Enum, Flag, IntFlag, auto
99
from pathlib import Path
10+
import array # Add import for array
1011

1112
import pydantic
1213
import pytest
@@ -203,6 +204,24 @@ class Color4(IntFlag):
203204
assert not comparator(a, c)
204205
assert not comparator(a, d)
205206

207+
arr1 = array.array('i', [1, 2, 3])
208+
arr2 = array.array('i', [1, 2, 3])
209+
arr3 = array.array('i', [4, 5, 6])
210+
arr4 = array.array('f', [1.0, 2.0, 3.0])
211+
212+
assert comparator(arr1, arr2)
213+
assert not comparator(arr1, arr3)
214+
assert not comparator(arr1, arr4)
215+
assert not comparator(arr1, [1, 2, 3])
216+
217+
empty_arr_i1 = array.array('i')
218+
empty_arr_i2 = array.array('i')
219+
empty_arr_f = array.array('f')
220+
assert comparator(empty_arr_i1, empty_arr_i2)
221+
assert not comparator(empty_arr_i1, empty_arr_f)
222+
assert not comparator(empty_arr_i1, arr1)
223+
224+
206225

207226
def test_numpy():
208227
try:

0 commit comments

Comments
 (0)