diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py index da6667df88..14328e9d3b 100644 --- a/redis/commands/search/field.py +++ b/redis/commands/search/field.py @@ -105,11 +105,20 @@ class TagField(Field): """ SEPARATOR = "SEPARATOR" + CASESENSITIVE = "CASESENSITIVE" - def __init__(self, name: str, separator: str = ",", **kwargs): - Field.__init__( - self, name, args=[Field.TAG, self.SEPARATOR, separator], **kwargs - ) + def __init__( + self, + name: str, + separator: str = ",", + case_sensitive: bool = False, + **kwargs, + ): + args = [Field.TAG, self.SEPARATOR, separator] + if case_sensitive: + args.append(self.CASESENSITIVE) + + Field.__init__(self, name, args=args, **kwargs) class VectorField(Field): diff --git a/tests/test_search.py b/tests/test_search.py index 343ac0d540..aee37cdd6f 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -107,7 +107,7 @@ def client(modclient): def test_client(client): num_docs = 500 createIndex(client.ft(), num_docs=num_docs) - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) # verify info info = client.ft().info() for k in [ @@ -252,7 +252,7 @@ def test_replace(client): client.ft().add_document("doc1", txt="foo bar") client.ft().add_document("doc2", txt="foo bar") - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) res = client.ft().search("foo bar") assert 2 == res.total @@ -272,7 +272,7 @@ def test_stopwords(client): client.ft().create_index((TextField("txt"),), stopwords=["foo", "bar", "baz"]) client.ft().add_document("doc1", txt="foo bar") client.ft().add_document("doc2", txt="hello world") - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) q1 = Query("foo bar").no_content() q2 = Query("foo bar hello world").no_content() @@ -287,7 +287,7 @@ def test_filters(client): client.ft().add_document("doc1", txt="foo bar", num=3.141, loc="-0.441,51.458") client.ft().add_document("doc2", txt="foo baz", num=2, loc="-0.1,51.2") - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) # Test numerical filter q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() q2 = ( @@ -456,7 +456,7 @@ def test_no_index(client): client.ft().add_document( "doc2", field="aab", text="2", numeric="2", geo="2,2", tag="2" ) - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) res = client.ft().search(Query("@text:aa*")) assert 0 == res.total @@ -498,7 +498,7 @@ def test_partial(client): client.ft().add_document("doc2", f1="f1_val", f2="f2_val") client.ft().add_document("doc1", f3="f3_val", partial=True) client.ft().add_document("doc2", f3="f3_val", replace=True) - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) # Search for f3 value. All documents should have it res = client.ft().search("@f3:f3_val") @@ -516,7 +516,7 @@ def test_no_create(client): client.ft().add_document("doc2", f1="f1_val", f2="f2_val") client.ft().add_document("doc1", f3="f3_val", no_create=True) client.ft().add_document("doc2", f3="f3_val", no_create=True, partial=True) - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) # Search for f3 value. All documents should have it res = client.ft().search("@f3:f3_val") @@ -546,7 +546,7 @@ def test_explaincli(client): @pytest.mark.redismod def test_summarize(client): createIndex(client.ft()) - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) q = Query("king henry").paging(0, 1) q.highlight(fields=("play", "txt"), tags=("", "")) @@ -654,7 +654,7 @@ def test_tags(client): client.ft().add_document("doc1", txt="fooz barz", tags=tags) client.ft().add_document("doc2", txt="noodles", tags=tags2) - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) q = Query("@tags:{foo}") res = client.ft().search(q) @@ -714,7 +714,7 @@ def test_spell_check(client): client.ft().add_document("doc1", f1="some valid content", f2="this is sample text") client.ft().add_document("doc2", f1="very important", f2="lorem ipsum") - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) # test spellcheck res = client.ft().spellcheck("impornant") @@ -1304,6 +1304,31 @@ def test_fields_as_name(client): assert "25" == total[0].just_a_number +@pytest.mark.redismod +def test_casesensitive(client): + # create index + SCHEMA = (TagField("t", case_sensitive=False),) + client.ft().create_index(SCHEMA) + client.ft().client.hset("1", "t", "HELLO") + client.ft().client.hset("2", "t", "hello") + + res = client.ft().search("@t:{HELLO}").docs + + assert 2 == len(res) + assert "1" == res[0].id + assert "2" == res[1].id + + # create casesensitive index + client.ft().dropindex() + SCHEMA = (TagField("t", case_sensitive=True),) + client.ft().create_index(SCHEMA) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + + res = client.ft().search("@t:{HELLO}").docs + assert 1 == len(res) + assert "1" == res[0].id + + @pytest.mark.redismod @skip_ifmodversion_lt("2.2.0", "search") def test_search_return_fields(client): @@ -1321,7 +1346,7 @@ def test_search_return_fields(client): NumericField("$.flt"), ) client.ft().create_index(SCHEMA, definition=definition) - waitForIndex(client, "idx") + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs assert 1 == len(total)