diff --git a/MirrorSharp.AspNetCore/Internal/Middleware.cs b/MirrorSharp.AspNetCore/Internal/Middleware.cs index 9f4b8e51..7c10920f 100644 --- a/MirrorSharp.AspNetCore/Internal/Middleware.cs +++ b/MirrorSharp.AspNetCore/Internal/Middleware.cs @@ -1,4 +1,5 @@ -using System.Threading.Tasks; +using System.Threading; +using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.AspNetCore.Http; using MirrorSharp.Advanced; @@ -19,7 +20,7 @@ public async Task Invoke(HttpContext context) { } using (var socket = await context.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false)) { - await WebSocketLoopAsync(socket).ConfigureAwait(false); + await WebSocketLoopAsync(socket, CancellationToken.None).ConfigureAwait(false); } } } diff --git a/MirrorSharp.Common/Advanced/MiddlewareBase.cs b/MirrorSharp.Common/Advanced/MiddlewareBase.cs index f0d47d4e..5411a8f8 100644 --- a/MirrorSharp.Common/Advanced/MiddlewareBase.cs +++ b/MirrorSharp.Common/Advanced/MiddlewareBase.cs @@ -1,5 +1,6 @@ using System; using System.Net.WebSockets; +using System.Threading; using System.Threading.Tasks; using MirrorSharp.Internal; @@ -11,7 +12,7 @@ protected MiddlewareBase(MirrorSharpOptions options) { _options = options; } - protected async Task WebSocketLoopAsync(WebSocket socket) { + protected async Task WebSocketLoopAsync(WebSocket socket, CancellationToken cancellationToken) { WorkSession session = null; Connection connection = null; try { @@ -20,14 +21,14 @@ protected async Task WebSocketLoopAsync(WebSocket socket) { while (connection.IsConnected) { try { - await connection.ReceiveAndProcessAsync().ConfigureAwait(false); + await connection.ReceiveAndProcessAsync(cancellationToken).ConfigureAwait(false); } catch { // this is sent back by connection itself } } } - catch (Exception) when (connection == null && session != null) { + catch when (connection == null && session != null) { await session.DisposeAsync().ConfigureAwait(false); throw; } diff --git a/MirrorSharp.Common/Internal/Connection.cs b/MirrorSharp.Common/Internal/Connection.cs index 12090629..ce327577 100644 --- a/MirrorSharp.Common/Internal/Connection.cs +++ b/MirrorSharp.Common/Internal/Connection.cs @@ -42,13 +42,13 @@ public Connection(WebSocket socket, IWorkSession session, IConnectionOptions opt public bool IsConnected => _socket.State == WebSocketState.Open; - public async Task ReceiveAndProcessAsync() { + public async Task ReceiveAndProcessAsync(CancellationToken cancellationToken) { try { - await ReceiveAndProcessInternalAsync().ConfigureAwait(false); + await ReceiveAndProcessInternalAsync(cancellationToken).ConfigureAwait(false); } catch (Exception ex) { try { - await SendErrorAsync(ex.Message).ConfigureAwait(false); + await SendErrorAsync(ex.Message, cancellationToken).ConfigureAwait(false); } catch (Exception sendException) { throw new AggregateException(ex, sendException); @@ -57,20 +57,22 @@ public async Task ReceiveAndProcessAsync() { } } - private async Task ReceiveAndProcessInternalAsync() { - var received = await _socket.ReceiveAsync(new ArraySegment(_inputByteBuffer), CancellationToken.None).ConfigureAwait(false); + private async Task ReceiveAndProcessInternalAsync(CancellationToken cancellationToken) { + var received = await _socket.ReceiveAsync(new ArraySegment(_inputByteBuffer), cancellationToken).ConfigureAwait(false); if (received.MessageType == WebSocketMessageType.Binary) throw new FormatException("Expected text data (received binary)."); - if (received.MessageType == WebSocketMessageType.Close) + if (received.MessageType == WebSocketMessageType.Close) { + await _socket.CloseAsync(received.CloseStatus ?? WebSocketCloseStatus.Empty, received.CloseStatusDescription, cancellationToken).ConfigureAwait(false); return; + } - await ProcessMessageAsync(new ArraySegment(_inputByteBuffer, 0, received.Count)).ConfigureAwait(false); + await ProcessMessageAsync(new ArraySegment(_inputByteBuffer, 0, received.Count), cancellationToken).ConfigureAwait(false); if (_options.SendDebugCompareMessages) - await SendDebugCompareAsync(_inputByteBuffer[0]).ConfigureAwait(false); + await SendDebugCompareAsync(_inputByteBuffer[0], cancellationToken).ConfigureAwait(false); } - private Task ProcessMessageAsync(ArraySegment data) { + private Task ProcessMessageAsync(ArraySegment data, CancellationToken cancellationToken) { var command = data.Array[data.Offset]; switch (command) { case Commands.ReplaceProgress: @@ -82,9 +84,9 @@ private Task ProcessMessageAsync(ArraySegment data) { ProcessMoveCursor(Shift(data)); return Done; } - case Commands.TypeChar: return ProcessTypeCharAsync(Shift(data)); - case Commands.CommitCompletion: return ProcessCommitCompletionAsync(Shift(data)); - case Commands.SlowUpdate: return ProcessSlowUpdateAsync(); + case Commands.TypeChar: return ProcessTypeCharAsync(Shift(data), cancellationToken); + case Commands.CommitCompletion: return ProcessCommitCompletionAsync(Shift(data), cancellationToken); + case Commands.SlowUpdate: return ProcessSlowUpdateAsync(cancellationToken); default: throw new FormatException($"Unknown command: '{(char)command}'."); } } @@ -133,17 +135,17 @@ private void ProcessMoveCursor(ArraySegment data) { _session.MoveCursor(cursorPosition); } - private async Task ProcessTypeCharAsync(ArraySegment data) { + private async Task ProcessTypeCharAsync(ArraySegment data, CancellationToken cancellationToken) { var @char = FastConvert.Utf8ByteArrayToChar(data, _charBuffer); - var result = await _session.TypeCharAsync(@char).ConfigureAwait(false); + var result = await _session.TypeCharAsync(@char, cancellationToken).ConfigureAwait(false); if (result.Completions == null) return; - await SendTypeCharResultAsync(result).ConfigureAwait(false); + await SendTypeCharResultAsync(result, cancellationToken).ConfigureAwait(false); } - private Task SendTypeCharResultAsync(TypeCharResult result) { + private Task SendTypeCharResultAsync(TypeCharResult result, CancellationToken cancellationToken) { var completions = result.Completions; var writer = StartJsonMessage("completions"); @@ -168,16 +170,16 @@ private Task SendTypeCharResultAsync(TypeCharResult result) { } writer.WriteEndArray(); writer.WriteEndObject(); - return SendJsonMessageAsync(); + return SendJsonMessageAsync(cancellationToken); } - private async Task ProcessCommitCompletionAsync(ArraySegment data) { + private async Task ProcessCommitCompletionAsync(ArraySegment data, CancellationToken cancellationToken) { var itemIndex = FastConvert.Utf8ByteArrayToInt32(data); - var change = await _session.GetCompletionChangeAsync(itemIndex); - await SendCompletionChangeAsync(change).ConfigureAwait(false); + var change = await _session.GetCompletionChangeAsync(itemIndex, cancellationToken); + await SendCompletionChangeAsync(change, cancellationToken).ConfigureAwait(false); } - private Task SendCompletionChangeAsync(CompletionChange change) { + private Task SendCompletionChangeAsync(CompletionChange change, CancellationToken cancellationToken) { var writer = StartJsonMessage("changes"); writer.WritePropertyStartArray("changes"); foreach (var textChange in change.TextChanges) { @@ -188,15 +190,15 @@ private Task SendCompletionChangeAsync(CompletionChange change) { writer.WriteEndObject(); } writer.WriteEndArray(); - return SendJsonMessageAsync(); + return SendJsonMessageAsync(cancellationToken); } - private async Task ProcessSlowUpdateAsync() { - var update = await _session.GetSlowUpdateAsync().ConfigureAwait(false); - await SendSlowUpdateAsync(update).ConfigureAwait(false); + private async Task ProcessSlowUpdateAsync(CancellationToken cancellationToken) { + var update = await _session.GetSlowUpdateAsync(cancellationToken).ConfigureAwait(false); + await SendSlowUpdateAsync(update, cancellationToken).ConfigureAwait(false); } - private Task SendSlowUpdateAsync(SlowUpdateResult update) { + private Task SendSlowUpdateAsync(SlowUpdateResult update, CancellationToken cancellationToken) { var writer = StartJsonMessage("slowUpdate"); writer.WritePropertyStartArray("diagnostics"); foreach (var diagnostic in update.Diagnostics) { @@ -215,10 +217,10 @@ private Task SendSlowUpdateAsync(SlowUpdateResult update) { writer.WriteEndObject(); } writer.WriteEndArray(); - return SendJsonMessageAsync(); + return SendJsonMessageAsync(cancellationToken); } - private Task SendDebugCompareAsync(byte command) { + private Task SendDebugCompareAsync(byte command, CancellationToken cancellationToken) { if (command == Commands.CommitCompletion || command == Commands.SlowUpdate) // these cannot cause state changes return Done; @@ -229,13 +231,13 @@ private Task SendDebugCompareAsync(byte command) { if (command != Commands.MoveCursor) writer.WriteProperty("text", _session.SourceText.ToString()); writer.WriteProperty("cursor", _session.CursorPosition); - return SendJsonMessageAsync(); + return SendJsonMessageAsync(cancellationToken); } - private Task SendErrorAsync(string message) { + private Task SendErrorAsync(string message, CancellationToken cancellationToken) { var writer = StartJsonMessage("error"); writer.WriteProperty("message", message); - return SendJsonMessageAsync(); + return SendJsonMessageAsync(cancellationToken); } private JsonWriter StartJsonMessage(string messageType) { @@ -245,14 +247,13 @@ private JsonWriter StartJsonMessage(string messageType) { return _jsonWriter; } - private Task SendJsonMessageAsync() { + private Task SendJsonMessageAsync(CancellationToken cancellationToken) { _jsonWriter.WriteEndObject(); _jsonWriter.Flush(); - return SendOutputBufferAsync((int)_jsonOutputStream.Position); - } - - private Task SendOutputBufferAsync(int byteCount) { - return _socket.SendAsync(new ArraySegment(_outputByteBuffer, 0, byteCount), WebSocketMessageType.Text, true, CancellationToken.None); + return _socket.SendAsync( + new ArraySegment(_outputByteBuffer, 0, (int)_jsonOutputStream.Position), + WebSocketMessageType.Text, true, cancellationToken + ); } public Task DisposeAsync() => _session.DisposeAsync(); diff --git a/MirrorSharp.Common/Internal/IWorkSession.cs b/MirrorSharp.Common/Internal/IWorkSession.cs index 983de35f..5bf9f36b 100644 --- a/MirrorSharp.Common/Internal/IWorkSession.cs +++ b/MirrorSharp.Common/Internal/IWorkSession.cs @@ -1,6 +1,5 @@ -using System.Collections.Immutable; +using System.Threading; using System.Threading.Tasks; -using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Completion; using Microsoft.CodeAnalysis.Text; using MirrorSharp.Internal.Results; @@ -12,8 +11,8 @@ public interface IWorkSession : IAsyncDisposable { void ReplaceText(int start, int length, string newText, int cursorPositionAfter); void MoveCursor(int cursorPosition); - Task TypeCharAsync(char @char); - Task GetCompletionChangeAsync(int itemIndex); - Task GetSlowUpdateAsync(); + Task TypeCharAsync(char @char, CancellationToken cancellationToken); + Task GetCompletionChangeAsync(int itemIndex, CancellationToken cancellationToken); + Task GetSlowUpdateAsync(CancellationToken cancellationToken); } } \ No newline at end of file diff --git a/MirrorSharp.Common/Internal/WorkSession.cs b/MirrorSharp.Common/Internal/WorkSession.cs index a66acdb1..cbce9b80 100644 --- a/MirrorSharp.Common/Internal/WorkSession.cs +++ b/MirrorSharp.Common/Internal/WorkSession.cs @@ -30,8 +30,6 @@ public class WorkSession : IWorkSession { private CompletionList _completionList; private readonly CompletionService _completionService; - - //private readonly Task _compilationLoopTask; private readonly CancellationTokenSource _disposing; private static readonly ImmutableList DefaultAssemblyReferences = ImmutableList.Create( @@ -85,26 +83,26 @@ public void MoveCursor(int cursorPosition) { _cursorPosition = cursorPosition; } - public Task TypeCharAsync(char @char) { + public Task TypeCharAsync(char @char, CancellationToken cancellationToken) { ReplaceText(_cursorPosition, 0, FastConvert.CharToString(@char), _cursorPosition + 1); if (!_completionService.ShouldTriggerCompletion(_sourceText, _cursorPosition, CompletionTrigger.CreateInsertionTrigger(@char))) return TypeCharEmptyResultTask; - return CreateResultFromCompletionsAsync(); + return CreateResultFromCompletionsAsync(cancellationToken); } - public Task GetCompletionChangeAsync(int itemIndex) { + public Task GetCompletionChangeAsync(int itemIndex, CancellationToken cancellationToken) { var item = _completionList.Items[itemIndex]; - return _completionService.GetChangeAsync(_document, item); + return _completionService.GetChangeAsync(_document, item, cancellationToken: cancellationToken); } - public async Task GetSlowUpdateAsync() { - var compilation = await _document.Project.GetCompilationAsync(); - var diagnostics = await compilation.WithAnalyzers(_analyzers).GetAllDiagnosticsAsync(); + public async Task GetSlowUpdateAsync(CancellationToken cancellationToken) { + var compilation = await _document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false); + var diagnostics = await compilation.WithAnalyzers(_analyzers).GetAllDiagnosticsAsync(cancellationToken).ConfigureAwait(false); return new SlowUpdateResult(diagnostics); } - private async Task CreateResultFromCompletionsAsync() { - _completionList = await _completionService.GetCompletionsAsync(_document, _cursorPosition).ConfigureAwait(false); + private async Task CreateResultFromCompletionsAsync(CancellationToken cancellationToken) { + _completionList = await _completionService.GetCompletionsAsync(_document, _cursorPosition, cancellationToken: cancellationToken).ConfigureAwait(false); return new TypeCharResult(_completionList); } diff --git a/MirrorSharp.Owin/Internal/Middleware.cs b/MirrorSharp.Owin/Internal/Middleware.cs index 5b1d404e..0c92b1b5 100644 --- a/MirrorSharp.Owin/Internal/Middleware.cs +++ b/MirrorSharp.Owin/Internal/Middleware.cs @@ -1,9 +1,11 @@ using System; using System.Collections.Generic; using System.Net.WebSockets; +using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using MirrorSharp.Advanced; +using MirrorSharp.Internal; namespace MirrorSharp.Owin.Internal { using AppFunc = Func, Task>; @@ -37,9 +39,13 @@ public Task Invoke(IDictionary environment) { ); } - using (context.WebSocket) { - await WebSocketLoopAsync(context.WebSocket).ConfigureAwait(false); - } + var callCancelled = (CancellationToken)e["websocket.CallCancelled"]; + // there is a weird issue where a socket never gets closed (deadlock?) + // if the loop is done in the standard ASP.NET thread + await Task.Run( + () => WebSocketLoopAsync(context.WebSocket, callCancelled), + callCancelled + ); }); return Done; } diff --git a/MirrorSharp.Tests/ConnectionTests.cs b/MirrorSharp.Tests/ConnectionTests.cs index 046236bb..5ff2882b 100644 --- a/MirrorSharp.Tests/ConnectionTests.cs +++ b/MirrorSharp.Tests/ConnectionTests.cs @@ -5,7 +5,6 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.Classification; using Microsoft.CodeAnalysis.Completion; using Microsoft.CodeAnalysis.Text; using MirrorSharp.Internal; @@ -27,7 +26,7 @@ public async void ReceiveAndProcessAsync_CallsMoveCursorOnSession_AfterReceiving SetupReceive(socketMock, command); var sessionMock = Mock.Of(); - await new Connection(socketMock, sessionMock).ReceiveAndProcessAsync(); + await new Connection(socketMock, sessionMock).ReceiveAndProcessAsync(CancellationToken.None); Mock.Get(sessionMock).Verify(s => s.MoveCursor(expectedPosition)); } @@ -40,9 +39,10 @@ public async void ReceiveAndProcessAsync_CallsTypeCharAsyncOnSession_AfterReceiv var socketMock = Mock.Of(); SetupReceive(socketMock, command); var sessionMock = Mock.Of(); + var cancellationToken = new CancellationTokenSource().Token; - await new Connection(socketMock, sessionMock).ReceiveAndProcessAsync(); - Mock.Get(sessionMock).Verify(s => s.TypeCharAsync(expectedChar)); + await new Connection(socketMock, sessionMock).ReceiveAndProcessAsync(cancellationToken); + Mock.Get(sessionMock).Verify(s => s.TypeCharAsync(expectedChar, cancellationToken)); } [Theory] @@ -54,7 +54,7 @@ public async void ReceiveAndProcessAsync_CallsReplaceTextOnSession_AfterReceivin SetupReceive(socketMock, command); var sessionMock = Mock.Of(); - await new Connection(socketMock, sessionMock).ReceiveAndProcessAsync(); + await new Connection(socketMock, sessionMock).ReceiveAndProcessAsync(CancellationToken.None); Mock.Get(sessionMock).Verify(s => s.ReplaceText(expectedStart, expectedLength, expectedText, expectedPosition)); } @@ -66,11 +66,12 @@ public async void ReceiveAndProcessAsync_CallsGetCompletionChangeAsyncOnSession_ var socketMock = Mock.Of(); SetupReceive(socketMock, command); var sessionMock = Mock.Of( - s => s.GetCompletionChangeAsync(It.IsAny()) == Task.FromResult(NoCompletionChange) + s => s.GetCompletionChangeAsync(It.IsAny(), It.IsAny()) == Task.FromResult(NoCompletionChange) ); + var cancellationToken = new CancellationTokenSource().Token; - await new Connection(socketMock, sessionMock).ReceiveAndProcessAsync(); - Mock.Get(sessionMock).Verify(s => s.GetCompletionChangeAsync(expectedItemIndex)); + await new Connection(socketMock, sessionMock).ReceiveAndProcessAsync(cancellationToken); + Mock.Get(sessionMock).Verify(s => s.GetCompletionChangeAsync(expectedItemIndex, cancellationToken)); } [Fact] @@ -78,11 +79,12 @@ public async void ReceiveAndProcessAsync_CallsGetSlowUpdateAsyncOnSession_AfterR var socketMock = Mock.Of(); SetupReceive(socketMock, "U"); var sessionMock = Mock.Of( - s => s.GetSlowUpdateAsync() == Task.FromResult(NoSlowUpdate) + s => s.GetSlowUpdateAsync(It.IsAny()) == Task.FromResult(NoSlowUpdate) ); + var cancellationToken = new CancellationTokenSource().Token; - await new Connection(socketMock, sessionMock).ReceiveAndProcessAsync(); - Mock.Get(sessionMock).Verify(s => s.GetSlowUpdateAsync()); + await new Connection(socketMock, sessionMock).ReceiveAndProcessAsync(cancellationToken); + Mock.Get(sessionMock).Verify(s => s.GetSlowUpdateAsync(cancellationToken)); } private static void SetupReceive(WebSocket socket, string command) { diff --git a/MirrorSharp.Tests/SessionTests.cs b/MirrorSharp.Tests/SessionTests.cs index 1869fc56..ec36dab3 100644 --- a/MirrorSharp.Tests/SessionTests.cs +++ b/MirrorSharp.Tests/SessionTests.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using MirrorSharp.Internal; @@ -19,7 +20,7 @@ public class SessionTests { public async Task TypeChar_InsertsSingleChar() { var session = SessionFromTextWithCursor("class A| {}"); - await session.TypeCharAsync('1'); + await session.TypeCharAsync('1', CancellationToken.None); Assert.Equal("class A1 {}", session.SourceText.ToString()); } @@ -29,7 +30,7 @@ public async Task TypeChar_MovesCursorBySingleChar() { var session = SessionFromTextWithCursor("class A| {}"); var cursorPosition = session.CursorPosition; - await session.TypeCharAsync('1'); + await session.TypeCharAsync('1', CancellationToken.None); Assert.Equal(cursorPosition + 1, session.CursorPosition); } @@ -41,7 +42,7 @@ class A { public int x; } class B { void M(A a) { a| } } "); - var result = await session.TypeCharAsync('.'); + var result = await session.TypeCharAsync('.', CancellationToken.None); Assert.Equal( new[] { "x" }.Concat(ObjectMemberNames).OrderBy(n => n), @@ -52,14 +53,14 @@ class B { void M(A a) { a| } } [Fact] public async Task SlowUpdate_ProducesDiagnosticWithCustomTagUnnecessary_ForUnusedNamespace() { var session = SessionFromTextWithCursor(@"using System;|"); - var result = await session.GetSlowUpdateAsync(); + var result = await session.GetSlowUpdateAsync(CancellationToken.None); - /*Assert.Contains( + Assert.Contains( new { Severity = DiagnosticSeverity.Hidden, IsUnnecessary = true }, result.Diagnostics.Select( d => new { d.Severity, IsUnnecessary = d.Descriptor.CustomTags.Contains(WellKnownDiagnosticTags.Unnecessary) } ).ToArray() - );*/ + ); } private WorkSession SessionFromTextWithCursor(string textWithCursor) { diff --git a/MirrorSharp.sln.DotSettings b/MirrorSharp.sln.DotSettings index bc74da31..55e76611 100644 --- a/MirrorSharp.sln.DotSettings +++ b/MirrorSharp.sln.DotSettings @@ -13,6 +13,7 @@ END_OF_LINE False END_OF_LINE + False <Policy Inspect="True" Prefix="" Suffix="" Style="aaBb"><ExtraRule Prefix="" Suffix="" Style="AaBb" /></Policy> <Policy Inspect="True" Prefix="" Suffix="" Style="aaBb"><ExtraRule Prefix="" Suffix="" Style="AaBb" /></Policy> <Policy Inspect="True" Prefix="" Suffix="" Style="AaBb" />