Skip to content

Commit c172954

Browse files
committed
fix(vectorizers): add inputType parameter for Cohere v3 models
Cohere embed-english-v3.0 requires inputType to be specified.
1 parent eeef663 commit c172954

File tree

2 files changed

+121
-79
lines changed

2 files changed

+121
-79
lines changed

core/src/test/java/com/redis/vl/notebooks/VectorizersNotebookIntegrationTest.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,13 @@ public void testCohereVectorizer() {
9898
}
9999

100100
// Create a vectorizer using Cohere (same model as Python)
101+
// Note: Cohere v3 models require inputType to be specified
101102
var cohereModel =
102-
CohereEmbeddingModel.builder().apiKey(apiKey).modelName("embed-english-v3.0").build();
103+
CohereEmbeddingModel.builder()
104+
.apiKey(apiKey)
105+
.modelName("embed-english-v3.0")
106+
.inputType("search_query")
107+
.build();
103108
var co = new LangChain4JVectorizer("embed-english-v3.0", cohereModel);
104109

105110
// Embed a search query
@@ -108,9 +113,6 @@ public void testCohereVectorizer() {
108113
assertThat(queryEmbed.length).isEqualTo(1024); // embed-english-v3.0 produces 1024 dims
109114
System.out.println(
110115
"First 10 dimensions: " + Arrays.toString(Arrays.copyOfRange(queryEmbed, 0, 10)));
111-
112-
// Note: LangChain4j Cohere doesn't expose input_type in the same way Python does
113-
// The model handles query vs document distinction internally
114116
}
115117

116118
@Test

notebooks/04_vectorizers.ipynb

