Skip to content

Commit

Permalink
Add Support for Hybrid Query Type (#850)
Browse files Browse the repository at this point in the history
* Add Support for Hybrid Query Type

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Add samples, guide and integ tests

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Removing wildcard imports

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Adding import

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Adding import

Signed-off-by: Varun Jain <varunudr@amazon.com>

---------

Signed-off-by: Varun Jain <varunudr@amazon.com>
(cherry picked from commit 821dae6)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Feb 20, 2024
1 parent c32c230 commit 49f0392
Show file tree
Hide file tree
Showing 9 changed files with 324 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
## [Unreleased 2.x]
### Added
- Add search role type for nodes in cluster stats ([#848](https://github.com/opensearch-project/opensearch-java/pull/848))
- Add support for Hybrid query type ([#850](https://github.com/opensearch-project/opensearch-java/pull/850))

### Dependencies

Expand Down
19 changes: 19 additions & 0 deletions guides/search.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ for (int i = 0; i < searchResponse.hits().hits().size(); i++) {
}
```

### Search documents using a hybrid query
```java
Query searchQuery = Query.of(
h -> h.hybrid(
q -> q.queries(Arrays.asList(
new MatchQuery.Builder().field("text").query(FieldValue.of("Text for document 2")).build().toQuery(),
new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(),
new NeuralQuery.Builder().field("passage_embedding").queryText("Hi world").modelId("bQ1J8ooBpBj3wT4HVUsb").k(100).build().toQuery()
)
)
)
);
SearchRequest searchRequest = new SearchRequest.Builder().query(searchQuery).build();
SearchResponse<IndexData> searchResponse = client.search(searchRequest, IndexData.class);
for (var hit : searchResponse.hits().hits()) {
LOGGER.info("Found {} with score {}", hit.source(), hit.score());
}
```

### Search documents using suggesters

[AppData](../samples/src/main/java/org/opensearch/client/samples/util/AppData.java) refers to the sample data class used in the below samples.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import java.util.List;
import java.util.function.Function;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;

public class HybridQuery extends QueryBase implements QueryVariant {
private final List<Query> queries;

private HybridQuery(HybridQuery.Builder builder) {
super(builder);
this.queries = ApiTypeHelper.unmodifiable(builder.queries);
}

public static HybridQuery of(Function<HybridQuery.Builder, ObjectBuilder<HybridQuery>> fn) {
return fn.apply(new HybridQuery.Builder()).build();
}

/**
* Required - list of search queries.
*
* @return list of queries provided under hybrid clause.
*/
public final List<Query> queries() {
return this.queries;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
super.serializeInternal(generator, mapper);
generator.writeKey("queries");
generator.writeStartArray();
for (Query item0 : this.queries) {
item0.serialize(generator, mapper);
}
generator.writeEnd();
}

@Override
public Query.Kind _queryKind() {
return Query.Kind.Hybrid;
}

public HybridQuery.Builder toBuilder() {
return new HybridQuery.Builder().queries(queries);
}

public static class Builder extends QueryBase.AbstractBuilder<HybridQuery.Builder> implements ObjectBuilder<HybridQuery> {
private List<Query> queries;

/**
* API name: {@code hybrid}
* <p>
* Adds all elements of <code>list</code> to <code>hybrid</code>.
*/
public final HybridQuery.Builder queries(List<Query> list) {
this.queries = _listAddAll(this.queries, list);
return this;
}

/**
* API name: {@code hybrid}
* <p>
* Adds one or more values to <code>hybrid</code>.
*/
public final HybridQuery.Builder queries(Query value, Query... values) {
this.queries = _listAdd(this.queries, value, values);
return this;
}

/**
* API name: {@code hybrid}
* <p>
* Adds a value to <code>hybrid</code> using a builder lambda.
*/
public final HybridQuery.Builder queries(Function<Query.Builder, ObjectBuilder<Query>> fn) {
return queries(fn.apply(new Query.Builder()).build());
}

@Override
protected Builder self() {
return this;
}

@Override
public HybridQuery build() {
_checkSingleUse();
return new HybridQuery(this);
}
}

public static final JsonpDeserializer<HybridQuery> _DESERIALIZER = ObjectBuilderDeserializer.lazy(
HybridQuery.Builder::new,
HybridQuery::setupHybridQueryDeserializer
);

protected static void setupHybridQueryDeserializer(ObjectDeserializer<HybridQuery.Builder> op) {
setupQueryBaseDeserializer(op);
op.add(HybridQuery.Builder::queries, JsonpDeserializer.arrayDeserializer(Query._DESERIALIZER), "queries");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ public enum Kind implements JsonEnum {

Neural("neural"),

Hybrid("hybrid"),

ParentId("parent_id"),

Percolate("percolate"),
Expand Down Expand Up @@ -725,6 +727,23 @@ public NeuralQuery neural() {
return TaggedUnionUtils.get(this, Kind.Neural);
}

/**
* Is this variant instance of kind {@code hybrid}?
*/
public boolean isHybrid() {
return _kind == Kind.Hybrid;
}

/**
* Get the {@code hybrid} variant value.
*
* @throws IllegalStateException
* if the current variant is not of the {@code hybrid} kind.
*/
public HybridQuery hybrid() {
return TaggedUnionUtils.get(this, Kind.Hybrid);
}

/**
* Is this variant instance of kind {@code parent_id}?
*/
Expand Down Expand Up @@ -1510,6 +1529,16 @@ public ObjectBuilder<Query> neural(Function<NeuralQuery.Builder, ObjectBuilder<N
return this.neural(fn.apply(new NeuralQuery.Builder()).build());
}

public ObjectBuilder<Query> hybrid(HybridQuery v) {
this._kind = Kind.Hybrid;
this._value = v;
return this;
}

public ObjectBuilder<Query> hybrid(Function<HybridQuery.Builder, ObjectBuilder<HybridQuery>> fn) {
return this.hybrid(fn.apply(new HybridQuery.Builder()).build());
}

public ObjectBuilder<Query> parentId(ParentIdQuery v) {
this._kind = Kind.ParentId;
this._value = v;
Expand Down Expand Up @@ -1818,6 +1847,7 @@ protected static void setupQueryDeserializer(ObjectDeserializer<Builder> op) {
op.add(Builder::multiMatch, MultiMatchQuery._DESERIALIZER, "multi_match");
op.add(Builder::nested, NestedQuery._DESERIALIZER, "nested");
op.add(Builder::neural, NeuralQuery._DESERIALIZER, "neural");
op.add(Builder::hybrid, HybridQuery._DESERIALIZER, "hybrid");
op.add(Builder::parentId, ParentIdQuery._DESERIALIZER, "parent_id");
op.add(Builder::percolate, PercolateQuery._DESERIALIZER, "percolate");
op.add(Builder::pinned, PinnedQuery._DESERIALIZER, "pinned");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ public static NeuralQuery.Builder neural() {
return new NeuralQuery.Builder();
}

/**
* Creates a builder for the {@link HybridQuery nested} {@code Query} variant.
*/
public static HybridQuery.Builder hybrid() {
return new HybridQuery.Builder();
}

/**
* Creates a builder for the {@link ParentIdQuery parent_id} {@code Query}
* variant.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.opensearch.client.opensearch._types.query_dsl;

import java.util.Arrays;
import org.junit.Test;
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch.model.ModelTestCase;

public class HybridQueryTest extends ModelTestCase {
@Test
public void toBuilder() {
HybridQuery origin = new HybridQuery.Builder().queries(
Arrays.asList(
new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(),
new NeuralQuery.Builder().field("passage_embedding")
.queryText("Hi world")
.modelId("bQ1J8ooBpBj3wT4HVUsb")
.k(100)
.build()
.toQuery(),
new KnnQuery.Builder().field("passage_embedding").vector(new float[] { 0.01f, 0.02f }).k(2).build().toQuery()
)
).build();
HybridQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,20 @@
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch._types.SortOrder;
import org.opensearch.client.opensearch._types.mapping.Property;
import org.opensearch.client.opensearch._types.query_dsl.MatchQuery;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.client.opensearch._types.query_dsl.TermQuery;
import org.opensearch.client.opensearch.core.SearchRequest;
import org.opensearch.client.opensearch.core.SearchResponse;
import org.opensearch.client.opensearch.indices.DeleteIndexRequest;
import org.opensearch.client.opensearch.indices.SegmentSortOrder;

public abstract class AbstractSearchRequestIT extends OpenSearchJavaClientTestCase {

@Test
public void shouldReturnSearchResults() throws Exception {
final String index = "search_request";
assertTrue(
javaClient().indices()
.create(
b -> b.index(index)
.mappings(
m -> m.properties("name", Property.of(p -> p.keyword(v -> v.docValues(true))))
.properties("size", Property.of(p -> p.keyword(v -> v.docValues(true))))
)
.settings(settings -> settings.sort(s -> s.field("name").order(SegmentSortOrder.Asc)))
)
.acknowledged()
);

createTestDocuments(index);
javaClient().indices().refresh();
createIndex(index);

final Query query = Query.of(
q -> q.bool(
Expand Down Expand Up @@ -72,23 +60,47 @@ public void shouldReturnSearchResults() throws Exception {
}

@Test
public void shouldReturnSearchResultsWithoutStoredFields() throws Exception {
final String index = "search_request";
assertTrue(
javaClient().indices()
.create(
b -> b.index(index)
.mappings(
m -> m.properties("name", Property.of(p -> p.keyword(v -> v.docValues(true))))
.properties("size", Property.of(p -> p.keyword(v -> v.docValues(true))))
)
.settings(settings -> settings.sort(s -> s.field("name").order(SegmentSortOrder.Asc)))
public void hybridSearchShouldReturnSearchResults() throws Exception {
final String index = "hybrid_search_request";
try {
createIndex(index);
final Query query = Query.of(
h -> h.hybrid(
q -> q.queries(Arrays.asList(new MatchQuery.Builder().field("size").query(FieldValue.of("huge")).build().toQuery()))
)
.acknowledged()
);
);

final SearchRequest request = SearchRequest.of(
r -> r.index(index)
.sort(s -> s.field(f -> f.field("name").order(SortOrder.Asc)))
.fields(f -> f.field("name"))
.query(query)
.source(s -> s.fetch(true))
);

final SearchResponse<ShopItem> response = javaClient().search(request, ShopItem.class);
assertEquals(response.hits().hits().size(), 5);

assertTrue(
Arrays.stream(response.hits().hits().get(2).fields().get("name").to(String[].class))
.collect(Collectors.toList())
.contains("hummer")
);
assertTrue(
Arrays.stream(response.hits().hits().get(3).fields().get("name").to(String[].class))
.collect(Collectors.toList())
.contains("jammer")
);
} finally {
DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest.Builder().index(index).build();
javaClient().indices().delete(deleteIndexRequest);
}
}

createTestDocuments(index);
javaClient().indices().refresh();
@Test
public void shouldReturnSearchResultsWithoutStoredFields() throws Exception {
final String index = "search_request";
createIndex(index);

final Query query = Query.of(
q -> q.bool(
Expand Down Expand Up @@ -117,6 +129,23 @@ private void createTestDocuments(String index) throws IOException {
javaClient().create(_1 -> _1.index(index).id("8").document(createItem("nuts", "small", "no", 2)));
}

private void createIndex(String index) throws IOException {
assertTrue(
javaClient().indices()
.create(
b -> b.index(index)
.mappings(
m -> m.properties("name", Property.of(p -> p.keyword(v -> v.docValues(true))))
.properties("size", Property.of(p -> p.keyword(v -> v.docValues(true))))
)
.settings(settings -> settings.sort(s -> s.field("name").order(SegmentSortOrder.Asc)))
)
.acknowledged()
);
createTestDocuments(index);
javaClient().indices().refresh();
}

private ShopItem createItem(String name, String size, String company, int quantity) {
return new ShopItem(name, size, company, quantity);
}
Expand Down
Loading

0 comments on commit 49f0392

Please sign in to comment.