Skip to content

Simplify GraphNode #14908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 56 additions & 192 deletions src/Compiler/Facilities/BuildGraph.fs
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,6 @@ type NodeCode private () =
|> Async.Parallel
|> Node

type private AgentMessage<'T> = GetValue of AsyncReplyChannel<Result<'T, Exception>> * callerCancellationToken: CancellationToken

type private Agent<'T> = MailboxProcessor<AgentMessage<'T>> * CancellationTokenSource

[<RequireQualifiedAccess>]
type private GraphNodeAction<'T> =
| GetValueByAgent
| GetValue
| CachedValue of 'T

[<RequireQualifiedAccess>]
module GraphNode =

Expand All @@ -228,210 +218,84 @@ module GraphNode =
| None -> ()

[<Sealed>]
type GraphNode<'T> private (retryCompute: bool, computation: NodeCode<'T>, cachedResult: Task<'T>, cachedResultNode: NodeCode<'T>) =
type GraphNode<'T> private (computation: NodeCode<'T>, cachedResult: ValueOption<'T>, cachedResultNode: NodeCode<'T>) =

let gate = obj ()
let mutable computation = computation
let mutable requestCount = 0

let mutable cachedResult: Task<'T> = cachedResult
let mutable cachedResult = cachedResult
let mutable cachedResultNode: NodeCode<'T> = cachedResultNode

let isCachedResultNodeNotNull () =
not (obj.ReferenceEquals(cachedResultNode, null))

let isCachedResultNotNull () =
not (obj.ReferenceEquals(cachedResult, null))

