Skip to content

Commit

Permalink
issue zio#5878 - Backport ThreadLocalBridge (zio#5980)
Browse files Browse the repository at this point in the history
* Rewritten based on the merged version of PR zio#5907

* call Supervisor.unsafeOnResume() even if `superviseOperations` is false (same as ZIO 2)
  • Loading branch information
dkarlinsky authored May 30, 2022
1 parent 9eaa6bb commit 22921ee
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 11 deletions.
10 changes: 4 additions & 6 deletions core-tests/shared/src/test/scala/zio/SupervisorSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ object SupervisorSpec extends ZIOBaseSpec {
assertTrue(
recorded.toSet == (Set(
s"unsafeOnResume($fiber1)",
s"unsafeOnResume($fiber2)"
).filter(_ => superviseOperations) ++
Set(
s"unsafeOnSuspend($fiber1)",
s"unsafeOnSuspend($fiber2)"
))
s"unsafeOnResume($fiber2)",
s"unsafeOnSuspend($fiber1)",
s"unsafeOnSuspend($fiber2)"
))
)
}
}
Expand Down
87 changes: 87 additions & 0 deletions core-tests/shared/src/test/scala/zio/ThreadLocalBridgeSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package zio

import zio.ThreadLocalBridge.TrackingFiberRef
import zio.test._

object ThreadLocalBridgeSpec extends ZIOBaseSpec {

def spec = suite("SupervisorSpec")(
suite("fiberRefTrackingSupervisor")(
testM("track initial value") {
val tag = "tiv"
val initialValue = s"initial-value-$tag"
tracking(initialValue) { (_, threadLocalGet) =>
for {
ab <- threadLocalGet zipPar threadLocalGet
(a, b) = ab
} yield {
assertTrue(
a.contains(initialValue),
b.contains(initialValue)
)
}
}
},
testM("track FiberRef.set / modify") {
val tag = "modify"
val initialValue = s"initial-value-$tag"
val newValue1 = s"new-value1-$tag"
val newValue2 = s"new-value2-$tag"
tracking(initialValue) { (fiberRef, threadLocalGet) =>
for {
beforeModify <- threadLocalGet
_ <- fiberRef.modify(_ => () -> newValue1)
afterModify <- threadLocalGet
ab <-
(fiberRef.set(newValue2) *> threadLocalGet) zipPar
threadLocalGet
(a, b) = ab
} yield {
assertTrue(
beforeModify.contains(initialValue),
afterModify.contains(newValue1),
a.contains(newValue2),
b.contains(newValue1)
)
}
}
},
testM("track in FiberRef.locally") {
val tag = "locally"
val initialValue = s"initial-value-$tag"
val newValue1 = s"new-value1-$tag"
val newValue2 = s"new-value2-$tag"
tracking(initialValue) { (fiberRef, threadLocalGet) =>
for {
a <- threadLocalGet
bc <- fiberRef.locally(newValue1) {
threadLocalGet zipPar
fiberRef.locally(newValue2)(threadLocalGet)
}
(b, c) = bc
d <- threadLocalGet
} yield assertTrue(
a.contains(initialValue),
b.contains(newValue1),
c.contains(newValue2),
d.contains(initialValue)
)
}
}
)
)

def tracking[R, E, A](
initialValue: String
)(effect: (TrackingFiberRef[String], UIO[Option[String]]) => ZIO[R, E, A]) = {
val threadLocal = aThreadLocal()
val threadLocalGet = ZIO.succeed(threadLocal.get)
ThreadLocalBridge(initialValue)(a => threadLocal.set(Some(a)))
.withFiberRef(fr => effect(fr, threadLocalGet))
}

private def aThreadLocal() =
new ThreadLocal[Option[String]] {
override def initialValue() = None
}
}
133 changes: 133 additions & 0 deletions core/shared/src/main/scala/zio/ThreadLocalBridge.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package zio

import zio.Supervisor.Propagation
import zio.ThreadLocalBridge.TrackingFiberRef
import zio.internal.FiberContext

trait ThreadLocalBridge[A] {
def withFiberRef[R, E, A1](f: TrackingFiberRef[A] => ZIO[R, E, A1]): ZIO[R, E, A1]
}

