Skip to content

Commit

Permalink
Corrected issue where WebSockets weren't closed properly (under Owin).
Browse files Browse the repository at this point in the history
  • Loading branch information
ashmind committed Sep 18, 2016
1 parent 6a5fbea commit 879bafd
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 78 deletions.
5 changes: 3 additions & 2 deletions MirrorSharp.AspNetCore/Internal/Middleware.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Threading.Tasks;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;
using Microsoft.AspNetCore.Http;
using MirrorSharp.Advanced;
Expand All @@ -19,7 +20,7 @@ public async Task Invoke(HttpContext context) {
}

using (var socket = await context.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false)) {
await WebSocketLoopAsync(socket).ConfigureAwait(false);
await WebSocketLoopAsync(socket, CancellationToken.None).ConfigureAwait(false);
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions MirrorSharp.Common/Advanced/MiddlewareBase.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
using MirrorSharp.Internal;

Expand All @@ -11,7 +12,7 @@ protected MiddlewareBase(MirrorSharpOptions options) {
_options = options;
}

protected async Task WebSocketLoopAsync(WebSocket socket) {
protected async Task WebSocketLoopAsync(WebSocket socket, CancellationToken cancellationToken) {
WorkSession session = null;
Connection connection = null;
try {
Expand All @@ -20,14 +21,14 @@ protected async Task WebSocketLoopAsync(WebSocket socket) {

while (connection.IsConnected) {
try {
await connection.ReceiveAndProcessAsync().ConfigureAwait(false);
await connection.ReceiveAndProcessAsync(cancellationToken).ConfigureAwait(false);
}
catch {
// this is sent back by connection itself
}
}
}
catch (Exception) when (connection == null && session != null) {
catch when (connection == null && session != null) {
await session.DisposeAsync().ConfigureAwait(false);
throw;
}
Expand Down
75 changes: 38 additions & 37 deletions MirrorSharp.Common/Internal/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ public Connection(WebSocket socket, IWorkSession session, IConnectionOptions opt

public bool IsConnected => _socket.State == WebSocketState.Open;

public async Task ReceiveAndProcessAsync() {
public async Task ReceiveAndProcessAsync(CancellationToken cancellationToken) {
try {
await ReceiveAndProcessInternalAsync().ConfigureAwait(false);
await ReceiveAndProcessInternalAsync(cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) {
try {
await SendErrorAsync(ex.Message).ConfigureAwait(false);
await SendErrorAsync(ex.Message, cancellationToken).ConfigureAwait(false);
}
catch (Exception sendException) {
throw new AggregateException(ex, sendException);
Expand All @@ -57,20 +57,22 @@ public async Task ReceiveAndProcessAsync() {
}
}

private async Task ReceiveAndProcessInternalAsync() {
var received = await _socket.ReceiveAsync(new ArraySegment<byte>(_inputByteBuffer), CancellationToken.None).ConfigureAwait(false);
private async Task ReceiveAndProcessInternalAsync(CancellationToken cancellationToken) {
var received = await _socket.ReceiveAsync(new ArraySegment<byte>(_inputByteBuffer), cancellationToken).ConfigureAwait(false);
if (received.MessageType == WebSocketMessageType.Binary)
throw new FormatException("Expected text data (received binary).");

if (received.MessageType == WebSocketMessageType.Close)
if (received.MessageType == WebSocketMessageType.Close) {
await _socket.CloseAsync(received.CloseStatus ?? WebSocketCloseStatus.Empty, received.CloseStatusDescription, cancellationToken).ConfigureAwait(false);
return;
}

await ProcessMessageAsync(new ArraySegment<byte>(_inputByteBuffer, 0, received.Count)).ConfigureAwait(false);
await ProcessMessageAsync(new ArraySegment<byte>(_inputByteBuffer, 0, received.Count), cancellationToken).ConfigureAwait(false);
if (_options.SendDebugCompareMessages)
await SendDebugCompareAsync(_inputByteBuffer[0]).ConfigureAwait(false);
await SendDebugCompareAsync(_inputByteBuffer[0], cancellationToken).ConfigureAwait(false);
}

private Task ProcessMessageAsync(ArraySegment<byte> data) {
private Task ProcessMessageAsync(ArraySegment<byte> data, CancellationToken cancellationToken) {
var command = data.Array[data.Offset];
switch (command) {
case Commands.ReplaceProgress:
Expand All @@ -82,9 +84,9 @@ private Task ProcessMessageAsync(ArraySegment<byte> data) {
ProcessMoveCursor(Shift(data));
return Done;
}
case Commands.TypeChar: return ProcessTypeCharAsync(Shift(data));
case Commands.CommitCompletion: return ProcessCommitCompletionAsync(Shift(data));
case Commands.SlowUpdate: return ProcessSlowUpdateAsync();
case Commands.TypeChar: return ProcessTypeCharAsync(Shift(data), cancellationToken);
case Commands.CommitCompletion: return ProcessCommitCompletionAsync(Shift(data), cancellationToken);
case Commands.SlowUpdate: return ProcessSlowUpdateAsync(cancellationToken);
default: throw new FormatException($"Unknown command: '{(char)command}'.");
}
}
Expand Down Expand Up @@ -133,17 +135,17 @@ private void ProcessMoveCursor(ArraySegment<byte> data) {
_session.MoveCursor(cursorPosition);
}

private async Task ProcessTypeCharAsync(ArraySegment<byte> data) {
private async Task ProcessTypeCharAsync(ArraySegment<byte> data, CancellationToken cancellationToken) {
var @char = FastConvert.Utf8ByteArrayToChar(data, _charBuffer);

var result = await _session.TypeCharAsync(@char).ConfigureAwait(false);
var result = await _session.TypeCharAsync(@char, cancellationToken).ConfigureAwait(false);
if (result.Completions == null)
return;

await SendTypeCharResultAsync(result).ConfigureAwait(false);
await SendTypeCharResultAsync(result, cancellationToken).ConfigureAwait(false);
}

private Task SendTypeCharResultAsync(TypeCharResult result) {
private Task SendTypeCharResultAsync(TypeCharResult result, CancellationToken cancellationToken) {
var completions = result.Completions;

var writer = StartJsonMessage("completions");
Expand All @@ -168,16 +170,16 @@ private Task SendTypeCharResultAsync(TypeCharResult result) {
}
writer.WriteEndArray();
writer.WriteEndObject();
return SendJsonMessageAsync();
return SendJsonMessageAsync(cancellationToken);
}

private async Task ProcessCommitCompletionAsync(ArraySegment<byte> data) {
private async Task ProcessCommitCompletionAsync(ArraySegment<byte> data, CancellationToken cancellationToken) {
var itemIndex = FastConvert.Utf8ByteArrayToInt32(data);
var change = await _session.GetCompletionChangeAsync(itemIndex);
await SendCompletionChangeAsync(change).ConfigureAwait(false);
var change = await _session.GetCompletionChangeAsync(itemIndex, cancellationToken);
await SendCompletionChangeAsync(change, cancellationToken).ConfigureAwait(false);
}

private Task SendCompletionChangeAsync(CompletionChange change) {
private Task SendCompletionChangeAsync(CompletionChange change, CancellationToken cancellationToken) {
var writer = StartJsonMessage("changes");
writer.WritePropertyStartArray("changes");
foreach (var textChange in change.TextChanges) {
Expand All @@ -188,15 +190,15 @@ private Task SendCompletionChangeAsync(CompletionChange change) {
writer.WriteEndObject();
}
writer.WriteEndArray();
return SendJsonMessageAsync();
return SendJsonMessageAsync(cancellationToken);
}

private async Task ProcessSlowUpdateAsync() {
var update = await _session.GetSlowUpdateAsync().ConfigureAwait(false);
await SendSlowUpdateAsync(update).ConfigureAwait(false);
private async Task ProcessSlowUpdateAsync(CancellationToken cancellationToken) {
var update = await _session.GetSlowUpdateAsync(cancellationToken).ConfigureAwait(false);
await SendSlowUpdateAsync(update, cancellationToken).ConfigureAwait(false);
}

private Task SendSlowUpdateAsync(SlowUpdateResult update) {
private Task SendSlowUpdateAsync(SlowUpdateResult update, CancellationToken cancellationToken) {
var writer = StartJsonMessage("slowUpdate");
writer.WritePropertyStartArray("diagnostics");
foreach (var diagnostic in update.Diagnostics) {
Expand All @@ -215,10 +217,10 @@ private Task SendSlowUpdateAsync(SlowUpdateResult update) {
writer.WriteEndObject();
}
writer.WriteEndArray();
return SendJsonMessageAsync();
return SendJsonMessageAsync(cancellationToken);
}

private Task SendDebugCompareAsync(byte command) {
private Task SendDebugCompareAsync(byte command, CancellationToken cancellationToken) {
if (command == Commands.CommitCompletion || command == Commands.SlowUpdate) // these cannot cause state changes
return Done;

Expand All @@ -229,13 +231,13 @@ private Task SendDebugCompareAsync(byte command) {
if (command != Commands.MoveCursor)
writer.WriteProperty("text", _session.SourceText.ToString());
writer.WriteProperty("cursor", _session.CursorPosition);
return SendJsonMessageAsync();
return SendJsonMessageAsync(cancellationToken);
}

private Task SendErrorAsync(string message) {
private Task SendErrorAsync(string message, CancellationToken cancellationToken) {
var writer = StartJsonMessage("error");
writer.WriteProperty("message", message);
return SendJsonMessageAsync();
return SendJsonMessageAsync(cancellationToken);
}

private JsonWriter StartJsonMessage(string messageType) {
Expand All @@ -245,14 +247,13 @@ private JsonWriter StartJsonMessage(string messageType) {
return _jsonWriter;
}

private Task SendJsonMessageAsync() {
private Task SendJsonMessageAsync(CancellationToken cancellationToken) {
_jsonWriter.WriteEndObject();
_jsonWriter.Flush();
return SendOutputBufferAsync((int)_jsonOutputStream.Position);
}

private Task SendOutputBufferAsync(int byteCount) {
return _socket.SendAsync(new ArraySegment<byte>(_outputByteBuffer, 0, byteCount), WebSocketMessageType.Text, true, CancellationToken.None);
return _socket.SendAsync(
new ArraySegment<byte>(_outputByteBuffer, 0, (int)_jsonOutputStream.Position),
WebSocketMessageType.Text, true, cancellationToken
);
}

public Task DisposeAsync() => _session.DisposeAsync();
Expand Down
9 changes: 4 additions & 5 deletions MirrorSharp.Common/Internal/IWorkSession.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Collections.Immutable;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Completion;
using Microsoft.CodeAnalysis.Text;
using MirrorSharp.Internal.Results;
Expand All @@ -12,8 +11,8 @@ public interface IWorkSession : IAsyncDisposable {

void ReplaceText(int start, int length, string newText, int cursorPositionAfter);
void MoveCursor(int cursorPosition);
Task<TypeCharResult> TypeCharAsync(char @char);
Task<CompletionChange> GetCompletionChangeAsync(int itemIndex);
Task<SlowUpdateResult> GetSlowUpdateAsync();
Task<TypeCharResult> TypeCharAsync(char @char, CancellationToken cancellationToken);
Task<CompletionChange> GetCompletionChangeAsync(int itemIndex, CancellationToken cancellationToken);
Task<SlowUpdateResult> GetSlowUpdateAsync(CancellationToken cancellationToken);
}
}
20 changes: 9 additions & 11 deletions MirrorSharp.Common/Internal/WorkSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ public class WorkSession : IWorkSession {
private CompletionList _completionList;

private readonly CompletionService _completionService;

//private readonly Task _compilationLoopTask;
private readonly CancellationTokenSource _disposing;

private static readonly ImmutableList<MetadataReference> DefaultAssemblyReferences = ImmutableList.Create<MetadataReference>(
Expand Down Expand Up @@ -85,26 +83,26 @@ public void MoveCursor(int cursorPosition) {
_cursorPosition = cursorPosition;
}

public Task<TypeCharResult> TypeCharAsync(char @char) {
public Task<TypeCharResult> TypeCharAsync(char @char, CancellationToken cancellationToken) {
ReplaceText(_cursorPosition, 0, FastConvert.CharToString(@char), _cursorPosition + 1);
if (!_completionService.ShouldTriggerCompletion(_sourceText, _cursorPosition, CompletionTrigger.CreateInsertionTrigger(@char)))
return TypeCharEmptyResultTask;
return CreateResultFromCompletionsAsync();
return CreateResultFromCompletionsAsync(cancellationToken);
}

public Task<CompletionChange> GetCompletionChangeAsync(int itemIndex) {
public Task<CompletionChange> GetCompletionChangeAsync(int itemIndex, CancellationToken cancellationToken) {
var item = _completionList.Items[itemIndex];
return _completionService.GetChangeAsync(_document, item);
return _completionService.GetChangeAsync(_document, item, cancellationToken: cancellationToken);
}

public async Task<SlowUpdateResult> GetSlowUpdateAsync() {
var compilation = await _document.Project.GetCompilationAsync();
var diagnostics = await compilation.WithAnalyzers(_analyzers).GetAllDiagnosticsAsync();
public async Task<SlowUpdateResult> GetSlowUpdateAsync(CancellationToken cancellationToken) {
var compilation = await _document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
var diagnostics = await compilation.WithAnalyzers(_analyzers).GetAllDiagnosticsAsync(cancellationToken).ConfigureAwait(false);
return new SlowUpdateResult(diagnostics);
}

private async Task<TypeCharResult> CreateResultFromCompletionsAsync() {
_completionList = await _completionService.GetCompletionsAsync(_document, _cursorPosition).ConfigureAwait(false);
private async Task<TypeCharResult> CreateResultFromCompletionsAsync(CancellationToken cancellationToken) {
_completionList = await _completionService.GetCompletionsAsync(_document, _cursorPosition, cancellationToken: cancellationToken).ConfigureAwait(false);
return new TypeCharResult(_completionList);
}

Expand Down
12 changes: 9 additions & 3 deletions MirrorSharp.Owin/Internal/Middleware.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using System;
using System.Collections.Generic;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;
using MirrorSharp.Advanced;
using MirrorSharp.Internal;

namespace MirrorSharp.Owin.Internal {
using AppFunc = Func<IDictionary<string, object>, Task>;
Expand Down Expand Up @@ -37,9 +39,13 @@ public Task Invoke(IDictionary<string, object> environment) {
);
}

using (context.WebSocket) {
await WebSocketLoopAsync(context.WebSocket).ConfigureAwait(false);
}
var callCancelled = (CancellationToken)e["websocket.CallCancelled"];
// there is a weird issue where a socket never gets closed (deadlock?)
// if the loop is done in the standard ASP.NET thread
await Task.Run(
() => WebSocketLoopAsync(context.WebSocket, callCancelled),
callCancelled
);
});
return Done;
}
Expand Down
Loading

0 comments on commit 879bafd

Please sign in to comment.