// retryCompute indicates that we abandon computations when the originator is
// cancelled.
//
// If retryCompute is 'true', the computation is run directly in the originating requestor's
// thread. If cancelled, other awaiting computations must restart the computation from scratch.
//
// If retryCompute is 'false', a MailboxProcessor is used to allow the cancelled originator
// to detach from the computation, while other awaiting computations continue to wait on the result.
//
// Currently, 'retryCompute' = true for all graph nodes. However, the code for we include the
// code to allow 'retryCompute' = false in case it's needed in the future, and ensure it is under independent
// unit test.
let loop (agent: MailboxProcessor<AgentMessage<'T>>) =
async {
assert (not retryCompute)

try
while true do
match! agent.Receive() with
| GetValue (replyChannel, callerCancellationToken) ->

Thread.CurrentThread.CurrentUICulture <- GraphNode.culture

try
use _reg =
// When a cancellation has occured, notify the reply channel to let the requester stop waiting for a response.
callerCancellationToken.Register(fun () ->
let ex = OperationCanceledException() :> exn
replyChannel.Reply(Result.Error ex))

callerCancellationToken.ThrowIfCancellationRequested()

if isCachedResultNotNull () then
replyChannel.Reply(Ok cachedResult.Result)
else
// This computation can only be canceled if the requestCount reaches zero.
let! result = computation |> Async.AwaitNodeCode
cachedResult <- Task.FromResult(result)
cachedResultNode <- node.Return result
computation <- Unchecked.defaultof<_>

if not callerCancellationToken.IsCancellationRequested then
replyChannel.Reply(Ok result)
with ex ->
if not callerCancellationToken.IsCancellationRequested then
replyChannel.Reply(Result.Error ex)
with _ ->
()
}

let mutable agent: Agent<'T> = Unchecked.defaultof<_>

let semaphore: SemaphoreSlim =
if retryCompute then
new SemaphoreSlim(1, 1)
else
Unchecked.defaultof<_>
let semaphore = new SemaphoreSlim(1, 1)

member _.GetOrComputeValue() =
// fast path
if isCachedResultNodeNotNull () then
cachedResultNode
else
node {
if isCachedResultNodeNotNull () then
return! cachedResult |> NodeCode.AwaitTask
else
let action =
lock gate
<| fun () ->
// We try to get the cached result after the lock so we don't spin up a new mailbox processor.
if isCachedResultNodeNotNull () then
GraphNodeAction<'T>.CachedValue cachedResult.Result
else
requestCount <- requestCount + 1

if retryCompute then
GraphNodeAction<'T>.GetValue
else
match box agent with
| null ->
try
let cts = new CancellationTokenSource()
let mbp = new MailboxProcessor<_>(loop, cancellationToken = cts.Token)
let newAgent = (mbp, cts)
agent <- newAgent
mbp.Start()
GraphNodeAction<'T>.GetValueByAgent
with exn ->
agent <- Unchecked.defaultof<_>
PreserveStackTrace exn
raise exn
| _ -> GraphNodeAction<'T>.GetValueByAgent

match action with
| GraphNodeAction.CachedValue result -> return result
| GraphNodeAction.GetValue ->
try
let! ct = NodeCode.CancellationToken

// We must set 'taken' before any implicit cancellation checks
// occur, making sure we are under the protection of the 'try'.
// For example, NodeCode's 'try/finally' (TryFinally) uses async.TryFinally which does
// implicit cancellation checks even before the try is entered, as do the
// de-sugaring of 'do!' and other NodeCode constructs.
let mutable taken = false

try
do!
semaphore
.WaitAsync(ct)
.ContinueWith(
(fun _ -> taken <- true),
(TaskContinuationOptions.NotOnCanceled
||| TaskContinuationOptions.NotOnFaulted
||| TaskContinuationOptions.ExecuteSynchronously)
)
|> NodeCode.AwaitTask

if isCachedResultNotNull () then
return cachedResult.Result
else
let tcs = TaskCompletionSource<'T>()
let (Node (p)) = computation

Async.StartWithContinuations(
async {
Thread.CurrentThread.CurrentUICulture <- GraphNode.culture
return! p
},
(fun res ->
cachedResult <- Task.FromResult(res)
cachedResultNode <- node.Return res
computation <- Unchecked.defaultof<_>
tcs.SetResult(res)),
(fun ex -> tcs.SetException(ex)),
(fun _ -> tcs.SetCanceled()),
ct
)

return! tcs.Task |> NodeCode.AwaitTask
finally
if taken then semaphore.Release() |> ignore
finally
lock gate <| fun () -> requestCount <- requestCount - 1

| GraphNodeAction.GetValueByAgent ->
assert (not retryCompute)
let mbp, cts = agent

try
let! ct = NodeCode.CancellationToken

let! res =
mbp.PostAndAsyncReply(fun replyChannel -> GetValue(replyChannel, ct))
|> NodeCode.AwaitAsync

match res with
| Ok result -> return result
| Result.Error ex -> return raise ex
finally
lock gate
<| fun () ->
requestCount <- requestCount - 1

if requestCount = 0 then
cts.Cancel() // cancel computation when all requests are cancelled

try
(mbp :> IDisposable).Dispose()
with _ ->
()

cts.Dispose()
agent <- Unchecked.defaultof<_>
Interlocked.Increment(&requestCount) |> ignore
try
let! ct = NodeCode.CancellationToken

// We must set 'taken' before any implicit cancellation checks
// occur, making sure we are under the protection of the 'try'.
// For example, NodeCode's 'try/finally' (TryFinally) uses async.TryFinally which does
// implicit cancellation checks even before the try is entered, as do the
// de-sugaring of 'do!' and other NodeCode constructs.
let mutable taken = false

try
do!
semaphore
.WaitAsync(ct)
.ContinueWith(
(fun _ -> taken <- true),
(TaskContinuationOptions.NotOnCanceled
||| TaskContinuationOptions.NotOnFaulted
||| TaskContinuationOptions.ExecuteSynchronously)
)
|> NodeCode.AwaitTask

match cachedResult with
| ValueSome value -> return value
| _ ->
let tcs = TaskCompletionSource<'T>()
let (Node (p)) = computation

Async.StartWithContinuations(
async {
Thread.CurrentThread.CurrentUICulture <- GraphNode.culture
return! p
},
(fun res ->
cachedResult <- ValueSome res
cachedResultNode <- node.Return res
computation <- Unchecked.defaultof<_>
tcs.SetResult(res)),
(fun ex -> tcs.SetException(ex)),
(fun _ -> tcs.SetCanceled()),
ct
)

return! tcs.Task |> NodeCode.AwaitTask
finally
if taken then semaphore.Release() |> ignore
finally
Interlocked.Decrement(&requestCount) |> ignore
}

member _.TryPeekValue() =
match box cachedResult with
| null -> ValueNone
| _ -> ValueSome cachedResult.Result
member _.TryPeekValue() = cachedResult

member _.HasValue = isCachedResultNotNull ()
member _.HasValue = cachedResult.IsSome

member _.IsComputing = requestCount > 0

static member FromResult(result: 'T) =
let nodeResult = node.Return result
GraphNode(true, nodeResult, Task.FromResult(result), nodeResult)
GraphNode(nodeResult, ValueSome result, nodeResult)

new(retryCompute: bool, computation) = GraphNode(retryCompute, computation, Unchecked.defaultof<_>, Unchecked.defaultof<_>)
new(computation) = GraphNode(retryCompute = true, computation = computation)
new(computation) = GraphNode(computation, ValueNone, Unchecked.defaultof<_>)
4 changes: 0 additions & 4 deletions src/Compiler/Facilities/BuildGraph.fsi
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,7 @@ module internal GraphNode =
[<Sealed>]
type internal GraphNode<'T> =

/// - retryCompute - When set to 'true', subsequent requesters will retry the computation if the first-in request cancels. Retrying computations will have better callstacks.
/// - computation - The computation code to run.
new: retryCompute: bool * computation: NodeCode<'T> -> GraphNode<'T>

/// By default, 'retryCompute' is 'true'.
new: computation: NodeCode<'T> -> GraphNode<'T>

/// Creates a GraphNode with given result already cached.
Expand Down
46 changes: 0 additions & 46 deletions tests/FSharp.Compiler.UnitTests/BuildGraphTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -227,52 +227,6 @@ module BuildGraphTests =
|> Seq.iter (fun x ->
try x.Wait(1000) |> ignore with | :? TimeoutException -> reraise() | _ -> ())

[<Fact>]
let ``No-RetryCompute - Many requests to get a value asynchronously should only evaluate the computation once even when some requests get canceled``() =
let requests = 10000
let resetEvent = new ManualResetEvent(false)
let mutable computationCountBeforeSleep = 0
let mutable computationCount = 0

let graphNode =
GraphNode(false, node {
computationCountBeforeSleep <- computationCountBeforeSleep + 1
let! _ = NodeCode.AwaitWaitHandle_ForTesting(resetEvent)
computationCount <- computationCount + 1
return 1
})

use cts = new CancellationTokenSource()

let work =
node {
let! _ = graphNode.GetOrComputeValue()
()
}

let tasks = ResizeArray()

for i = 0 to requests - 1 do
if i % 10 = 0 then
NodeCode.StartAsTask_ForTesting(work, ct = cts.Token)
|> tasks.Add
else
NodeCode.StartAsTask_ForTesting(work)
|> tasks.Add

cts.Cancel()
resetEvent.Set() |> ignore
NodeCode.RunImmediateWithoutCancellation(work)
|> ignore

Assert.shouldBeTrue cts.IsCancellationRequested
Assert.shouldBe 1 computationCountBeforeSleep
Assert.shouldBe 1 computationCount

tasks
|> Seq.iter (fun x ->
try x.Wait(1000) |> ignore with | :? TimeoutException -> reraise() | _ -> ())

[<Fact>]
let ``GraphNode created from an already computed result will return it in tryPeekValue`` () =
let graphNode = GraphNode.FromResult 1
Expand Down