Skip to content

Semantic Text Chunking Indexing Pressure #125517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
bc62301
Added circuit breaker
Mikep86 Mar 19, 2025
0a58daa
Pass circuit breaker to action filter
Mikep86 Mar 20, 2025
a77516f
Estimate memory usage before performing inference
Mikep86 Mar 20, 2025
fee2c90
Merge branch 'main' into semantic-text_oom-circuit-breaker
Mikep86 Mar 20, 2025
83395f4
Reset circuit breaker on completion of request handling
Mikep86 Mar 20, 2025
f84ff66
Calculate actual memory usage
Mikep86 Mar 20, 2025
8698ecd
Spotless
Mikep86 Mar 20, 2025
c7b1af1
Added TODOs
Mikep86 Mar 20, 2025
6458bb3
Added more comments
Mikep86 Mar 20, 2025
6b7db55
Merge branch 'main' into semantic-text_oom-circuit-breaker
Mikep86 Mar 21, 2025
f4a4689
Track memory usage of requests that don't perform inference
Mikep86 Mar 21, 2025
c74ca3c
Fix test failures
Mikep86 Mar 21, 2025
f5e8a94
Add circuit breaker unit test
Mikep86 Mar 21, 2025
1d5e5bd
Circuit breaker test development
Mikep86 Mar 21, 2025
d93050f
Fix memory usage tracking in estimateMemoryUsage
Mikep86 Mar 24, 2025
2480955
Make circuit breaker limit setting dynamically updatable
Mikep86 Mar 24, 2025
5d76384
Updated estimateMemoryUsage to throw InferenceException
Mikep86 Mar 24, 2025
4bcff47
Updated InferenceException to retain the original message when it is …
Mikep86 Mar 24, 2025
080ae60
Added circuit breaker trips on estimated inference bytes unit test
Mikep86 Mar 24, 2025
30e0a08
[CI] Auto commit changes from spotless
elasticsearchmachine Mar 24, 2025
8939687
Increment byte counters after updating breaker
Mikep86 Mar 28, 2025
6071cb0
Check that circuit breaker usage is 0
Mikep86 Mar 28, 2025
0507d94
Add indexing pressure to plugin services
Mikep86 Mar 28, 2025
a643877
Merge branch 'main' into semantic-text_oom-circuit-breaker
Mikep86 Mar 28, 2025
a6200d5
Merge branch 'main' into semantic-text_oom-circuit-breaker
Mikep86 Apr 1, 2025
c211ff6
Pass indexing pressure to action filter
Mikep86 Apr 1, 2025
a106f87
Pass coordinating object to AsyncBulkShardInferenceAction
Mikep86 Apr 1, 2025
4a2976c
Use coordinating indexing pressure in ShardBulkInferenceActionFilter
Mikep86 Apr 1, 2025
56a5f97
Update circuit breaker test
Mikep86 Apr 1, 2025
111bb1f
Update circuit breaker trips on estimated inference bytes test
Mikep86 Apr 1, 2025
02d70a0
Remove inference bytes circuit breaker
Mikep86 Apr 1, 2025
bf9d118
Adjust coordinating indexing pressure lifetime
Mikep86 Apr 2, 2025
456fc59
Merge branch 'main' into semantic-text_oom-circuit-breaker
elasticmachine Apr 2, 2025
fee592b
Account for indexing pressure from source in batches
Mikep86 Apr 3, 2025
437ca6b
Account for indexing pressure from empty chunk inference updates
Mikep86 Apr 4, 2025
96f4037
Add indexing pressure from source modifications
Mikep86 Apr 4, 2025
fcf7387
Fix testIndexingPressure
Mikep86 Apr 4, 2025
e5f64ff
Fix testIndexingPressureTripsOnEstimatedInferenceBytes
Mikep86 Apr 4, 2025
b183b20
Merge branch 'main' into semantic-text_oom-circuit-breaker
Mikep86 Apr 4, 2025
3349168
Cleanup
Mikep86 Apr 4, 2025
8a28093
Fix compilation errors
Mikep86 Apr 4, 2025
bf420c8
Added unit test
Mikep86 Apr 4, 2025
2e479a9
Revert changes to InferenceException
Mikep86 Apr 4, 2025
2330e3e
Resolve TODO
Mikep86 Apr 4, 2025
0bb32cc
Merge branch 'main' into semantic-text_oom-circuit-breaker
elasticmachine Apr 4, 2025
0065a87
Merge branch 'main' into semantic-text_oom-circuit-breaker
elasticmachine Apr 7, 2025
0ff48cf
Resolve TODOs
Mikep86 Apr 8, 2025
ecb8e02
Pass indexing pressure in constructor
Mikep86 Apr 8, 2025
6a3a4fd
Merge branch 'main' into semantic-text_oom-circuit-breaker
elasticmachine Apr 8, 2025
f4aef73
Added partial failure test
Mikep86 Apr 8, 2025
118c27f
Merge branch 'main' into semantic-text_oom-circuit-breaker
elasticmachine Apr 8, 2025
8cc4402
Update docs/changelog/125517.yaml
Mikep86 Apr 11, 2025
0a138f2
Fix changelog
Mikep86 Apr 11, 2025
e8742d2
Merge branch 'main' into semantic-text_oom-circuit-breaker
Mikep86 Apr 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125517.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125517
summary: Semantic Text Chunking Indexing Pressure
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,15 @@ public Incremental startIncrementalCoordinating(int operations, long bytes, bool
}

