Skip to content

Commit ac3112d

Browse files
authored
Adding vector search tests for types int8/uint8 (#3525)
1 parent cf2341f commit ac3112d

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

tests/test_search.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2856,6 +2856,64 @@ def test_vector_search_with_default_dialect(client):
28562856
assert res["total_results"] == 2
28572857

28582858

2859+
@pytest.mark.redismod
2860+
@skip_if_server_version_lt("7.9.0")
2861+
def test_vector_search_with_int8_type(client):
2862+
client.ft().create_index(
2863+
(VectorField("v", "FLAT", {"TYPE": "INT8", "DIM": 2, "DISTANCE_METRIC": "L2"}),)
2864+
)
2865+
2866+
a = [1.5, 10]
2867+
b = [123, 100]
2868+
c = [1, 1]
2869+
2870+
client.hset("a", "v", np.array(a, dtype=np.int8).tobytes())
2871+
client.hset("b", "v", np.array(b, dtype=np.int8).tobytes())
2872+
client.hset("c", "v", np.array(c, dtype=np.int8).tobytes())
2873+
2874+
query = Query("*=>[KNN 2 @v $vec as score]")
2875+
query_params = {"vec": np.array(a, dtype=np.int8).tobytes()}
2876+
2877+
assert 2 in query.get_args()
2878+
2879+
res = client.ft().search(query, query_params=query_params)
2880+
if is_resp2_connection(client):
2881+
assert res.total == 2
2882+
else:
2883+
assert res["total_results"] == 2
2884+
2885+
2886+
@pytest.mark.redismod
2887+
@skip_if_server_version_lt("7.9.0")
2888+
def test_vector_search_with_uint8_type(client):
2889+
client.ft().create_index(
2890+
(
2891+
VectorField(
2892+
"v", "FLAT", {"TYPE": "UINT8", "DIM": 2, "DISTANCE_METRIC": "L2"}
2893+
),
2894+
)
2895+
)
2896+
2897+
a = [1.5, 10]
2898+
b = [123, 100]
2899+
c = [1, 1]
2900+
2901+
client.hset("a", "v", np.array(a, dtype=np.uint8).tobytes())
2902+
client.hset("b", "v", np.array(b, dtype=np.uint8).tobytes())
2903+
client.hset("c", "v", np.array(c, dtype=np.uint8).tobytes())
2904+
2905+
query = Query("*=>[KNN 2 @v $vec as score]")
2906+
query_params = {"vec": np.array(a, dtype=np.uint8).tobytes()}
2907+
2908+
assert 2 in query.get_args()
2909+
2910+
res = client.ft().search(query, query_params=query_params)
2911+
if is_resp2_connection(client):
2912+
assert res.total == 2
2913+
else:
2914+
assert res["total_results"] == 2
2915+
2916+
28592917
@pytest.mark.redismod
28602918
@skip_ifmodversion_lt("2.4.3", "search")
28612919
def test_search_query_with_different_dialects(client):

0 commit comments

Comments
 (0)