Skip to content

Commit 53ea86b

Browse files
authored
fix: return batch errors alongside successfully inserted objects (#358)
1 parent 6f22833 commit 53ea86b

File tree

7 files changed

+323
-225
lines changed

7 files changed

+323
-225
lines changed

src/main/java/io/weaviate/client/v1/async/batch/api/ObjectsBatcher.java

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import org.apache.commons.lang3.ArrayUtils;
2323
import org.apache.commons.lang3.ObjectUtils;
24-
import org.apache.commons.lang3.StringUtils;
2524
import org.apache.commons.lang3.tuple.Pair;
2625
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;
2726
import org.apache.hc.core5.concurrent.FutureCallback;
@@ -328,38 +327,7 @@ private CompletableFuture<Result<ObjectGetResponse[]>> internalGrpcRun(List<Weav
328327
}
329328
}, executor)
330329
.thenApply(batchObjectsReply -> {
331-
List<WeaviateErrorMessage> weaviateErrorMessages = batchObjectsReply.getErrorsList().stream()
332-
.map(WeaviateProtoBatch.BatchObjectsReply.BatchError::getError)
333-
.filter(e -> !e.isEmpty())
334-
.map(msg -> WeaviateErrorMessage.builder().message(msg).build())
335-
.collect(Collectors.toList());
336-
337-
if (!weaviateErrorMessages.isEmpty()) {
338-
int statusCode = HttpStatus.SC_UNPROCESSABLE_CONTENT;
339-
WeaviateErrorResponse weaviateErrorResponse = WeaviateErrorResponse.builder()
340-
.code(statusCode)
341-
.message(StringUtils.join(weaviateErrorMessages, ","))
342-
.error(weaviateErrorMessages)
343-
.build();
344-
return new Result<>(statusCode, null, weaviateErrorResponse);
345-
}
346-
347-
ObjectGetResponse[] objectGetResponses = batch.stream().map(o -> {
348-
ObjectsGetResponseAO2Result result = new ObjectsGetResponseAO2Result();
349-
result.setStatus(ObjectGetResponseStatus.SUCCESS);
350-
351-
ObjectGetResponse resp = new ObjectGetResponse();
352-
resp.setId(o.getId());
353-
resp.setClassName(o.getClassName());
354-
resp.setTenant(o.getTenant());
355-
resp.setVector(o.getVector());
356-
resp.setVectors(o.getVectors());
357-
resp.setMultiVectors(o.getMultiVectors());
358-
resp.setResult(result);
359-
return resp;
360-
}).toArray(ObjectGetResponse[]::new);
361-
362-
return new Result<>(HttpStatus.SC_OK, objectGetResponses, null);
330+
return io.weaviate.client.v1.batch.api.ObjectsBatcher.resultFromBatchObjectsReply(batchObjectsReply, batch);
363331
});
364332
}
365333

src/main/java/io/weaviate/client/v1/batch/api/ObjectsBatcher.java

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.commons.lang3.ObjectUtils;
2525
import org.apache.commons.lang3.StringUtils;
2626
import org.apache.commons.lang3.tuple.Pair;
27+
import org.apache.hc.core5.http.HttpStatus;
2728

2829
import io.weaviate.client.Config;
2930
import io.weaviate.client.base.BaseClient;
@@ -38,12 +39,15 @@
3839
import io.weaviate.client.base.util.GrpcVersionSupport;
3940
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBase;
4041
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
42+
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch.BatchObjectsReply;
43+
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch.BatchObjectsReply.BatchError;
4144
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
4245
import io.weaviate.client.v1.batch.grpc.BatchObjectConverter;
4346
import io.weaviate.client.v1.batch.model.ObjectGetResponse;
4447
import io.weaviate.client.v1.batch.model.ObjectGetResponseStatus;
4548
import io.weaviate.client.v1.batch.model.ObjectsBatchRequestBody;
4649
import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result;
50+
import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result.ErrorResponse;
4751
import io.weaviate.client.v1.batch.util.ObjectsPath;
4852
import io.weaviate.client.v1.data.Data;
4953
import io.weaviate.client.v1.data.model.WeaviateObject;
@@ -320,34 +324,58 @@ private Result<ObjectGetResponse[]> internalGrpcRun(List<WeaviateObject> batch)
320324
} finally {
321325
grpcClient.shutdown();
322326
}
327+
return resultFromBatchObjectsReply(batchObjectsReply, batch);
328+
}
323329