public Coordinating markCoordinatingOperationStarted(int operations, long bytes, boolean forceExecution) {
Coordinating coordinating = new Coordinating(forceExecution);
Coordinating coordinating = createCoordinatingOperation(forceExecution);
coordinating.increment(operations, bytes);
return coordinating;
}

public Coordinating createCoordinatingOperation(boolean forceExecution) {
return new Coordinating(forceExecution);
}

public class Incremental implements Releasable {

private final AtomicBoolean closed = new AtomicBoolean();
Expand Down Expand Up @@ -254,7 +258,7 @@ public Coordinating(boolean forceExecution) {
this.forceExecution = forceExecution;
}

private void increment(int operations, long bytes) {
public void increment(int operations, long bytes) {
assert closed.get() == false;
long combinedBytes = currentCombinedCoordinatingAndPrimaryBytes.addAndGet(bytes);
long replicaWriteBytes = currentReplicaBytes.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,8 @@ public Map<String, String> queryFields() {
metadataCreateIndexService
);

final IndexingPressure indexingLimits = new IndexingPressure(settings);

PluginServiceInstances pluginServices = new PluginServiceInstances(
client,
clusterService,
Expand All @@ -950,7 +952,8 @@ public Map<String, String> queryFields() {
documentParsingProvider,
taskManager,
projectResolver,
slowLogFieldProvider
slowLogFieldProvider,
indexingLimits
);

Collection<?> pluginComponents = pluginsService.flatMap(plugin -> {
Expand Down Expand Up @@ -983,7 +986,6 @@ public Map<String, String> queryFields() {
.map(TerminationHandlerProvider::handler);
terminationHandler = getSinglePlugin(terminationHandlers, TerminationHandler.class).orElse(null);

final IndexingPressure indexingLimits = new IndexingPressure(settings);
final IncrementalBulkService incrementalBulkService = new IncrementalBulkService(client, indexingLimits);

final ResponseCollectorService responseCollectorService = new ResponseCollectorService(clusterService);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.env.Environment;
import org.elasticsearch.env.NodeEnvironment;
import org.elasticsearch.features.FeatureService;
import org.elasticsearch.index.IndexingPressure;
import org.elasticsearch.index.SlowLogFieldProvider;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.indices.SystemIndices;
Expand Down Expand Up @@ -55,5 +56,6 @@ public record PluginServiceInstances(
DocumentParsingProvider documentParsingProvider,
TaskManager taskManager,
ProjectResolver projectResolver,
SlowLogFieldProvider slowLogFieldProvider
SlowLogFieldProvider slowLogFieldProvider,
IndexingPressure indexingPressure
) implements Plugin.PluginServices {}
6 changes: 6 additions & 0 deletions server/src/main/java/org/elasticsearch/plugins/Plugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.features.FeatureService;
import org.elasticsearch.index.IndexModule;
import org.elasticsearch.index.IndexSettingProvider;
import org.elasticsearch.index.IndexingPressure;
import org.elasticsearch.index.SlowLogFieldProvider;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.indices.SystemIndices;
Expand Down Expand Up @@ -186,6 +187,11 @@ public interface PluginServices {
* Provider for additional SlowLog fields
*/
SlowLogFieldProvider slowLogFieldProvider();

/**
* Provider for indexing pressure
*/
IndexingPressure indexingPressure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add this to PluginServices? Can't this be done in the standard way? (via CircuitBreakerService) or in another way that does not add yet another service here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a circuit breaker though, so we can't use the CircuitBreakerService pattern. We could theoretically create some new IndexingPressureService pattern, but I think that would result in more changes and complexity for the same end result.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got confused by the title, and assumed IndexingPressure was using a CircuitBreaker internally that it could expose. Is this using a different mechanism?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is using coordinating indexing pressure, which was previously only used internally in the bulk action. We decided to extend its usage to here as well because the inference performed in ShardBulkInferenceActionFilter increases indexing pressure.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,13 @@ public Collection<?> createComponents(PluginServices services) {
}
inferenceServiceRegistry.set(serviceRegistry);

var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
var actionFilter = new ShardBulkInferenceActionFilter(
services.clusterService(),
serviceRegistry,
modelRegistry,
getLicenseState(),
services.indexingPressure()
);
shardBulkInferenceActionFilter.set(actionFilter);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexingPressure;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
Expand Down Expand Up @@ -109,18 +111,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private final InferenceServiceRegistry inferenceServiceRegistry;
private final ModelRegistry modelRegistry;
private final XPackLicenseState licenseState;
private final IndexingPressure indexingPressure;
private volatile long batchSizeInBytes;

public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
XPackLicenseState licenseState
XPackLicenseState licenseState,
IndexingPressure indexingPressure
) {
this.clusterService = clusterService;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry;
this.licenseState = licenseState;
this.indexingPressure = indexingPressure;
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
}
Expand All @@ -146,8 +151,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap();
if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) {
Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener);
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion);
// Maintain coordinating indexing pressure from inference until the indexing operations are complete
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.createCoordinatingOperation(false);
Runnable onInferenceCompletion = () -> chain.proceed(
task,
action,
request,
ActionListener.releaseAfter(listener, coordinatingIndexingPressure)
);
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion, coordinatingIndexingPressure);
return;
}
}
Expand All @@ -157,12 +169,14 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
private void processBulkShardRequest(
Map<String, InferenceFieldMetadata> fieldInferenceMap,
BulkShardRequest bulkShardRequest,
Runnable onCompletion
Runnable onCompletion,
IndexingPressure.Coordinating coordinatingIndexingPressure
) {
final ProjectMetadata project = clusterService.state().getMetadata().getProject();
var index = project.index(bulkShardRequest.index());
boolean useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(index.getSettings()) == false;
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion).run();
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion, coordinatingIndexingPressure)
.run();
}

