From a2805041734335972f593a33ec940e14ff3ce24c Mon Sep 17 00:00:00 2001 From: opg1 Date: Wed, 28 Dec 2022 21:19:14 +0900 Subject: [PATCH] Add redis pubsub layer Fix RedisExecutor structure Add PubSub api Fix broken compile Add PushProtocolOutput suite Fix message field type as generic Fix broken output Fix PushProtocol output spec Add key property in PushProtocol Add test implementation Add PubSub integration test Fix formatting Refactor RedisPubSub Apply RedisPubSub refactoring Apply RedisPubSub refactoring to t/c Remove unused file Fix logic bugs Fix broken t/c Fix unsubscribe process Fix pubSubSpec Add request message broker in SingleNodeRedisPubSub Simplify RedisPubSub's public api Revert unrelated changes --- .../redis/benchmarks/BenchmarkRuntime.scala | 1 + example/src/main/scala/example/Main.scala | 3 +- redis/src/main/scala/zio/redis/Output.scala | 46 ++++ redis/src/main/scala/zio/redis/Redis.scala | 13 +- .../main/scala/zio/redis/RedisExecutor.scala | 4 +- .../main/scala/zio/redis/RedisPubSub.scala | 25 +++ .../scala/zio/redis/RedisPubSubCommand.scala | 19 ++ .../main/scala/zio/redis/ResultBuilder.scala | 5 + .../zio/redis/SingleNodeRedisPubSub.scala | 169 ++++++++++++++ .../main/scala/zio/redis/TestExecutor.scala | 209 +++++++++++++++++- .../src/main/scala/zio/redis/api/PubSub.scala | 89 ++++++++ .../scala/zio/redis/options/Cluster.scala | 6 +- .../main/scala/zio/redis/options/PubSub.scala | 49 ++++ redis/src/main/scala/zio/redis/package.scala | 3 +- redis/src/test/scala/zio/redis/ApiSpec.scala | 18 +- .../scala/zio/redis/ClusterExecutorSpec.scala | 1 + redis/src/test/scala/zio/redis/KeysSpec.scala | 1 + .../src/test/scala/zio/redis/OutputSpec.scala | 88 ++++++++ .../src/test/scala/zio/redis/PubSubSpec.scala | 175 +++++++++++++++ 19 files changed, 903 insertions(+), 21 deletions(-) create mode 100644 redis/src/main/scala/zio/redis/RedisPubSub.scala create mode 100644 redis/src/main/scala/zio/redis/RedisPubSubCommand.scala create mode 100644 redis/src/main/scala/zio/redis/SingleNodeRedisPubSub.scala create mode 100644 redis/src/main/scala/zio/redis/api/PubSub.scala create mode 100644 redis/src/main/scala/zio/redis/options/PubSub.scala create mode 100644 redis/src/test/scala/zio/redis/PubSubSpec.scala diff --git a/benchmarks/src/main/scala/zio/redis/benchmarks/BenchmarkRuntime.scala b/benchmarks/src/main/scala/zio/redis/benchmarks/BenchmarkRuntime.scala index 59d2b1ff7..c91c41fe9 100644 --- a/benchmarks/src/main/scala/zio/redis/benchmarks/BenchmarkRuntime.scala +++ b/benchmarks/src/main/scala/zio/redis/benchmarks/BenchmarkRuntime.scala @@ -35,6 +35,7 @@ object BenchmarkRuntime { private final val Layer = ZLayer.make[Redis]( RedisExecutor.local, + RedisPubSub.local, ZLayer.succeed[BinaryCodec](ProtobufCodec), RedisLive.layer ) diff --git a/example/src/main/scala/example/Main.scala b/example/src/main/scala/example/Main.scala index 55fb13398..5b07c4b95 100644 --- a/example/src/main/scala/example/Main.scala +++ b/example/src/main/scala/example/Main.scala @@ -21,7 +21,7 @@ import example.config.AppConfig import sttp.client3.httpclient.zio.HttpClientZioBackend import zhttp.service.Server import zio._ -import zio.redis.{RedisExecutor, RedisLive} +import zio.redis.{RedisExecutor, RedisLive, RedisPubSub} import zio.schema.codec.{BinaryCodec, ProtobufCodec} object Main extends ZIOAppDefault { @@ -33,6 +33,7 @@ object Main extends ZIOAppDefault { ContributorsCache.layer, HttpClientZioBackend.layer(), RedisExecutor.layer, + RedisPubSub.layer, RedisLive.layer, ZLayer.succeed[BinaryCodec](ProtobufCodec) ) diff --git a/redis/src/main/scala/zio/redis/Output.scala b/redis/src/main/scala/zio/redis/Output.scala index 21a29b3a5..f72db1fa1 100644 --- a/redis/src/main/scala/zio/redis/Output.scala +++ b/redis/src/main/scala/zio/redis/Output.scala @@ -827,4 +827,50 @@ object Output { case other => throw ProtocolError(s"$other isn't an array") } } + + case object PushProtocolOutput extends Output[PushProtocol] { + protected def tryDecode(respValue: RespValue)(implicit codec: BinaryCodec): PushProtocol = + respValue match { + case RespValue.NullArray => throw ProtocolError(s"Array must not be empty") + case RespValue.Array(values) => + val name = MultiStringOutput.unsafeDecode(values(0)) + val key = MultiStringOutput.unsafeDecode(values(1)) + name match { + case "subscribe" => + val num = LongOutput.unsafeDecode(values(2)) + PushProtocol.Subscribe(key, num) + case "psubscribe" => + val num = LongOutput.unsafeDecode(values(2)) + PushProtocol.PSubscribe(key, num) + case "unsubscribe" => + val num = LongOutput.unsafeDecode(values(2)) + PushProtocol.Unsubscribe(key, num) + case "punsubscribe" => + val num = LongOutput.unsafeDecode(values(2)) + PushProtocol.PUnsubscribe(key, num) + case "message" => + val message = values(2) + PushProtocol.Message(key, message) + case "pmessage" => + val channel = MultiStringOutput.unsafeDecode(values(2)) + val message = values(3) + PushProtocol.PMessage(key, channel, message) + case other => throw ProtocolError(s"$other isn't a pushed message") + } + case other => throw ProtocolError(s"$other isn't an array") + } + } + + case object NumSubResponseOutput extends Output[Chunk[NumSubResponse]] { + protected def tryDecode(respValue: RespValue)(implicit codec: BinaryCodec): Chunk[NumSubResponse] = + respValue match { + case RespValue.Array(values) => + Chunk.fromIterator(values.grouped(2).map { chunk => + val channel = MultiStringOutput.unsafeDecode(chunk(0)) + val numOfSubscription = LongOutput.unsafeDecode(chunk(1)) + NumSubResponse(channel, numOfSubscription) + }) + case other => throw ProtocolError(s"$other isn't an array") + } + } } diff --git a/redis/src/main/scala/zio/redis/Redis.scala b/redis/src/main/scala/zio/redis/Redis.scala index 282fde35d..4781ee32b 100644 --- a/redis/src/main/scala/zio/redis/Redis.scala +++ b/redis/src/main/scala/zio/redis/Redis.scala @@ -34,11 +34,18 @@ trait Redis with api.Cluster { def codec: BinaryCodec def executor: RedisExecutor + def pubSub: RedisPubSub } -final case class RedisLive(codec: BinaryCodec, executor: RedisExecutor) extends Redis +final case class RedisLive(codec: BinaryCodec, executor: RedisExecutor, pubSub: RedisPubSub) extends Redis object RedisLive { - lazy val layer: URLayer[RedisExecutor with BinaryCodec, Redis] = - ZLayer.fromFunction(RedisLive.apply _) + lazy val layer: URLayer[RedisPubSub with RedisExecutor with BinaryCodec, Redis] = + ZLayer.fromZIO( + for { + codec <- ZIO.service[BinaryCodec] + executor <- ZIO.service[RedisExecutor] + pubSub <- ZIO.service[RedisPubSub] + } yield RedisLive(codec, executor, pubSub) + ) } diff --git a/redis/src/main/scala/zio/redis/RedisExecutor.scala b/redis/src/main/scala/zio/redis/RedisExecutor.scala index 36e88794b..f4fcaba42 100644 --- a/redis/src/main/scala/zio/redis/RedisExecutor.scala +++ b/redis/src/main/scala/zio/redis/RedisExecutor.scala @@ -8,10 +8,10 @@ trait RedisExecutor { object RedisExecutor { lazy val layer: ZLayer[RedisConfig, RedisError.IOError, RedisExecutor] = - RedisConnectionLive.layer >>> SingleNodeExecutor.layer + RedisConnectionLive.layer.fresh >>> SingleNodeExecutor.layer lazy val local: ZLayer[Any, RedisError.IOError, RedisExecutor] = - RedisConnectionLive.default >>> SingleNodeExecutor.layer + RedisConnectionLive.default.fresh >>> SingleNodeExecutor.layer lazy val test: ULayer[RedisExecutor] = TestExecutor.layer diff --git a/redis/src/main/scala/zio/redis/RedisPubSub.scala b/redis/src/main/scala/zio/redis/RedisPubSub.scala new file mode 100644 index 000000000..36f9951f9 --- /dev/null +++ b/redis/src/main/scala/zio/redis/RedisPubSub.scala @@ -0,0 +1,25 @@ +package zio.redis + +import zio.schema.codec.BinaryCodec +import zio.stream._ +import zio.{ULayer, ZIO, ZLayer} + +trait RedisPubSub { + def execute(command: RedisPubSubCommand): ZStream[BinaryCodec, RedisError, PushProtocol] +} + +object RedisPubSub { + lazy val layer: ZLayer[RedisConfig with BinaryCodec, RedisError.IOError, RedisPubSub] = + RedisConnectionLive.layer.fresh >>> pubSublayer + + lazy val local: ZLayer[BinaryCodec, RedisError.IOError, RedisPubSub] = + RedisConnectionLive.default.fresh >>> pubSublayer + + lazy val test: ULayer[RedisPubSub] = + TestExecutor.layer + + private lazy val pubSublayer: ZLayer[RedisConnection with BinaryCodec, RedisError.IOError, RedisPubSub] = + ZLayer.scoped( + ZIO.service[RedisConnection].flatMap(SingleNodeRedisPubSub.create(_)) + ) +} diff --git a/redis/src/main/scala/zio/redis/RedisPubSubCommand.scala b/redis/src/main/scala/zio/redis/RedisPubSubCommand.scala new file mode 100644 index 000000000..fcfef72e6 --- /dev/null +++ b/redis/src/main/scala/zio/redis/RedisPubSubCommand.scala @@ -0,0 +1,19 @@ +package zio.redis + +import zio.ZLayer +import zio.stream.ZStream + +sealed abstract class RedisPubSubCommand + +object RedisPubSubCommand { + case class Subscribe(channel: String, channels: List[String]) extends RedisPubSubCommand + case class PSubscribe(pattern: String, patterns: List[String]) extends RedisPubSubCommand + case class Unsubscribe(channels: List[String]) extends RedisPubSubCommand + case class PUnsubscribe(patterns: List[String]) extends RedisPubSubCommand + + def run(command: RedisPubSubCommand): ZStream[Redis, RedisError, PushProtocol] = + ZStream.serviceWithStream { redis => + val codecLayer = ZLayer.succeed(redis.codec) + redis.pubSub.execute(command).provideLayer(codecLayer) + } +} diff --git a/redis/src/main/scala/zio/redis/ResultBuilder.scala b/redis/src/main/scala/zio/redis/ResultBuilder.scala index 7f89f91e0..c3d1a533c 100644 --- a/redis/src/main/scala/zio/redis/ResultBuilder.scala +++ b/redis/src/main/scala/zio/redis/ResultBuilder.scala @@ -19,6 +19,7 @@ package zio.redis import zio.IO import zio.redis.ResultBuilder.NeedsReturnType import zio.schema.Schema +import zio.stream.ZStream sealed trait ResultBuilder { final def map(f: Nothing => Any)(implicit nrt: NeedsReturnType): IO[Nothing, Nothing] = ??? @@ -46,4 +47,8 @@ object ResultBuilder { trait ResultOutputBuilder extends ResultBuilder { def returning[R: Output]: IO[RedisError, R] } + + trait ResultOutputStreamBuilder { + def returning[R: Schema]: ZStream[Redis, RedisError, R] + } } diff --git a/redis/src/main/scala/zio/redis/SingleNodeRedisPubSub.scala b/redis/src/main/scala/zio/redis/SingleNodeRedisPubSub.scala new file mode 100644 index 000000000..cfdbd5915 --- /dev/null +++ b/redis/src/main/scala/zio/redis/SingleNodeRedisPubSub.scala @@ -0,0 +1,169 @@ +package zio.redis + +import zio.redis.Input.{NonEmptyList, StringInput, Varargs} +import zio.redis.Output.PushProtocolOutput +import zio.redis.SingleNodeRedisPubSub.{Request, RequestQueueSize, True} +import zio.redis.api.PubSub +import zio.schema.codec.BinaryCodec +import zio.stream._ +import zio.{Chunk, ChunkBuilder, Hub, Promise, Queue, Ref, Schedule, UIO, ZIO} + +import scala.reflect.ClassTag + +final class SingleNodeRedisPubSub( + pubSubHubsRef: Ref[Map[SubscriptionKey, Hub[PushProtocol]]], + reqQueue: Queue[Request], + connection: RedisConnection +) extends RedisPubSub { + + def execute(command: RedisPubSubCommand): ZStream[BinaryCodec, RedisError, PushProtocol] = + command match { + case RedisPubSubCommand.Subscribe(channel, channels) => subscribe(channel, channels) + case RedisPubSubCommand.PSubscribe(pattern, patterns) => pSubscribe(pattern, patterns) + case RedisPubSubCommand.Unsubscribe(channels) => unsubscribe(channels) + case RedisPubSubCommand.PUnsubscribe(patterns) => pUnsubscribe(patterns) + } + + private def subscribe( + channel: String, + channels: List[String] + ): ZStream[BinaryCodec, RedisError, PushProtocol] = + makeSubscriptionStream(PubSub.Subscribe, SubscriptionKey.Channel(channel), channels.map(SubscriptionKey.Channel(_))) + + private def pSubscribe( + pattern: String, + patterns: List[String] + ): ZStream[BinaryCodec, RedisError, PushProtocol] = + makeSubscriptionStream( + PubSub.PSubscribe, + SubscriptionKey.Pattern(pattern), + patterns.map(SubscriptionKey.Pattern(_)) + ) + + private def unsubscribe(channels: List[String]): ZStream[BinaryCodec, RedisError, PushProtocol] = + makeUnsubscriptionStream(PubSub.Unsubscribe, channels.map(SubscriptionKey.Channel(_))) + + private def pUnsubscribe(patterns: List[String]): ZStream[BinaryCodec, RedisError, PushProtocol] = + makeUnsubscriptionStream(PubSub.PUnsubscribe, patterns.map(SubscriptionKey.Pattern(_))) + + private def makeSubscriptionStream(command: String, key: SubscriptionKey, keys: List[SubscriptionKey]) = + ZStream.unwrap[BinaryCodec, RedisError, PushProtocol]( + ZIO.serviceWithZIO[BinaryCodec] { implicit codec => + for { + promise <- Promise.make[RedisError, Unit] + chunk = StringInput.encode(command) ++ NonEmptyList(StringInput).encode((key.value, keys.map(_.value))) + stream <- makeStream(key :: keys) + _ <- reqQueue.offer(Request(chunk, promise)) + _ <- promise.await + } yield stream + } + ) + + private def makeUnsubscriptionStream[T <: SubscriptionKey: ClassTag](command: String, keys: List[T]) = + ZStream.unwrap[BinaryCodec, RedisError, PushProtocol]( + ZIO.serviceWithZIO[BinaryCodec] { implicit codec => + for { + targets <- if (keys.isEmpty) pubSubHubsRef.get.map(_.keys.collect { case t: T => t }.toList) + else ZIO.succeedNow(keys) + chunk = StringInput.encode(command) ++ Varargs(StringInput).encode(keys.map(_.value)) + promise <- Promise.make[RedisError, Unit] + stream <- makeStream(targets) + _ <- reqQueue.offer(Request(chunk, promise)) + _ <- promise.await + } yield stream + } + ) + + private def makeStream(keys: List[SubscriptionKey]): UIO[Stream[RedisError, PushProtocol]] = + for { + streams <- ZIO.foreach(keys)(getHub(_).map(ZStream.fromHub(_))) + stream = streams.fold(ZStream.empty)(_ merge _) + } yield stream + + private def getHub(key: SubscriptionKey) = { + def makeNewHub = + Hub + .unbounded[PushProtocol] + .tap(hub => pubSubHubsRef.update(_ + (key -> hub))) + + for { + hubs <- pubSubHubsRef.get + hub <- ZIO.fromOption(hubs.get(key)).orElse(makeNewHub) + } yield hub + } + + private def send = + reqQueue.takeBetween(1, RequestQueueSize).flatMap { reqs => + val buffer = ChunkBuilder.make[Byte]() + val it = reqs.iterator + + while (it.hasNext) { + val req = it.next() + buffer ++= RespValue.Array(req.command).serialize + } + + val bytes = buffer.result() + + connection + .write(bytes) + .mapError(RedisError.IOError(_)) + .tapBoth( + e => ZIO.foreachDiscard(reqs)(_.promise.fail(e)), + _ => ZIO.foreachDiscard(reqs)(_.promise.succeed(())) + ) + } + + private def receive: ZIO[BinaryCodec, RedisError, Unit] = + ZIO.serviceWithZIO[BinaryCodec] { implicit codec => + connection.read + .mapError(RedisError.IOError(_)) + .via(RespValue.decoder) + .collectSome + .mapZIO(resp => ZIO.attempt(PushProtocolOutput.unsafeDecode(resp))) + .refineToOrDie[RedisError] + .foreach(push => getHub(push.key).flatMap(_.offer(push))) + } + + private def resubscribe: ZIO[BinaryCodec, RedisError, Unit] = + ZIO.serviceWithZIO[BinaryCodec] { implicit codec => + def makeCommand(name: String, keys: Set[String]) = + RespValue.Array(StringInput.encode(name) ++ Varargs(StringInput).encode(keys)).serialize + + for { + keySet <- pubSubHubsRef.get.map(_.keySet) + (channels, patterns) = keySet.partition(_.isChannelKey) + _ <- (connection.write(makeCommand(PubSub.Subscribe, channels.map(_.value))).when(channels.nonEmpty) *> + connection.write(makeCommand(PubSub.PSubscribe, patterns.map(_.value))).when(patterns.nonEmpty)) + .mapError(RedisError.IOError(_)) + .retryWhile(True) + } yield () + } + + /** + * Opens a connection to the server and launches receive operations. All failures are retried by opening a new + * connection. Only exits by interruption or defect. + */ + val run: ZIO[BinaryCodec, RedisError, AnyVal] = + ZIO.logTrace(s"$this Executable sender and reader has been started") *> + (send.repeat[BinaryCodec, Long](Schedule.forever) race receive) + .tapError(e => ZIO.logWarning(s"Reconnecting due to error: $e") *> resubscribe) + .retryWhile(True) + .tapError(e => ZIO.logError(s"Executor exiting: $e")) +} + +object SingleNodeRedisPubSub { + final case class Request(command: Chunk[RespValue.BulkString], promise: Promise[RedisError, Unit]) + + private final val True: Any => Boolean = _ => true + + private final val RequestQueueSize = 16 + + def create(conn: RedisConnection) = + for { + hubRef <- Ref.make(Map.empty[SubscriptionKey, Hub[PushProtocol]]) + reqQueue <- Queue.bounded[Request](RequestQueueSize) + pubSub = new SingleNodeRedisPubSub(hubRef, reqQueue, conn) + _ <- pubSub.run.forkScoped + _ <- logScopeFinalizer(s"$pubSub Node PubSub is closed") + } yield pubSub +} diff --git a/redis/src/main/scala/zio/redis/TestExecutor.scala b/redis/src/main/scala/zio/redis/TestExecutor.scala index 1dee64973..38ed91fdc 100644 --- a/redis/src/main/scala/zio/redis/TestExecutor.scala +++ b/redis/src/main/scala/zio/redis/TestExecutor.scala @@ -17,10 +17,15 @@ package zio.redis import zio._ +import zio.redis.Input.StringInput +import zio.redis.Output.PushProtocolOutput import zio.redis.RedisError.ProtocolError import zio.redis.RespValue.{BulkString, bulkString} import zio.redis.TestExecutor._ +import zio.redis.api.PubSub +import zio.schema.codec.BinaryCodec import zio.stm._ +import zio.stream.ZStream import java.nio.file.{FileSystems, Paths} import java.time.Instant @@ -38,8 +43,75 @@ final class TestExecutor private ( randomPick: Int => USTM[Int], hyperLogLogs: TMap[String, Set[String]], hashes: TMap[String, Map[String, String]], - sortedSets: TMap[String, Map[String, Double]] -) extends RedisExecutor { + sortedSets: TMap[String, Map[String, Double]], + pubSubs: TMap[SubscriptionKey, THub[RespValue]] +) extends RedisExecutor + with RedisPubSub { + + def execute(command: RedisPubSubCommand): ZStream[BinaryCodec, RedisError, PushProtocol] = + command match { + case RedisPubSubCommand.Subscribe(channel, channels) => subscribe(channel, channels) + case RedisPubSubCommand.PSubscribe(pattern, patterns) => pSubscribe(pattern, patterns) + case RedisPubSubCommand.Unsubscribe(channels) => unsubscribe(channels) + case RedisPubSubCommand.PUnsubscribe(patterns) => pUnsubscribe(patterns) + } + + private def subscribe(channel: String, channels: List[String]): ZStream[BinaryCodec, RedisError, PushProtocol] = + pubSubStream(PubSub.Subscribe, (channel :: channels).map(SubscriptionKey.Channel(_)), false) + + private def unsubscribe(channels: List[String]): ZStream[BinaryCodec, RedisError, PushProtocol] = + ZStream + .unwrap( + (for { + keys <- if (channels.isEmpty) pubSubs.keys.map(_.collect { case t: SubscriptionKey.Channel => t }) + else ZSTM.succeed(channels.map(SubscriptionKey.Channel(_))) + stream = pubSubStream(PubSub.Unsubscribe, keys, true) + } yield stream).commit + ) + + private def pSubscribe(pattern: String, patterns: List[String]): ZStream[BinaryCodec, RedisError, PushProtocol] = + pubSubStream(PubSub.PSubscribe, (pattern :: patterns).map(SubscriptionKey.Pattern(_)), false) + + private def pUnsubscribe(patterns: List[String]): ZStream[BinaryCodec, RedisError, PushProtocol] = + ZStream + .unwrap( + (for { + keys <- if (patterns.isEmpty) pubSubs.keys.map(_.collect { case t: SubscriptionKey.Pattern => t }) + else ZSTM.succeed(patterns.map(SubscriptionKey.Pattern(_))) + stream = pubSubStream(PubSub.PUnsubscribe, keys, true) + } yield stream).commit + ) + private def pubSubStream(cmd: String, keys: List[SubscriptionKey], isUnSubs: Boolean) = + ZStream + .unwrap[BinaryCodec, RedisError, PushProtocol]( + ZIO + .clockWith(_.instant) + .flatMap(now => + ZSTM + .serviceWithSTM[BinaryCodec] { implicit codec => + for { + streams <- + ZSTM.foreach(keys) { key => + for { + resp <- runCommand(cmd, StringInput.encode(key.value), now) + hub <- getPubSubHub(key) + queue <- hub.subscribe + _ <- resp match { + case RespValue.ArrayValues(value) => + hub.offer(value) + case other => + ZSTM.fail(RedisError.ProtocolError(s"Invalid pubsub command response $other")) + } + _ <- pubSubs.delete(key).when(isUnSubs) + } yield ZStream + .fromTQueue(queue) + .mapZIO(resp => ZIO.attempt(PushProtocolOutput.unsafeDecode(resp)).refineToOrDie[RedisError]) + } + } yield streams.fold(ZStream.empty)(_ merge _) + } + .commit + ) + ) override def execute(command: Chunk[RespValue.BulkString]): IO[RedisError, RespValue] = for { @@ -71,7 +143,7 @@ final class TestExecutor private ( val timeout = command.tail.last.asString.toInt runBlockingCommand(name.asString, command.tail, timeout, RespValue.NullBulkString, now) - case "CLIENT" | "STRALGO" => + case "CLIENT" | "STRALGO" | "PUBSUB" => val command1 = command.tail val name1 = command1.head runCommand(name.asString + " " + name1.asString, command1.tail, now).commit @@ -104,6 +176,123 @@ final class TestExecutor private ( private[this] def runCommand(name: String, input: Chunk[RespValue.BulkString], now: Instant): USTM[RespValue] = { name match { + case api.PubSub.Publish => + for { + (channels, patterns) <- pubSubs.keys.map(_.partition(_.isChannelKey)) + keyString = input(0).asString + msg = input(1) + targetChannels = channels.filter(_.value == keyString) + messages = + targetChannels.map { ch => + ( + ch, + RespValue.array(RespValue.bulkString("message"), RespValue.bulkString(ch.value), msg) + ) + } + targetPatterns = patterns.filter { pattern => + val matcher = FileSystems.getDefault.getPathMatcher("glob:" + pattern.value) + matcher.matches(Paths.get(keyString)) + } + pMessages = + targetPatterns.map { pattern => + ( + pattern, + RespValue + .array(RespValue.bulkString("pmessage"), RespValue.bulkString(pattern.value), input(0), msg) + ) + } + _ <- ZSTM.foreachDiscard(messages ++ pMessages) { case (key, message) => + getPubSubHub(key).flatMap(_.offer(message)) + } + } yield RespValue.Integer(targetChannels.size + targetPatterns.size.toLong) + + case api.PubSub.PubSubChannels => + for { + channels <- pubSubs.keys.map(_.collect { case t: SubscriptionKey.Channel => t.value }) + pattern = input(0).asString + matcher = FileSystems.getDefault.getPathMatcher("glob:" + pattern) + matchedChannels = channels + .filter(ch => matcher.matches(Paths.get(ch))) + .map(channel => RespValue.bulkString(channel)) + } yield RespValue.Array(Chunk.fromIterable(matchedChannels)) + + case api.PubSub.PubSubNumSub => + for { + channels <- pubSubs.keys.map(_.collect { case t: SubscriptionKey.Channel => t.value }) + keys = input.map(_.asString) + targets = keys.map(key => + RespValue.bulkString(key) -> RespValue.Integer( + if (channels contains key) 1L + else 0L + ) + ) + } yield RespValue.Array(targets.foldLeft(Chunk.empty[RespValue]) { case (acc, (bulk, num)) => + acc.appended(bulk).appended(num) + }) + + case api.PubSub.PubSubNumPat => + pubSubs.keys + .map(_.collect { case t: SubscriptionKey.Pattern => t.value }) + .map(patterns => RespValue.Integer(patterns.size.toLong)) + + case api.PubSub.Unsubscribe => + for { + channels <- pubSubs.keys.map(_.collect { case t: SubscriptionKey.Channel => t }) + keys = input.map(_.asString) + subsCount = channels.size.toLong - 1 + messages = keys.zipWithIndex.map { case (key, idx) => + RespValue.array( + RespValue.bulkString("unsubscribe"), + RespValue.bulkString(key), + RespValue.Integer((subsCount - idx) max 0L) + ) + } + } yield RespValue.Array(messages) + + case api.PubSub.PUnsubscribe => + for { + patterns <- pubSubs.keys.map(_.collect { case t: SubscriptionKey.Pattern => t }) + keys = input.map(_.asString) + subsCount = patterns.size.toLong - 1 + messages = keys.zipWithIndex.map { case (key, idx) => + RespValue.array( + RespValue.bulkString("punsubscribe"), + RespValue.bulkString(key), + RespValue.Integer((subsCount - idx) max 0L) + ) + } + } yield RespValue.Array(messages) + + case api.PubSub.Subscribe => + for { + subsCount <- pubSubs.keys.map(_.size + 1L) + keys = input.map(_.asString) + messages = + keys.zipWithIndex.map { case (key, idx) => + RespValue.array( + RespValue.bulkString("subscribe"), + RespValue.bulkString(key), + RespValue.Integer(subsCount + idx) + ) + } + _ <- ZSTM.foreachDiscard(keys.map(SubscriptionKey.Channel(_)))(getPubSubHub(_)) + } yield RespValue.Array(messages) + + case api.PubSub.PSubscribe => + for { + subsCount <- pubSubs.keys.map(_.size + 1L) + keys = input.map(_.asString) + subsMessages = + keys.zipWithIndex.map { case (key, idx) => + RespValue.array( + RespValue.bulkString("psubscribe"), + RespValue.bulkString(key), + RespValue.Integer(subsCount + idx) + ) + } + _ <- ZSTM.foreachDiscard(keys.map(SubscriptionKey.Pattern(_)))(getPubSubHub(_)) + } yield RespValue.Array(subsMessages) + case api.Connection.Auth => onConnection(name, input)(RespValue.bulkString("OK")) @@ -3582,6 +3771,13 @@ final class TestExecutor private ( } } + private def getPubSubHub(key: SubscriptionKey) = + for { + hubOpt <- pubSubs.get(key) + hub <- ZSTM.fromOption(hubOpt).orElse(THub.unbounded[RespValue]) + _ <- pubSubs.putIfAbsent(key, hub) + } yield hub + private[this] def orWrongType(predicate: USTM[Boolean])( program: => USTM[RespValue] ): USTM[RespValue] = @@ -3922,7 +4118,7 @@ object TestExecutor { lazy val redisType: RedisType = KeyType.toRedisType(`type`) } - lazy val layer: ULayer[RedisExecutor] = + lazy val layer: ULayer[RedisExecutor with RedisPubSub] = ZLayer { for { seed <- ZIO.randomWith(_.nextInt) @@ -3940,6 +4136,7 @@ object TestExecutor { clientTInfo = ClientTrackingInfo(ClientTrackingFlags(clientSideCaching = false), ClientTrackingRedirect.NotEnabled) clientTrackingInfo <- TRef.make(clientTInfo).commit + pubSubs <- TMap.empty[SubscriptionKey, THub[RespValue]].commit } yield new TestExecutor( clientInfo, clientTrackingInfo, @@ -3950,8 +4147,8 @@ object TestExecutor { randomPick, hyperLogLogs, hashes, - sortedSets + sortedSets, + pubSubs ) } - } diff --git a/redis/src/main/scala/zio/redis/api/PubSub.scala b/redis/src/main/scala/zio/redis/api/PubSub.scala new file mode 100644 index 000000000..9e47fcefa --- /dev/null +++ b/redis/src/main/scala/zio/redis/api/PubSub.scala @@ -0,0 +1,89 @@ +package zio.redis.api + +import zio.redis.Input._ +import zio.redis.Output._ +import zio.redis.ResultBuilder.ResultOutputStreamBuilder +import zio.redis._ +import zio.redis.api.PubSub._ +import zio.schema.Schema +import zio.stream._ +import zio.{Chunk, ZIO} + +trait PubSub { + final def subscribeStreamBuilder(channel: String, channels: String*): ResultOutputStreamBuilder = + new ResultOutputStreamBuilder { + override def returning[R: Schema]: ZStream[Redis, RedisError, R] = + ZStream.serviceWithStream[Redis] { redis => + RedisPubSubCommand + .run(RedisPubSubCommand.Subscribe(channel, channels.toList)) + .collect { case t: PushProtocol.Message => t.message } + .mapZIO(resp => + ZIO + .attempt(ArbitraryOutput[R]().unsafeDecode(resp)(redis.codec)) + .refineToOrDie[RedisError] + ) + } + } + + final def subscribe(channel: String, channels: String*): ZStream[Redis, RedisError, PushProtocol] = + RedisPubSubCommand.run(RedisPubSubCommand.Subscribe(channel, channels.toList)) + + final def unsubscribe(channels: String*): ZStream[Redis, RedisError, PushProtocol.Unsubscribe] = + RedisPubSubCommand.run(RedisPubSubCommand.Unsubscribe(channels.toList)).collect { + case t: PushProtocol.Unsubscribe => t + } + + final def pSubscribe(pattern: String, patterns: String*): ZStream[Redis, RedisError, PushProtocol] = + RedisPubSubCommand.run(RedisPubSubCommand.PSubscribe(pattern, patterns.toList)) + + final def pSubscribeStreamBuilder(pattern: String, patterns: String*): ResultOutputStreamBuilder = + new ResultOutputStreamBuilder { + override def returning[R: Schema]: ZStream[Redis, RedisError, R] = + ZStream.serviceWithStream[Redis] { redis => + RedisPubSubCommand + .run(RedisPubSubCommand.PSubscribe(pattern, patterns.toList)) + .collect { case t: PushProtocol.PMessage => t.message } + .mapZIO(resp => + ZIO + .attempt(ArbitraryOutput[R]().unsafeDecode(resp)(redis.codec)) + .refineToOrDie[RedisError] + ) + } + } + + final def pUnsubscribe(patterns: String*): ZStream[Redis, RedisError, PushProtocol.PUnsubscribe] = + RedisPubSubCommand.run(RedisPubSubCommand.PUnsubscribe(patterns.toList)).collect { + case t: PushProtocol.PUnsubscribe => t + } + + final def publish[A: Schema](channel: String, message: A): ZIO[Redis, RedisError, Long] = { + val command = RedisCommand(Publish, Tuple2(StringInput, ArbitraryInput[A]()), LongOutput) + command.run((channel, message)) + } + + final def pubSubChannels(pattern: String): ZIO[Redis, RedisError, Chunk[String]] = { + val command = RedisCommand(PubSubChannels, StringInput, ChunkOutput(MultiStringOutput)) + command.run(pattern) + } + + final def pubSubNumPat: ZIO[Redis, RedisError, Long] = { + val command = RedisCommand(PubSubNumPat, NoInput, LongOutput) + command.run(()) + } + + final def pubSubNumSub(channel: String, channels: String*): ZIO[Redis, RedisError, Chunk[NumSubResponse]] = { + val command = RedisCommand(PubSubNumSub, NonEmptyList(StringInput), NumSubResponseOutput) + command.run((channel, channels.toList)) + } +} + +private[redis] object PubSub { + final val Subscribe = "SUBSCRIBE" + final val Unsubscribe = "UNSUBSCRIBE" + final val PSubscribe = "PSUBSCRIBE" + final val PUnsubscribe = "PUNSUBSCRIBE" + final val Publish = "PUBLISH" + final val PubSubChannels = "PUBSUB CHANNELS" + final val PubSubNumPat = "PUBSUB NUMPAT" + final val PubSubNumSub = "PUBSUB NUMSUB" +} diff --git a/redis/src/main/scala/zio/redis/options/Cluster.scala b/redis/src/main/scala/zio/redis/options/Cluster.scala index 5c11752c4..35bf83aa8 100644 --- a/redis/src/main/scala/zio/redis/options/Cluster.scala +++ b/redis/src/main/scala/zio/redis/options/Cluster.scala @@ -15,14 +15,14 @@ */ package zio.redis.options -import zio.redis.{RedisExecutor, RedisUri} +import zio.redis.{RedisExecutor, RedisPubSub, RedisUri} import zio.{Chunk, Scope} object Cluster { private[redis] final val SlotsAmount = 16384 - final case class ExecutorScope(executor: RedisExecutor, scope: Scope.Closeable) + final case class ExecutorScope(executor: RedisExecutor, pubSub: RedisPubSub, scope: Scope.Closeable) final case class ClusterConnection( partitions: Chunk[Partition], @@ -31,6 +31,8 @@ object Cluster { ) { def executor(slot: Slot): Option[RedisExecutor] = executors.get(slots(slot)).map(_.executor) + def pubSub(slot: Slot): Option[RedisPubSub] = executors.get(slots(slot)).map(_.pubSub) + def addExecutor(uri: RedisUri, es: ExecutorScope): ClusterConnection = copy(executors = executors + (uri -> es)) } diff --git a/redis/src/main/scala/zio/redis/options/PubSub.scala b/redis/src/main/scala/zio/redis/options/PubSub.scala new file mode 100644 index 000000000..f8915607c --- /dev/null +++ b/redis/src/main/scala/zio/redis/options/PubSub.scala @@ -0,0 +1,49 @@ +package zio.redis.options + +import zio.redis.RespValue + +trait PubSub { + + sealed trait SubscriptionKey { self => + def value: String + + def isChannelKey = self match { + case _: SubscriptionKey.Channel => true + case _: SubscriptionKey.Pattern => false + } + + def isPatternKey = !isChannelKey + } + + object SubscriptionKey { + case class Channel(value: String) extends SubscriptionKey + case class Pattern(value: String) extends SubscriptionKey + } + + case class NumSubResponse(channel: String, subscriberCount: Long) + + sealed trait PushProtocol { + def key: SubscriptionKey + } + + object PushProtocol { + case class Subscribe(channel: String, numOfSubscription: Long) extends PushProtocol { + def key: SubscriptionKey = SubscriptionKey.Channel(channel) + } + case class PSubscribe(pattern: String, numOfSubscription: Long) extends PushProtocol { + def key: SubscriptionKey = SubscriptionKey.Pattern(pattern) + } + case class Unsubscribe(channel: String, numOfSubscription: Long) extends PushProtocol { + def key: SubscriptionKey = SubscriptionKey.Channel(channel) + } + case class PUnsubscribe(pattern: String, numOfSubscription: Long) extends PushProtocol { + def key: SubscriptionKey = SubscriptionKey.Pattern(pattern) + } + case class Message(channel: String, message: RespValue) extends PushProtocol { + def key: SubscriptionKey = SubscriptionKey.Channel(channel) + } + case class PMessage(pattern: String, channel: String, message: RespValue) extends PushProtocol { + def key: SubscriptionKey = SubscriptionKey.Pattern(pattern) + } + } +} diff --git a/redis/src/main/scala/zio/redis/package.scala b/redis/src/main/scala/zio/redis/package.scala index 948ec1623..6dd5ef9f6 100644 --- a/redis/src/main/scala/zio/redis/package.scala +++ b/redis/src/main/scala/zio/redis/package.scala @@ -25,7 +25,8 @@ package object redis with options.Strings with options.Lists with options.Streams - with options.Scripting { + with options.Scripting + with options.PubSub { type Id[+A] = A diff --git a/redis/src/test/scala/zio/redis/ApiSpec.scala b/redis/src/test/scala/zio/redis/ApiSpec.scala index 58fe1d456..4a445edf7 100644 --- a/redis/src/test/scala/zio/redis/ApiSpec.scala +++ b/redis/src/test/scala/zio/redis/ApiSpec.scala @@ -16,7 +16,8 @@ object ApiSpec with HashSpec with StreamsSpec with ScriptingSpec - with ClusterSpec { + with ClusterSpec + with PubSubSpec { def spec: Spec[TestEnvironment, Any] = suite("Redis commands")( @@ -38,10 +39,12 @@ object ApiSpec hyperLogLogSuite, hashSuite, streamsSuite, - scriptingSpec + scriptingSpec, + pubSubSuite ) - val Layer: Layer[Any, Redis] = ZLayer.make[Redis](RedisExecutor.local.orDie, ZLayer.succeed(codec), RedisLive.layer) + val Layer: Layer[Any, Redis] = + ZLayer.make[Redis](RedisExecutor.local.orDie, RedisPubSub.local.orDie, ZLayer.succeed(codec), RedisLive.layer) } private object Test { @@ -55,10 +58,12 @@ object ApiSpec hashSuite, sortedSetsSuite, geoSuite, - stringsSuite + stringsSuite, + pubSubSuite ).filterAnnotations(TestAnnotation.tagged)(t => !t.contains(BaseSpec.TestExecutorUnsupported)).get - val Layer: Layer[Any, Redis] = ZLayer.make[Redis](RedisExecutor.test, ZLayer.succeed(codec), RedisLive.layer) + val Layer: Layer[Any, Redis] = + ZLayer.make[Redis](RedisExecutor.test, RedisPubSub.test, ZLayer.succeed(codec), RedisLive.layer) } private object Cluster { @@ -75,7 +80,8 @@ object ApiSpec geoSuite, streamsSuite @@ clusterExecutorUnsupported, scriptingSpec @@ clusterExecutorUnsupported, - clusterSpec + clusterSpec, + pubSubSuite ).filterAnnotations(TestAnnotation.tagged)(t => !t.contains(BaseSpec.ClusterExecutorUnsupported)).get val Layer: Layer[Any, Redis] = diff --git a/redis/src/test/scala/zio/redis/ClusterExecutorSpec.scala b/redis/src/test/scala/zio/redis/ClusterExecutorSpec.scala index ab5b9a79f..502584458 100644 --- a/redis/src/test/scala/zio/redis/ClusterExecutorSpec.scala +++ b/redis/src/test/scala/zio/redis/ClusterExecutorSpec.scala @@ -69,6 +69,7 @@ object ClusterExecutorSpec extends BaseSpec { ZLayer.make[Redis]( ZLayer.succeed(RedisConfig(uri.host, uri.port)), RedisExecutor.layer, + RedisPubSub.layer, ZLayer.succeed(codec), RedisLive.layer ) diff --git a/redis/src/test/scala/zio/redis/KeysSpec.scala b/redis/src/test/scala/zio/redis/KeysSpec.scala index 5e3105082..c5ce10c2a 100644 --- a/redis/src/test/scala/zio/redis/KeysSpec.scala +++ b/redis/src/test/scala/zio/redis/KeysSpec.scala @@ -464,6 +464,7 @@ object KeysSpec { ZLayer.succeed(RedisConfig("localhost", 6380)), RedisConnectionLive.layer, SingleNodeExecutor.layer, + RedisPubSub.layer, ZLayer.succeed[BinaryCodec](ProtobufCodec), RedisLive.layer ) diff --git a/redis/src/test/scala/zio/redis/OutputSpec.scala b/redis/src/test/scala/zio/redis/OutputSpec.scala index e305130cb..eceae9120 100644 --- a/redis/src/test/scala/zio/redis/OutputSpec.scala +++ b/redis/src/test/scala/zio/redis/OutputSpec.scala @@ -1035,6 +1035,94 @@ object OutputSpec extends BaseSpec { res <- ZIO.attempt(ClientTrackingRedirectOutput.unsafeDecode(resp)).either } yield assert(res)(isLeft(isSubtype[ProtocolError](anything))) } + ), + suite("PushProtocol")( + test("subscribe") { + val channel = "foo" + val numOfSubscription = 1L + val input = + RespValue.array( + RespValue.bulkString("subscribe"), + RespValue.bulkString(channel), + RespValue.Integer(numOfSubscription) + ) + val expected = PushProtocol.Subscribe(channel, numOfSubscription) + assertZIO(ZIO.attempt(PushProtocolOutput.unsafeDecode(input)))( + equalTo(expected) + ) + }, + test("psubscribe") { + val pattern = "f*" + val numOfSubscription = 1L + val input = + RespValue.array( + RespValue.bulkString("psubscribe"), + RespValue.bulkString(pattern), + RespValue.Integer(numOfSubscription) + ) + val expected = PushProtocol.PSubscribe(pattern, numOfSubscription) + assertZIO(ZIO.attempt(PushProtocolOutput.unsafeDecode(input)))( + equalTo(expected) + ) + }, + test("unsubscribe") { + val channel = "foo" + val numOfSubscription = 1L + val input = + RespValue.array( + RespValue.bulkString("unsubscribe"), + RespValue.bulkString(channel), + RespValue.Integer(numOfSubscription) + ) + val expected = PushProtocol.Unsubscribe(channel, numOfSubscription) + assertZIO(ZIO.attempt(PushProtocolOutput.unsafeDecode(input)))( + equalTo(expected) + ) + }, + test("punsubscribe") { + val pattern = "f*" + val numOfSubscription = 1L + val input = + RespValue.array( + RespValue.bulkString("punsubscribe"), + RespValue.bulkString(pattern), + RespValue.Integer(numOfSubscription) + ) + val expected = PushProtocol.PUnsubscribe(pattern, numOfSubscription) + assertZIO(ZIO.attempt(PushProtocolOutput.unsafeDecode(input)))( + equalTo(expected) + ) + }, + test("message") { + val channel = "foo" + val message = RespValue.bulkString("bar") + val input = + RespValue.array( + RespValue.bulkString("message"), + RespValue.bulkString(channel), + message + ) + val expected = PushProtocol.Message(channel, message) + assertZIO(ZIO.attempt(PushProtocolOutput.unsafeDecode(input)))( + equalTo(expected) + ) + }, + test("pmessage") { + val pattern = "f*" + val channel = "foo" + val message = RespValue.bulkString("bar") + val input = + RespValue.array( + RespValue.bulkString("pmessage"), + RespValue.bulkString(pattern), + RespValue.bulkString(channel), + message + ) + val expected = PushProtocol.PMessage(pattern, channel, message) + assertZIO(ZIO.attempt(PushProtocolOutput.unsafeDecode(input)))( + equalTo(expected) + ) + } ) ) diff --git a/redis/src/test/scala/zio/redis/PubSubSpec.scala b/redis/src/test/scala/zio/redis/PubSubSpec.scala new file mode 100644 index 000000000..f1dcb38d0 --- /dev/null +++ b/redis/src/test/scala/zio/redis/PubSubSpec.scala @@ -0,0 +1,175 @@ +package zio.redis + +import zio.test.Assertion._ +import zio.test._ +import zio.{Chunk, ZIO} + +import scala.util.Random + +trait PubSubSpec extends BaseSpec { + def pubSubSuite: Spec[Redis, RedisError] = + suite("pubSubs")( + suite("subscribe")( + test("subscribe response") { + for { + channel <- generateRandomString() + res <- subscribe(channel).runHead + } yield assertTrue(res.get.key == SubscriptionKey.Channel(channel)) + }, + test("message response") { + for { + channel <- generateRandomString() + message = "bar" + stream <- subscribeStreamBuilder(channel) + .returning[String] + .runHead + .fork + _ <- pubSubChannels(channel) + .repeatUntil(_ contains channel) + _ <- publish(channel, message) + res <- stream.join + } yield assertTrue(res.get == message) + }, + test("multiple subscribe") { + val numOfPublish = 20 + for { + prefix <- generateRandomString(5) + channel1 <- generateRandomString(prefix) + channel2 <- generateRandomString(prefix) + pattern = prefix + '*' + message <- generateRandomString(5) + stream1 <- subscribe(channel1) + .runFoldWhile(Chunk.empty[PushProtocol])( + _.forall(_.isInstanceOf[PushProtocol.Unsubscribe] == false) + )(_ appended _) + .fork + stream2 <- subscribe(channel2) + .runFoldWhile(Chunk.empty[PushProtocol])( + _.forall(_.isInstanceOf[PushProtocol.Unsubscribe] == false) + )(_ appended _) + .fork + _ <- pubSubChannels(pattern) + .repeatUntil(channels => channels.size >= 2) + ch1SubsCount <- publish(channel1, message).replicateZIO(numOfPublish).map(_.head) + ch2SubsCount <- publish(channel2, message).replicateZIO(numOfPublish).map(_.head) + _ <- unsubscribe().runDrain.fork + res1 <- stream1.join + res2 <- stream2.join + } yield assertTrue(ch1SubsCount == 1L) && + assertTrue(ch2SubsCount == 1L) && + assertTrue(res1.size == numOfPublish + 2) && + assertTrue(res2.size == numOfPublish + 2) + }, + test("psubscribe response") { + for { + pattern <- generateRandomString() + res <- pSubscribe(pattern).runHead + } yield assertTrue(res.get.key.value == pattern) + }, + test("pmessage response") { + for { + prefix <- generateRandomString(5) + pattern = prefix + '*' + channel <- generateRandomString(prefix) + message <- generateRandomString(prefix) + stream <- pSubscribeStreamBuilder(pattern) + .returning[String] + .runHead + .fork + _ <- pubSubNumPat.repeatUntil(_ > 0) + _ <- publish(channel, message) + res <- stream.join + } yield assertTrue(res.get == message) + } + ), + suite("publish")(test("publish long type message") { + val message = 1L + assertZIO( + for { + channel <- generateRandomString() + stream <- subscribeStreamBuilder(channel) + .returning[Long] + .runFoldWhile(0L)(_ < 10L) { case (sum, message) => + sum + message + } + .fork + _ <- pubSubChannels(channel).repeatUntil(_ contains channel) + _ <- ZIO.replicateZIO(10)(publish(channel, message)) + res <- stream.join + } yield res + )(equalTo(10L)) + }), + suite("unsubscribe")( + test("don't receive message type after unsubscribe") { + val numOfPublished = 5 + for { + prefix <- generateRandomString(5) + pattern = prefix + '*' + channel <- generateRandomString(prefix) + message <- generateRandomString() + stream <- subscribe(channel) + .runFoldWhile(Chunk.empty[PushProtocol])(_.size < 2)(_ appended _) + .fork + _ <- pubSubChannels(pattern) + .repeatUntil(_ contains channel) + _ <- unsubscribe(channel).runHead + receiverCount <- publish(channel, message).replicateZIO(numOfPublished).map(_.head) + res <- stream.join + } yield assertTrue( + res.size == 2 + ) && assertTrue(receiverCount == 0L) + }, + test("unsubscribe response") { + for { + channel <- generateRandomString() + res <- unsubscribe(channel).runHead + } yield assertTrue(res.get.key.value == channel) + }, + test("punsubscribe response") { + for { + pattern <- generateRandomString() + res <- pUnsubscribe(pattern).runHead + } yield assertTrue(res.get.key.value == pattern) + }, + test("unsubscribe with empty param") { + for { + prefix <- generateRandomString(5) + pattern = prefix + '*' + channel1 <- generateRandomString(prefix) + channel2 <- generateRandomString(prefix) + stream1 <- + subscribe(channel1) + .runFoldWhile(Chunk.empty[PushProtocol])(_.forall(_.isInstanceOf[PushProtocol.Unsubscribe] == false))( + _ appended _ + ) + .fork + stream2 <- + subscribe(channel2) + .runFoldWhile(Chunk.empty[PushProtocol])(_.forall(_.isInstanceOf[PushProtocol.Unsubscribe] == false))( + _ appended _ + ) + .fork + _ <- pubSubChannels(pattern) + .repeatUntil(_.size >= 2) + _ <- unsubscribe().runDrain.fork + unsubscribeMessages <- stream1.join zip stream2.join + (result1, result2) = unsubscribeMessages + numSubResponses <- pubSubNumSub(channel1, channel2) + } yield assertTrue( + result1.size == 2 && result2.size == 2 + ) && assertTrue( + numSubResponses == Chunk( + NumSubResponse(channel1, 0L), + NumSubResponse(channel2, 0L) + ) + ) + } + ) + ) + + private def generateRandomString(prefix: String = "") = + ZIO.succeed(Random.alphanumeric.take(15).mkString).map(prefix + _.substring((prefix.length - 1) max 0)) + + private def generateRandomString(len: Int) = + ZIO.succeed(Random.alphanumeric.take(len).mkString) +}