@@ -2856,6 +2856,64 @@ def test_vector_search_with_default_dialect(client):
2856
2856
assert res ["total_results" ] == 2
2857
2857
2858
2858
2859
+ @pytest .mark .redismod
2860
+ @skip_if_server_version_lt ("7.9.0" )
2861
+ def test_vector_search_with_int8_type (client ):
2862
+ client .ft ().create_index (
2863
+ (VectorField ("v" , "FLAT" , {"TYPE" : "INT8" , "DIM" : 2 , "DISTANCE_METRIC" : "L2" }),)
2864
+ )
2865
+
2866
+ a = [1.5 , 10 ]
2867
+ b = [123 , 100 ]
2868
+ c = [1 , 1 ]
2869
+
2870
+ client .hset ("a" , "v" , np .array (a , dtype = np .int8 ).tobytes ())
2871
+ client .hset ("b" , "v" , np .array (b , dtype = np .int8 ).tobytes ())
2872
+ client .hset ("c" , "v" , np .array (c , dtype = np .int8 ).tobytes ())
2873
+
2874
+ query = Query ("*=>[KNN 2 @v $vec as score]" )
2875
+ query_params = {"vec" : np .array (a , dtype = np .int8 ).tobytes ()}
2876
+
2877
+ assert 2 in query .get_args ()
2878
+
2879
+ res = client .ft ().search (query , query_params = query_params )
2880
+ if is_resp2_connection (client ):
2881
+ assert res .total == 2
2882
+ else :
2883
+ assert res ["total_results" ] == 2
2884
+
2885
+
2886
+ @pytest .mark .redismod
2887
+ @skip_if_server_version_lt ("7.9.0" )
2888
+ def test_vector_search_with_uint8_type (client ):
2889
+ client .ft ().create_index (
2890
+ (
2891
+ VectorField (
2892
+ "v" , "FLAT" , {"TYPE" : "UINT8" , "DIM" : 2 , "DISTANCE_METRIC" : "L2" }
2893
+ ),
2894
+ )
2895
+ )
2896
+
2897
+ a = [1.5 , 10 ]
2898
+ b = [123 , 100 ]
2899
+ c = [1 , 1 ]
2900
+
2901
+ client .hset ("a" , "v" , np .array (a , dtype = np .uint8 ).tobytes ())
2902
+ client .hset ("b" , "v" , np .array (b , dtype = np .uint8 ).tobytes ())
2903
+ client .hset ("c" , "v" , np .array (c , dtype = np .uint8 ).tobytes ())
2904
+
2905
+ query = Query ("*=>[KNN 2 @v $vec as score]" )
2906
+ query_params = {"vec" : np .array (a , dtype = np .uint8 ).tobytes ()}
2907
+
2908
+ assert 2 in query .get_args ()
2909
+
2910
+ res = client .ft ().search (query , query_params = query_params )
2911
+ if is_resp2_connection (client ):
2912
+ assert res .total == 2
2913
+ else :
2914
+ assert res ["total_results" ] == 2
2915
+
2916
+
2859
2917
@pytest .mark .redismod
2860
2918
@skip_ifmodversion_lt ("2.4.3" , "search" )
2861
2919
def test_search_query_with_different_dialects (client ):
0 commit comments