Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.

Several Stream.Read/WriteAsync improvements #2724

Merged
merged 4 commits into from
Jan 19, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 99 additions & 76 deletions src/mscorlib/src/System/IO/Stream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,13 @@ protected virtual WaitHandle CreateWaitHandle()
public virtual IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, Object state)
{
Contract.Ensures(Contract.Result<IAsyncResult>() != null);
return BeginReadInternal(buffer, offset, count, callback, state, serializeAsynchronously: false);
return BeginReadInternal(buffer, offset, count, callback, state, serializeAsynchronously: false, apm: true);
}

[HostProtection(ExternalThreading = true)]
internal IAsyncResult BeginReadInternal(byte[] buffer, int offset, int count, AsyncCallback callback, Object state, bool serializeAsynchronously)
internal IAsyncResult BeginReadInternal(
byte[] buffer, int offset, int count, AsyncCallback callback, Object state,
bool serializeAsynchronously, bool apm)
{
Contract.Ensures(Contract.Result<IAsyncResult>() != null);
if (!CanRead) __Error.ReadNotSupported();
Expand Down Expand Up @@ -326,18 +328,31 @@ internal IAsyncResult BeginReadInternal(byte[] buffer, int offset, int count, As

// Create the task to asynchronously do a Read. This task serves both
// as the asynchronous work item and as the IAsyncResult returned to the user.
var asyncResult = new ReadWriteTask(true /*isRead*/, delegate
var asyncResult = new ReadWriteTask(true /*isRead*/, apm, delegate
{
// The ReadWriteTask stores all of the parameters to pass to Read.
// As we're currently inside of it, we can get the current task
// and grab the parameters from it.
var thisTask = Task.InternalCurrent as ReadWriteTask;
Contract.Assert(thisTask != null, "Inside ReadWriteTask, InternalCurrent should be the ReadWriteTask");

// Do the Read and return the number of bytes read
var bytesRead = thisTask._stream.Read(thisTask._buffer, thisTask._offset, thisTask._count);
thisTask.ClearBeginState(); // just to help alleviate some memory pressure
return bytesRead;
try
{
// Do the Read and return the number of bytes read
return thisTask._stream.Read(thisTask._buffer, thisTask._offset, thisTask._count);
}
finally
{
// If this implementation is part of Begin/EndXx, then the EndXx method will handle
// finishing the async operation. However, if this is part of XxAsync, then there won't
// be an end method, and this task is responsible for cleaning up.
if (!thisTask._apm)
{
thisTask._stream.FinishTrackingAsyncOperation();
}

thisTask.ClearBeginState(); // just to help alleviate some memory pressure
}
}, state, this, buffer, offset, count, callback);

// Schedule it
Expand Down Expand Up @@ -388,9 +403,7 @@ public virtual int EndRead(IAsyncResult asyncResult)
}
finally
{
_activeReadWriteTask = null;
Contract.Assert(_asyncActiveSemaphore != null, "Must have been initialized in order to get here.");
_asyncActiveSemaphore.Release();
FinishTrackingAsyncOperation();
}
#endif
}
Expand All @@ -413,8 +426,20 @@ public virtual Task<int> ReadAsync(Byte[] buffer, int offset, int count, Cancell
: BeginEndReadAsync(buffer, offset, count);
}

[System.Security.SecuritySafeCritical]
[MethodImplAttribute(MethodImplOptions.InternalCall)]
private extern bool HasOverriddenBeginEndRead();

private Task<Int32> BeginEndReadAsync(Byte[] buffer, Int32 offset, Int32 count)
{
{
if (!HasOverriddenBeginEndRead())
{
// If the Stream does not override Begin/EndRead, then we can take an optimized path
// that skips an extra layer of tasks / IAsyncResults.
return (Task<Int32>)BeginReadInternal(buffer, offset, count, null, null, serializeAsynchronously: true, apm: false);
}

// Otherwise, we need to wrap calls to Begin/EndWrite to ensure we use the derived type's functionality.
return TaskFactory<Int32>.FromAsyncTrim(
this, new ReadWriteParameters { Buffer = buffer, Offset = offset, Count = count },
(stream, args, callback, state) => stream.BeginRead(args.Buffer, args.Offset, args.Count, callback, state), // cached by compiler
Expand All @@ -434,11 +459,13 @@ private Task<Int32> BeginEndReadAsync(Byte[] buffer, Int32 offset, Int32 count)
public virtual IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, Object state)
{
Contract.Ensures(Contract.Result<IAsyncResult>() != null);
return BeginWriteInternal(buffer, offset, count, callback, state, serializeAsynchronously: false);
return BeginWriteInternal(buffer, offset, count, callback, state, serializeAsynchronously: false, apm: true);
}

[HostProtection(ExternalThreading = true)]
internal IAsyncResult BeginWriteInternal(byte[] buffer, int offset, int count, AsyncCallback callback, Object state, bool serializeAsynchronously)
internal IAsyncResult BeginWriteInternal(
byte[] buffer, int offset, int count, AsyncCallback callback, Object state,
bool serializeAsynchronously, bool apm)
{
Contract.Ensures(Contract.Result<IAsyncResult>() != null);
if (!CanWrite) __Error.WriteNotSupported();
Expand Down Expand Up @@ -470,18 +497,32 @@ internal IAsyncResult BeginWriteInternal(byte[] buffer, int offset, int count, A

// Create the task to asynchronously do a Write. This task serves both
// as the asynchronous work item and as the IAsyncResult returned to the user.
var asyncResult = new ReadWriteTask(false /*isRead*/, delegate
var asyncResult = new ReadWriteTask(false /*isRead*/, apm, delegate
{
// The ReadWriteTask stores all of the parameters to pass to Write.
// As we're currently inside of it, we can get the current task
// and grab the parameters from it.
var thisTask = Task.InternalCurrent as ReadWriteTask;
Contract.Assert(thisTask != null, "Inside ReadWriteTask, InternalCurrent should be the ReadWriteTask");

// Do the Write
thisTask._stream.Write(thisTask._buffer, thisTask._offset, thisTask._count);
thisTask.ClearBeginState(); // just to help alleviate some memory pressure
return 0; // not used, but signature requires a value be returned
try
{
// Do the Write
thisTask._stream.Write(thisTask._buffer, thisTask._offset, thisTask._count);
return 0; // not used, but signature requires a value be returned
}
finally
{
// If this implementation is part of Begin/EndXx, then the EndXx method will handle
// finishing the async operation. However, if this is part of XxAsync, then there won't
// be an end method, and this task is responsible for cleaning up.
if (!thisTask._apm)
{
thisTask._stream.FinishTrackingAsyncOperation();
}

thisTask.ClearBeginState(); // just to help alleviate some memory pressure
}
}, state, this, buffer, offset, count, callback);

// Schedule it
Expand All @@ -501,23 +542,19 @@ private void RunReadWriteTaskWhenReady(Task asyncWaiter, ReadWriteTask readWrite
// preconditions in async methods that await.
Contract.Assert(asyncWaiter != null); // Ditto

// If the wait has already complete, run the task.
// If the wait has already completed, run the task.
if (asyncWaiter.IsCompleted)
{
Contract.Assert(asyncWaiter.IsRanToCompletion, "The semaphore wait should always complete successfully.");
RunReadWriteTask(readWriteTask);
}
else // Otherwise, wait for our turn, and then run the task.
{
asyncWaiter.ContinueWith((t, state) =>
{
Contract.Assert(t.IsRanToCompletion, "The semaphore wait should always complete successfully.");
var tuple = (Tuple<Stream,ReadWriteTask>)state;
tuple.Item1.RunReadWriteTask(tuple.Item2); // RunReadWriteTask(readWriteTask);
}, Tuple.Create<Stream,ReadWriteTask>(this, readWriteTask),
default(CancellationToken),
TaskContinuationOptions.ExecuteSynchronously,
TaskScheduler.Default);
asyncWaiter.ContinueWith((t, state) => {
Contract.Assert(t.IsRanToCompletion, "The semaphore wait should always complete successfully.");
var rwt = (ReadWriteTask)state;
rwt._stream.RunReadWriteTask(rwt); // RunReadWriteTask(readWriteTask);
}, readWriteTask, default(CancellationToken), TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
}
}

Expand All @@ -534,6 +571,13 @@ private void RunReadWriteTask(ReadWriteTask readWriteTask)
readWriteTask.m_taskScheduler = TaskScheduler.Default;
readWriteTask.ScheduleAndStart(needsProtection: false);
}

private void FinishTrackingAsyncOperation()
{
_activeReadWriteTask = null;
Contract.Assert(_asyncActiveSemaphore != null, "Must have been initialized in order to get here.");
_asyncActiveSemaphore.Release();
}
#endif

public virtual void EndWrite(IAsyncResult asyncResult)
Expand Down Expand Up @@ -574,9 +618,7 @@ public virtual void EndWrite(IAsyncResult asyncResult)
}
finally
{
_activeReadWriteTask = null;
Contract.Assert(_asyncActiveSemaphore != null, "Must have been initialized in order to get here.");
_asyncActiveSemaphore.Release();
FinishTrackingAsyncOperation();
}
#endif
}
Expand All @@ -600,11 +642,12 @@ public virtual void EndWrite(IAsyncResult asyncResult)
// with a single allocation.
private sealed class ReadWriteTask : Task<int>, ITaskCompletionAction
{
internal readonly bool _isRead;
internal readonly bool _isRead;
internal readonly bool _apm; // true if this is from Begin/EndXx; false if it's from XxAsync
internal Stream _stream;
internal byte [] _buffer;
internal int _offset;
internal int _count;
internal readonly int _offset;
internal readonly int _count;
private AsyncCallback _callback;
private ExecutionContext _context;

Expand All @@ -618,6 +661,7 @@ private sealed class ReadWriteTask : Task<int>, ITaskCompletionAction
[MethodImpl(MethodImplOptions.NoInlining)]
public ReadWriteTask(
bool isRead,
bool apm,
Func<object,int> function, object state,
Stream stream, byte[] buffer, int offset, int count, AsyncCallback callback) :
base(function, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach)
Expand All @@ -631,6 +675,7 @@ public ReadWriteTask(

// Store the arguments
_isRead = isRead;
_apm = apm;
_stream = stream;
_buffer = buffer;
_offset = offset;
Expand Down Expand Up @@ -697,6 +742,8 @@ public Task WriteAsync(Byte[] buffer, int offset, int count)
return WriteAsync(buffer, offset, count, CancellationToken.None);
}



[HostProtection(ExternalThreading = true)]
[ComVisible(false)]
public virtual Task WriteAsync(Byte[] buffer, int offset, int count, CancellationToken cancellationToken)
Expand All @@ -708,9 +755,20 @@ public virtual Task WriteAsync(Byte[] buffer, int offset, int count, Cancellatio
: BeginEndWriteAsync(buffer, offset, count);
}

[System.Security.SecuritySafeCritical]
[MethodImplAttribute(MethodImplOptions.InternalCall)]
private extern bool HasOverriddenBeginEndWrite();

private Task BeginEndWriteAsync(Byte[] buffer, Int32 offset, Int32 count)
{
{
if (!HasOverriddenBeginEndWrite())
{
// If the Stream does not override Begin/EndWrite, then we can take an optimized path
// that skips an extra layer of tasks / IAsyncResults.
return (Task)BeginWriteInternal(buffer, offset, count, null, null, serializeAsynchronously: true, apm: false);
}

// Otherwise, we need to wrap calls to Begin/EndWrite to ensure we use the derived type's functionality.
return TaskFactory<VoidTaskResult>.FromAsyncTrim(
this, new ReadWriteParameters { Buffer=buffer, Offset=offset, Count=count },
(stream, args, callback, state) => stream.BeginWrite(args.Buffer, args.Offset, args.Count, callback, state), // cached by compiler
Expand Down Expand Up @@ -1057,10 +1115,6 @@ internal static void EndWrite(IAsyncResult asyncResult) {
internal sealed class SyncStream : Stream, IDisposable
{
private Stream _stream;
[NonSerialized]
private bool? _overridesBeginRead;
[NonSerialized]
private bool? _overridesBeginWrite;

internal SyncStream(Stream stream)
{
Expand Down Expand Up @@ -1179,38 +1233,11 @@ public override int ReadByte()
lock(_stream)
return _stream.ReadByte();
}

private static bool OverridesBeginMethod(Stream stream, string methodName)
{
Contract.Requires(stream != null, "Expected a non-null stream.");
Contract.Requires(methodName == "BeginRead" || methodName == "BeginWrite",
"Expected BeginRead or BeginWrite as the method name to check.");

// Get all of the methods on the underlying stream
var methods = stream.GetType().GetMethods(BindingFlags.Public | BindingFlags.Instance);

// If any of the methods have the desired name and are defined on the base Stream
// Type, then the method was not overridden. If none of them were defined on the
// base Stream, then it must have been overridden.
foreach (var method in methods)
{
if (method.DeclaringType == typeof(Stream) &&
method.Name == methodName)
{
return false;
}
}
return true;
}

[HostProtection(ExternalThreading=true)]
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, Object state)
{
// Lazily-initialize whether the wrapped stream overrides BeginRead
if (_overridesBeginRead == null)
{
_overridesBeginRead = OverridesBeginMethod(_stream, "BeginRead");
}
bool overridesBeginRead = _stream.HasOverriddenBeginEndRead();

lock (_stream)
{
Expand All @@ -1220,9 +1247,9 @@ public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, Asy
// than a synchronous wait. A synchronous wait will result in a deadlock condition, because
// the EndXx method for the outstanding async operation won't be able to acquire the lock on
// _stream due to this call blocked while holding the lock.
return _overridesBeginRead.Value ?
return overridesBeginRead ?
_stream.BeginRead(buffer, offset, count, callback, state) :
_stream.BeginReadInternal(buffer, offset, count, callback, state, serializeAsynchronously: true);
_stream.BeginReadInternal(buffer, offset, count, callback, state, serializeAsynchronously: true, apm: true);
}
}

Expand Down Expand Up @@ -1264,11 +1291,7 @@ public override void WriteByte(byte b)
[HostProtection(ExternalThreading=true)]
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, Object state)
{
// Lazily-initialize whether the wrapped stream overrides BeginWrite
if (_overridesBeginWrite == null)
{
_overridesBeginWrite = OverridesBeginMethod(_stream, "BeginWrite");
}
bool overridesBeginWrite = _stream.HasOverriddenBeginEndWrite();

lock (_stream)
{
Expand All @@ -1278,9 +1301,9 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As
// than a synchronous wait. A synchronous wait will result in a deadlock condition, because
// the EndXx method for the outstanding async operation won't be able to acquire the lock on
// _stream due to this call blocked while holding the lock.
return _overridesBeginWrite.Value ?
return overridesBeginWrite ?
_stream.BeginWrite(buffer, offset, count, callback, state) :
_stream.BeginWriteInternal(buffer, offset, count, callback, state, serializeAsynchronously: true);
_stream.BeginWriteInternal(buffer, offset, count, callback, state, serializeAsynchronously: true, apm: true);
}
}

Expand Down
Loading