private record InferenceProvider(InferenceService service, Model model) {}
Expand Down Expand Up @@ -232,18 +246,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
private final BulkShardRequest bulkShardRequest;
private final Runnable onCompletion;
private final AtomicArray<FieldInferenceResponseAccumulator> inferenceResults;
private final IndexingPressure.Coordinating coordinatingIndexingPressure;

private AsyncBulkShardInferenceAction(
boolean useLegacyFormat,
Map<String, InferenceFieldMetadata> fieldInferenceMap,
BulkShardRequest bulkShardRequest,
Runnable onCompletion
Runnable onCompletion,
IndexingPressure.Coordinating coordinatingIndexingPressure
) {
this.useLegacyFormat = useLegacyFormat;
this.fieldInferenceMap = fieldInferenceMap;
this.bulkShardRequest = bulkShardRequest;
this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length);
this.onCompletion = onCompletion;
this.coordinatingIndexingPressure = coordinatingIndexingPressure;
}

@Override
Expand Down Expand Up @@ -431,9 +448,9 @@ public void onFailure(Exception exc) {
*/
private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
boolean isUpdateRequest = false;
final IndexRequest indexRequest;
final IndexRequestWithIndexingPressure indexRequest;
if (item.request() instanceof IndexRequest ir) {
indexRequest = ir;
indexRequest = new IndexRequestWithIndexingPressure(ir);
} else if (item.request() instanceof UpdateRequest updateRequest) {
isUpdateRequest = true;
if (updateRequest.script() != null) {
Expand All @@ -447,13 +464,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
);
return 0;
}
indexRequest = updateRequest.doc();
indexRequest = new IndexRequestWithIndexingPressure(updateRequest.doc());
} else {
// ignore delete request
return 0;
}