324-
List<WeaviateErrorMessage> weaviateErrorMessages = batchObjectsReply.getErrorsList().stream()
325-
.map(WeaviateProtoBatch.BatchObjectsReply.BatchError::getError)
326-
.filter(e -> !e.isEmpty())
327-
.map(msg -> WeaviateErrorMessage.builder().message(msg).build())
328-
.collect(Collectors.toList());
329-
330-
if (!weaviateErrorMessages.isEmpty()) {
331-
WeaviateErrorResponse weaviateErrorResponse = WeaviateErrorResponse.builder()
332-
.code(422).message(StringUtils.join(weaviateErrorMessages, ",")).error(weaviateErrorMessages).build();
333-
return new Result<>(422, null, weaviateErrorResponse);
334-
}
335-
336-
ObjectGetResponse[] objectGetResponses = batch.stream().map(o -> {
337-
ObjectGetResponse resp = new ObjectGetResponse();
338-
resp.setId(o.getId());
339-
resp.setClassName(o.getClassName());
340-
resp.setTenant(o.getTenant());
341-
resp.setVector(o.getVector());
342-
resp.setVectors(o.getVectors());
343-
resp.setMultiVectors(o.getMultiVectors());
344-
ObjectsGetResponseAO2Result result = new ObjectsGetResponseAO2Result();
345-
result.setStatus(ObjectGetResponseStatus.SUCCESS);
346-
resp.setResult(result);
347-
return resp;
348-
}).toArray(ObjectGetResponse[]::new);
349-
350-
return new Result<>(200, objectGetResponses, null);
330+
public static Result<ObjectGetResponse[]> resultFromBatchObjectsReply(BatchObjectsReply reply,
331+
List<WeaviateObject> batch) {
332+
Map<Integer, String> errors = reply.getErrorsList()
333+
.stream().collect(Collectors.toMap(BatchError::getIndex, BatchError::getError));
334+
List<WeaviateErrorMessage> errorMessages = new ArrayList<>();
335+
WeaviateErrorResponse responseError = null;
336+
int responseCode = HttpStatus.SC_SUCCESS;
337+
338+
ObjectGetResponse[] responseObjects = new ObjectGetResponse[batch.size()];
339+
for (int i = 0; i < responseObjects.length; i++) {
340+
ObjectGetResponse r = new ObjectGetResponse();
341+
ObjectsGetResponseAO2Result insertResult = new ObjectsGetResponseAO2Result();
342+
if (errors.containsKey(i)) {
343+
insertResult.setStatus(ObjectGetResponseStatus.FAILED);
344+
insertResult.setErrors(new ErrorResponse(errors.get(i)));
345+
346+
errorMessages.add(WeaviateErrorMessage.builder().message(errors.get(i)).build());
347+
} else {
348+
insertResult.setStatus(ObjectGetResponseStatus.SUCCESS);
349+
350+
WeaviateObject batchObject = batch.get(i);
351+
r.setId(batchObject.getId());
352+
r.setClassName(batchObject.getClassName());
353+
r.setTenant(batchObject.getTenant());
354+
r.setVector(batchObject.getVector());
355+
r.setVectors(batchObject.getVectors());
356+
r.setMultiVectors(batchObject.getMultiVectors());
357+
}
358+
r.setResult(insertResult);
359+
responseObjects[i] = r;
360+
}
361+
362+
if (!errors.isEmpty()) {
363+
responseCode = HttpStatus.SC_UNPROCESSABLE_CONTENT;
364+
365+
// An important distinction between internalGrpcRun and internalRun
366+
// is that the regular batching (non-gRPC) method will not surface
367+
// an error on the "response level" and only report errors on the
368+
// object level.
369+
//
370+
// Because previously internalGrpcRun used to return 422 and a
371+
// WeaviateErrorResponse on partial errors too, we preserve this
372+
// behavior for b/c.
373+
responseError = WeaviateErrorResponse.builder()
374+
.code(responseCode)
375+
.message(StringUtils.join(errors.values(), ","))
376+
.error(errorMessages).build();
377+
}
378+
return new Result<>(responseCode, responseObjects, responseError);
351379
}
352380

