diff --git a/src/mscorlib/src/System/IO/Stream.cs b/src/mscorlib/src/System/IO/Stream.cs index a39e43828185..ac1811f4d7a3 100644 --- a/src/mscorlib/src/System/IO/Stream.cs +++ b/src/mscorlib/src/System/IO/Stream.cs @@ -289,11 +289,13 @@ protected virtual WaitHandle CreateWaitHandle() public virtual IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, Object state) { Contract.Ensures(Contract.Result() != null); - return BeginReadInternal(buffer, offset, count, callback, state, serializeAsynchronously: false); + return BeginReadInternal(buffer, offset, count, callback, state, serializeAsynchronously: false, apm: true); } [HostProtection(ExternalThreading = true)] - internal IAsyncResult BeginReadInternal(byte[] buffer, int offset, int count, AsyncCallback callback, Object state, bool serializeAsynchronously) + internal IAsyncResult BeginReadInternal( + byte[] buffer, int offset, int count, AsyncCallback callback, Object state, + bool serializeAsynchronously, bool apm) { Contract.Ensures(Contract.Result() != null); if (!CanRead) __Error.ReadNotSupported(); @@ -326,7 +328,7 @@ internal IAsyncResult BeginReadInternal(byte[] buffer, int offset, int count, As // Create the task to asynchronously do a Read. This task serves both // as the asynchronous work item and as the IAsyncResult returned to the user. - var asyncResult = new ReadWriteTask(true /*isRead*/, delegate + var asyncResult = new ReadWriteTask(true /*isRead*/, apm, delegate { // The ReadWriteTask stores all of the parameters to pass to Read. // As we're currently inside of it, we can get the current task @@ -334,10 +336,23 @@ internal IAsyncResult BeginReadInternal(byte[] buffer, int offset, int count, As var thisTask = Task.InternalCurrent as ReadWriteTask; Contract.Assert(thisTask != null, "Inside ReadWriteTask, InternalCurrent should be the ReadWriteTask"); - // Do the Read and return the number of bytes read - var bytesRead = thisTask._stream.Read(thisTask._buffer, thisTask._offset, thisTask._count); - thisTask.ClearBeginState(); // just to help alleviate some memory pressure - return bytesRead; + try + { + // Do the Read and return the number of bytes read + return thisTask._stream.Read(thisTask._buffer, thisTask._offset, thisTask._count); + } + finally + { + // If this implementation is part of Begin/EndXx, then the EndXx method will handle + // finishing the async operation. However, if this is part of XxAsync, then there won't + // be an end method, and this task is responsible for cleaning up. + if (!thisTask._apm) + { + thisTask._stream.FinishTrackingAsyncOperation(); + } + + thisTask.ClearBeginState(); // just to help alleviate some memory pressure + } }, state, this, buffer, offset, count, callback); // Schedule it @@ -388,9 +403,7 @@ public virtual int EndRead(IAsyncResult asyncResult) } finally { - _activeReadWriteTask = null; - Contract.Assert(_asyncActiveSemaphore != null, "Must have been initialized in order to get here."); - _asyncActiveSemaphore.Release(); + FinishTrackingAsyncOperation(); } #endif } @@ -413,8 +426,20 @@ public virtual Task ReadAsync(Byte[] buffer, int offset, int count, Cancell : BeginEndReadAsync(buffer, offset, count); } + [System.Security.SecuritySafeCritical] + [MethodImplAttribute(MethodImplOptions.InternalCall)] + private extern bool HasOverriddenBeginEndRead(); + private Task BeginEndReadAsync(Byte[] buffer, Int32 offset, Int32 count) - { + { + if (!HasOverriddenBeginEndRead()) + { + // If the Stream does not override Begin/EndRead, then we can take an optimized path + // that skips an extra layer of tasks / IAsyncResults. + return (Task)BeginReadInternal(buffer, offset, count, null, null, serializeAsynchronously: true, apm: false); + } + + // Otherwise, we need to wrap calls to Begin/EndWrite to ensure we use the derived type's functionality. return TaskFactory.FromAsyncTrim( this, new ReadWriteParameters { Buffer = buffer, Offset = offset, Count = count }, (stream, args, callback, state) => stream.BeginRead(args.Buffer, args.Offset, args.Count, callback, state), // cached by compiler @@ -434,11 +459,13 @@ private Task BeginEndReadAsync(Byte[] buffer, Int32 offset, Int32 count) public virtual IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, Object state) { Contract.Ensures(Contract.Result() != null); - return BeginWriteInternal(buffer, offset, count, callback, state, serializeAsynchronously: false); + return BeginWriteInternal(buffer, offset, count, callback, state, serializeAsynchronously: false, apm: true); } [HostProtection(ExternalThreading = true)] - internal IAsyncResult BeginWriteInternal(byte[] buffer, int offset, int count, AsyncCallback callback, Object state, bool serializeAsynchronously) + internal IAsyncResult BeginWriteInternal( + byte[] buffer, int offset, int count, AsyncCallback callback, Object state, + bool serializeAsynchronously, bool apm) { Contract.Ensures(Contract.Result() != null); if (!CanWrite) __Error.WriteNotSupported(); @@ -470,7 +497,7 @@ internal IAsyncResult BeginWriteInternal(byte[] buffer, int offset, int count, A // Create the task to asynchronously do a Write. This task serves both // as the asynchronous work item and as the IAsyncResult returned to the user. - var asyncResult = new ReadWriteTask(false /*isRead*/, delegate + var asyncResult = new ReadWriteTask(false /*isRead*/, apm, delegate { // The ReadWriteTask stores all of the parameters to pass to Write. // As we're currently inside of it, we can get the current task @@ -478,10 +505,24 @@ internal IAsyncResult BeginWriteInternal(byte[] buffer, int offset, int count, A var thisTask = Task.InternalCurrent as ReadWriteTask; Contract.Assert(thisTask != null, "Inside ReadWriteTask, InternalCurrent should be the ReadWriteTask"); - // Do the Write - thisTask._stream.Write(thisTask._buffer, thisTask._offset, thisTask._count); - thisTask.ClearBeginState(); // just to help alleviate some memory pressure - return 0; // not used, but signature requires a value be returned + try + { + // Do the Write + thisTask._stream.Write(thisTask._buffer, thisTask._offset, thisTask._count); + return 0; // not used, but signature requires a value be returned + } + finally + { + // If this implementation is part of Begin/EndXx, then the EndXx method will handle + // finishing the async operation. However, if this is part of XxAsync, then there won't + // be an end method, and this task is responsible for cleaning up. + if (!thisTask._apm) + { + thisTask._stream.FinishTrackingAsyncOperation(); + } + + thisTask.ClearBeginState(); // just to help alleviate some memory pressure + } }, state, this, buffer, offset, count, callback); // Schedule it @@ -501,7 +542,7 @@ private void RunReadWriteTaskWhenReady(Task asyncWaiter, ReadWriteTask readWrite // preconditions in async methods that await. Contract.Assert(asyncWaiter != null); // Ditto - // If the wait has already complete, run the task. + // If the wait has already completed, run the task. if (asyncWaiter.IsCompleted) { Contract.Assert(asyncWaiter.IsRanToCompletion, "The semaphore wait should always complete successfully."); @@ -509,15 +550,11 @@ private void RunReadWriteTaskWhenReady(Task asyncWaiter, ReadWriteTask readWrite } else // Otherwise, wait for our turn, and then run the task. { - asyncWaiter.ContinueWith((t, state) => - { - Contract.Assert(t.IsRanToCompletion, "The semaphore wait should always complete successfully."); - var tuple = (Tuple)state; - tuple.Item1.RunReadWriteTask(tuple.Item2); // RunReadWriteTask(readWriteTask); - }, Tuple.Create(this, readWriteTask), - default(CancellationToken), - TaskContinuationOptions.ExecuteSynchronously, - TaskScheduler.Default); + asyncWaiter.ContinueWith((t, state) => { + Contract.Assert(t.IsRanToCompletion, "The semaphore wait should always complete successfully."); + var rwt = (ReadWriteTask)state; + rwt._stream.RunReadWriteTask(rwt); // RunReadWriteTask(readWriteTask); + }, readWriteTask, default(CancellationToken), TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); } } @@ -534,6 +571,13 @@ private void RunReadWriteTask(ReadWriteTask readWriteTask) readWriteTask.m_taskScheduler = TaskScheduler.Default; readWriteTask.ScheduleAndStart(needsProtection: false); } + + private void FinishTrackingAsyncOperation() + { + _activeReadWriteTask = null; + Contract.Assert(_asyncActiveSemaphore != null, "Must have been initialized in order to get here."); + _asyncActiveSemaphore.Release(); + } #endif public virtual void EndWrite(IAsyncResult asyncResult) @@ -574,9 +618,7 @@ public virtual void EndWrite(IAsyncResult asyncResult) } finally { - _activeReadWriteTask = null; - Contract.Assert(_asyncActiveSemaphore != null, "Must have been initialized in order to get here."); - _asyncActiveSemaphore.Release(); + FinishTrackingAsyncOperation(); } #endif } @@ -600,11 +642,12 @@ public virtual void EndWrite(IAsyncResult asyncResult) // with a single allocation. private sealed class ReadWriteTask : Task, ITaskCompletionAction { - internal readonly bool _isRead; + internal readonly bool _isRead; + internal readonly bool _apm; // true if this is from Begin/EndXx; false if it's from XxAsync internal Stream _stream; internal byte [] _buffer; - internal int _offset; - internal int _count; + internal readonly int _offset; + internal readonly int _count; private AsyncCallback _callback; private ExecutionContext _context; @@ -618,6 +661,7 @@ private sealed class ReadWriteTask : Task, ITaskCompletionAction [MethodImpl(MethodImplOptions.NoInlining)] public ReadWriteTask( bool isRead, + bool apm, Func function, object state, Stream stream, byte[] buffer, int offset, int count, AsyncCallback callback) : base(function, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach) @@ -631,6 +675,7 @@ public ReadWriteTask( // Store the arguments _isRead = isRead; + _apm = apm; _stream = stream; _buffer = buffer; _offset = offset; @@ -697,6 +742,8 @@ public Task WriteAsync(Byte[] buffer, int offset, int count) return WriteAsync(buffer, offset, count, CancellationToken.None); } + + [HostProtection(ExternalThreading = true)] [ComVisible(false)] public virtual Task WriteAsync(Byte[] buffer, int offset, int count, CancellationToken cancellationToken) @@ -708,9 +755,20 @@ public virtual Task WriteAsync(Byte[] buffer, int offset, int count, Cancellatio : BeginEndWriteAsync(buffer, offset, count); } + [System.Security.SecuritySafeCritical] + [MethodImplAttribute(MethodImplOptions.InternalCall)] + private extern bool HasOverriddenBeginEndWrite(); private Task BeginEndWriteAsync(Byte[] buffer, Int32 offset, Int32 count) - { + { + if (!HasOverriddenBeginEndWrite()) + { + // If the Stream does not override Begin/EndWrite, then we can take an optimized path + // that skips an extra layer of tasks / IAsyncResults. + return (Task)BeginWriteInternal(buffer, offset, count, null, null, serializeAsynchronously: true, apm: false); + } + + // Otherwise, we need to wrap calls to Begin/EndWrite to ensure we use the derived type's functionality. return TaskFactory.FromAsyncTrim( this, new ReadWriteParameters { Buffer=buffer, Offset=offset, Count=count }, (stream, args, callback, state) => stream.BeginWrite(args.Buffer, args.Offset, args.Count, callback, state), // cached by compiler @@ -1057,10 +1115,6 @@ internal static void EndWrite(IAsyncResult asyncResult) { internal sealed class SyncStream : Stream, IDisposable { private Stream _stream; - [NonSerialized] - private bool? _overridesBeginRead; - [NonSerialized] - private bool? _overridesBeginWrite; internal SyncStream(Stream stream) { @@ -1179,38 +1233,11 @@ public override int ReadByte() lock(_stream) return _stream.ReadByte(); } - - private static bool OverridesBeginMethod(Stream stream, string methodName) - { - Contract.Requires(stream != null, "Expected a non-null stream."); - Contract.Requires(methodName == "BeginRead" || methodName == "BeginWrite", - "Expected BeginRead or BeginWrite as the method name to check."); - - // Get all of the methods on the underlying stream - var methods = stream.GetType().GetMethods(BindingFlags.Public | BindingFlags.Instance); - - // If any of the methods have the desired name and are defined on the base Stream - // Type, then the method was not overridden. If none of them were defined on the - // base Stream, then it must have been overridden. - foreach (var method in methods) - { - if (method.DeclaringType == typeof(Stream) && - method.Name == methodName) - { - return false; - } - } - return true; - } [HostProtection(ExternalThreading=true)] public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, Object state) { - // Lazily-initialize whether the wrapped stream overrides BeginRead - if (_overridesBeginRead == null) - { - _overridesBeginRead = OverridesBeginMethod(_stream, "BeginRead"); - } + bool overridesBeginRead = _stream.HasOverriddenBeginEndRead(); lock (_stream) { @@ -1220,9 +1247,9 @@ public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, Asy // than a synchronous wait. A synchronous wait will result in a deadlock condition, because // the EndXx method for the outstanding async operation won't be able to acquire the lock on // _stream due to this call blocked while holding the lock. - return _overridesBeginRead.Value ? + return overridesBeginRead ? _stream.BeginRead(buffer, offset, count, callback, state) : - _stream.BeginReadInternal(buffer, offset, count, callback, state, serializeAsynchronously: true); + _stream.BeginReadInternal(buffer, offset, count, callback, state, serializeAsynchronously: true, apm: true); } } @@ -1264,11 +1291,7 @@ public override void WriteByte(byte b) [HostProtection(ExternalThreading=true)] public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, Object state) { - // Lazily-initialize whether the wrapped stream overrides BeginWrite - if (_overridesBeginWrite == null) - { - _overridesBeginWrite = OverridesBeginMethod(_stream, "BeginWrite"); - } + bool overridesBeginWrite = _stream.HasOverriddenBeginEndWrite(); lock (_stream) { @@ -1278,9 +1301,9 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As // than a synchronous wait. A synchronous wait will result in a deadlock condition, because // the EndXx method for the outstanding async operation won't be able to acquire the lock on // _stream due to this call blocked while holding the lock. - return _overridesBeginWrite.Value ? + return overridesBeginWrite ? _stream.BeginWrite(buffer, offset, count, callback, state) : - _stream.BeginWriteInternal(buffer, offset, count, callback, state, serializeAsynchronously: true); + _stream.BeginWriteInternal(buffer, offset, count, callback, state, serializeAsynchronously: true, apm: true); } } diff --git a/src/vm/comutilnative.cpp b/src/vm/comutilnative.cpp index 9664bf93257d..b3f267028215 100644 --- a/src/vm/comutilnative.cpp +++ b/src/vm/comutilnative.cpp @@ -3159,3 +3159,76 @@ INT32 QCALLTYPE CoreFxGlobalization::HashSortKey(PCBYTE pSortKey, INT32 cbSortKe return retVal; } #endif //FEATURE_COREFX_GLOBALIZATION + +static MethodTable * g_pStreamMT; +static WORD g_slotBeginRead, g_slotEndRead; +static WORD g_slotBeginWrite, g_slotEndWrite; + +static bool HasOverriddenStreamMethod(MethodTable * pMT, WORD slot) +{ + CONTRACTL{ + NOTHROW; + GC_NOTRIGGER; + MODE_ANY; + SO_TOLERANT; + } CONTRACTL_END; + + PCODE actual = pMT->GetRestoredSlot(slot); + PCODE base = g_pStreamMT->GetRestoredSlot(slot); + if (actual == base) + return false; + + if (!g_pStreamMT->IsZapped()) + { + // If mscorlib is JITed, the slots can be patched and thus we need to compare the actual MethodDescs + // to detect match reliably + if (MethodTable::GetMethodDescForSlotAddress(actual) == MethodTable::GetMethodDescForSlotAddress(base)) + return false; + } + + return true; +} + +FCIMPL1(FC_BOOL_RET, StreamNative::HasOverriddenBeginEndRead, Object *stream) +{ + FCALL_CONTRACT; + + if (stream == NULL) + FC_RETURN_BOOL(TRUE); + + if (g_pStreamMT == NULL || g_slotBeginRead == 0 || g_slotEndRead == 0) + { + HELPER_METHOD_FRAME_BEGIN_RET_1(stream); + g_pStreamMT = MscorlibBinder::GetClass(CLASS__STREAM); + g_slotBeginRead = MscorlibBinder::GetMethod(METHOD__STREAM__BEGIN_READ)->GetSlot(); + g_slotEndRead = MscorlibBinder::GetMethod(METHOD__STREAM__END_READ)->GetSlot(); + HELPER_METHOD_FRAME_END(); + } + + MethodTable * pMT = stream->GetMethodTable(); + + FC_RETURN_BOOL(HasOverriddenStreamMethod(pMT, g_slotBeginRead) || HasOverriddenStreamMethod(pMT, g_slotEndRead)); +} +FCIMPLEND + +FCIMPL1(FC_BOOL_RET, StreamNative::HasOverriddenBeginEndWrite, Object *stream) +{ + FCALL_CONTRACT; + + if (stream == NULL) + FC_RETURN_BOOL(TRUE); + + if (g_pStreamMT == NULL || g_slotBeginWrite == 0 || g_slotEndWrite == 0) + { + HELPER_METHOD_FRAME_BEGIN_RET_1(stream); + g_pStreamMT = MscorlibBinder::GetClass(CLASS__STREAM); + g_slotBeginWrite = MscorlibBinder::GetMethod(METHOD__STREAM__BEGIN_WRITE)->GetSlot(); + g_slotEndWrite = MscorlibBinder::GetMethod(METHOD__STREAM__END_WRITE)->GetSlot(); + HELPER_METHOD_FRAME_END(); + } + + MethodTable * pMT = stream->GetMethodTable(); + + FC_RETURN_BOOL(HasOverriddenStreamMethod(pMT, g_slotBeginWrite) || HasOverriddenStreamMethod(pMT, g_slotEndWrite)); +} +FCIMPLEND diff --git a/src/vm/comutilnative.h b/src/vm/comutilnative.h index 3a9b35a365cf..21d7b91823b8 100644 --- a/src/vm/comutilnative.h +++ b/src/vm/comutilnative.h @@ -316,4 +316,10 @@ class CoreFxGlobalization { }; #endif // FEATURE_COREFX_GLOBALIZATION +class StreamNative { +public: + static FCDECL1(FC_BOOL_RET, HasOverriddenBeginEndRead, Object *stream); + static FCDECL1(FC_BOOL_RET, HasOverriddenBeginEndWrite, Object *stream); +}; + #endif // _COMUTILNATIVE_H_ diff --git a/src/vm/ecalllist.h b/src/vm/ecalllist.h index 504802b50d4d..9461deff138e 100644 --- a/src/vm/ecalllist.h +++ b/src/vm/ecalllist.h @@ -2069,6 +2069,11 @@ FCFuncStart(gVersioningHelperFuncs) FCFuncElement("GetRuntimeId", GetRuntimeId_Wrapper) FCFuncEnd() +FCFuncStart(gStreamFuncs) + FCFuncElement("HasOverriddenBeginEndRead", StreamNative::HasOverriddenBeginEndRead) + FCFuncElement("HasOverriddenBeginEndWrite", StreamNative::HasOverriddenBeginEndWrite) +FCFuncEnd() + #ifndef FEATURE_CORECLR FCFuncStart(gConsoleStreamFuncs) FCFuncElement("WaitForAvailableConsoleInput", ConsoleStreamHelper::WaitForAvailableConsoleInput) @@ -2420,6 +2425,7 @@ FCClassElement("SizedReference", "System", gSizedRefHandleFuncs) FCClassElement("StackBuilderSink", "System.Runtime.Remoting.Messaging", gStackBuilderSinkFuncs) #endif FCClassElement("StackTrace", "System.Diagnostics", gDiagnosticsStackTrace) +FCClassElement("Stream", "System.IO", gStreamFuncs) FCClassElement("String", "System", gStringFuncs) FCClassElement("StringBuilder", "System.Text", gStringBufferFuncs) FCClassElement("StringExpressionSet", "System.Security.Util", gCOMStringExpressionSetFuncs) diff --git a/src/vm/metasig.h b/src/vm/metasig.h index c64ab675e3bc..7836fbaa4770 100644 --- a/src/vm/metasig.h +++ b/src/vm/metasig.h @@ -679,6 +679,10 @@ DEFINE_METASIG_T(SM(RefCleanupWorkList_SafeHandle_RetIntPtr, r(C(CLEANUP_WORK_LI DEFINE_METASIG_T(IM(RuntimeTypeHandle_RefException_RetBool, g(RT_TYPE_HANDLE) r(C(EXCEPTION)), F)) DEFINE_METASIG_T(IM(RuntimeTypeHandle_RetRuntimeTypeHandle, g(RT_TYPE_HANDLE), g(RT_TYPE_HANDLE))) +DEFINE_METASIG_T(IM(ArrByte_Int_Int_AsyncCallback_Object_RetIAsyncResult, a(b) i i C(ASYNCCALLBACK) j, C(IASYNCRESULT))) +DEFINE_METASIG_T(IM(IAsyncResult_RetInt, C(IASYNCRESULT), i)) +DEFINE_METASIG_T(IM(IAsyncResult_RetVoid, C(IASYNCRESULT), v)) + // Undefine macros in case we include the file again in the compilation unit #undef DEFINE_METASIG diff --git a/src/vm/mscorlib.h b/src/vm/mscorlib.h index 8fb7853bb081..978b79d5b7b2 100644 --- a/src/vm/mscorlib.h +++ b/src/vm/mscorlib.h @@ -1495,6 +1495,10 @@ DEFINE_CLASS(STACK_TRACE, Diagnostics, StackTrace) DEFINE_METHOD(STACK_TRACE, GET_MANAGED_STACK_TRACE_HELPER, GetManagedStackTraceStringHelper, SM_Bool_RetStr) DEFINE_CLASS(STREAM, IO, Stream) +DEFINE_METHOD(STREAM, BEGIN_READ, BeginRead, IM_ArrByte_Int_Int_AsyncCallback_Object_RetIAsyncResult) +DEFINE_METHOD(STREAM, END_READ, EndRead, IM_IAsyncResult_RetInt) +DEFINE_METHOD(STREAM, BEGIN_WRITE, BeginWrite, IM_ArrByte_Int_Int_AsyncCallback_Object_RetIAsyncResult) +DEFINE_METHOD(STREAM, END_WRITE, EndWrite, IM_IAsyncResult_RetVoid) // Defined as element type alias // DEFINE_CLASS(INTPTR, System, IntPtr)