-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Semantic Text Chunking Indexing Pressure #125517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bc62301
0a58daa
a77516f
fee2c90
83395f4
f84ff66
8698ecd
c7b1af1
6458bb3
6b7db55
f4a4689
c74ca3c
f5e8a94
1d5e5bd
d93050f
2480955
5d76384
4bcff47
080ae60
30e0a08
8939687
6071cb0
0507d94
a643877
a6200d5
c211ff6
a106f87
4a2976c
56a5f97
111bb1f
02d70a0
bf9d118
456fc59
fee592b
437ca6b
96f4037
fcf7387
e5f64ff
b183b20
3349168
8a28093
bf420c8
2e479a9
2330e3e
0bb32cc
0065a87
0ff48cf
ecb8e02
6a3a4fd
f4aef73
118c27f
8cc4402
0a138f2
e8742d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 125517 | ||
summary: Semantic Text Chunking Indexing Pressure | ||
area: Machine Learning | ||
type: enhancement | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,11 +29,13 @@ | |
import org.elasticsearch.common.settings.Setting; | ||
import org.elasticsearch.common.unit.ByteSizeValue; | ||
import org.elasticsearch.common.util.concurrent.AtomicArray; | ||
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; | ||
import org.elasticsearch.common.xcontent.XContentHelper; | ||
import org.elasticsearch.common.xcontent.support.XContentMapValues; | ||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.core.Releasable; | ||
import org.elasticsearch.core.TimeValue; | ||
import org.elasticsearch.index.IndexingPressure; | ||
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; | ||
import org.elasticsearch.inference.ChunkInferenceInput; | ||
import org.elasticsearch.inference.ChunkedInference; | ||
|
@@ -109,18 +111,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { | |
private final InferenceServiceRegistry inferenceServiceRegistry; | ||
private final ModelRegistry modelRegistry; | ||
private final XPackLicenseState licenseState; | ||
private final IndexingPressure indexingPressure; | ||
private volatile long batchSizeInBytes; | ||
|
||
public ShardBulkInferenceActionFilter( | ||
ClusterService clusterService, | ||
InferenceServiceRegistry inferenceServiceRegistry, | ||
ModelRegistry modelRegistry, | ||
XPackLicenseState licenseState | ||
XPackLicenseState licenseState, | ||
IndexingPressure indexingPressure | ||
) { | ||
this.clusterService = clusterService; | ||
this.inferenceServiceRegistry = inferenceServiceRegistry; | ||
this.modelRegistry = modelRegistry; | ||
this.licenseState = licenseState; | ||
this.indexingPressure = indexingPressure; | ||
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes(); | ||
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize); | ||
} | ||
|
@@ -146,8 +151,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app | |
BulkShardRequest bulkShardRequest = (BulkShardRequest) request; | ||
var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap(); | ||
if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { | ||
Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); | ||
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); | ||
// Maintain coordinating indexing pressure from inference until the indexing operations are complete | ||
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.createCoordinatingOperation(false); | ||
Runnable onInferenceCompletion = () -> chain.proceed( | ||
task, | ||
action, | ||
request, | ||
ActionListener.releaseAfter(listener, coordinatingIndexingPressure) | ||
); | ||
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion, coordinatingIndexingPressure); | ||
return; | ||
} | ||
} | ||
|
@@ -157,12 +169,14 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app | |
private void processBulkShardRequest( | ||
Map<String, InferenceFieldMetadata> fieldInferenceMap, | ||
BulkShardRequest bulkShardRequest, | ||
Runnable onCompletion | ||
Runnable onCompletion, | ||
IndexingPressure.Coordinating coordinatingIndexingPressure | ||
) { | ||
final ProjectMetadata project = clusterService.state().getMetadata().getProject(); | ||
var index = project.index(bulkShardRequest.index()); | ||
boolean useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(index.getSettings()) == false; | ||
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion).run(); | ||
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion, coordinatingIndexingPressure) | ||
.run(); | ||
} | ||
|
||
private record InferenceProvider(InferenceService service, Model model) {} | ||
|
@@ -232,18 +246,21 @@ private class AsyncBulkShardInferenceAction implements Runnable { | |
private final BulkShardRequest bulkShardRequest; | ||
private final Runnable onCompletion; | ||
private final AtomicArray<FieldInferenceResponseAccumulator> inferenceResults; | ||
private final IndexingPressure.Coordinating coordinatingIndexingPressure; | ||
|
||
private AsyncBulkShardInferenceAction( | ||
boolean useLegacyFormat, | ||
Map<String, InferenceFieldMetadata> fieldInferenceMap, | ||
BulkShardRequest bulkShardRequest, | ||
Runnable onCompletion | ||
Runnable onCompletion, | ||
IndexingPressure.Coordinating coordinatingIndexingPressure | ||
) { | ||
this.useLegacyFormat = useLegacyFormat; | ||
this.fieldInferenceMap = fieldInferenceMap; | ||
this.bulkShardRequest = bulkShardRequest; | ||
this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); | ||
this.onCompletion = onCompletion; | ||
this.coordinatingIndexingPressure = coordinatingIndexingPressure; | ||
} | ||
|
||
@Override | ||
|
@@ -431,9 +448,9 @@ public void onFailure(Exception exc) { | |
*/ | ||
private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) { | ||
boolean isUpdateRequest = false; | ||
final IndexRequest indexRequest; | ||
final IndexRequestWithIndexingPressure indexRequest; | ||
if (item.request() instanceof IndexRequest ir) { | ||
indexRequest = ir; | ||
indexRequest = new IndexRequestWithIndexingPressure(ir); | ||
} else if (item.request() instanceof UpdateRequest updateRequest) { | ||
isUpdateRequest = true; | ||
if (updateRequest.script() != null) { | ||
|
@@ -447,13 +464,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< | |
); | ||
return 0; | ||
} | ||
indexRequest = updateRequest.doc(); | ||
indexRequest = new IndexRequestWithIndexingPressure(updateRequest.doc()); | ||
} else { | ||
// ignore delete request | ||
return 0; | ||
} | ||
|
||
final Map<String, Object> docMap = indexRequest.sourceAsMap(); | ||
final Map<String, Object> docMap = indexRequest.getIndexRequest().sourceAsMap(); | ||
long inputLength = 0; | ||
for (var entry : fieldInferenceMap.values()) { | ||
String field = entry.getName(); | ||
|
@@ -489,6 +506,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< | |
* This ensures that the field is treated as intentionally cleared, | ||
* preventing any unintended carryover of prior inference results. | ||
*/ | ||
if (incrementIndexingPressure(indexRequest, itemIndex) == false) { | ||
return inputLength; | ||
Mikep86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
var slot = ensureResponseAccumulatorSlot(itemIndex); | ||
slot.addOrUpdateResponse( | ||
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE) | ||
|
@@ -510,6 +531,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< | |
} | ||
continue; | ||
} | ||
|
||
var slot = ensureResponseAccumulatorSlot(itemIndex); | ||
final List<String> values; | ||
try { | ||
|
@@ -527,7 +549,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< | |
List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); | ||
int offsetAdjustment = 0; | ||
for (String v : values) { | ||
inputLength += v.length(); | ||
if (incrementIndexingPressure(indexRequest, itemIndex) == false) { | ||
return inputLength; | ||
} | ||
|
||
if (v.isBlank()) { | ||
slot.addOrUpdateResponse( | ||
new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) | ||
|
@@ -536,6 +561,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< | |
requests.add( | ||
new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings) | ||
); | ||
inputLength += v.length(); | ||
} | ||
|
||
// When using the inference metadata fields format, all the input values are concatenated so that the | ||
|
@@ -545,9 +571,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< | |
} | ||
} | ||
} | ||
|
||
return inputLength; | ||
} | ||
|
||
private static class IndexRequestWithIndexingPressure { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happened to the rough estimation on beforehand? I understand that this is much simpler, but it feels bad to start a very resource-heavy process and fail halfway through, when a simple back-of-the-envelope calculation could have shown that it was destined to fail... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be simple to follow up add an estimated cost of the inference response to the initial call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm totally fine to do that as a follow-up, but I think it's something we should do. For each request, add an estimate to the circuit breaker upfront. When we get the real value, adjust the circuit breaker accordingly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm all for integrating an estimated cost of the inference response in a follow-up. However, we should stick with the batch-based processing implemented in this PR. Not only is it simpler, but it avoids duplicating expensive operations like parsing the source to a map. |
||
private final IndexRequest indexRequest; | ||
private boolean indexingPressureIncremented; | ||
|
||
private IndexRequestWithIndexingPressure(IndexRequest indexRequest) { | ||
this.indexRequest = indexRequest; | ||
this.indexingPressureIncremented = false; | ||
} | ||
|
||
private IndexRequest getIndexRequest() { | ||
return indexRequest; | ||
} | ||
|
||
private boolean isIndexingPressureIncremented() { | ||
return indexingPressureIncremented; | ||
} | ||
|
||
private void setIndexingPressureIncremented() { | ||
this.indexingPressureIncremented = true; | ||
} | ||
} | ||
|
||
private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure indexRequest, int itemIndex) { | ||
boolean success = true; | ||
if (indexRequest.isIndexingPressureIncremented() == false) { | ||
try { | ||
// Track operation count as one operation per document source update | ||
coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed()); | ||
indexRequest.setIndexingPressureIncremented(); | ||
} catch (EsRejectedExecutionException e) { | ||
addInferenceResponseFailure( | ||
itemIndex, | ||
new InferenceException( | ||
"Insufficient memory available to update source on document [" + indexRequest.getIndexRequest().id() + "]", | ||
e | ||
) | ||
); | ||
success = false; | ||
} | ||
} | ||
|
||
return success; | ||
} | ||
|
||
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { | ||
FieldInferenceResponseAccumulator acc = inferenceResults.get(id); | ||
if (acc == null) { | ||
|
@@ -624,6 +695,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons | |
inferenceFieldsMap.put(fieldName, result); | ||
} | ||
|
||
BytesReference originalSource = indexRequest.source(); | ||
if (useLegacyFormat) { | ||
var newDocMap = indexRequest.sourceAsMap(); | ||
for (var entry : inferenceFieldsMap.entrySet()) { | ||
|
@@ -636,6 +708,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons | |
indexRequest.source(builder); | ||
} | ||
} | ||
long modifiedSourceSize = indexRequest.source().ramBytesUsed(); | ||
|
||
// Add the indexing pressure from the source modifications. | ||
// Don't increment operation count because we count one source update as one operation, and we already accounted for those | ||
// in addFieldInferenceRequests. | ||
try { | ||
coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed()); | ||
} catch (EsRejectedExecutionException e) { | ||
indexRequest.source(originalSource, indexRequest.getContentType()); | ||
item.abort( | ||
item.index(), | ||
new InferenceException( | ||
"Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]", | ||
e | ||
) | ||
); | ||
} | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to add this to PluginServices? Can't this be done in the standard way? (via CircuitBreakerService) or in another way that does not add yet another service here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't a circuit breaker though, so we can't use the
CircuitBreakerService
pattern. We could theoretically create some newIndexingPressureService
pattern, but I think that would result in more changes and complexity for the same end result.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got confused by the title, and assumed
IndexingPressure
was using a CircuitBreaker internally that it could expose. Is this using a different mechanism?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is using coordinating indexing pressure, which was previously only used internally in the bulk action. We decided to extend its usage to here as well because the inference performed in
ShardBulkInferenceActionFilter
increases indexing pressure.