Skip to content

Commit 6aa0111

Browse files
Fix dtype parsing from vectorizer kwargs (#237)
Fixing a bug in how we handled kwargs within the vectorizer classes after a recent change was introduced.
1 parent bdef909 commit 6aa0111

File tree

8 files changed

+85
-29
lines changed

8 files changed

+85
-29
lines changed

redisvl/utils/vectorize/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,13 @@ def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None):
8181
else:
8282
yield seq[pos : pos + size]
8383

84-
def _process_embedding(self, embedding: List[float], as_buffer: bool, **kwargs):
84+
def _process_embedding(
85+
self, embedding: List[float], as_buffer: bool, dtype: Optional[str]
86+
):
8587
if as_buffer:
86-
if "dtype" not in kwargs:
88+
if not dtype:
8789
raise RuntimeError(
8890
"dtype is required if converting from float to byte string."
8991
)
90-
return array_to_buffer(embedding, kwargs["dtype"])
92+
return array_to_buffer(embedding, dtype)
9193
return embedding

redisvl/utils/vectorize/text/azureopenai.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,13 @@ def embed_many(
190190
if len(texts) > 0 and not isinstance(texts[0], str):
191191
raise TypeError("Must pass in a list of str values to embed.")
192192

193+
dtype = kwargs.pop("dtype", None)
194+
193195
embeddings: List = []
194196
for batch in self.batchify(texts, batch_size, preprocess):
195197
response = self._client.embeddings.create(input=batch, model=self.model)
196198
embeddings += [
197-
self._process_embedding(r.embedding, as_buffer, **kwargs)
199+
self._process_embedding(r.embedding, as_buffer, dtype)
198200
for r in response.data
199201
]
200202
return embeddings
@@ -231,8 +233,11 @@ def embed(
231233

232234
if preprocess:
233235
text = preprocess(text)
236+
237+
dtype = kwargs.pop("dtype", None)
238+
234239
result = self._client.embeddings.create(input=[text], model=self.model)
235-
return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs)
240+
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
236241

237242
@retry(
238243
wait=wait_random_exponential(min=1, max=60),
@@ -269,13 +274,15 @@ async def aembed_many(
269274
if len(texts) > 0 and not isinstance(texts[0], str):
270275
raise TypeError("Must pass in a list of str values to embed.")
271276

277+
dtype = kwargs.pop("dtype", None)
278+
272279
embeddings: List = []
273280
for batch in self.batchify(texts, batch_size, preprocess):
274281
response = await self._aclient.embeddings.create(
275282
input=batch, model=self.model
276283
)
277284
embeddings += [
278-
self._process_embedding(r.embedding, as_buffer, **kwargs)
285+
self._process_embedding(r.embedding, as_buffer, dtype)
279286
for r in response.data
280287
]
281288
return embeddings
@@ -312,8 +319,11 @@ async def aembed(
312319

313320
if preprocess:
314321
text = preprocess(text)
322+
323+
dtype = kwargs.pop("dtype", None)
324+
315325
result = await self._aclient.embeddings.create(input=[text], model=self.model)
316-
return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs)
326+
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
317327

318328
@property
319329
def type(self) -> str:

redisvl/utils/vectorize/text/cohere.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,16 @@ def embed(
155155
"Must pass in a str value for cohere embedding input_type. \
156156
See https://docs.cohere.com/reference/embed."
157157
)
158+
158159
if preprocess:
159160
text = preprocess(text)
161+
162+
dtype = kwargs.pop("dtype", None)
163+
160164
embedding = self._client.embed(
161165
texts=[text], model=self.model, input_type=input_type
162166
).embeddings[0]
163-
return self._process_embedding(embedding, as_buffer, **kwargs)
167+
return self._process_embedding(embedding, as_buffer, dtype)
164168

165169
@retry(
166170
wait=wait_random_exponential(min=1, max=60),
@@ -224,13 +228,15 @@ def embed_many(
224228
See https://docs.cohere.com/reference/embed."
225229
)
226230

231+
dtype = kwargs.pop("dtype", None)
232+
227233
embeddings: List = []
228234
for batch in self.batchify(texts, batch_size, preprocess):
229235
response = self._client.embed(
230236
texts=batch, model=self.model, input_type=input_type
231237
)
232238
embeddings += [
233-
self._process_embedding(embedding, as_buffer, **kwargs)
239+
self._process_embedding(embedding, as_buffer, dtype)
234240
for embedding in response.embeddings
235241
]
236242
return embeddings

redisvl/utils/vectorize/text/custom.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,11 @@ def embed(
172172

173173
if preprocess:
174174
text = preprocess(text)
175-
else:
176-
result = self._embed_func(text, **kwargs)
177-
return self._process_embedding(result, as_buffer, **kwargs)
175+
176+
dtype = kwargs.pop("dtype", None)
177+
178+
result = self._embed_func(text, **kwargs)
179+
return self._process_embedding(result, as_buffer, dtype)
178180

179181
def embed_many(
180182
self,
@@ -210,11 +212,13 @@ def embed_many(
210212
if not self._embed_many_func:
211213
raise NotImplementedError
212214

215+
dtype = kwargs.pop("dtype", None)
216+
213217
embeddings: List = []
214218
for batch in self.batchify(texts, batch_size, preprocess):
215219
results = self._embed_many_func(batch, **kwargs)
216220
embeddings += [
217-
self._process_embedding(r, as_buffer, **kwargs) for r in results
221+
self._process_embedding(r, as_buffer, dtype) for r in results
218222
]
219223
return embeddings
220224

@@ -249,9 +253,11 @@ async def aembed(
249253

250254
if preprocess:
251255
text = preprocess(text)
252-
else:
253-
result = await self._aembed_func(text, **kwargs)
254-
return self._process_embedding(result, as_buffer, **kwargs)
256+
257+
dtype = kwargs.pop("dtype", None)
258+
259+
result = await self._aembed_func(text, **kwargs)
260+
return self._process_embedding(result, as_buffer, dtype)
255261

256262
async def aembed_many(
257263
self,
@@ -287,11 +293,13 @@ async def aembed_many(
287293
if not self._aembed_many_func:
288294
raise NotImplementedError
289295

296+
dtype = kwargs.pop("dtype", None)
297+
290298
embeddings: List = []
291299
for batch in self.batchify(texts, batch_size, preprocess):
292300
results = await self._aembed_many_func(batch, **kwargs)
293301
embeddings += [
294-
self._process_embedding(r, as_buffer, **kwargs) for r in results
302+
self._process_embedding(r, as_buffer, dtype) for r in results
295303
]
296304
return embeddings
297305

redisvl/utils/vectorize/text/huggingface.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,11 @@ def embed(
9999

100100
if preprocess:
101101
text = preprocess(text)
102+
103+
dtype = kwargs.pop("dtype", None)
104+
102105
embedding = self._client.encode([text], **kwargs)[0]
103-
return self._process_embedding(embedding.tolist(), as_buffer, **kwargs)
106+
return self._process_embedding(embedding.tolist(), as_buffer, dtype)
104107

105108
def embed_many(
106109
self,
@@ -133,12 +136,14 @@ def embed_many(
133136
if len(texts) > 0 and not isinstance(texts[0], str):
134137
raise TypeError("Must pass in a list of str values to embed.")
135138

139+
dtype = kwargs.pop("dtype", None)
140+
136141
embeddings: List = []
137142
for batch in self.batchify(texts, batch_size, preprocess):
138143
batch_embeddings = self._client.encode(batch, **kwargs)
139144
embeddings.extend(
140145
[
141-
self._process_embedding(embedding.tolist(), as_buffer, **kwargs)
146+
self._process_embedding(embedding.tolist(), as_buffer, dtype)
142147
for embedding in batch_embeddings
143148
]
144149
)

redisvl/utils/vectorize/text/mistral.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,13 @@ def embed_many(
140140
if len(texts) > 0 and not isinstance(texts[0], str):
141141
raise TypeError("Must pass in a list of str values to embed.")
142142

143+
dtype = kwargs.pop("dtype", None)
144+
143145
embeddings: List = []
144146
for batch in self.batchify(texts, batch_size, preprocess):
145147
response = self._client.embeddings(model=self.model, input=batch)
146148
embeddings += [
147-
self._process_embedding(r.embedding, as_buffer, **kwargs)
149+
self._process_embedding(r.embedding, as_buffer, dtype)
148150
for r in response.data
149151
]
150152
return embeddings
@@ -181,8 +183,11 @@ def embed(
181183

182184
if preprocess:
183185
text = preprocess(text)
186+
187+
dtype = kwargs.pop("dtype", None)
188+
184189
result = self._client.embeddings(model=self.model, input=[text])
185-
return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs)
190+
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
186191

187192
@retry(
188193
wait=wait_random_exponential(min=1, max=60),
@@ -219,11 +224,13 @@ async def aembed_many(
219224
if len(texts) > 0 and not isinstance(texts[0], str):
220225
raise TypeError("Must pass in a list of str values to embed.")
221226

227+
dtype = kwargs.pop("dtype", None)
228+
222229
embeddings: List = []
223230
for batch in self.batchify(texts, batch_size, preprocess):
224231
response = await self._aclient.embeddings(model=self.model, input=batch)
225232
embeddings += [
226-
self._process_embedding(r.embedding, as_buffer, **kwargs)
233+
self._process_embedding(r.embedding, as_buffer, dtype)
227234
for r in response.data
228235
]
229236
return embeddings
@@ -260,8 +267,11 @@ async def aembed(
260267

261268
if preprocess:
262269
text = preprocess(text)
270+
271+
dtype = kwargs.pop("dtype", None)
272+
263273
result = await self._aclient.embeddings(model=self.model, input=[text])
264-
return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs)
274+
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
265275

266276
@property
267277
def type(self) -> str:

redisvl/utils/vectorize/text/openai.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,13 @@ def embed_many(
144144
if len(texts) > 0 and not isinstance(texts[0], str):
145145
raise TypeError("Must pass in a list of str values to embed.")
146146

147+
dtype = kwargs.pop("dtype", None)
148+
147149
embeddings: List = []
148150
for batch in self.batchify(texts, batch_size, preprocess):
149151
response = self._client.embeddings.create(input=batch, model=self.model)
150152
embeddings += [
151-
self._process_embedding(r.embedding, as_buffer, **kwargs)
153+
self._process_embedding(r.embedding, as_buffer, dtype)
152154
for r in response.data
153155
]
154156
return embeddings
@@ -185,8 +187,11 @@ def embed(
185187

186188
if preprocess:
187189
text = preprocess(text)
190+
191+
dtype = kwargs.pop("dtype", None)
192+
188193
result = self._client.embeddings.create(input=[text], model=self.model)
189-
return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs)
194+
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
190195

191196
@retry(
192197
wait=wait_random_exponential(min=1, max=60),
@@ -223,13 +228,15 @@ async def aembed_many(
223228
if len(texts) > 0 and not isinstance(texts[0], str):
224229
raise TypeError("Must pass in a list of str values to embed.")
225230

231+
dtype = kwargs.pop("dtype", None)
232+
226233
embeddings: List = []
227234
for batch in self.batchify(texts, batch_size, preprocess):
228235
response = await self._aclient.embeddings.create(
229236
input=batch, model=self.model
230237
)
231238
embeddings += [
232-
self._process_embedding(r.embedding, as_buffer, **kwargs)
239+
self._process_embedding(r.embedding, as_buffer, dtype)
233240
for r in response.data
234241
]
235242
return embeddings
@@ -266,8 +273,11 @@ async def aembed(
266273

267274
if preprocess:
268275
text = preprocess(text)
276+
277+
dtype = kwargs.pop("dtype", None)
278+
269279
result = await self._aclient.embeddings.create(input=[text], model=self.model)
270-
return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs)
280+
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
271281

272282
@property
273283
def type(self) -> str:

redisvl/utils/vectorize/text/vertexai.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,13 @@ def embed_many(
151151
if len(texts) > 0 and not isinstance(texts[0], str):
152152
raise TypeError("Must pass in a list of str values to embed.")
153153

154+
dtype = kwargs.pop("dtype", None)
155+
154156
embeddings: List = []
155157
for batch in self.batchify(texts, batch_size, preprocess):
156158
response = self._client.get_embeddings(batch)
157159
embeddings += [
158-
self._process_embedding(r.values, as_buffer, **kwargs) for r in response
160+
self._process_embedding(r.values, as_buffer, dtype) for r in response
159161
]
160162
return embeddings
161163

@@ -191,8 +193,11 @@ def embed(
191193

192194
if preprocess:
193195
text = preprocess(text)
196+
197+
dtype = kwargs.pop("dtype", None)
198+
194199
result = self._client.get_embeddings([text])
195-
return self._process_embedding(result[0].values, as_buffer, **kwargs)
200+
return self._process_embedding(result[0].values, as_buffer, dtype)
196201

197202
@property
198203
def type(self) -> str:

0 commit comments

Comments
 (0)