Skip to content

Prevent concurrent access to local breaker in rerank #128162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class CrankyCircuitBreakerService extends CircuitBreakerService {
*/
public static final String ERROR_MESSAGE = "cranky breaker";

private final CircuitBreaker breaker = new CircuitBreaker() {
public static final class CrankyCircuitBreaker implements CircuitBreaker {
private final AtomicLong used = new AtomicLong();

@Override
Expand Down Expand Up @@ -82,7 +82,9 @@ public Durability getDurability() {
public void setLimitAndOverhead(long limit, double overhead) {

}
};
}

private final CrankyCircuitBreaker breaker = new CrankyCircuitBreaker();

@Override
public CircuitBreaker getBreaker(String name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public final class LocalCircuitBreaker implements CircuitBreaker, Releasable {
private final long maxOverReservedBytes;
private long reservedBytes;
private final AtomicBoolean closed = new AtomicBoolean(false);
private volatile Thread activeThread;

public record SizeSettings(long overReservedBytes, long maxOverReservedBytes) {
public SizeSettings(Settings settings) {
Expand Down Expand Up @@ -57,6 +58,7 @@ public void circuitBreak(String fieldName, long bytesNeeded) {

@Override
public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException {
assert assertSingleThread();
if (bytes <= reservedBytes) {
reservedBytes -= bytes;
maybeReduceReservedBytes();
Expand All @@ -68,6 +70,7 @@ public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws Circu

@Override
public void addWithoutBreaking(long bytes) {
assert assertSingleThread();
if (bytes <= reservedBytes) {
reservedBytes -= bytes;
maybeReduceReservedBytes();
Expand Down Expand Up @@ -130,6 +133,7 @@ public void setLimitAndOverhead(long limit, double overhead) {

@Override
public void close() {
assert assertSingleThread();
if (closed.compareAndSet(false, true)) {
breaker.addWithoutBreaking(-reservedBytes);
}
Expand All @@ -139,4 +143,34 @@ public void close() {
public String toString() {
return "LocalCircuitBreaker[" + reservedBytes + "/" + overReservedBytes + ":" + maxOverReservedBytes + "]";
}

private boolean assertSingleThread() {
Thread activeThread = this.activeThread;
Thread currentThread = Thread.currentThread();
assert activeThread == null || activeThread == currentThread
: "Local breaker must be accessed by a single thread at a time: expected ["
+ activeThread
+ "] != actual ["
+ currentThread
+ "]";
return true;
}

/**
* Marks the beginning of a run loop for assertion purposes.
* Sets the current thread as the only thread allowed to access this breaker.
*/
public boolean assertBeginRunLoop() {
activeThread = Thread.currentThread();
return true;
}

/**
* Marks the end of a run loop for assertion purposes.
* Clears the active thread to allow other threads to access this breaker.
*/
public boolean assertEndRunLoop() {
activeThread = null;
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,13 @@ SubscribableListener<Void> run(TimeValue maxTime, int maxIterations, LongSupplie
while (true) {
IsBlockedResult isBlocked = Operator.NOT_BLOCKED;
try {
assert driverContext.assertBeginRunLoop();
isBlocked = runSingleLoopIteration();
} catch (DriverEarlyTerminationException unused) {
closeEarlyFinishedOperators();
assert isFinished() : "not finished after early termination";
} finally {
assert driverContext.assertEndRunLoop();
}
totalIterationsThisRun++;
iterationsSinceLastStatusUpdate++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.LocalCircuitBreaker;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;

Expand Down Expand Up @@ -74,14 +74,6 @@ private DriverContext(BigArrays bigArrays, BlockFactory blockFactory, WarningsMo
this.warningsMode = warningsMode;
}

public static DriverContext getLocalDriver() {
return new DriverContext(
BigArrays.NON_RECYCLING_INSTANCE,
// TODO maybe this should have a small fixed limit?
new BlockFactory(new NoopCircuitBreaker(CircuitBreaker.REQUEST), BigArrays.NON_RECYCLING_INSTANCE)
);
}

public BigArrays bigArrays() {
return bigArrays;
}
Expand Down Expand Up @@ -208,6 +200,26 @@ public enum WarningsMode {
IGNORE
}

/**
* Marks the beginning of a run loop for assertion purposes.
*/
public boolean assertBeginRunLoop() {
if (blockFactory.breaker() instanceof LocalCircuitBreaker localBreaker) {
assert localBreaker.assertBeginRunLoop();
}
return true;
}

/**
* Marks the end of a run loop for assertion purposes.
*/
public boolean assertEndRunLoop() {
if (blockFactory.breaker() instanceof LocalCircuitBreaker localBreaker) {
assert localBreaker.assertEndRunLoop();
}
return true;
}

private static class AsyncActions {
private final SubscribableListener<Void> completion = new SubscribableListener<>();
private final AtomicBoolean finished = new AtomicBoolean();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AsyncOperator;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.DriverRunner;
Expand Down Expand Up @@ -247,9 +248,20 @@ public void testSimpleFinishClose() {
try (var operator = simple().get(driverContext)) {
assert operator.needsInput();
for (Page page : input) {
operator.addInput(page);
if (operator.needsInput()) {
operator.addInput(page);
} else {
page.releaseBlocks();
}
}
operator.finish();
// for async operator, we need to wait for async actions to finish.
if (operator instanceof AsyncOperator<?> || randomBoolean()) {
driverContext.finish();
PlainActionFuture<Void> waitForAsync = new PlainActionFuture<>();
driverContext.waitForAsyncActions(waitForAsync);
waitForAsync.actionGet(TimeValue.timeValueSeconds(30));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@

package org.elasticsearch.compute.test;

import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.LocalCircuitBreaker;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.SinkOperator;
import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.indices.CrankyCircuitBreakerService;

import java.util.List;

Expand All @@ -38,6 +44,20 @@ public static Driver create(
SinkOperator sink,
Releasable releasable
) {
// Do not wrap the local breaker for small local breakers, as the output mights not match expectations.
if (driverContext.breaker() instanceof CrankyCircuitBreakerService.CrankyCircuitBreaker == false
&& driverContext.breaker() instanceof LocalCircuitBreaker == false
&& driverContext.breaker().getLimit() >= ByteSizeValue.ofMb(100).getBytes()
&& Randomness.get().nextBoolean()) {
final int overReservedBytes = Randomness.get().nextInt(1024 * 1024);
final int maxOverReservedBytes = overReservedBytes + Randomness.get().nextInt(1024 * 1024);
var localBreaker = new LocalCircuitBreaker(driverContext.breaker(), overReservedBytes, maxOverReservedBytes);
BlockFactory localBlockFactory = driverContext.blockFactory().newChildFactory(localBreaker);
driverContext = new DriverContext(localBlockFactory.bigArrays(), localBlockFactory);
}
if (driverContext.breaker() instanceof LocalCircuitBreaker localBreaker) {
releasable = Releasables.wrap(releasable, localBreaker);
}
return new Driver(
"unset",
"test-task",
Expand Down
Loading