From 35a4119c9527c4338ffa690190172dbb62806b4b Mon Sep 17 00:00:00 2001 From: Michael Staib Date: Fri, 19 May 2023 16:44:38 +0200 Subject: [PATCH] Fixed Default Source Stream Error Behaviour (#6170) --- .../Abstractions/Execution/ISourceStream.cs | 2 +- .../src/Subscriptions/DefaultSourceStream.cs | 108 +++++++++++++++--- 2 files changed, 91 insertions(+), 19 deletions(-) diff --git a/src/HotChocolate/Core/src/Abstractions/Execution/ISourceStream.cs b/src/HotChocolate/Core/src/Abstractions/Execution/ISourceStream.cs index 35e8360eaf9..4987c617673 100644 --- a/src/HotChocolate/Core/src/Abstractions/Execution/ISourceStream.cs +++ b/src/HotChocolate/Core/src/Abstractions/Execution/ISourceStream.cs @@ -13,5 +13,5 @@ public interface ISourceStream : IAsyncDisposable /// /// Reads the subscription result from the pub/sub system. /// - IAsyncEnumerable ReadEventsAsync(); + IAsyncEnumerable ReadEventsAsync(); } diff --git a/src/HotChocolate/Core/src/Subscriptions/DefaultSourceStream.cs b/src/HotChocolate/Core/src/Subscriptions/DefaultSourceStream.cs index 74072d74189..d14b4f26290 100644 --- a/src/HotChocolate/Core/src/Subscriptions/DefaultSourceStream.cs +++ b/src/HotChocolate/Core/src/Subscriptions/DefaultSourceStream.cs @@ -30,7 +30,7 @@ public IAsyncEnumerable ReadEventsAsync() => new MessageEnumerable(_outgoing.Reader, _completed); /// - IAsyncEnumerable ISourceStream.ReadEventsAsync() + IAsyncEnumerable ISourceStream.ReadEventsAsync() => new MessageEnumerableAsObject(_outgoing.Reader, _completed); /// @@ -56,17 +56,41 @@ public MessageEnumerable( _completed = completed; } - public async IAsyncEnumerator GetAsyncEnumerator( + public IAsyncEnumerator GetAsyncEnumerator( + CancellationToken cancellationToken) + => new MessageEnumerator(_reader, _completed, cancellationToken); + } + + private sealed class MessageEnumerator : IAsyncEnumerator + { + private readonly ChannelReader _reader; + private readonly TaskCompletionSource _completed; + private readonly CancellationToken _cancellationToken; + + public MessageEnumerator( + ChannelReader reader, + TaskCompletionSource completed, CancellationToken cancellationToken) { - while (!_reader.Completion.IsCompleted) + _reader = reader; + _completed = completed; + _cancellationToken = cancellationToken; + } + + public TMessage Current { get; private set; } = default!; + + public async ValueTask MoveNextAsync() + { + try { - if (_reader.TryRead(out var message)) - { - yield return message; - } - else + while (!_reader.Completion.IsCompleted) { + if (_reader.TryRead(out var message)) + { + Current = message; + return true; + } + if (_completed.Task.IsCompleted) { break; @@ -81,13 +105,25 @@ await Task.WhenAny(_completed.Task, WaitForMessages()) } } } + catch + { + // ignore errors + } + + return false; async Task WaitForMessages() - => await _reader.WaitToReadAsync(cancellationToken); + => await _reader.WaitToReadAsync(_cancellationToken); + } + + public ValueTask DisposeAsync() + { + _completed.TrySetCanceled(); + return default; } } - private sealed class MessageEnumerableAsObject : IAsyncEnumerable + private sealed class MessageEnumerableAsObject : IAsyncEnumerable { private readonly ChannelReader _reader; private readonly TaskCompletionSource _completed; @@ -100,17 +136,41 @@ public MessageEnumerableAsObject( _completed = completed; } - public async IAsyncEnumerator GetAsyncEnumerator( + public IAsyncEnumerator GetAsyncEnumerator( + CancellationToken cancellationToken) + => new MessageEnumeratorAsObject(_reader, _completed, cancellationToken); + } + + private sealed class MessageEnumeratorAsObject : IAsyncEnumerator + { + private readonly ChannelReader _reader; + private readonly TaskCompletionSource _completed; + private readonly CancellationToken _cancellationToken; + + public MessageEnumeratorAsObject( + ChannelReader reader, + TaskCompletionSource completed, CancellationToken cancellationToken) { - while (!_reader.Completion.IsCompleted) + _reader = reader; + _completed = completed; + _cancellationToken = cancellationToken; + } + + public object? Current { get; private set; } + + public async ValueTask MoveNextAsync() + { + try { - if (_reader.TryRead(out var message)) - { - yield return message!; - } - else + while (!_reader.Completion.IsCompleted) { + if (_reader.TryRead(out var message)) + { + Current = message; + return true; + } + if (_completed.Task.IsCompleted) { break; @@ -125,9 +185,21 @@ await Task.WhenAny(_completed.Task, WaitForMessages()) } } } + catch + { + // ignore errors + } + + return false; async Task WaitForMessages() - => await _reader.WaitToReadAsync(cancellationToken); + => await _reader.WaitToReadAsync(_cancellationToken); + } + + public ValueTask DisposeAsync() + { + _completed.TrySetCanceled(); + return default; } } }