|
| 1 | +"""Integration tests for TextQuery with field weights.""" |
| 2 | + |
| 3 | +import uuid |
| 4 | + |
| 5 | +import pytest |
| 6 | + |
| 7 | +from redisvl.index import SearchIndex |
| 8 | +from redisvl.query import TextQuery |
| 9 | +from redisvl.query.filter import Tag |
| 10 | + |
| 11 | + |
| 12 | +@pytest.fixture |
| 13 | +def weighted_index(redis_url, worker_id): |
| 14 | + """Create an index with multiple text fields for testing weights.""" |
| 15 | + unique_id = str(uuid.uuid4())[:8] |
| 16 | + schema_dict = { |
| 17 | + "index": { |
| 18 | + "name": f"weighted_test_idx_{worker_id}_{unique_id}", |
| 19 | + "prefix": f"weighted_doc_{worker_id}_{unique_id}", |
| 20 | + "storage_type": "json", |
| 21 | + }, |
| 22 | + "fields": [ |
| 23 | + {"name": "title", "type": "text"}, |
| 24 | + {"name": "content", "type": "text"}, |
| 25 | + {"name": "tags", "type": "text"}, |
| 26 | + {"name": "category", "type": "tag"}, |
| 27 | + {"name": "score", "type": "numeric"}, |
| 28 | + ], |
| 29 | + } |
| 30 | + |
| 31 | + index = SearchIndex.from_dict(schema_dict, redis_url=redis_url) |
| 32 | + index.create(overwrite=True) |
| 33 | + |
| 34 | + # Load test data |
| 35 | + data = [ |
| 36 | + { |
| 37 | + "id": "1", |
| 38 | + "title": "Redis database introduction", |
| 39 | + "content": "A comprehensive guide to getting started with Redis", |
| 40 | + "tags": "tutorial beginner", |
| 41 | + "category": "database", |
| 42 | + "score": 95, |
| 43 | + }, |
| 44 | + { |
| 45 | + "id": "2", |
| 46 | + "title": "Advanced caching strategies", |
| 47 | + "content": "Learn about Redis caching patterns and best practices", |
| 48 | + "tags": "redis cache performance", |
| 49 | + "category": "optimization", |
| 50 | + "score": 88, |
| 51 | + }, |
| 52 | + { |
| 53 | + "id": "3", |
| 54 | + "title": "Python programming basics", |
| 55 | + "content": "Introduction to Python with examples using Redis client", |
| 56 | + "tags": "python redis programming", |
| 57 | + "category": "programming", |
| 58 | + "score": 90, |
| 59 | + }, |
| 60 | + { |
| 61 | + "id": "4", |
| 62 | + "title": "Data structures overview", |
| 63 | + "content": "Understanding Redis data structures and their applications", |
| 64 | + "tags": "redis structures", |
| 65 | + "category": "database", |
| 66 | + "score": 85, |
| 67 | + }, |
| 68 | + ] |
| 69 | + |
| 70 | + index.load(data) |
| 71 | + yield index |
| 72 | + index.delete(drop=True) |
| 73 | + |
| 74 | + |
| 75 | +def test_text_query_with_single_weighted_field(weighted_index): |
| 76 | + """Test TextQuery with a single weighted field.""" |
| 77 | + text = "redis" |
| 78 | + |
| 79 | + # Query with higher weight on title |
| 80 | + query = TextQuery( |
| 81 | + text=text, |
| 82 | + text_field_name={"title": 5.0}, |
| 83 | + return_fields=["title", "content"], |
| 84 | + num_results=4, |
| 85 | + ) |
| 86 | + |
| 87 | + results = weighted_index.query(query) |
| 88 | + assert len(results) > 0 |
| 89 | + |
| 90 | + # The document with "Redis" in the title should rank high |
| 91 | + top_result = results[0] |
| 92 | + assert "redis" in top_result["title"].lower() |
| 93 | + |
| 94 | + |
| 95 | +def test_text_query_with_multiple_weighted_fields(weighted_index): |
| 96 | + """Test TextQuery with multiple weighted fields.""" |
| 97 | + text = "redis" |
| 98 | + |
| 99 | + # Query across multiple fields with different weights |
| 100 | + query = TextQuery( |
| 101 | + text=text, |
| 102 | + text_field_name={"title": 3.0, "content": 2.0, "tags": 1.0}, |
| 103 | + return_fields=["title", "content", "tags"], |
| 104 | + num_results=4, |
| 105 | + ) |
| 106 | + |
| 107 | + results = weighted_index.query(query) |
| 108 | + assert len(results) > 0 |
| 109 | + |
| 110 | + # Check that results contain the search term in at least one field |
| 111 | + for result in results: |
| 112 | + text_found = ( |
| 113 | + "redis" in result.get("title", "").lower() |
| 114 | + or "redis" in result.get("content", "").lower() |
| 115 | + or "redis" in result.get("tags", "").lower() |
| 116 | + ) |
| 117 | + assert text_found |
| 118 | + |
| 119 | + |
| 120 | +def test_text_query_weights_with_filter(weighted_index): |
| 121 | + """Test TextQuery with weights and filter expression.""" |
| 122 | + text = "redis" |
| 123 | + |
| 124 | + # Query with weights and filter |
| 125 | + filter_expr = Tag("category") == "database" |
| 126 | + query = TextQuery( |
| 127 | + text=text, |
| 128 | + text_field_name={"title": 5.0, "content": 1.0}, |
| 129 | + filter_expression=filter_expr, |
| 130 | + return_fields=["title", "content", "category"], |
| 131 | + num_results=4, |
| 132 | + ) |
| 133 | + |
| 134 | + results = weighted_index.query(query) |
| 135 | + |
| 136 | + # Should only get database category results |
| 137 | + for result in results: |
| 138 | + assert result["category"] == "database" |
| 139 | + |
| 140 | + |
| 141 | +def test_dynamic_weight_update(weighted_index): |
| 142 | + """Test updating field weights dynamically.""" |
| 143 | + text = "redis" |
| 144 | + |
| 145 | + # Start with equal weights |
| 146 | + query = TextQuery( |
| 147 | + text=text, |
| 148 | + text_field_name={"title": 1.0, "content": 1.0}, |
| 149 | + return_fields=["title", "content"], |
| 150 | + num_results=4, |
| 151 | + ) |
| 152 | + |
| 153 | + results1 = weighted_index.query(query) |
| 154 | + |
| 155 | + # Update to prioritize title |
| 156 | + query.set_field_weights({"title": 10.0, "content": 1.0}) |
| 157 | + |
| 158 | + results2 = weighted_index.query(query) |
| 159 | + |
| 160 | + # Results might be reordered based on new weights |
| 161 | + # At minimum, both queries should return results |
| 162 | + assert len(results1) > 0 |
| 163 | + assert len(results2) > 0 |
| 164 | + |
| 165 | + |
| 166 | +def test_backward_compatibility_single_field(weighted_index): |
| 167 | + """Test that the original single field API still works.""" |
| 168 | + text = "redis" |
| 169 | + |
| 170 | + # Original API with single field name |
| 171 | + query = TextQuery( |
| 172 | + text=text, |
| 173 | + text_field_name="content", |
| 174 | + return_fields=["title", "content"], |
| 175 | + num_results=4, |
| 176 | + ) |
| 177 | + |
| 178 | + results = weighted_index.query(query) |
| 179 | + assert len(results) > 0 |
| 180 | + |
| 181 | + # Check results are from content field |
| 182 | + for result in results: |
| 183 | + if "redis" in result.get("content", "").lower(): |
| 184 | + break |
| 185 | + else: |
| 186 | + # At least one result should have redis in content |
| 187 | + assert False, "No results with 'redis' in content field" |
0 commit comments