Lines changed: 115 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -25,43 +25,7 @@
2525
"execution_count": null,
2626
"metadata": {},
2727
"outputs": [],
28-
"source": [
29-
"// Load Maven dependencies\n",
30-
"%maven redis.clients:jedis:5.2.0\n",
31-
"%maven org.slf4j:slf4j-nop:2.0.16\n",
32-
"%maven com.fasterxml.jackson.core:jackson-databind:2.18.0\n",
33-
"%maven com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.18.0\n",
34-
"%maven com.github.f4b6a3:ulid-creator:5.2.3\n",
35-
"%maven dev.langchain4j:langchain4j:0.36.2\n",
36-
"%maven dev.langchain4j:langchain4j-open-ai:0.36.2\n",
37-
"%maven dev.langchain4j:langchain4j-cohere:0.36.2\n",
38-
"%maven dev.langchain4j:langchain4j-voyage-ai:0.36.2\n",
39-
"%maven com.microsoft.onnxruntime:onnxruntime:1.16.3\n",
40-
"%maven com.squareup.okhttp3:okhttp:4.12.0\n",
41-
"%maven com.google.code.gson:gson:2.10.1\n",
42-
"%maven ai.djl.huggingface:tokenizers:0.30.0\n",
43-
"\n",
44-
"// Note: RedisVL JAR is in classpath (loaded automatically by Docker container)\n",
45-
"\n",
46-
"// Import RedisVL classes\n",
47-
"import com.redis.vl.utils.vectorize.*;\n",
48-
"import com.redis.vl.index.SearchIndex;\n",
49-
"import com.redis.vl.schema.IndexSchema;\n",
50-
"import com.redis.vl.schema.VectorField;\n",
51-
"import com.redis.vl.query.VectorQuery;\n",
52-
"\n",
53-
"// Import Redis client\n",
54-
"import redis.clients.jedis.UnifiedJedis;\n",
55-
"import redis.clients.jedis.HostAndPort;\n",
56-
"\n",
57-
"// Import LangChain4J\n",
58-
"import dev.langchain4j.model.openai.OpenAiEmbeddingModel;\n",
59-
"import dev.langchain4j.model.cohere.CohereEmbeddingModel;\n",
60-
"import dev.langchain4j.model.voyageai.VoyageAiEmbeddingModel;\n",
61-
"\n",
62-
"// Import Java standard libraries\n",
63-
"import java.util.*;"
64-
]
28+
"source": "// Load Maven dependencies\n%maven redis.clients:jedis:6.2.0\n%maven org.slf4j:slf4j-nop:2.0.16\n%maven com.fasterxml.jackson.core:jackson-databind:2.18.0\n%maven com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.18.0\n%maven com.github.f4b6a3:ulid-creator:5.2.3\n%maven dev.langchain4j:langchain4j:0.36.2\n%maven dev.langchain4j:langchain4j-open-ai:0.36.2\n%maven dev.langchain4j:langchain4j-cohere:0.36.2\n%maven dev.langchain4j:langchain4j-voyage-ai:0.36.2\n%maven com.microsoft.onnxruntime:onnxruntime:1.16.3\n%maven com.squareup.okhttp3:okhttp:4.12.0\n%maven com.google.code.gson:gson:2.10.1\n%maven ai.djl.huggingface:tokenizers:0.30.0\n\n// Note: RedisVL JAR is in classpath (loaded automatically by Docker container)\n\n// Import RedisVL classes\nimport com.redis.vl.utils.vectorize.*;\nimport com.redis.vl.index.SearchIndex;\nimport com.redis.vl.schema.IndexSchema;\nimport com.redis.vl.schema.VectorField;\nimport com.redis.vl.query.VectorQuery;\n\n// Import Redis client\nimport redis.clients.jedis.UnifiedJedis;\nimport redis.clients.jedis.HostAndPort;\n\n// Import LangChain4J\nimport dev.langchain4j.model.openai.OpenAiEmbeddingModel;\nimport dev.langchain4j.model.cohere.CohereEmbeddingModel;\nimport dev.langchain4j.model.voyageai.VoyageAiEmbeddingModel;\n\n// Import Java standard libraries\nimport java.util.*;"
6529
},
6630
{
6731
"cell_type": "markdown",
@@ -78,7 +42,7 @@
7842
},
7943
{
8044
"cell_type": "code",
81-
"execution_count": null,
45+
"execution_count": 2,
8246
"metadata": {},
8347
"outputs": [],
8448
"source": [
@@ -103,9 +67,27 @@
10367
},
10468
{
10569
"cell_type": "code",
106-
"execution_count": null,
70+
"execution_count": 3,
10771
"metadata": {},
108-
"outputs": [],
72+
"outputs": [
73+
{
74+
"name": "stderr",
75+
"output_type": "stream",
76+
"text": [
77+
"SLF4J: Failed to load class \"org.slf4j.impl.StaticLoggerBinder\".\n",
78+
"SLF4J: Defaulting to no-operation (NOP) logger implementation\n",
79+
"SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.\n"
80+
]
81+
},
82+
{
83+
"name": "stdout",
84+
"output_type": "stream",
85+
"text": [
86+
"Vector dimensions: 768\n",
87+
"First 10 dimensions: [3.7808096E-4, -0.050803404, -0.035147283, -0.02325103, -0.044158336, 0.020487826, 0.0014618257, 0.031261846, 0.056051537, 0.018815337]\n"
88+
]
89+
}
90+
],
10991
"source": [
11092
"// Create a vectorizer using HuggingFace Sentence Transformers\n",
11193
"// This model runs locally - no API key needed!\n",
@@ -119,9 +101,18 @@
119101
},
120102
{
121103
"cell_type": "code",
122-
"execution_count": null,
104+
"execution_count": 4,
123105
"metadata": {},
124-
"outputs": [],
106+
"outputs": [
107+
{
108+
"name": "stdout",
109+
"output_type": "stream",
110+
"text": [
111+
"Created 3 embeddings\n",
112+
"First embedding (first 10): [-0.019594658, -0.03229537, -0.011301539, 0.019816127, 0.07692131, 0.026120719, -0.04493195, 0.011639726, 1.3129076E-4, -0.006208265]\n"
113+
]
114+
}
115+
],
125116
"source": [
126117
"// Create many embeddings at once\n",
127118
"List<float[]> embeddings = hf.embedBatch(sentences);\n",
@@ -142,9 +133,19 @@
142133
},
143134
{
144135
"cell_type": "code",
145-
"execution_count": null,
136+
"execution_count": 5,
146137
"metadata": {},
147-
"outputs": [],
138+
"outputs": [
139+
{
140+
"name": "stdout",
141+
"output_type": "stream",
142+
"text": [
143+
"OpenAI Vector dimensions: 1536\n",
144+
"First 10 dimensions: [-0.0011391325, -0.0032063872, 0.0023801322, -0.004501554, -0.010328997, 0.012922565, -0.00549112, -0.0029864837, -0.0073279613, -0.033658173]\n",
145+
"Created 3 embeddings\n"
146+
]
147+
}
148+
],
148149
"source": [
149150
"// Get API key from environment\n",
150151
"String apiKey = System.getenv(\"OPENAI_API_KEY\");\n",
@@ -187,23 +188,7 @@
187188
"execution_count": null,
188189
"metadata": {},
189190
"outputs": [],
190-
"source": [
191-
"String cohereApiKey = System.getenv(\"COHERE_API_KEY\");\n",
192-
"if (cohereApiKey == null || cohereApiKey.isEmpty()) {\n",
193-
" System.out.println(\"Skipping Cohere example - COHERE_API_KEY not set\");\n",
194-
"} else {\n",
195-
" var cohereModel = CohereEmbeddingModel.builder()\n",
196-
" .apiKey(cohereApiKey)\n",
197-
" .modelName(\"embed-english-v3.0\")\n",
198-
" .build();\n",
199-
" \n",
200-
" BaseVectorizer co = new LangChain4JVectorizer(\"embed-english-v3.0\", cohereModel);\n",
201-
" \n",
202-
" float[] cohereTest = co.embed(\"This is a test sentence.\");\n",
203-
" System.out.println(\"Cohere Vector dimensions: \" + cohereTest.length);\n",
204-
" System.out.println(\"First 10 dimensions: \" + Arrays.toString(Arrays.copyOfRange(cohereTest, 0, 10)));\n",
205-
"}"
206-
]
191+
"source": "String cohereApiKey = System.getenv(\"COHERE_API_KEY\");\nif (cohereApiKey == null || cohereApiKey.isEmpty()) {\n System.out.println(\"Skipping Cohere example - COHERE_API_KEY not set\");\n} else {\n // Note: Cohere v3 models require inputType to be specified\n var cohereModel = CohereEmbeddingModel.builder()\n .apiKey(cohereApiKey)\n .modelName(\"embed-english-v3.0\")\n .inputType(\"search_query\")\n .build();\n \n BaseVectorizer co = new LangChain4JVectorizer(\"embed-english-v3.0\", cohereModel);\n \n float[] cohereTest = co.embed(\"This is a test sentence.\");\n System.out.println(\"Cohere Vector dimensions: \" + cohereTest.length);\n System.out.println(\"First 10 dimensions: \" + Arrays.toString(Arrays.copyOfRange(cohereTest, 0, 10)));\n}"
207192
},
208193
{
209194
"cell_type": "markdown",
@@ -218,9 +203,18 @@
218203
},
219204
{
220205
"cell_type": "code",
221-
"execution_count": null,
206+
"execution_count": 7,
222207
"metadata": {},
223-
"outputs": [],
208+
"outputs": [
209+
{
210+
"name": "stdout",
211+
"output_type": "stream",
212+
"text": [
213+
"VoyageAI Vector dimensions: 1024\n",
214+
"First 10 dimensions: [0.021035708, 0.029149586, -0.012015127, -0.06209732, -0.004808997, -0.09621606, 0.046154775, -0.006889617, 0.01176895, 0.06111594]\n"
215+
]
216+
}
217+
],
224218
"source": [
225219
"String voyageApiKey = System.getenv(\"VOYAGE_API_KEY\");\n",
226220
"if (voyageApiKey == null || voyageApiKey.isEmpty()) {\n",
@@ -250,9 +244,18 @@
250244
},
251245
{
252246
"cell_type": "code",
253-
"execution_count": null,
247+
"execution_count": 8,
254248
"metadata": {},
255-
"outputs": [],
249+
"outputs": [
250+
{
251+
"name": "stdout",
252+
"output_type": "stream",
253+
"text": [
254+
"Custom vectorizer dimensions: 768\n",
255+
"First 10 values: [0.101, 0.101, 0.101, 0.101, 0.101, 0.101, 0.101, 0.101, 0.101, 0.101]\n"
256+
]
257+
}
258+
],
256259
"source": [
257260
"// Create a simple custom vectorizer\n",
258261
"class CustomVectorizer extends BaseVectorizer {\n",
@@ -298,9 +301,17 @@
298301
},
299302
{
300303
"cell_type": "code",
301-
"execution_count": null,
304+
"execution_count": 9,
302305
"metadata": {},
303-
"outputs": [],
306+
"outputs": [
307+
{
308+
"name": "stdout",
309+
"output_type": "stream",
310+
"text": [
311+
"Index created: vectorizers\n"
312+
]
313+
}
314+
],
304315
"source": [
305316
"// Connect to Redis\n",
306317
"UnifiedJedis jedis = new UnifiedJedis(new HostAndPort(\"redis-stack\", 6379));\n",
@@ -334,9 +345,18 @@
334345
},
335346
{
336347
"cell_type": "code",
337-
"execution_count": null,
348+
"execution_count": 10,
338349
"metadata": {},
339-
"outputs": [],
350+
"outputs": [
351+
{
352+
"name": "stdout",
353+
"output_type": "stream",
354+
"text": [
355+
"Loaded 3 documents\n",
356+
"Keys: [doc:01K70W4SXV6219Z5SD4RJW986B, doc:01K70W4SXV6X0MWYDKENSG41JB, doc:01K70W4SXV35AGKEG2MK8CZRPZ]\n"
357+
]
358+
}
359+
],
340360
"source": [
341361
"// Create embeddings for our sentences using HuggingFace\n",
342362
"List<float[]> sentenceEmbeddings = hf.embedBatch(sentences);\n",
@@ -358,9 +378,21 @@
358378
},
359379
{
360380
"cell_type": "code",
361-
"execution_count": null,
381+
"execution_count": 11,
362382
"metadata": {},
363-
"outputs": [],
383+
"outputs": [
384+
{
385+
"name": "stdout",
386+
"output_type": "stream",
387+
"text": [
388+
"\n",
389+
"Search results for: 'That is a happy cat'\n",
390+
"That is a happy dog - Distance: 0.160861968994\n",
391+
"That is a happy person - Distance: 0.273597836494\n",
392+
"Today is a sunny day - Distance: 0.744559407234\n"
393+
]
394+
}
395+
],
364396
"source": [
365397
"// Create a query embedding for \"That is a happy cat\"\n",
366398
"float[] queryEmbedding = hf.embed(\"That is a happy cat\");\n",
@@ -390,9 +422,17 @@
390422
},
391423
{
392424
"cell_type": "code",
393-
"execution_count": null,
425+
"execution_count": 12,
394426
"metadata": {},
395-
"outputs": [],
427+
"outputs": [
428+
{
429+
"name": "stdout",
430+
"output_type": "stream",
431+
"text": [
432+
"Index deleted and connection closed\n"
433+
]
434+
}
435+
],
396436
"source": [
397437
"// Cleanup\n",
398438
"index.delete(true);\n",
@@ -428,9 +468,9 @@
428468
"mimetype": "text/x-java-source",
429469
"name": "Java",
430470
"pygments_lexer": "java",
431-
"version": "21+35"
471+
"version": "21.0.8+9-Ubuntu-0ubuntu124.04.1"
432472
}
433473
},
434474
"nbformat": 4,
435475
"nbformat_minor": 4
436-
}
476+
}

0 commit comments

Comments
 (0)