2222import static org .neo4j .driver .internal .async .connection .ChannelAttributes .poolId ;
2323import static org .neo4j .driver .internal .async .connection .ChannelAttributes .setTerminationReason ;
2424import static org .neo4j .driver .internal .util .Futures .asCompletionStage ;
25+ import static org .neo4j .driver .internal .util .LockUtil .executeWithLock ;
2526
2627import io .netty .channel .Channel ;
2728import io .netty .channel .ChannelHandler ;
2829import java .time .Clock ;
2930import java .util .concurrent .CompletableFuture ;
3031import java .util .concurrent .CompletionStage ;
3132import java .util .concurrent .TimeUnit ;
32- import java .util .concurrent .atomic .AtomicReference ;
33+ import java .util .concurrent .locks .Lock ;
34+ import java .util .concurrent .locks .ReentrantLock ;
3335import org .neo4j .driver .Logger ;
3436import org .neo4j .driver .Logging ;
3537import org .neo4j .driver .internal .BoltServerAddress ;
4143import org .neo4j .driver .internal .handlers .ResetResponseHandler ;
4244import org .neo4j .driver .internal .messaging .BoltProtocol ;
4345import org .neo4j .driver .internal .messaging .Message ;
46+ import org .neo4j .driver .internal .messaging .request .DiscardAllMessage ;
47+ import org .neo4j .driver .internal .messaging .request .DiscardMessage ;
48+ import org .neo4j .driver .internal .messaging .request .PullAllMessage ;
49+ import org .neo4j .driver .internal .messaging .request .PullMessage ;
4450import org .neo4j .driver .internal .messaging .request .ResetMessage ;
51+ import org .neo4j .driver .internal .messaging .request .RunWithMetadataMessage ;
4552import org .neo4j .driver .internal .metrics .ListenerEvent ;
4653import org .neo4j .driver .internal .metrics .MetricsListener ;
4754import org .neo4j .driver .internal .spi .Connection ;
5360 */
5461public class NetworkConnection implements Connection {
5562 private final Logger log ;
63+ private final Lock lock ;
5664 private final Channel channel ;
5765 private final InboundMessageDispatcher messageDispatcher ;
5866 private final String serverAgent ;
@@ -61,12 +69,13 @@ public class NetworkConnection implements Connection {
6169 private final ExtendedChannelPool channelPool ;
6270 private final CompletableFuture <Void > releaseFuture ;
6371 private final Clock clock ;
64-
65- private final AtomicReference <Status > status = new AtomicReference <>(Status .OPEN );
6672 private final MetricsListener metricsListener ;
6773 private final ListenerEvent <?> inUseEvent ;
6874
6975 private final Long connectionReadTimeout ;
76+
77+ private Status status = Status .OPEN ;
78+ private TerminationAwareStateLockingExecutor terminationAwareStateLockingExecutor ;
7079 private ChannelHandler connectionReadTimeoutHandler ;
7180
7281 public NetworkConnection (
@@ -76,6 +85,7 @@ public NetworkConnection(
7685 MetricsListener metricsListener ,
7786 Logging logging ) {
7887 this .log = logging .getLog (getClass ());
88+ this .lock = new ReentrantLock ();
7989 this .channel = channel ;
8090 this .messageDispatcher = ChannelAttributes .messageDispatcher (channel );
8191 this .serverAgent = ChannelAttributes .serverAgent (channel );
@@ -93,7 +103,7 @@ public NetworkConnection(
93103
94104 @ Override
95105 public boolean isOpen () {
96- return status . get () == Status .OPEN ;
106+ return executeWithLock ( lock , () -> status == Status .OPEN ) ;
97107 }
98108
99109 @ Override
@@ -110,52 +120,31 @@ public void disableAutoRead() {
110120 }
111121 }
112122
113- @ Override
114- public void flush () {
115- if (verifyOpen (null , null )) {
116- flushInEventLoop ();
117- }
118- }
119-
120123 @ Override
121124 public void write (Message message , ResponseHandler handler ) {
122- if (verifyOpen (handler , null )) {
125+ if (verifyOpen (handler )) {
123126 writeMessageInEventLoop (message , handler , false );
124127 }
125128 }
126129
127- @ Override
128- public void write (Message message1 , ResponseHandler handler1 , Message message2 , ResponseHandler handler2 ) {
129- if (verifyOpen (handler1 , handler2 )) {
130- writeMessagesInEventLoop (message1 , handler1 , message2 , handler2 , false );
131- }
132- }
133-
134130 @ Override
135131 public void writeAndFlush (Message message , ResponseHandler handler ) {
136- if (verifyOpen (handler , null )) {
132+ if (verifyOpen (handler )) {
137133 writeMessageInEventLoop (message , handler , true );
138134 }
139135 }
140136
141137 @ Override
142- public void writeAndFlush (Message message1 , ResponseHandler handler1 , Message message2 , ResponseHandler handler2 ) {
143- if (verifyOpen (handler1 , handler2 )) {
144- writeMessagesInEventLoop (message1 , handler1 , message2 , handler2 , true );
145- }
146- }
147-
148- @ Override
149- public CompletionStage <Void > reset () {
150- CompletableFuture <Void > result = new CompletableFuture <>();
151- ResetResponseHandler handler = new ResetResponseHandler (messageDispatcher , result );
138+ public CompletionStage <Void > reset (Throwable throwable ) {
139+ var result = new CompletableFuture <Void >();
140+ var handler = new ResetResponseHandler (messageDispatcher , result , throwable );
152141 writeResetMessageIfNeeded (handler , true );
153142 return result ;
154143 }
155144
156145 @ Override
157146 public CompletionStage <Void > release () {
158- if (status . compareAndSet ( Status . OPEN , Status .RELEASED )) {
147+ if (executeWithLock ( lock , () -> updateStateIfOpen ( Status .RELEASED ) )) {
159148 ChannelReleasingResetResponseHandler handler = new ChannelReleasingResetResponseHandler (
160149 channel , channelPool , messageDispatcher , clock , releaseFuture );
161150
@@ -167,7 +156,7 @@ public CompletionStage<Void> release() {
167156
168157 @ Override
169158 public void terminateAndRelease (String reason ) {
170- if (status . compareAndSet ( Status . OPEN , Status .TERMINATED )) {
159+ if (executeWithLock ( lock , () -> updateStateIfOpen ( Status .TERMINATED ) )) {
171160 setTerminationReason (channel , reason );
172161 asCompletionStage (channel .close ())
173162 .exceptionally (throwable -> null )
@@ -194,6 +183,25 @@ public BoltProtocol protocol() {
194183 return protocol ;
195184 }
196185
186+ @ Override
187+ public void bindTerminationAwareStateLockingExecutor (TerminationAwareStateLockingExecutor executor ) {
188+ executeWithLock (lock , () -> {
189+ if (this .terminationAwareStateLockingExecutor != null ) {
190+ throw new IllegalStateException ("terminationAwareStateLockingExecutor is already set" );
191+ }
192+ this .terminationAwareStateLockingExecutor = executor ;
193+ });
194+ }
195+
196+ private boolean updateStateIfOpen (Status newStatus ) {
197+ if (Status .OPEN .equals (status )) {
198+ status = newStatus ;
199+ return true ;
200+ } else {
201+ return false ;
202+ }
203+ }
204+
197205 private void writeResetMessageIfNeeded (ResponseHandler resetHandler , boolean isSessionReset ) {
198206 channel .eventLoop ().execute (() -> {
199207 if (isSessionReset && !isOpen ()) {
@@ -208,73 +216,49 @@ private void writeResetMessageIfNeeded(ResponseHandler resetHandler, boolean isS
208216 });
209217 }
210218
211- private void flushInEventLoop () {
212- channel .eventLoop ().execute (() -> {
213- channel .flush ();
214- registerConnectionReadTimeout (channel );
215- });
216- }
217-
218219 private void writeMessageInEventLoop (Message message , ResponseHandler handler , boolean flush ) {
219- channel .eventLoop ().execute (() -> {
220- messageDispatcher .enqueue (handler );
221-
222- if (flush ) {
223- channel .writeAndFlush (message ).addListener (future -> registerConnectionReadTimeout (channel ));
224- } else {
225- channel .write (message , channel .voidPromise ());
226- }
227- });
228- }
229-
230- private void writeMessagesInEventLoop (
231- Message message1 , ResponseHandler handler1 , Message message2 , ResponseHandler handler2 , boolean flush ) {
232- channel .eventLoop ().execute (() -> {
233- messageDispatcher .enqueue (handler1 );
234- messageDispatcher .enqueue (handler2 );
235-
236- channel .write (message1 , channel .voidPromise ());
237-
238- if (flush ) {
239- channel .writeAndFlush (message2 ).addListener (future -> registerConnectionReadTimeout (channel ));
240- } else {
241- channel .write (message2 , channel .voidPromise ());
242- }
243- });
220+ channel .eventLoop ()
221+ .execute (() -> terminationAwareStateLockingExecutor (message ).execute (causeOfTermination -> {
222+ if (causeOfTermination == null ) {
223+ messageDispatcher .enqueue (handler );
224+
225+ if (flush ) {
226+ channel .writeAndFlush (message )
227+ .addListener (future -> registerConnectionReadTimeout (channel ));
228+ } else {
229+ channel .write (message , channel .voidPromise ());
230+ }
231+ } else {
232+ handler .onFailure (causeOfTermination );
233+ }
234+ }));
244235 }
245236
246237 private void setAutoRead (boolean value ) {
247238 channel .config ().setAutoRead (value );
248239 }
249240
250- private boolean verifyOpen (ResponseHandler handler1 , ResponseHandler handler2 ) {
251- Status connectionStatus = this .status .get ();
252- switch (connectionStatus ) {
253- case OPEN :
254- return true ;
255- case RELEASED :
241+ private boolean verifyOpen (ResponseHandler handler ) {
242+ var connectionStatus = executeWithLock (lock , () -> status );
243+ return switch (connectionStatus ) {
244+ case OPEN -> true ;
245+ case RELEASED -> {
256246 Exception error =
257247 new IllegalStateException ("Connection has been released to the pool and can't be used" );
258- if (handler1 != null ) {
259- handler1 .onFailure (error );
248+ if (handler != null ) {
249+ handler .onFailure (error );
260250 }
261- if (handler2 != null ) {
262- handler2 .onFailure (error );
263- }
264- return false ;
265- case TERMINATED :
251+ yield false ;
252+ }
253+ case TERMINATED -> {
266254 Exception terminatedError =
267255 new IllegalStateException ("Connection has been terminated and can't be used" );
268- if (handler1 != null ) {
269- handler1 .onFailure (terminatedError );
270- }
271- if (handler2 != null ) {
272- handler2 .onFailure (terminatedError );
256+ if (handler != null ) {
257+ handler .onFailure (terminatedError );
273258 }
274- return false ;
275- default :
276- throw new IllegalStateException ("Unknown status: " + connectionStatus );
277- }
259+ yield false ;
260+ }
261+ };
278262 }
279263
280264 private void registerConnectionReadTimeout (Channel channel ) {
@@ -295,6 +279,25 @@ private void registerConnectionReadTimeout(Channel channel) {
295279 }
296280 }
297281
282+ private TerminationAwareStateLockingExecutor terminationAwareStateLockingExecutor (Message message ) {
283+ var result = (TerminationAwareStateLockingExecutor ) consumer -> consumer .accept (null );
284+ if (isQueryMessage (message )) {
285+ var lockingExecutor = executeWithLock (lock , () -> this .terminationAwareStateLockingExecutor );
286+ if (lockingExecutor != null ) {
287+ result = lockingExecutor ;
288+ }
289+ }
290+ return result ;
291+ }
292+
293+ private boolean isQueryMessage (Message message ) {
294+ return message instanceof RunWithMetadataMessage
295+ || message instanceof PullMessage
296+ || message instanceof PullAllMessage
297+ || message instanceof DiscardMessage
298+ || message instanceof DiscardAllMessage ;
299+ }
300+
298301 private enum Status {
299302 OPEN ,
300303 RELEASED ,
0 commit comments