final Map<String, Object> docMap = indexRequest.sourceAsMap();
final Map<String, Object> docMap = indexRequest.getIndexRequest().sourceAsMap();
long inputLength = 0;
for (var entry : fieldInferenceMap.values()) {
String field = entry.getName();
Expand Down Expand Up @@ -489,6 +506,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
* This ensures that the field is treated as intentionally cleared,
* preventing any unintended carryover of prior inference results.
*/
if (incrementIndexingPressure(indexRequest, itemIndex) == false) {
return inputLength;
}

var slot = ensureResponseAccumulatorSlot(itemIndex);
slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
Expand All @@ -510,6 +531,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
}
continue;
}

var slot = ensureResponseAccumulatorSlot(itemIndex);
final List<String> values;
try {
Expand All @@ -527,7 +549,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
int offsetAdjustment = 0;
for (String v : values) {
inputLength += v.length();
if (incrementIndexingPressure(indexRequest, itemIndex) == false) {
return inputLength;
}

if (v.isBlank()) {
slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
Expand All @@ -536,6 +561,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
requests.add(
new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings)
);
inputLength += v.length();
}

// When using the inference metadata fields format, all the input values are concatenated so that the
Expand All @@ -545,9 +571,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
}
}
}

return inputLength;
}

private static class IndexRequestWithIndexingPressure {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened to the rough estimation on beforehand?

I understand that this is much simpler, but it feels bad to start a very resource-heavy process and fail halfway through, when a simple back-of-the-envelope calculation could have shown that it was destined to fail...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be simple to follow up add an estimated cost of the inference response to the initial call to incrementIndexingPressure(). We can follow up with that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm totally fine to do that as a follow-up, but I think it's something we should do.

For each request, add an estimate to the circuit breaker upfront. When we get the real value, adjust the circuit breaker accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm all for integrating an estimated cost of the inference response in a follow-up. However, we should stick with the batch-based processing implemented in this PR. Not only is it simpler, but it avoids duplicating expensive operations like parsing the source to a map.

private final IndexRequest indexRequest;
private boolean indexingPressureIncremented;

private IndexRequestWithIndexingPressure(IndexRequest indexRequest) {
this.indexRequest = indexRequest;
this.indexingPressureIncremented = false;
}

private IndexRequest getIndexRequest() {
return indexRequest;
}

private boolean isIndexingPressureIncremented() {
return indexingPressureIncremented;
}

private void setIndexingPressureIncremented() {
this.indexingPressureIncremented = true;
}
}

private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure indexRequest, int itemIndex) {
boolean success = true;
if (indexRequest.isIndexingPressureIncremented() == false) {
try {
// Track operation count as one operation per document source update
coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed());
indexRequest.setIndexingPressureIncremented();
} catch (EsRejectedExecutionException e) {
addInferenceResponseFailure(
itemIndex,
new InferenceException(
"Insufficient memory available to update source on document [" + indexRequest.getIndexRequest().id() + "]",
e
)
);
success = false;
}
}

return success;
}

private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
if (acc == null) {
Expand Down Expand Up @@ -624,6 +695,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
inferenceFieldsMap.put(fieldName, result);
}

BytesReference originalSource = indexRequest.source();
if (useLegacyFormat) {
var newDocMap = indexRequest.sourceAsMap();
for (var entry : inferenceFieldsMap.entrySet()) {
Expand All @@ -636,6 +708,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
indexRequest.source(builder);
}
}
long modifiedSourceSize = indexRequest.source().ramBytesUsed();

// Add the indexing pressure from the source modifications.
// Don't increment operation count because we count one source update as one operation, and we already accounted for those
// in addFieldInferenceRequests.
try {
coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed());
} catch (EsRejectedExecutionException e) {
indexRequest.source(originalSource, indexRequest.getContentType());
item.abort(
item.index(),
new InferenceException(
"Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]",
e
)
);
}
}
}

Expand Down
Loading
Loading