353381
private Pair<List<ObjectGetResponse>, List<WeaviateObject>> fetchCreatedAndBuildBatchToReRun(

src/main/java/io/weaviate/client/v1/batch/model/ObjectsGetResponseAO2Result.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
package io.weaviate.client.v1.batch.model;
22

3+
import java.util.Arrays;
4+
import java.util.List;
5+
import java.util.stream.Collectors;
6+
37
import lombok.AccessLevel;
8+
import lombok.AllArgsConstructor;
49
import lombok.EqualsAndHashCode;
510
import lombok.Getter;
611
import lombok.Setter;
712
import lombok.ToString;
813
import lombok.experimental.FieldDefaults;
914

10-
import java.util.List;
11-
1215
@Getter
1316
@Setter
1417
@ToString
@@ -18,19 +21,23 @@ public class ObjectsGetResponseAO2Result {
1821
ErrorResponse errors;
1922
String status;
2023

21-
2224
@Getter
2325
@ToString
2426
@EqualsAndHashCode
2527
@FieldDefaults(level = AccessLevel.PRIVATE)
2628
public static class ErrorResponse {
2729
List<ErrorItem> error;
30+
31+
public ErrorResponse(String... errors) {
32+
this.error = Arrays.stream(errors).map(ErrorItem::new).collect(Collectors.toList());
33+
}
2834
}
2935

3036
@Getter
3137
@ToString
3238
@EqualsAndHashCode
3339
@FieldDefaults(level = AccessLevel.PRIVATE)
40+
@AllArgsConstructor
3441
public static class ErrorItem {
3542
String message;
3643
}

src/test/java/io/weaviate/integration/client/async/batch/ClientBatchGrpcCreateTest.java

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
package io.weaviate.integration.client.async.batch;
22

3+
import java.util.List;
4+
import java.util.UUID;
5+
import java.util.concurrent.ExecutionException;
6+
import java.util.function.Function;
7+
8+
import org.assertj.core.api.Assertions;
9+
import org.assertj.core.api.InstanceOfAssertFactories;
10+
import org.junit.Before;
11+
import org.junit.ClassRule;
12+
import org.junit.Test;
13+
314
import io.weaviate.client.Config;
415
import io.weaviate.client.WeaviateClient;
516
import io.weaviate.client.base.Result;
@@ -8,26 +19,29 @@
819
import io.weaviate.client.v1.data.model.WeaviateObject;
920
import io.weaviate.client.v1.schema.model.WeaviateClass;
1021
import io.weaviate.integration.client.WeaviateDockerCompose;
22+
import io.weaviate.integration.client.WeaviateTestGenerics;
23+
import io.weaviate.integration.tests.batch.BatchObjectsTestSuite;
1124
import io.weaviate.integration.tests.batch.ClientBatchGrpcCreateTestSuite;
12-
import java.util.List;
13-
import java.util.concurrent.ExecutionException;
14-
import java.util.function.Function;
15-
import org.junit.Before;
16-
import org.junit.ClassRule;
17-
import org.junit.Test;
1825

1926
public class ClientBatchGrpcCreateTest {
2027

2128
private static String httpHost;
2229
private static String grpcHost;
2330

31+
private final WeaviateTestGenerics testGenerics = new WeaviateTestGenerics();
32+
2433
@ClassRule
2534
public static WeaviateDockerCompose compose = new WeaviateDockerCompose();
2635

2736
@Before
2837
public void before() {
2938
httpHost = compose.getHttpHostAddress();
3039
grpcHost = compose.getGrpcHostAddress();
40+
41+
WeaviateClient client = createClient(false);
42+
43+
testGenerics.cleanupWeaviate(client);
44+
testGenerics.createWeaviateTestSchemaFood(client);
3145
}
3246

3347
@Test
@@ -46,8 +60,8 @@ public void shouldCreate(boolean useGRPC) {
4660
Function<WeaviateClass, Result<Boolean>> createClass = (weaviateClass) -> {
4761
try (WeaviateAsyncClient asyncClient = client.async()) {
4862
return asyncClient.schema().classCreator()
49-
.withClass(weaviateClass)
50-
.run().get();
63+
.withClass(weaviateClass)
64+
.run().get();
5165
} catch (InterruptedException | ExecutionException e) {
5266
throw new RuntimeException(e);
5367
}
@@ -56,8 +70,8 @@ public void shouldCreate(boolean useGRPC) {
5670
Function<WeaviateObject[], Result<ObjectGetResponse[]>> batchCreate = (objects) -> {
5771
try (WeaviateAsyncClient asyncClient = client.async()) {
5872
return asyncClient.batch().objectsBatcher()
59-
.withObjects(objects)
60-
.run().get();
73+
.withObjects(objects)
74+
.run().get();
6175
} catch (InterruptedException | ExecutionException e) {
6276
throw new RuntimeException(e);
6377
}
@@ -66,8 +80,8 @@ public void shouldCreate(boolean useGRPC) {
6680
Function<WeaviateObject, Result<List<WeaviateObject>>> fetchObject = (obj) -> {
6781
try (WeaviateAsyncClient asyncClient = client.async()) {
6882
return asyncClient.data().objectsGetter()
69-
.withID(obj.getId()).withClassName(obj.getClassName()).withVector()
70-
.run().get();
83+
.withID(obj.getId()).withClassName(obj.getClassName()).withVector()
84+
.run().get();
7185
} catch (InterruptedException | ExecutionException e) {
7286
throw new RuntimeException(e);
7387
}
@@ -84,6 +98,39 @@ public void shouldCreate(boolean useGRPC) {
8498
ClientBatchGrpcCreateTestSuite.shouldCreateBatch(client, createClass, batchCreate, fetchObject, deleteClass);
8599
}
86100

101+
@Test
102+
public void testPartialErrorResponse() throws ExecutionException, InterruptedException {
103+
WeaviateClient syncClient = createClient(true);
104+
105+
try (WeaviateAsyncClient client = syncClient.async()) {
106+
107+
WeaviateObject[] batchObjects = {
108+
WeaviateObject.builder()
109+
.className("Pizza")
110+
.id(UUID.randomUUID().toString())
111+
.properties(BatchObjectsTestSuite.createFoodProperties(1, "This pizza should throw a invalid name error"))
112+
.build(),
113+
WeaviateObject.builder()
114+
.className("Pizza")
115+
.id(UUID.randomUUID().toString())
116+
.properties(BatchObjectsTestSuite.PIZZA_2_PROPS)
117+
.build(),
118+
};
119+
120+
Result<ObjectGetResponse[]> result = client.batch().objectsBatcher()
121+
.withObjects(batchObjects)
122+
.run().get();
123+
124+
Assertions.assertThat(result).as("batch insert result")
125+
.returns(true, Result::hasErrors)
126+
.extracting(Result::getResult).asInstanceOf(InstanceOfAssertFactories.array(ObjectGetResponse[].class))
127+
.hasSameSizeAs(batchObjects).as("all batch objects included in the response");
128+
129+
Assertions.assertThat(result.getResult()[0].getResult().getErrors().getError().get(0).getMessage())
130+
.contains("invalid text property 'name' on class 'Pizza': not a string, but float64");
131+
}
132+
}
133+
87134
private WeaviateClient createClient(Boolean useGRPC) {
88135
Config config = new Config("http", httpHost);
89136
if (useGRPC) {

0 commit comments

Comments
 (0)