|
24 | 24 | import org.apache.commons.lang3.ObjectUtils; |
25 | 25 | import org.apache.commons.lang3.StringUtils; |
26 | 26 | import org.apache.commons.lang3.tuple.Pair; |
| 27 | +import org.apache.hc.core5.http.HttpStatus; |
27 | 28 |
|
28 | 29 | import io.weaviate.client.Config; |
29 | 30 | import io.weaviate.client.base.BaseClient; |
|
38 | 39 | import io.weaviate.client.base.util.GrpcVersionSupport; |
39 | 40 | import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBase; |
40 | 41 | import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch; |
| 42 | +import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch.BatchObjectsReply; |
| 43 | +import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch.BatchObjectsReply.BatchError; |
41 | 44 | import io.weaviate.client.v1.auth.provider.AccessTokenProvider; |
42 | 45 | import io.weaviate.client.v1.batch.grpc.BatchObjectConverter; |
43 | 46 | import io.weaviate.client.v1.batch.model.ObjectGetResponse; |
44 | 47 | import io.weaviate.client.v1.batch.model.ObjectGetResponseStatus; |
45 | 48 | import io.weaviate.client.v1.batch.model.ObjectsBatchRequestBody; |
46 | 49 | import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result; |
| 50 | +import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result.ErrorResponse; |
47 | 51 | import io.weaviate.client.v1.batch.util.ObjectsPath; |
48 | 52 | import io.weaviate.client.v1.data.Data; |
49 | 53 | import io.weaviate.client.v1.data.model.WeaviateObject; |
@@ -320,34 +324,58 @@ private Result<ObjectGetResponse[]> internalGrpcRun(List<WeaviateObject> batch) |
320 | 324 | } finally { |
321 | 325 | grpcClient.shutdown(); |
322 | 326 | } |
| 327 | + return resultFromBatchObjectsReply(batchObjectsReply, batch); |
| 328 | + } |
323 | 329 |
|
324 | | - List<WeaviateErrorMessage> weaviateErrorMessages = batchObjectsReply.getErrorsList().stream() |
325 | | - .map(WeaviateProtoBatch.BatchObjectsReply.BatchError::getError) |
326 | | - .filter(e -> !e.isEmpty()) |
327 | | - .map(msg -> WeaviateErrorMessage.builder().message(msg).build()) |
328 | | - .collect(Collectors.toList()); |
329 | | - |
330 | | - if (!weaviateErrorMessages.isEmpty()) { |
331 | | - WeaviateErrorResponse weaviateErrorResponse = WeaviateErrorResponse.builder() |
332 | | - .code(422).message(StringUtils.join(weaviateErrorMessages, ",")).error(weaviateErrorMessages).build(); |
333 | | - return new Result<>(422, null, weaviateErrorResponse); |
334 | | - } |
335 | | - |
336 | | - ObjectGetResponse[] objectGetResponses = batch.stream().map(o -> { |
337 | | - ObjectGetResponse resp = new ObjectGetResponse(); |
338 | | - resp.setId(o.getId()); |
339 | | - resp.setClassName(o.getClassName()); |
340 | | - resp.setTenant(o.getTenant()); |
341 | | - resp.setVector(o.getVector()); |
342 | | - resp.setVectors(o.getVectors()); |
343 | | - resp.setMultiVectors(o.getMultiVectors()); |
344 | | - ObjectsGetResponseAO2Result result = new ObjectsGetResponseAO2Result(); |
345 | | - result.setStatus(ObjectGetResponseStatus.SUCCESS); |
346 | | - resp.setResult(result); |
347 | | - return resp; |
348 | | - }).toArray(ObjectGetResponse[]::new); |
349 | | - |
350 | | - return new Result<>(200, objectGetResponses, null); |
| 330 | + public static Result<ObjectGetResponse[]> resultFromBatchObjectsReply(BatchObjectsReply reply, |
| 331 | + List<WeaviateObject> batch) { |
| 332 | + Map<Integer, String> errors = reply.getErrorsList() |
| 333 | + .stream().collect(Collectors.toMap(BatchError::getIndex, BatchError::getError)); |
| 334 | + List<WeaviateErrorMessage> errorMessages = new ArrayList<>(); |
| 335 | + WeaviateErrorResponse responseError = null; |
| 336 | + int responseCode = HttpStatus.SC_SUCCESS; |
| 337 | + |
| 338 | + ObjectGetResponse[] responseObjects = new ObjectGetResponse[batch.size()]; |
| 339 | + for (int i = 0; i < responseObjects.length; i++) { |
| 340 | + ObjectGetResponse r = new ObjectGetResponse(); |
| 341 | + ObjectsGetResponseAO2Result insertResult = new ObjectsGetResponseAO2Result(); |
| 342 | + if (errors.containsKey(i)) { |
| 343 | + insertResult.setStatus(ObjectGetResponseStatus.FAILED); |
| 344 | + insertResult.setErrors(new ErrorResponse(errors.get(i))); |
| 345 | + |
| 346 | + errorMessages.add(WeaviateErrorMessage.builder().message(errors.get(i)).build()); |
| 347 | + } else { |
| 348 | + insertResult.setStatus(ObjectGetResponseStatus.SUCCESS); |
| 349 | + |
| 350 | + WeaviateObject batchObject = batch.get(i); |
| 351 | + r.setId(batchObject.getId()); |
| 352 | + r.setClassName(batchObject.getClassName()); |
| 353 | + r.setTenant(batchObject.getTenant()); |
| 354 | + r.setVector(batchObject.getVector()); |
| 355 | + r.setVectors(batchObject.getVectors()); |
| 356 | + r.setMultiVectors(batchObject.getMultiVectors()); |
| 357 | + } |
| 358 | + r.setResult(insertResult); |
| 359 | + responseObjects[i] = r; |
| 360 | + } |
| 361 | + |
| 362 | + if (!errors.isEmpty()) { |
| 363 | + responseCode = HttpStatus.SC_UNPROCESSABLE_CONTENT; |
| 364 | + |
| 365 | + // An important distinction between internalGrpcRun and internalRun |
| 366 | + // is that the regular batching (non-gRPC) method will not surface |
| 367 | + // an error on the "response level" and only report errors on the |
| 368 | + // object level. |
| 369 | + // |
| 370 | + // Because previously internalGrpcRun used to return 422 and a |
| 371 | + // WeaviateErrorResponse on partial errors too, we preserve this |
| 372 | + // behavior for b/c. |
| 373 | + responseError = WeaviateErrorResponse.builder() |
| 374 | + .code(responseCode) |
| 375 | + .message(StringUtils.join(errors.values(), ",")) |
| 376 | + .error(errorMessages).build(); |
| 377 | + } |
| 378 | + return new Result<>(responseCode, responseObjects, responseError); |
351 | 379 | } |
352 | 380 |
|
353 | 381 | private Pair<List<ObjectGetResponse>, List<WeaviateObject>> fetchCreatedAndBuildBatchToReRun( |
|
0 commit comments