Skip to content

Commit

Permalink
search interface supprt not specify metricType (milvus-io#750)
Browse files Browse the repository at this point in the history
Signed-off-by: lentitude2tk <xushuang.hu@zilliz.com>
  • Loading branch information
lentitude2tk authored Feb 1, 2024
1 parent 8163dd7 commit 60364f3
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 89 deletions.
2 changes: 1 addition & 1 deletion src/main/java/io/milvus/param/MetricType.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
*/
public enum MetricType {
None,
INVALID,

// Only for float vectors
L2,
IP,
Expand Down
31 changes: 8 additions & 23 deletions src/main/java/io/milvus/param/ParamUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -197,24 +197,6 @@ public static void CheckNullString(String target, String name) throws ParamExcep
}
}

/**
* Checks if a metric is for float vector.
*
* @param metric metric type
*/
public static boolean IsFloatMetric(MetricType metric) {
return metric == MetricType.L2 || metric == MetricType.IP || metric == MetricType.COSINE;
}

/**
* Checks if a metric is for binary vector.
*
* @param metric metric type
*/
public static boolean IsBinaryMetric(MetricType metric) {
return metric != MetricType.INVALID && !IsFloatMetric(metric);
}

public static class InsertBuilderWrapper {
private InsertRequest.Builder insertBuilder;
private UpsertRequest.Builder upsertBuilder;
Expand Down Expand Up @@ -483,11 +465,6 @@ public static SearchRequest convertSearchParam(@NonNull SearchParam requestParam
.setKey(Constant.TOP_K)
.setValue(String.valueOf(requestParam.getTopK()))
.build())
.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.METRIC_TYPE)
.setValue(requestParam.getMetricType())
.build())
.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.ROUND_DECIMAL)
Expand All @@ -499,6 +476,14 @@ public static SearchRequest convertSearchParam(@NonNull SearchParam requestParam
.setValue(String.valueOf(requestParam.isIgnoreGrowing()))
.build());

if (!Objects.equals(requestParam.getMetricType(), MetricType.None.name())) {
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.METRIC_TYPE)
.setValue(requestParam.getMetricType())
.build());
}

if (null != requestParam.getParams() && !requestParam.getParams().isEmpty()) {
try {
Map<String, Object> paramMap = JacksonUtils.fromJson(requestParam.getParams(),Map.class);
Expand Down
4 changes: 0 additions & 4 deletions src/main/java/io/milvus/param/QueryNodeSingleSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ public QueryNodeSingleSearch build() throws ParamException {
ParamUtils.CheckNullEmptyString(collectionName, "Collection name");
ParamUtils.CheckNullEmptyString(vectorFieldName, "Target field name");

if (metricType == MetricType.INVALID) {
throw new ParamException("Metric type is illegal");
}

if (vectors == null || vectors.isEmpty()) {
throw new ParamException("Target vectors can not be empty");
}
Expand Down
16 changes: 1 addition & 15 deletions src/main/java/io/milvus/param/dml/SearchParam.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public static Builder newBuilder() {
public static class Builder {
private String collectionName;
private final List<String> partitionNames = Lists.newArrayList();
private MetricType metricType = MetricType.L2;
private MetricType metricType = MetricType.None;
private String vectorFieldName;
private Integer topK;
private String expr = "";
Expand Down Expand Up @@ -287,10 +287,6 @@ public SearchParam build() throws ParamException {
throw new ParamException("The guarantee timestamp must be greater than 0");
}

if (metricType == MetricType.INVALID) {
throw new ParamException("Metric type is invalid");
}

if (vectors == null || vectors.isEmpty()) {
throw new ParamException("Target vectors can not be empty");
}
Expand All @@ -309,11 +305,6 @@ public SearchParam build() throws ParamException {
throw new ParamException("Target vector dimension must be equal");
}
}

// check metric type
if (!ParamUtils.IsFloatMetric(metricType)) {
throw new ParamException("Target vector is float but metric type is incorrect");
}
} else if (vectors.get(0) instanceof ByteBuffer) {
// binary vectors
ByteBuffer first = (ByteBuffer) vectors.get(0);
Expand All @@ -324,11 +315,6 @@ public SearchParam build() throws ParamException {
throw new ParamException("Target vector dimension must be equal");
}
}

// check metric type
if (!ParamUtils.IsBinaryMetric(metricType)) {
throw new ParamException("Target vector is binary but metric type is incorrect");
}
} else {
throw new ParamException("Target vector type must be List<Float> or ByteBuffer");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public MetricType getMetricType() {
return MetricType.valueOf(params.get(Constant.METRIC_TYPE));
}

return MetricType.INVALID;
return MetricType.None;
}

public String getExtraParam() {
Expand Down
48 changes: 3 additions & 45 deletions src/test/java/io/milvus/client/MilvusServiceClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1973,20 +1973,6 @@ void searchParam() {
.build()
);

// metric type is invalid
assertThrows(ParamException.class, () -> SearchParam.newBuilder()
.withCollectionName("collection1")
.withPartitionNames(partitions)
.withParams("{}")
.withOutFields(outputFields)
.withVectorFieldName("field1")
.withMetricType(MetricType.INVALID)
.withTopK(5)
.withVectors(vectors)
.withExpr("dummy")
.build()
);

// illegal topk value
assertThrows(ParamException.class, () -> SearchParam.newBuilder()
.withCollectionName("collection1")
Expand Down Expand Up @@ -2066,38 +2052,9 @@ void searchParam() {
.withExpr("dummy")
.build()
);

// float vector metric type is illegal
List<List<Float>> vectors2 = Collections.singletonList(vector2);
assertThrows(ParamException.class, () -> SearchParam.newBuilder()
.withCollectionName("collection1")
.withPartitionNames(partitions)
.withParams("{}")
.withOutFields(outputFields)
.withVectorFieldName("field1")
.withMetricType(MetricType.JACCARD)
.withTopK(5)
.withVectors(vectors2)
.withExpr("dummy")
.build()
);

// binary vector metric type is illegal
List<ByteBuffer> binVectors2 = Collections.singletonList(buf2);
assertThrows(ParamException.class, () -> SearchParam.newBuilder()
.withCollectionName("collection1")
.withPartitionNames(partitions)
.withParams("{}")
.withOutFields(outputFields)
.withVectorFieldName("field1")
.withMetricType(MetricType.IP)
.withTopK(5)
.withVectors(binVectors2)
.withExpr("dummy")
.build()
);


// succeed float vector case
List<List<Float>> vectors2 = Collections.singletonList(vector2);
assertDoesNotThrow(() -> SearchParam.newBuilder()
.withCollectionName("collection1")
.withPartitionNames(partitions)
Expand All @@ -2113,6 +2070,7 @@ void searchParam() {
);

// succeed binary vector case
List<ByteBuffer> binVectors2 = Collections.singletonList(buf2);
assertDoesNotThrow(() -> SearchParam.newBuilder()
.withCollectionName("collection1")
.withPartitionNames(partitions)
Expand Down

0 comments on commit 60364f3

Please sign in to comment.