Skip to content

Commit b530ea5

Browse files
committed
fix(test): clean VoyageAI test and rename TokenizationDebugTest -> CrossEncoderTokenizationTest
1 parent 2021c19 commit b530ea5

12 files changed

+226
-196
lines changed

core/src/main/java/com/redis/vl/extensions/router/SemanticRouter.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,7 @@ public java.util.Map<String, Object> toDict() {
194194

195195
java.util.Map<String, Object> configDict = new java.util.HashMap<>();
196196
configDict.put("max_k", routingConfig.getMaxK());
197-
configDict.put(
198-
"aggregation_method", routingConfig.getAggregationMethod().name().toLowerCase());
197+
configDict.put("aggregation_method", routingConfig.getAggregationMethod().name().toLowerCase());
199198
dict.put("routing_config", configDict);
200199

201200
return dict;

core/src/test/java/com/redis/vl/extensions/cache/NotebookSemanticCacheTest.java

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,8 @@ public class NotebookSemanticCacheTest extends BaseIntegrationTest {
2727
public void setUp() {
2828

2929
// Cell 5: Create vectorizer using SentenceTransformersVectorizer
30-
// This should download the redis/langcache-embed-v3 model from HuggingFace on first use
31-
System.out.println("Initializing SentenceTransformersVectorizer with redis/langcache-embed-v3");
3230
try {
3331
vectorizer = new SentenceTransformersVectorizer("redis/langcache-embed-v3");
34-
System.out.println("Model dimensions: " + vectorizer.getDimensions());
3532
} catch (Exception e) {
3633
System.err.println("Failed to initialize SentenceTransformersVectorizer: " + e.getMessage());
3734
e.printStackTrace();
@@ -46,32 +43,27 @@ public void setUp() {
4643
.distanceThreshold(0.1f)
4744
.vectorizer(vectorizer)
4845
.build();
49-
50-
System.out.println("SemanticCache initialized with index: " + llmcache.getName());
5146
}
5247

5348
@Test
5449
public void testNotebookFlow() {
5550
// Cell 6: Verify cache is ready
5651
assertNotNull(llmcache);
5752
assertEquals("llmcache_test", llmcache.getName());
58-
System.out.println("Cache index '" + llmcache.getName() + "' is ready for use");
5953

6054
// Cell 8: Define question
6155
String question = "What is the capital of France?";
6256

6357
// Cell 9: Check empty cache
6458
Optional<CacheHit> response = llmcache.check(question);
6559
assertFalse(response.isPresent(), "Cache should be empty initially");
66-
System.out.println("Initial cache check: " + (response.isPresent() ? "Found" : "Empty"));
6760

6861
// Cell 11: Store in cache
6962
Map<String, Object> metadata = new HashMap<>();
7063
metadata.put("city", "Paris");
7164
metadata.put("country", "france");
7265

7366
llmcache.store(question, "Paris", metadata);
74-
System.out.println("Stored in cache");
7567

7668
// Cell 13: Check cache again
7769
Optional<CacheHit> cacheResponse = llmcache.check(question);
@@ -84,11 +76,6 @@ public void testNotebookFlow() {
8476
assertNotNull(hit.getMetadata());
8577
assertEquals("Paris", hit.getMetadata().get("city"));
8678
assertEquals("france", hit.getMetadata().get("country"));
87-
System.out.println("Found in cache:");
88-
System.out.println(" Prompt: " + hit.getPrompt());
89-
System.out.println(" Response: " + hit.getResponse());
90-
System.out.println(" Distance: " + hit.getDistance());
91-
System.out.println(" Metadata: " + hit.getMetadata());
9279
}
9380

9481
// Cell 14: Check semantically similar question
@@ -97,29 +84,22 @@ public void testNotebookFlow() {
9784
assertTrue(similarResponse.isPresent(), "Should find semantically similar entry");
9885
if (similarResponse.isPresent()) {
9986
assertEquals("Paris", similarResponse.get().getResponse());
100-
System.out.println("Similar question result: " + similarResponse.get().getResponse());
10187
}
10288

10389
// Cell 16: Adjust distance threshold
10490
llmcache.setDistanceThreshold(0.5f);
10591
assertEquals(0.5f, llmcache.getDistanceThreshold(), 0.001f);
106-
System.out.println("Distance threshold set to 0.5");
10792

10893
// Cell 17: Try with tricky question
10994
String trickQuestion =
11095
"What is the capital city of the country in Europe that also has a city named Nice?";
11196
Optional<CacheHit> trickResponse = llmcache.check(trickQuestion);
11297
// With wider threshold, this might match
113-
System.out.println(
114-
"Trick question result: "
115-
+ (trickResponse.isPresent() ? trickResponse.get().getResponse() : "Not found"));
11698

11799
// Cell 18: Clear cache
118100
llmcache.clear();
119101
Optional<CacheHit> clearedResponse = llmcache.check(trickQuestion);
120102
assertFalse(clearedResponse.isPresent(), "Cache should be empty after clear");
121-
System.out.println(
122-
"Cache after clear: " + (clearedResponse.isPresent() ? "Not empty" : "Empty"));
123103
}
124104

125105
@Test
@@ -134,11 +114,8 @@ public void testTTLCache() throws InterruptedException {
134114
.ttl(5) // 5 seconds
135115
.build();
136116

137-
System.out.println("Created cache with 5 second TTL");
138-
139117
// Cell 21: Store with TTL
140118
ttlCache.store("This is a TTL test", "This is a TTL test response");
141-
System.out.println("Stored entry with TTL");
142119

143120
// Verify it's there immediately
144121
Optional<CacheHit> immediateCheck = ttlCache.check("This is a TTL test");
@@ -150,8 +127,6 @@ public void testTTLCache() throws InterruptedException {
150127
// Cell 22: Check after TTL expiry
151128
Optional<CacheHit> ttlResult = ttlCache.check("This is a TTL test");
152129
assertFalse(ttlResult.isPresent(), "Entry should have expired");
153-
System.out.println(
154-
"Result after TTL expiry: " + (ttlResult.isPresent() ? "Found" : "Empty (expired)"));
155130

156131
// Cell 23: Clean up
157132
ttlCache.clear();
@@ -204,23 +179,18 @@ public void testUserMetadataFiltering() {
204179
"The number on file is 123-555-1111",
205180
userDef);
206181

207-
System.out.println("Stored user-specific cache entries");
208-
209182
// Cell 32: Check cache entries
210183
Optional<CacheHit> phoneResponse =
211184
llmcache.check("What is the phone number linked to my account?");
212185

213186
assertTrue(phoneResponse.isPresent());
214187
if (phoneResponse.isPresent()) {
215-
System.out.println("Found entry: " + phoneResponse.get().getResponse());
216188
// Should return one of the phone numbers based on similarity
217189
String response = phoneResponse.get().getResponse();
218190
assertTrue(response.contains("123-555-"));
219191
}
220192

221193
// Cell 33: Final cleanup
222194
llmcache.clear();
223-
System.out.println("\nAll caches cleaned up.");
224-
System.out.println("SemanticCache demonstration complete!");
225195
}
226196
}

core/src/test/java/com/redis/vl/extensions/router/SemanticRouterIntegrationTest.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,14 @@ void testNotebookQuery() {
215215
// Test no-match query from Python notebook cell 9
216216
// Python output: RouteMatch(name=None, distance=None)
217217
RouteMatch noMatch = router.route("are aliens real?");
218-
System.out.println("DEBUG: aliens query - name=" + noMatch.getName() + ", distance=" + noMatch.getDistance());
219-
System.out.println("DEBUG: technology threshold=" + notebookRoutes.get(0).getDistanceThreshold());
218+
System.out.println(
219+
"DEBUG: aliens query - name=" + noMatch.getName() + ", distance=" + noMatch.getDistance());
220+
System.out.println(
221+
"DEBUG: technology threshold=" + notebookRoutes.get(0).getDistanceThreshold());
220222

221223
assertThat(noMatch).isNotNull();
222-
// NOTE: Java ONNX embeddings differ from Python, this query may match with distance near threshold
224+
// NOTE: Java ONNX embeddings differ from Python, this query may match with distance near
225+
// threshold
223226
// Python: None, Java: may match technology with distance ~0.33
224227
// Accept either outcome as embedding implementations differ
225228
if (noMatch.getName() != null) {

core/src/test/java/com/redis/vl/utils/rerank/BAAIModelRealIntegrationTest.java

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import org.junit.jupiter.api.Test;
1010

1111
/**
12-
* Real integration test with BAAI/bge-reranker-base model.
12+
* Integration test with BAAI/bge-reranker-base model.
1313
*
1414
* <p>Compares outputs to Python notebook to verify correctness.
1515
*
@@ -23,12 +23,10 @@ class BAAIModelRealIntegrationTest {
2323

2424
@BeforeAll
2525
static void setUp() {
26-
System.out.println("=== LOADING BAAI/bge-reranker-base MODEL ===");
2726
reranker = HFCrossEncoderReranker.builder().model("BAAI/bge-reranker-base").build();
2827

2928
assertNotNull(reranker, "Reranker must initialize");
3029
assertEquals("BAAI/bge-reranker-base", reranker.getModel());
31-
System.out.println("=== MODEL LOADED ===");
3230
}
3331

3432
@AfterAll
@@ -57,18 +55,6 @@ void testBAAIModelProducesCorrectScores() {
5755
List<?> results = result.getDocuments();
5856
List<Double> scores = result.getScores();
5957

60-
System.out.println("\n=== JAVA OUTPUT ===");
61-
for (int i = 0; i < results.size(); i++) {
62-
String docPreview =
63-
results.get(i).toString().substring(0, Math.min(50, results.get(i).toString().length()));
64-
System.out.println(scores.get(i) + " -- " + docPreview + "...");
65-
}
66-
67-
System.out.println("\n=== EXPECTED PYTHON OUTPUT (with sigmoid) ===");
68-
System.out.println("0.9999381 -- Washington, D.C. ...");
69-
System.out.println("0.3802366 -- Charlotte Amalie ...");
70-
System.out.println("0.0746112 -- Carson City ...");
71-
7258
// Verify we got 3 results
7359
assertEquals(3, results.size(), "Should return 3 results");
7460
assertEquals(3, scores.size(), "Should return 3 scores");
@@ -81,9 +67,6 @@ void testBAAIModelProducesCorrectScores() {
8167

8268
// Score for Washington D.C. should be ~0.9999 (after sigmoid)
8369
double topScore = scores.get(0);
84-
System.out.println("\n=== SCORE COMPARISON ===");
85-
System.out.println("Expected top score: ~0.9999");
86-
System.out.println("Actual top score: " + topScore);
8770

8871
assertTrue(
8972
topScore > 0.0 && topScore < 1.0,

core/src/test/java/com/redis/vl/utils/rerank/CohereRerankerIntegrationTest.java

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,6 @@ void testRerankStringDocuments() {
5757
assertEquals(3, docs.size(), "Should return 3 results");
5858
assertEquals(3, scores.size(), "Should return 3 scores");
5959

60-
System.out.println("\n=== JAVA STRING DOCS OUTPUT ===");
61-
for (int i = 0; i < docs.size(); i++) {
62-
String docPreview =
63-
docs.get(i).toString().substring(0, Math.min(50, docs.get(i).toString().length()));
64-
System.out.println(scores.get(i) + " -- " + docPreview + "...");
65-
}
66-
67-
System.out.println("\n=== EXPECTED PYTHON OUTPUT ===");
68-
System.out.println("0.9990564 -- Washington, D.C. ...");
69-
System.out.println("0.7516481 -- Capital punishment ...");
70-
System.out.println("0.08882029 -- Northern Mariana Islands ...");
71-
7260
// Top result must be Washington D.C.
7361
String topDoc = (String) docs.get(0);
7462
assertTrue(
@@ -77,9 +65,6 @@ void testRerankStringDocuments() {
7765

7866
// Top score should be ~0.999
7967
double topScore = scores.get(0);
80-
System.out.println("\n=== SCORE COMPARISON ===");
81-
System.out.println("Expected top score: ~0.999");
82-
System.out.println("Actual top score: " + topScore);
8368

8469
assertTrue(topScore > 0.9, "Top score should be > 0.9, but was: " + topScore);
8570
assertTrue(
@@ -140,17 +125,6 @@ void testRerankDictionaryDocumentsWithRankBy() {
140125
assertEquals(3, docs.size(), "Should return 3 results");
141126
assertEquals(3, scores.size(), "Should return 3 scores");
142127

143-
System.out.println("\n=== JAVA DICT DOCS OUTPUT ===");
144-
for (int i = 0; i < docs.size(); i++) {
145-
System.out.println(scores.get(i) + " -- " + docs.get(i));
146-
}
147-
148-
System.out.println("\n=== EXPECTED PYTHON OUTPUT ===");
149-
System.out.println("0.9988121 -- {'source': 'textbook', 'passage': 'Washington, D.C. ...'}");
150-
System.out.println("0.5974905 -- {'source': 'wiki', 'passage': 'Capital punishment ...'}");
151-
System.out.println(
152-
"0.059101548 -- {'source': 'encyclopedia', 'passage': 'Northern Mariana ...'}");
153-
154128
// Top result must be Washington D.C. with source=textbook
155129
@SuppressWarnings("unchecked")
156130
Map<String, Object> topDoc = (Map<String, Object>) docs.get(0);
@@ -161,10 +135,6 @@ void testRerankDictionaryDocumentsWithRankBy() {
161135

162136
// Top score should be ~0.998
163137
double topScore = scores.get(0);
164-
System.out.println("\n=== SCORE COMPARISON ===");
165-
System.out.println("Expected top score: ~0.998");
166-
System.out.println("Actual top score: " + topScore);
167-
168138
assertTrue(topScore > 0.9, "Top score should be > 0.9, but was: " + topScore);
169139
assertTrue(
170140
Math.abs(topScore - 0.998) < 0.05,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package com.redis.vl.utils.rerank;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import java.util.Arrays;
6+
import java.util.Map;
7+
import org.junit.jupiter.api.Tag;
8+
import org.junit.jupiter.api.Test;
9+
10+
/**
11+
* Test to validate cross-encoder tokenization matches Python transformers library.
12+
*
13+
* <p>Compares Java WordPiece tokenization against Python reference values to ensure embeddings and
14+
* reranking scores match Python implementation.
15+
*/
16+
@Tag("integration")
17+
class CrossEncoderTokenizationTest {
18+
19+
@Test
20+
void testTokenizationMatchesPython() throws Exception {
21+
HFCrossEncoderReranker reranker =
22+
HFCrossEncoderReranker.builder().model("BAAI/bge-reranker-base").build();
23+
24+
String query = "What is the capital of the United States?";
25+
String doc =
26+
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States.";
27+
28+
// Access the internal tokenizer through reflection
29+
java.lang.reflect.Field loaderField =
30+
HFCrossEncoderReranker.class.getDeclaredField("modelLoader");
31+
loaderField.setAccessible(true);
32+
CrossEncoderLoader loader = (CrossEncoderLoader) loaderField.get(reranker);
33+
34+
Map<String, long[][]> tokens = loader.tokenizePair(query, doc);
35+
36+
long[] inputIds = tokens.get("input_ids")[0];
37+
long[] tokenTypeIds = tokens.get("token_type_ids")[0];
38+
long[] attentionMask = tokens.get("attention_mask")[0];
39+
40+
// Expected token IDs from Python transformers tokenizer
41+
// Generated with: tokenizer("What is the capital...", "Washington, D.C...")
42+
long[] expectedTokenIds = {
43+
0, 4865, 83, 70, 10323, 111, 70, 14098, 46684, 32, 2, 2, 17955, 4, 391, 5, 441, 5, 15, 289
44+
};
45+
46+
// Validate token IDs match Python (first 20 tokens)
47+
long[] actualFirst20 = Arrays.copyOf(inputIds, Math.min(20, inputIds.length));
48+
assertThat(actualFirst20)
49+
.as("Token IDs should match Python transformers tokenizer")
50+
.containsExactly(expectedTokenIds);
51+
52+
// Validate total token count matches Python
53+
assertThat(inputIds.length).as("Total tokens should match Python tokenization").isEqualTo(49);
54+
55+
// Validate attention mask is correct (all 1s for non-padding tokens)
56+
for (int i = 0; i < attentionMask.length; i++) {
57+
assertThat(attentionMask[i])
58+
.as("Attention mask[%d] should be 1 (no padding)", i)
59+
.isEqualTo(1);
60+
}
61+
62+
// XLM-Roberta doesn't use token type IDs, so they should all be 0
63+
for (int i = 0; i < tokenTypeIds.length; i++) {
64+
assertThat(tokenTypeIds[i])
65+
.as("Token type ID[%d] should be 0 for XLM-Roberta", i)
66+
.isEqualTo(0);
67+
}
68+
69+
System.out.println("\n✓ Tokenization matches Python transformers");
70+
System.out.println(" Query: " + query.substring(0, Math.min(50, query.length())));
71+
System.out.println(" Document: " + doc.substring(0, Math.min(50, doc.length())));
72+
System.out.println(" Token IDs (first 20): " + Arrays.toString(actualFirst20));
73+
System.out.println(" Total tokens: " + inputIds.length);
74+
75+
reranker.close();
76+
}
77+
}

core/src/test/java/com/redis/vl/utils/rerank/HFCrossEncoderRerankerNotebookTest.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,6 @@ void testNotebookSimpleReranking() {
9797
assertTrue(
9898
topDoc.contains("Washington, D.C.") || topDoc.contains("capital of the United States"),
9999
"Top result should be about Washington D.C., but was: " + topDoc);
100-
101-
// Print results like notebook does
102-
System.out.println("\nNotebook test results (BAAI model):");
103-
for (int i = 0; i < results.size(); i++) {
104-
System.out.println(scores.get(i) + " -- " + results.get(i));
105-
}
106100
}
107101

108102
@Test
@@ -167,12 +161,6 @@ void testNotebookStructuredDocuments() {
167161
assertTrue(doc.containsKey("source"), "Should preserve 'source' field");
168162
assertTrue(doc.containsKey("content"), "Should preserve 'content' field");
169163
}
170-
171-
// Print like notebook
172-
System.out.println("\nNotebook structured doc results (BAAI model):");
173-
for (int i = 0; i < rerankedResults.size(); i++) {
174-
System.out.println(structuredScores.get(i) + " -- " + rerankedResults.get(i));
175-
}
176164
}
177165

178166
@Test

0 commit comments

Comments
 (0)