Skip to content

Fix bug in rule_query that resulted in errors when performing text_ex… #105311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/105311.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 105311
summary: Fix bug in `rule_query` that resulted in errors when performing text_expansion queries
area: Application
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,16 @@
import static org.elasticsearch.xpack.searchbusinessrules.PinnedQueryBuilder.Item;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to your changes, but do you think the name Item still fits our purposes? My guess is that when we added the pinned query, we did not think this class will get reused, so having it called Item has not an issue.
But now when I see Item referenced outside of the pinned query, I have to stop and think where that comes from.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question. Item is nice because it's short, but PinnedDocument would probably be a better name in retrospect since Document would be too overloaded. Or PinnedDoc if we thought PinnedDocument was way too long.

I'd be open to a quick rename PR as a followup.


public class AppliedQueryRules {

private final List<String> pinnedIds;
private final List<Item> pinnedDocs;

public AppliedQueryRules() {
this(new ArrayList<>(0), new ArrayList<>(0));
this(new ArrayList<>(0));
}

public AppliedQueryRules(List<String> pinnedIds, List<Item> pinnedDocs) {
this.pinnedIds = pinnedIds;
public AppliedQueryRules(List<Item> pinnedDocs) {
this.pinnedDocs = pinnedDocs;
}

public List<String> pinnedIds() {
return pinnedIds;
}

public List<Item> pinnedDocs() {
return pinnedDocs;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ public AppliedQueryRules applyRule(AppliedQueryRules appliedRules, Map<String, O
throw new UnsupportedOperationException("Only pinned query rules are supported");
}

List<String> matchingPinnedIds = new ArrayList<>();
List<PinnedQueryBuilder.Item> matchingPinnedDocs = new ArrayList<>();
Boolean isRuleMatch = null;

Expand All @@ -302,7 +301,8 @@ public AppliedQueryRules applyRule(AppliedQueryRules appliedRules, Map<String, O

if (isRuleMatch != null && isRuleMatch) {
if (actions.containsKey(IDS_FIELD.getPreferredName())) {
matchingPinnedIds.addAll((List<String>) actions.get(IDS_FIELD.getPreferredName()));
List<String> ids = (List<String>) actions.get(IDS_FIELD.getPreferredName());
matchingPinnedDocs.addAll(ids.stream().map(id -> new PinnedQueryBuilder.Item(null, id)).toList());
} else if (actions.containsKey(DOCS_FIELD.getPreferredName())) {
List<Map<String, String>> docsToPin = (List<Map<String, String>>) actions.get(DOCS_FIELD.getPreferredName());
List<PinnedQueryBuilder.Item> items = docsToPin.stream()
Expand All @@ -317,10 +317,8 @@ public AppliedQueryRules applyRule(AppliedQueryRules appliedRules, Map<String, O
}
}

List<String> pinnedIds = appliedRules.pinnedIds();
List<PinnedQueryBuilder.Item> pinnedDocs = appliedRules.pinnedDocs();
pinnedIds.addAll(matchingPinnedIds);
pinnedDocs.addAll(matchingPinnedDocs);
return new AppliedQueryRules(pinnedIds, pinnedDocs);
return new AppliedQueryRules(pinnedDocs);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ public class RuleQueryBuilder extends AbstractQueryBuilder<RuleQueryBuilder> {
private final String rulesetId;
private final Map<String, Object> matchCriteria;
private final QueryBuilder organicQuery;

private final List<String> pinnedIds;
private final Supplier<List<String>> pinnedIdsSupplier;
private final List<Item> pinnedDocs;
private final Supplier<List<Item>> pinnedDocsSupplier;

Expand All @@ -77,16 +74,14 @@ public TransportVersion getMinimalSupportedVersion() {
}

public RuleQueryBuilder(QueryBuilder organicQuery, Map<String, Object> matchCriteria, String rulesetId) {
this(organicQuery, matchCriteria, rulesetId, null, null, null, null);
this(organicQuery, matchCriteria, rulesetId, null, null);
}

public RuleQueryBuilder(StreamInput in) throws IOException {
super(in);
organicQuery = in.readNamedWriteable(QueryBuilder.class);
matchCriteria = in.readGenericMap();
rulesetId = in.readString();
pinnedIds = in.readOptionalStringCollectionAsList();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cannot be removed without a backward compatibility check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦‍♀️ Thanks for the catch!

pinnedIdsSupplier = null;
pinnedDocs = in.readOptionalCollectionAsList(Item::new);
pinnedDocsSupplier = null;
}
Expand All @@ -95,9 +90,7 @@ private RuleQueryBuilder(
QueryBuilder organicQuery,
Map<String, Object> matchCriteria,
String rulesetId,
List<String> pinnedIds,
List<Item> pinnedDocs,
Supplier<List<String>> pinnedIdsSupplier,
Supplier<List<Item>> pinnedDocsSupplier

) {
Expand All @@ -113,11 +106,6 @@ private RuleQueryBuilder(

// PinnedQueryBuilder will return an error if we attempt to return more than the maximum number of
// pinned hits. Here, we truncate matching rules rather than return an error.
if (pinnedIds != null && pinnedIds.size() > MAX_NUM_PINNED_HITS) {
HeaderWarning.addWarning("Truncating query rule pinned hits to " + MAX_NUM_PINNED_HITS + " documents");
pinnedIds = pinnedIds.subList(0, MAX_NUM_PINNED_HITS);
}

if (pinnedDocs != null && pinnedDocs.size() > MAX_NUM_PINNED_HITS) {
HeaderWarning.addWarning("Truncating query rule pinned hits to " + MAX_NUM_PINNED_HITS + " documents");
pinnedDocs = pinnedDocs.subList(0, MAX_NUM_PINNED_HITS);
Expand All @@ -126,25 +114,19 @@ private RuleQueryBuilder(
this.organicQuery = organicQuery;
this.matchCriteria = matchCriteria;
this.rulesetId = rulesetId;
this.pinnedIds = pinnedIds;
this.pinnedIdsSupplier = pinnedIdsSupplier;
this.pinnedDocs = pinnedDocs;
this.pinnedDocsSupplier = pinnedDocsSupplier;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
if (pinnedIdsSupplier != null) {
throw new IllegalStateException("pinnedIdsSupplier must be null, can't serialize suppliers, missing a rewriteAndFetch?");
}
if (pinnedDocsSupplier != null) {
throw new IllegalStateException("pinnedDocsSupplier must be null, can't serialize suppliers, missing a rewriteAndFetch?");
}

out.writeNamedWriteable(organicQuery);
out.writeGenericMap(matchCriteria);
out.writeString(rulesetId);
out.writeOptionalStringCollection(pinnedIds);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

out.writeOptionalCollection(pinnedDocs);
}

Expand Down Expand Up @@ -173,41 +155,23 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
}

@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
if ((pinnedIds != null && pinnedIds.isEmpty() == false) && (pinnedDocs != null && pinnedDocs.isEmpty() == false)) {
throw new IllegalArgumentException("applied rules contain both pinned ids and pinned docs, only one of ids or docs is allowed");
}

if (pinnedIds != null && pinnedIds.isEmpty() == false) {
PinnedQueryBuilder pinnedQueryBuilder = new PinnedQueryBuilder(organicQuery, pinnedIds.toArray(new String[0]));
return pinnedQueryBuilder.toQuery(context);
} else if (pinnedDocs != null && pinnedDocs.isEmpty() == false) {
PinnedQueryBuilder pinnedQueryBuilder = new PinnedQueryBuilder(organicQuery, pinnedDocs.toArray(new Item[0]));
return pinnedQueryBuilder.toQuery(context);
} else {
return organicQuery.toQuery(context);
}
protected Query doToQuery(SearchExecutionContext context) {
throw new IllegalStateException(NAME + " should have been rewritten to another query type");

}

@SuppressWarnings("unchecked")
@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
if (pinnedIds != null || pinnedDocs != null) {
return this;
} else if (pinnedIdsSupplier != null || pinnedDocsSupplier != null) {
List<String> identifiedPinnedIds = pinnedIdsSupplier != null ? pinnedIdsSupplier.get() : null;
List<Item> identifiedPinnedDocs = pinnedDocsSupplier != null ? pinnedDocsSupplier.get() : null;
if (identifiedPinnedIds == null && identifiedPinnedDocs == null) {
return this; // not executed yet
} else {
return new RuleQueryBuilder(organicQuery, matchCriteria, rulesetId, identifiedPinnedIds, identifiedPinnedDocs, null, null);
}
if (pinnedDocs != null) {
return new PinnedQueryBuilder(organicQuery, pinnedDocs.toArray(new Item[0]));
} else if (pinnedDocsSupplier != null) {
List<Item> identifiedPinnedDocs = pinnedDocsSupplier.get();
return identifiedPinnedDocs != null ? new PinnedQueryBuilder(organicQuery, identifiedPinnedDocs.toArray(new Item[0])) : this;
}

// Identify matching rules and apply them as applicable
GetRequest getRequest = new GetRequest(QueryRulesIndexService.QUERY_RULES_ALIAS_NAME, rulesetId);
SetOnce<List<String>> pinnedIdsSetOnce = new SetOnce<>();
SetOnce<List<Item>> pinnedDocsSetOnce = new SetOnce<>();
AppliedQueryRules appliedRules = new AppliedQueryRules();

Expand All @@ -227,7 +191,6 @@ public void onResponse(GetResponse getResponse) {
for (QueryRule rule : queryRuleset.rules()) {
rule.applyRule(appliedRules, matchCriteria);
}
pinnedIdsSetOnce.set(appliedRules.pinnedIds().stream().distinct().toList());
pinnedDocsSetOnce.set(appliedRules.pinnedDocs().stream().distinct().toList());
listener.onResponse(null);
}
Expand All @@ -245,16 +208,18 @@ public void onFailure(Exception e) {
});

QueryBuilder newOrganicQuery = organicQuery.rewrite(queryRewriteContext);
RuleQueryBuilder rewritten = new RuleQueryBuilder(
newOrganicQuery,
matchCriteria,
this.rulesetId,
null,
null,
pinnedIdsSetOnce::get,
pinnedDocsSetOnce::get
);
List<Item> docsToPin = pinnedDocsSetOnce.get();
QueryBuilder rewritten;

if (docsToPin != null) {
rewritten = docsToPin.isEmpty()
? newOrganicQuery // We've identified there are no documents to pin so let's bypass returning a pinned query
: new RuleQueryBuilder(newOrganicQuery, matchCriteria, rulesetId, docsToPin, null).rewrite(queryRewriteContext);
} else {
rewritten = new RuleQueryBuilder(newOrganicQuery, matchCriteria, this.rulesetId, null, pinnedDocsSetOnce::get);
}
rewritten.boost(this.boost);
rewritten.queryName(this.queryName);
return rewritten;
}

Expand All @@ -265,15 +230,13 @@ protected boolean doEquals(RuleQueryBuilder other) {
return Objects.equals(rulesetId, other.rulesetId)
&& Objects.equals(matchCriteria, other.matchCriteria)
&& Objects.equals(organicQuery, other.organicQuery)
&& Objects.equals(pinnedIds, other.pinnedIds)
&& Objects.equals(pinnedDocs, other.pinnedDocs)
&& Objects.equals(pinnedIdsSupplier, other.pinnedIdsSupplier)
&& Objects.equals(pinnedDocsSupplier, other.pinnedDocsSupplier);
}

@Override
protected int doHashCode() {
return Objects.hash(rulesetId, matchCriteria, organicQuery, pinnedIds, pinnedDocs, pinnedIdsSupplier, pinnedDocsSupplier);
return Objects.hash(rulesetId, matchCriteria, organicQuery, pinnedDocs, pinnedDocsSupplier);
}

private static final ConstructingObjectParser<RuleQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(NAME, a -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import static org.elasticsearch.xpack.application.rules.QueryRuleCriteriaType.EXACT;
import static org.elasticsearch.xpack.application.rules.QueryRuleCriteriaType.PREFIX;
import static org.elasticsearch.xpack.application.rules.QueryRuleCriteriaType.SUFFIX;
import static org.elasticsearch.xpack.searchbusinessrules.PinnedQueryBuilder.Item;
import static org.hamcrest.CoreMatchers.equalTo;

public class QueryRuleTests extends ESTestCase {
Expand Down Expand Up @@ -174,11 +175,11 @@ public void testApplyRuleWithOneCriteria() {
);
AppliedQueryRules appliedQueryRules = new AppliedQueryRules();
rule.applyRule(appliedQueryRules, Map.of("query", "elastic"));
assertEquals(List.of("id1", "id2"), appliedQueryRules.pinnedIds());
assertEquals(List.of(new Item(null, "id1"), new Item(null, "id2")), appliedQueryRules.pinnedDocs());

appliedQueryRules = new AppliedQueryRules();
rule.applyRule(appliedQueryRules, Map.of("query", "elastic1"));
assertEquals(Collections.emptyList(), appliedQueryRules.pinnedIds());
assertEquals(Collections.emptyList(), appliedQueryRules.pinnedDocs());
}

public void testApplyRuleWithMultipleCriteria() {
Expand All @@ -190,11 +191,11 @@ public void testApplyRuleWithMultipleCriteria() {
);
AppliedQueryRules appliedQueryRules = new AppliedQueryRules();
rule.applyRule(appliedQueryRules, Map.of("query", "elastic - you know, for search"));
assertEquals(List.of("id1", "id2"), appliedQueryRules.pinnedIds());
assertEquals(List.of(new Item(null, "id1"), new Item(null, "id2")), appliedQueryRules.pinnedDocs());

appliedQueryRules = new AppliedQueryRules();
rule.applyRule(appliedQueryRules, Map.of("query", "elastic"));
assertEquals(Collections.emptyList(), appliedQueryRules.pinnedIds());
assertEquals(Collections.emptyList(), appliedQueryRules.pinnedDocs());
}

private void assertXContent(QueryRule queryRule, boolean humanReadable) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@

package org.elasticsearch.xpack.application.rules;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.get.GetRequest;
Expand Down Expand Up @@ -181,4 +186,39 @@ protected Map<String, String> getObjectsHoldingArbitraryContent() {
objects.put(RuleQueryBuilder.MATCH_CRITERIA_FIELD.getPreferredName(), null);
return objects;
}

/**
* Overridden to ensure that {@link SearchExecutionContext} has a non-null {@link IndexReader}; this query should always be rewritten
*/
@Override
public void testToQuery() throws IOException {
try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) {
Document document = new Document();
document.add(new StringField("foo", "bar", org.apache.lucene.document.Field.Store.NO));
iw.addDocument(document);
try (IndexReader reader = iw.getReader()) {
SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader));
RuleQueryBuilder queryBuilder = createTestQueryBuilder();
IllegalStateException e = expectThrows(IllegalStateException.class, () -> queryBuilder.toQuery(context));
assertEquals("rule_query should have been rewritten to another query type", e.getMessage());
}
}
}

@Override
public void testMustRewrite() {
SearchExecutionContext context = createSearchExecutionContext();
RuleQueryBuilder builder = createTestQueryBuilder();
IllegalStateException e = expectThrows(IllegalStateException.class, () -> builder.toQuery(context));
assertEquals("rule_query should have been rewritten to another query type", e.getMessage());
}

@Override
public void testCacheability() throws IOException {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This had to be overwritten to avoid throwing indoToQuery

RuleQueryBuilder queryBuilder = createTestQueryBuilder();
SearchExecutionContext context = createSearchExecutionContext();
queryBuilder.rewrite(new SearchExecutionContext(context));
assertTrue("query should be cacheable: " + queryBuilder, context.isCacheable());
}

}