Skip to content

Commit

Permalink
Fixed Default Source Stream Error Behaviour (#6170)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib authored May 19, 2023
1 parent eadbb47 commit 35a4119
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ public interface ISourceStream : IAsyncDisposable
/// <summary>
/// Reads the subscription result from the pub/sub system.
/// </summary>
IAsyncEnumerable<object> ReadEventsAsync();
IAsyncEnumerable<object?> ReadEventsAsync();
}
108 changes: 90 additions & 18 deletions src/HotChocolate/Core/src/Subscriptions/DefaultSourceStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public IAsyncEnumerable<TMessage> ReadEventsAsync()
=> new MessageEnumerable(_outgoing.Reader, _completed);

/// <inheritdoc />
IAsyncEnumerable<object> ISourceStream.ReadEventsAsync()
IAsyncEnumerable<object?> ISourceStream.ReadEventsAsync()
=> new MessageEnumerableAsObject(_outgoing.Reader, _completed);

/// <inheritdoc />
Expand All @@ -56,17 +56,41 @@ public MessageEnumerable(
_completed = completed;
}

public async IAsyncEnumerator<TMessage> GetAsyncEnumerator(
public IAsyncEnumerator<TMessage> GetAsyncEnumerator(
CancellationToken cancellationToken)
=> new MessageEnumerator(_reader, _completed, cancellationToken);
}

private sealed class MessageEnumerator : IAsyncEnumerator<TMessage>
{
private readonly ChannelReader<TMessage> _reader;
private readonly TaskCompletionSource<bool> _completed;
private readonly CancellationToken _cancellationToken;

public MessageEnumerator(
ChannelReader<TMessage> reader,
TaskCompletionSource<bool> completed,
CancellationToken cancellationToken)
{
while (!_reader.Completion.IsCompleted)
_reader = reader;
_completed = completed;
_cancellationToken = cancellationToken;
}

public TMessage Current { get; private set; } = default!;

public async ValueTask<bool> MoveNextAsync()
{
try
{
if (_reader.TryRead(out var message))
{
yield return message;
}
else
while (!_reader.Completion.IsCompleted)
{
if (_reader.TryRead(out var message))
{
Current = message;
return true;
}

if (_completed.Task.IsCompleted)
{
break;
Expand All @@ -81,13 +105,25 @@ await Task.WhenAny(_completed.Task, WaitForMessages())
}
}
}
catch
{
// ignore errors
}

return false;

async Task WaitForMessages()
=> await _reader.WaitToReadAsync(cancellationToken);
=> await _reader.WaitToReadAsync(_cancellationToken);
}

public ValueTask DisposeAsync()
{
_completed.TrySetCanceled();
return default;
}
}

private sealed class MessageEnumerableAsObject : IAsyncEnumerable<object>
private sealed class MessageEnumerableAsObject : IAsyncEnumerable<object?>
{
private readonly ChannelReader<TMessage> _reader;
private readonly TaskCompletionSource<bool> _completed;
Expand All @@ -100,17 +136,41 @@ public MessageEnumerableAsObject(
_completed = completed;
}

public async IAsyncEnumerator<object> GetAsyncEnumerator(
public IAsyncEnumerator<object?> GetAsyncEnumerator(
CancellationToken cancellationToken)
=> new MessageEnumeratorAsObject(_reader, _completed, cancellationToken);
}

private sealed class MessageEnumeratorAsObject : IAsyncEnumerator<object?>
{
private readonly ChannelReader<TMessage> _reader;
private readonly TaskCompletionSource<bool> _completed;
private readonly CancellationToken _cancellationToken;

public MessageEnumeratorAsObject(
ChannelReader<TMessage> reader,
TaskCompletionSource<bool> completed,
CancellationToken cancellationToken)
{
while (!_reader.Completion.IsCompleted)
_reader = reader;
_completed = completed;
_cancellationToken = cancellationToken;
}

public object? Current { get; private set; }

public async ValueTask<bool> MoveNextAsync()
{
try
{
if (_reader.TryRead(out var message))
{
yield return message!;
}
else
while (!_reader.Completion.IsCompleted)
{
if (_reader.TryRead(out var message))
{
Current = message;
return true;
}

if (_completed.Task.IsCompleted)
{
break;
Expand All @@ -125,9 +185,21 @@ await Task.WhenAny(_completed.Task, WaitForMessages())
}
}
}
catch
{
// ignore errors
}

return false;

async Task WaitForMessages()
=> await _reader.WaitToReadAsync(cancellationToken);
=> await _reader.WaitToReadAsync(_cancellationToken);
}

public ValueTask DisposeAsync()
{
_completed.TrySetCanceled();
return default;
}
}
}

0 comments on commit 35a4119

Please sign in to comment.