object ThreadLocalBridge {
def apply[A](initialValue: A)(link: A => Unit) =
new ThreadLocalBridge[A] {
override def withFiberRef[R, E, A1](f: TrackingFiberRef[A] => ZIO[R, E, A1]): ZIO[R, E, A1] =
for {
fiberRef <- FiberRef.make(initialValue)
_ = link(initialValue)
supervisor = new FiberRefTrackingSupervisor
_ = supervisor.trackFiberRef(fiberRef, link)
res <- f(new TrackingFiberRef(fiberRef, link))
.supervised(supervisor)
} yield res
}

private class FiberRefTrackingSupervisor extends Supervisor[Unit] {

private val trackedRefs: Ref[Set[(FiberRef[_], Any => Unit)]] = Ref.unsafeMake(Set.empty)

override def value: UIO[Unit] = ZIO.unit

override def unsafeOnEnd[R, E, A1](value: Exit[E, A1], fiber: Fiber.Runtime[E, A1]) = Propagation.Continue

override private[zio] def unsafeOnStart[R, E, A](
environment: R,
effect: ZIO[R, E, A],
parent: Option[Fiber.Runtime[Any, Any]],
fiber: Fiber.Runtime[E, A]
) = Propagation.Continue

def trackFiberRef[B](fiberRef: FiberRef[B], link: B => Unit): Unit =
trackedRefs.unsafeUpdate(old => old + ((fiberRef, link.asInstanceOf[Any => Unit])))

override def unsafeOnSuspend[E, A1](fiber: Fiber.Runtime[E, A1]): Unit =
foreachTrackedRef { (fiberRef, link) =>
link(fiberRef.initial)
}

override def unsafeOnResume[E, A1](fiber: Fiber.Runtime[E, A1]): Unit =
foreachTrackedRef { (fiberRef, link) =>
val value = fiber.asInstanceOf[FiberContext[E, A1]].fiberRefLocals.get(fiberRef)
if (value == null) {
link(fiberRef.initial)
} else {
link(value)
}
}

private def foreachTrackedRef(f: (FiberRef[_], Any => Unit) => Unit): Unit =
trackedRefs.unsafeGet.foreach { case (fiberRef, link) =>
f(fiberRef, link)
}
}

final class TrackingFiberRef[A] private[zio] (fiberRef: FiberRef[A], link: A => Unit) {

def set(value: A): IO[Nothing, Unit] =
fiberRef.set(value) <* linkM(value)

def update(f: A => A): UIO[Unit] = modify { v =>
val result = f(v)
((), result)
}

def modify[B](f: A => (B, A)): UIO[B] =
fiberRef.modify(f) <* (fiberRef.get >>= linkM)

val get: UIO[A] = fiberRef.get

def getAndSet(a: A): UIO[A] =
fiberRef.getAndSet(a) <* linkM(a)

def getAndUpdate(f: A => A): UIO[A] = modify { v =>
val result = f(v)
(v, result)
}

def getAndUpdateSome(pf: PartialFunction[A, A]): UIO[A] = modify { v =>
val result = pf.applyOrElse[A, A](v, identity)
(v, result)
}

def locally[R, E, B](value: A)(use: ZIO[R, E, B]): ZIO[R, E, B] =
for {
oldValue <- get
b <- set(value).bracket_(set(oldValue))(use)
} yield b

def modifySome[B](default: B)(pf: PartialFunction[A, (B, A)]): UIO[B] = modify { v =>
pf.applyOrElse[A, (B, A)](v, _ => (default, v))
}

def updateAndGet(f: A => A): UIO[A] = modify { v =>
val result = f(v)
(result, result)
}

def updateLocally[R, E, B](f: A => A)(use: ZIO[R, E, B]): ZIO[R, E, B] =
for {
oldValue <- get
b <- set(f(oldValue)).bracket_(set(oldValue))(use)
} yield b

def updateSomeLocally[R, E, B](pf: PartialFunction[A, A])(use: ZIO[R, E, B]): ZIO[R, E, B] =
for {
oldValue <- get
value = pf.applyOrElse[A, A](oldValue, identity)
b <- set(value).bracket_(set(oldValue))(use)
} yield b

def updateSome(pf: PartialFunction[A, A]): UIO[Unit] = modify { v =>
val result = pf.applyOrElse[A, A](v, identity)
((), result)
}

def updateSomeAndGet(pf: PartialFunction[A, A]): UIO[A] = modify { v =>
val result = pf.applyOrElse[A, A](v, identity)
(result, result)
}

private def linkM(a: A) = UIO(link(a))
}

}
8 changes: 3 additions & 5 deletions core/shared/src/main/scala/zio/internal/FiberContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,9 @@ private[zio] final class FiberContext[E, A](
} else null

Fiber._currentFiber.set(this)
if (platform.superviseOperations) {
val currentSupervisor = supervisors.peek()
if ((currentSupervisor ne null) && (currentSupervisor ne Supervisor.none))
currentSupervisor.unsafeOnResume(self)
}
val currentSupervisor = supervisors.peek()
if ((currentSupervisor ne null) && (currentSupervisor ne Supervisor.none))
currentSupervisor.unsafeOnResume(self)

while (curZio ne null) {
try {
Expand Down

0 comments on commit 22921ee

Please sign in to comment.