Skip to content

Commit 41137a3

Browse files
authored
Prevent concurrent access to local breaker in rerank (#128162) (#128945)
When an async operator receives a response, we can't create new blocks on the responding thread because multiple threads may adjust the local breaker simultaneously, leading to a data race. To address this, we can either use the global breaker or delay block creation in getOutput. While using the global block factory is simpler, I prefer the second option to use the local breaker when possible. Therefore, I opted to keep the results in the queue and create new blocks in getOutput. Our tests didn't catch this issue because: (1) only one block is created in the test, and (2) there is no delay when simulating the inference service. Closes #127638 Closes #127051
1 parent dd4a6d5 commit 41137a3

File tree

5 files changed

+75
-12
lines changed

5 files changed

+75
-12
lines changed

test/framework/src/main/java/org/elasticsearch/indices/CrankyCircuitBreakerService.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public class CrankyCircuitBreakerService extends CircuitBreakerService {
2929
*/
3030
public static final String ERROR_MESSAGE = "cranky breaker";
3131

32-
private final CircuitBreaker breaker = new CircuitBreaker() {
32+
public static final class CrankyCircuitBreaker implements CircuitBreaker {
3333
private final AtomicLong used = new AtomicLong();
3434

3535
@Override
@@ -82,7 +82,9 @@ public Durability getDurability() {
8282
public void setLimitAndOverhead(long limit, double overhead) {
8383

8484
}
85-
};
85+
}
86+
87+
private final CrankyCircuitBreaker breaker = new CrankyCircuitBreaker();
8688

8789
@Override
8890
public CircuitBreaker getBreaker(String name) {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/LocalCircuitBreaker.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public final class LocalCircuitBreaker implements CircuitBreaker, Releasable {
2828
private final long maxOverReservedBytes;
2929
private long reservedBytes;
3030
private final AtomicBoolean closed = new AtomicBoolean(false);
31+
private volatile Thread activeThread;
3132

3233
public record SizeSettings(long overReservedBytes, long maxOverReservedBytes) {
3334
public SizeSettings(Settings settings) {
@@ -57,6 +58,7 @@ public void circuitBreak(String fieldName, long bytesNeeded) {
5758

5859
@Override
5960
public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException {
61+
assert assertSingleThread();
6062
if (bytes <= reservedBytes) {
6163
reservedBytes -= bytes;
6264
maybeReduceReservedBytes();
@@ -68,6 +70,7 @@ public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws Circu
6870

6971
@Override
7072
public void addWithoutBreaking(long bytes) {
73+
assert assertSingleThread();
7174
if (bytes <= reservedBytes) {
7275
reservedBytes -= bytes;
7376
maybeReduceReservedBytes();
@@ -130,6 +133,7 @@ public void setLimitAndOverhead(long limit, double overhead) {
130133

131134
@Override
132135
public void close() {
136+
assert assertSingleThread();
133137
if (closed.compareAndSet(false, true)) {
134138
breaker.addWithoutBreaking(-reservedBytes);
135139
}
@@ -139,4 +143,34 @@ public void close() {
139143
public String toString() {
140144
return "LocalCircuitBreaker[" + reservedBytes + "/" + overReservedBytes + ":" + maxOverReservedBytes + "]";
141145
}
146+
147+
private boolean assertSingleThread() {
148+
Thread activeThread = this.activeThread;
149+
Thread currentThread = Thread.currentThread();
150+
assert activeThread == null || activeThread == currentThread
151+
: "Local breaker must be accessed by a single thread at a time: expected ["
152+
+ activeThread
153+
+ "] != actual ["
154+
+ currentThread
155+
+ "]";
156+
return true;
157+
}
158+
159+
/**
160+
* Marks the beginning of a run loop for assertion purposes.
161+
* Sets the current thread as the only thread allowed to access this breaker.
162+
*/
163+
public boolean assertBeginRunLoop() {
164+
activeThread = Thread.currentThread();
165+
return true;
166+
}
167+
168+
/**
169+
* Marks the end of a run loop for assertion purposes.
170+
* Clears the active thread to allow other threads to access this breaker.
171+
*/
172+
public boolean assertEndRunLoop() {
173+
activeThread = null;
174+
return true;
175+
}
142176
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,13 @@ SubscribableListener<Void> run(TimeValue maxTime, int maxIterations, LongSupplie
211211
while (true) {
212212
IsBlockedResult isBlocked = Operator.NOT_BLOCKED;
213213
try {
214+
assert driverContext.assertBeginRunLoop();
214215
isBlocked = runSingleLoopIteration();
215216
} catch (DriverEarlyTerminationException unused) {
216217
closeEarlyFinishedOperators();
217218
assert isFinished() : "not finished after early termination";
219+
} finally {
220+
assert driverContext.assertEndRunLoop();
218221
}
219222
totalIterationsThisRun++;
220223
iterationsSinceLastStatusUpdate++;

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.action.support.SubscribableListener;
1212
import org.elasticsearch.common.breaker.CircuitBreaker;
13-
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
1413
import org.elasticsearch.common.util.BigArrays;
1514
import org.elasticsearch.compute.data.BlockFactory;
15+
import org.elasticsearch.compute.data.LocalCircuitBreaker;
1616
import org.elasticsearch.core.Releasable;
1717
import org.elasticsearch.core.Releasables;
1818

@@ -74,14 +74,6 @@ private DriverContext(BigArrays bigArrays, BlockFactory blockFactory, WarningsMo
7474
this.warningsMode = warningsMode;
7575
}
7676

77-
public static DriverContext getLocalDriver() {
78-
return new DriverContext(
79-
BigArrays.NON_RECYCLING_INSTANCE,
80-
// TODO maybe this should have a small fixed limit?
81-
new BlockFactory(new NoopCircuitBreaker(CircuitBreaker.REQUEST), BigArrays.NON_RECYCLING_INSTANCE)
82-
);
83-
}
84-
8577
public BigArrays bigArrays() {
8678
return bigArrays;
8779
}
@@ -208,6 +200,26 @@ public enum WarningsMode {
208200
IGNORE
209201
}
210202

203+
/**
204+
* Marks the beginning of a run loop for assertion purposes.
205+
*/
206+
public boolean assertBeginRunLoop() {
207+
if (blockFactory.breaker() instanceof LocalCircuitBreaker localBreaker) {
208+
assert localBreaker.assertBeginRunLoop();
209+
}
210+
return true;
211+
}
212+
213+
/**
214+
* Marks the end of a run loop for assertion purposes.
215+
*/
216+
public boolean assertEndRunLoop() {
217+
if (blockFactory.breaker() instanceof LocalCircuitBreaker localBreaker) {
218+
assert localBreaker.assertEndRunLoop();
219+
}
220+
return true;
221+
}
222+
211223
private static class AsyncActions {
212224
private final SubscribableListener<Void> completion = new SubscribableListener<>();
213225
private final AtomicBoolean finished = new AtomicBoolean();

x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/OperatorTestCase.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.compute.data.Block;
2424
import org.elasticsearch.compute.data.BlockFactory;
2525
import org.elasticsearch.compute.data.Page;
26+
import org.elasticsearch.compute.operator.AsyncOperator;
2627
import org.elasticsearch.compute.operator.Driver;
2728
import org.elasticsearch.compute.operator.DriverContext;
2829
import org.elasticsearch.compute.operator.DriverRunner;
@@ -249,9 +250,20 @@ public void testSimpleFinishClose() {
249250
try (var operator = simple().get(driverContext)) {
250251
assert operator.needsInput();
251252
for (Page page : input) {
252-
operator.addInput(page);
253+
if (operator.needsInput()) {
254+
operator.addInput(page);
255+
} else {
256+
page.releaseBlocks();
257+
}
253258
}
254259
operator.finish();
260+
// for async operator, we need to wait for async actions to finish.
261+
if (operator instanceof AsyncOperator<?> || randomBoolean()) {
262+
driverContext.finish();
263+
PlainActionFuture<Void> waitForAsync = new PlainActionFuture<>();
264+
driverContext.waitForAsyncActions(waitForAsync);
265+
waitForAsync.actionGet(TimeValue.timeValueSeconds(30));
266+
}
255267
}
256268
}
257269

0 commit comments

Comments
 (0)