diff --git a/modules/core/shared/src/main/scala/Session.scala b/modules/core/shared/src/main/scala/Session.scala index dcfd5030..6f4995b6 100644 --- a/modules/core/shared/src/main/scala/Session.scala +++ b/modules/core/shared/src/main/scala/Session.scala @@ -428,9 +428,9 @@ object Session { ssl: SSL = SSL.None, parameters: Map[String, String] = Session.DefaultConnectionParameters, socketOptions: List[SocketOption] = Session.DefaultSocketOptions, - commandCache: Int = 1024, - queryCache: Int = 1024, - parseCache: Int = 1024, + commandCache: Int = 2048, + queryCache: Int = 2048, + parseCache: Int = 2048, readTimeout: Duration = Duration.Inf, redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn, ): Resource[F, Resource[F, Session[F]]] = { @@ -470,9 +470,9 @@ object Session { ssl: SSL = SSL.None, parameters: Map[String, String] = Session.DefaultConnectionParameters, socketOptions: List[SocketOption] = Session.DefaultSocketOptions, - commandCache: Int = 1024, - queryCache: Int = 1024, - parseCache: Int = 1024, + commandCache: Int = 2048, + queryCache: Int = 2048, + parseCache: Int = 2048, readTimeout: Duration = Duration.Inf, redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn, ): Resource[F, Tracer[F] => Resource[F, Session[F]]] = { @@ -508,9 +508,9 @@ object Session { strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly, ssl: SSL = SSL.None, parameters: Map[String, String] = Session.DefaultConnectionParameters, - commandCache: Int = 1024, - queryCache: Int = 1024, - parseCache: Int = 1024, + commandCache: Int = 2048, + queryCache: Int = 2048, + parseCache: Int = 2048, readTimeout: Duration = Duration.Inf, redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn, ): Resource[F, Session[F]] = @@ -532,9 +532,9 @@ object Session { strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly, ssl: SSL = SSL.None, parameters: Map[String, String] = Session.DefaultConnectionParameters, - commandCache: Int = 1024, - queryCache: Int = 1024, - parseCache: Int = 1024, + commandCache: Int = 2048, + queryCache: Int = 2048, + parseCache: Int = 2048, readTimeout: Duration = Duration.Inf, redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn, ): Tracer[F] => Resource[F, Session[F]] = diff --git a/modules/core/shared/src/main/scala/data/Cache.scala b/modules/core/shared/src/main/scala/data/Cache.scala new file mode 100644 index 00000000..990a6ffe --- /dev/null +++ b/modules/core/shared/src/main/scala/data/Cache.scala @@ -0,0 +1,96 @@ +// Copyright (c) 2018-2024 by Rob Norris and Contributors +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package skunk.data + +/** + * Immutable, least recently used cache. + * + * Entries are stored in the `entries` hash map. A numeric stamp is assigned to + * each entry and stored in the `usages` field, which provides a bidirectional + * mapping between stamp and key, sorted by stamp. The `entries` and `usages` + * fields always have the same size. + * + * Upon put and get of an entry, a new stamp is assigned and `usages` + * is updated. Stamps are assigned in ascending order and each stamp is used only once. + * Hence, the head of `usages` contains the least recently used key. + */ +sealed abstract case class Cache[K, V] private ( + max: Int, + entries: Map[K, V] +)(usages: SortedBiMap[Long, K], + stamp: Long +) { + assert(entries.size == usages.size) + + def size: Int = entries.size + + def contains(k: K): Boolean = entries.contains(k) + + /** + * Gets the value associated with the specified key. + * + * Accessing an entry makes it the most recently used entry, hence a new cache + * is returned with the target entry updated to reflect the recent access. + */ + def get(k: K): Option[(Cache[K, V], V)] = + entries.get(k) match { + case Some(v) => + val newUsages = usages + (stamp -> k) + val newCache = Cache(max, entries, newUsages, stamp + 1) + Some(newCache -> v) + case None => + None + } + + /** + * Returns a new cache with the specified entry added along with the + * entry that was evicted, if any. + * + * The evicted value is defined under two circumstances: + * - the cache already contains a different value for the specified key, + * in which case the old pairing is returned + * - the cache has reeached its max size, in which case the least recently + * used value is evicted + * + * Note: if the cache contains (k, v), calling `put(k, v)` does NOT result + * in an eviction, but calling `put(k, v2)` where `v != v2` does. + */ + def put(k: K, v: V): (Cache[K, V], Option[(K, V)]) = + if (max <= 0) { + // max is 0 so immediately evict the new entry + (this, Some((k, v))) + } else if (entries.size >= max && !contains(k)) { + // at max size already and we need to add a new key, hence we must evict + // the least recently used entry + val (lruStamp, lruKey) = usages.head + val newEntries = entries - lruKey + (k -> v) + val newUsages = usages - lruStamp + (stamp -> k) + val newCache = Cache(max, newEntries, newUsages, stamp + 1) + (newCache, Some(lruKey -> entries(lruKey))) + } else { + // not growing past max size at this point, so only need to evict if + // the new entry is replacing an existing entry with different value + val newEntries = entries + (k -> v) + val newUsages = usages + (stamp -> k) + val newCache = Cache(max, newEntries, newUsages, stamp + 1) + val evicted = entries.get(k).filter(_ != v).map(k -> _) + (newCache, evicted) + } + + def values: Iterable[V] = entries.values + + override def toString: String = + usages.entries.iterator.map { case (_, k) => s"$k -> ${entries(k)}" }.mkString("Cache(", ", ", ")") +} + +object Cache { + private def apply[K, V](max: Int, entries: Map[K, V], usages: SortedBiMap[Long, K], stamp: Long): Cache[K, V] = + new Cache(max, entries)(usages, stamp) {} + + def empty[K, V](max: Int): Cache[K, V] = + apply(max max 0, Map.empty, SortedBiMap.empty, 0L) +} + + diff --git a/modules/core/shared/src/main/scala/data/SemispaceCache.scala b/modules/core/shared/src/main/scala/data/SemispaceCache.scala deleted file mode 100644 index b8e531d2..00000000 --- a/modules/core/shared/src/main/scala/data/SemispaceCache.scala +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2018-2024 by Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package skunk.data - -import cats.syntax.all._ - -/** - * Cache based on a two-generation GC. - * Taken from https://twitter.com/pchiusano/status/1260255494519865346 - */ -sealed abstract case class SemispaceCache[K, V](gen0: Map[K, V], gen1: Map[K, V], max: Int, evicted: SemispaceCache.EvictionSet[V]) { - - assert(max >= 0) - assert(gen0.size <= max) - assert(gen1.size <= max) - - def insert(k: K, v: V): SemispaceCache[K, V] = - if (max == 0) SemispaceCache(gen0, gen1, max, evicted + v) // immediately evict - else if (gen0.size < max) SemispaceCache(gen0 + (k -> v), gen1 - k, max, evicted - v) // room in gen0, remove from gen1, done! - else SemispaceCache(Map(k -> v), gen0, max, evicted ++ gen1.values - v) // no room in gen0, slide it down - - def lookup(k: K): Option[(SemispaceCache[K, V], V)] = - gen0.get(k).tupleLeft(this) orElse // key is in gen0, done! - gen1.get(k).map(v => (insert(k, v), v)) // key is in gen1, copy to gen0 - - def containsKey(k: K): Boolean = - gen0.contains(k) || gen1.contains(k) - - def values: List[V] = - (gen0.values.toSet | gen1.values.toSet).toList - - def evictAll: SemispaceCache[K, V] = - SemispaceCache(Map.empty, Map.empty, max, evicted ++ gen0.values ++ gen1.values) - - def clearEvicted: (SemispaceCache[K, V], List[V]) = - (SemispaceCache(gen0, gen1, max, evicted.clear), evicted.toList) -} - -object SemispaceCache { - - private def apply[K, V](gen0: Map[K, V], gen1: Map[K, V], max: Int, evicted: EvictionSet[V]): SemispaceCache[K, V] = { - val r = new SemispaceCache[K, V](gen0, gen1, max, evicted) {} - val gen0Intersection: Set[V] = (gen0.values.toSet intersect evicted.toList.toSet) - val gen1Intersection: Set[V] = (gen1.values.toSet intersect evicted.toList.toSet) - assert(gen0Intersection.isEmpty, s"gen0 has overlapping values in evicted: ${gen0Intersection}") - assert(gen1Intersection.isEmpty, s"gen1 has overlapping values in evicted: ${gen1Intersection}") - r - } - - def empty[K, V](max: Int, trackEviction: Boolean): SemispaceCache[K, V] = - SemispaceCache[K, V](Map.empty, Map.empty, max max 0, if (trackEviction) EvictionSet.empty else new EvictionSet.ZeroEvictionSet) - - sealed trait EvictionSet[V] { - def +(v: V): EvictionSet[V] - def ++(vs: Iterable[V]): EvictionSet[V] - def -(v: V): EvictionSet[V] - def toList: List[V] - def clear: EvictionSet[V] - } - - private[SemispaceCache] object EvictionSet { - - class ZeroEvictionSet[V] extends EvictionSet[V] { - def +(v: V): EvictionSet[V] = this - def ++(vs: Iterable[V]): EvictionSet[V] = this - def -(v: V): EvictionSet[V] = this - def toList: List[V] = Nil - def clear: EvictionSet[V] = this - } - - class DefaultEvictionSet[V](underlying: Set[V]) extends EvictionSet[V] { - def +(v: V): EvictionSet[V] = new DefaultEvictionSet(underlying + v) - def ++(vs: Iterable[V]): EvictionSet[V] = new DefaultEvictionSet(underlying ++ vs) - def -(v: V): EvictionSet[V] = new DefaultEvictionSet(underlying - v) - def toList: List[V] = underlying.toList - def clear: EvictionSet[V] = new DefaultEvictionSet(Set.empty) - } - - def empty[V]: EvictionSet[V] = new DefaultEvictionSet(Set.empty) - } -} diff --git a/modules/core/shared/src/main/scala/data/SortedBiMap.scala b/modules/core/shared/src/main/scala/data/SortedBiMap.scala new file mode 100644 index 00000000..c4edc9b7 --- /dev/null +++ b/modules/core/shared/src/main/scala/data/SortedBiMap.scala @@ -0,0 +1,48 @@ +// Copyright (c) 2018-2024 by Rob Norris and Contributors +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package skunk.data + +import scala.collection.immutable.SortedMap +import scala.math.Ordering + +/** Immutable bi-directional map that is sorted by key. */ +sealed abstract case class SortedBiMap[K: Ordering, V](entries: SortedMap[K, V], inverse: Map[V, K]) { + assert(entries.size == inverse.size) + + def size: Int = entries.size + + def head: (K, V) = entries.head + + def get(k: K): Option[V] = entries.get(k) + + def put(k: K, v: V): SortedBiMap[K, V] = + // nb: couple important properties here: + // - SortedBiMap(k0 -> v, v -> k0).put(k1, v) == SortedBiMap(k1 -> v, v -> k1) + // - SortedBiMap(k -> v0, v0 -> k).put(k, v1) == SortedBiMap(k -> v1, v1 -> k) + SortedBiMap( + inverse.get(v).fold(entries)(entries - _) + (k -> v), + entries.get(k).fold(inverse)(inverse - _) + (v -> k)) + + def +(kv: (K, V)): SortedBiMap[K, V] = put(kv._1, kv._2) + + def -(k: K): SortedBiMap[K, V] = + get(k) match { + case Some(v) => SortedBiMap(entries - k, inverse - v) + case None => this + } + + def inverseGet(v: V): Option[K] = inverse.get(v) + + override def toString: String = + entries.iterator.map { case (k, v) => s"$k <-> $v" }.mkString("SortedBiMap(", ", ", ")") +} + +object SortedBiMap { + private def apply[K: Ordering, V](entries: SortedMap[K, V], inverse: Map[V, K]): SortedBiMap[K, V] = + new SortedBiMap[K, V](entries, inverse) {} + + def empty[K: Ordering, V]: SortedBiMap[K, V] = apply(SortedMap.empty, Map.empty) +} + diff --git a/modules/core/shared/src/main/scala/util/StatementCache.scala b/modules/core/shared/src/main/scala/util/StatementCache.scala index 7c6c4cc4..65c67d0b 100644 --- a/modules/core/shared/src/main/scala/util/StatementCache.scala +++ b/modules/core/shared/src/main/scala/util/StatementCache.scala @@ -8,7 +8,7 @@ import cats.{ Functor, ~> } import cats.syntax.all._ import skunk.Statement import cats.effect.kernel.Ref -import skunk.data.SemispaceCache +import skunk.data.Cache /** An LRU (by access) cache, keyed by statement `CacheKey`. */ sealed trait StatementCache[F[_], V] { outer => @@ -35,31 +35,42 @@ sealed trait StatementCache[F[_], V] { outer => object StatementCache { def empty[F[_]: Functor: Ref.Make, V](max: Int, trackEviction: Boolean): F[StatementCache[F, V]] = - Ref[F].of(SemispaceCache.empty[Statement.CacheKey, V](max, trackEviction)).map { ref => + // State is the cache and a set of evicted values; the evicted set only grows when trackEviction is true + Ref[F].of((Cache.empty[Statement.CacheKey, V](max), Set.empty[V])).map { ref => new StatementCache[F, V] { def get(k: Statement[_]): F[Option[V]] = - ref.modify { c => - c.lookup(k.cacheKey) match { - case Some((c使, v)) => (c使, Some(v)) - case None => (c, None) + ref.modify { case (c, evicted) => + c.get(k.cacheKey) match { + case Some((c使, v)) => (c使 -> evicted, Some(v)) + case None => (c -> evicted, None) } } def put(k: Statement[_], v: V): F[Unit] = - ref.update(_.insert(k.cacheKey, v)) + ref.update { case (c, evicted) => + val (c2, e) = c.put(k.cacheKey, v) + // Remove the value we just inserted from the evicted set and add the newly evicted value, if any + val evicted2 = e.filter(_ => trackEviction).fold(evicted - v) { case (_, ev) => evicted - v + ev } + (c2, evicted2) + } def containsKey(k: Statement[_]): F[Boolean] = - ref.get.map(_.containsKey(k.cacheKey)) + ref.get.map(_._1.contains(k.cacheKey)) def clear: F[Unit] = - ref.update(_.evictAll) + ref.update { case (c, evicted) => + val evicted2 = if (trackEviction) evicted ++ c.values else evicted + (Cache.empty[Statement.CacheKey, V](max), evicted2) + } def values: F[List[V]] = - ref.get.map(_.values) + ref.get.map(_._1.values.toList) - def clearEvicted: F[List[V]] = - ref.modify(_.clearEvicted) + def clearEvicted: F[List[V]] = + ref.modify { case (c, evicted) => + (c, Set.empty[V]) -> evicted.toList + } } } } diff --git a/modules/tests/shared/src/test/scala/PrepareCacheTest.scala b/modules/tests/shared/src/test/scala/PrepareCacheTest.scala index 84c64362..42d9e0aa 100644 --- a/modules/tests/shared/src/test/scala/PrepareCacheTest.scala +++ b/modules/tests/shared/src/test/scala/PrepareCacheTest.scala @@ -8,7 +8,7 @@ import skunk.implicits._ import skunk.codec.numeric.int8 import skunk.codec.text import skunk.codec.boolean -import cats.syntax.all.* +import cats.syntax.all._ class PrepareCacheTest extends SkunkTest { @@ -17,16 +17,8 @@ class PrepareCacheTest extends SkunkTest { private val pgStatementsCountByStatement = sql"select count(*) from pg_prepared_statements where statement = ${text.text}".query(int8) private val pgStatementsCount = sql"select count(*) from pg_prepared_statements".query(int8) private val pgStatements = sql"select statement from pg_prepared_statements order by prepare_time".query(text.text) - - pooledTest("concurrent prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 2) { p => - List.fill(4)( - p.use { s => - s.execute(pgStatementsByName)("foo").void >> s.execute(pgStatementsByStatement)("bar").void >> s.execute(pgStatementsCountByStatement)("baz").void - } - ).sequence - } - - pooledTest("prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 1) { p => + + pooledTest("prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 2) { p => p.use { s => s.execute(pgStatementsByName)("foo").void >> s.execute(pgStatementsByStatement)("bar").void >> @@ -49,7 +41,7 @@ class PrepareCacheTest extends SkunkTest { } } - pooledTest("prepared statements via prepare shouldn't get evicted until they go out of scope", max = 1, parseCacheSize = 1) { p => + pooledTest("prepared statements via prepare shouldn't get evicted until they go out of scope", max = 1, parseCacheSize = 2) { p => p.use { s => // creates entry in cache s.prepare(pgStatementsByName) @@ -97,4 +89,14 @@ class PrepareCacheTest extends SkunkTest { } } } + + pooledTest("concurrent prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 4) { p => + List.fill(8)( + p.use { s => + s.execute(pgStatementsByName)("foo").void >> + s.execute(pgStatementsByStatement)("bar").void >> + s.execute(pgStatementsCountByStatement)("baz").void + } + ).sequence + } } diff --git a/modules/tests/shared/src/test/scala/data/CacheTest.scala b/modules/tests/shared/src/test/scala/data/CacheTest.scala new file mode 100644 index 00000000..72dbaaf6 --- /dev/null +++ b/modules/tests/shared/src/test/scala/data/CacheTest.scala @@ -0,0 +1,79 @@ +// Copyright (c) 2018-2024 by Rob Norris and Contributors +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package skunk.data + +import munit.ScalaCheckSuite +import org.scalacheck.Gen +import org.scalacheck.Prop.forAll + +class CacheTest extends ScalaCheckSuite { + + val genEmpty: Gen[Cache[Int, String]] = + Gen.choose(-1, 10).map(Cache.empty) + + test("insert on empty cache results in eviction") { + val cache = Cache.empty(0).put("one", 1)._1 + assertEquals(cache.values.toList, Nil) + assert(!cache.contains("one")) + } + + test("max is never negative") { + forAll(genEmpty) { c => + assert(c.max >= 0) + } + } + + test("insert should allow get") { + forAll(genEmpty) { c => + val c使 = c.put(1, "x")._1 + assertEquals(c使.get(1), if (c.max == 0) None else Some((c使, "x"))) + } + } + + test("eviction") { + forAll(genEmpty) { c => + val max = c.max + + // Load up the cache such that one more insert will cause it to overflow + val c1 = (0 until max).foldLeft(c) { case (c, n) => c.put(n, "x")._1 } + assertEquals(c1.values.size, max) + + // Overflow the cache + val (c2, evicted) = c1.put(max, "x") + assertEquals(c2.values.size, max) + assertEquals(evicted, Some(0 -> "x")) + + if (max > 2) { + // Access oldest element + val c3 = c2.get(1).get._1 + + // Insert another element and make sure oldest element is not the element evicted + val (c4, evicted1) = c3.put(max + 1, "x") + assertEquals(evicted1, Some(2 -> "x")) + } + } + } + + test("eviction 2") { + val c1 = Cache.empty(2).put("one", 1)._1 + assertEquals(c1.values.toList, List(1)) + assertEquals(c1.get("one").map(_._2), Some(1)) + + val (c2, evicted2) = c1.put("two", 2) + assert(c2.contains("one")) + assert(c2.contains("two")) + assertEquals(evicted2, None) + + val (c3, evicted3) = c2.put("one", 1) + assert(c3.contains("one")) + assert(c3.contains("two")) + assertEquals(evicted3, None) + + val (c4, evicted4) = c2.put("one", 0) + assert(c4.contains("one")) + assert(c4.contains("two")) + assertEquals(evicted4, Some("one" -> 1)) + } +} diff --git a/modules/tests/shared/src/test/scala/data/SemispaceCacheTest.scala b/modules/tests/shared/src/test/scala/data/SemispaceCacheTest.scala deleted file mode 100644 index cb13cdfc..00000000 --- a/modules/tests/shared/src/test/scala/data/SemispaceCacheTest.scala +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) 2018-2024 by Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package skunk.data - -import munit.ScalaCheckSuite -import org.scalacheck.Gen -import org.scalacheck.Prop._ - -class SemispaceCacheTest extends ScalaCheckSuite { - - val genEmpty: Gen[SemispaceCache[Int, String]] = - Gen.choose(-1, 10).map(SemispaceCache.empty(_, true)) - - test("eviction should never contain values in gen0/gen1") { - val cache = SemispaceCache.empty(2, true).insert("one", 1) - - val i1 = cache.insert("one", 1) - // Two doesn't exist; space in gen0, insert - val i2 = i1.lookup("two").map(_._1).getOrElse(i1.insert("two", 2)) - assertEquals(i2.gen0, Map("one" -> 1, "two" -> 2)) - assertEquals(i2.gen1, Map.empty[String, Int]) - assertEquals(i2.evicted.toList, Nil) - - // Three doesn't exist, hit max; slide gen0 -> gen1 and add to gen0 - val i3 = i2.lookup("three").map(_._1).getOrElse(i2.insert("three", 3)) - assertEquals(i3.gen0, Map("three" -> 3)) - assertEquals(i3.gen1, Map("one" -> 1, "two" -> 2)) - assertEquals(i3.evicted.toList, Nil) - - // One exists in gen1; pull up to gen0 and REMOVE from gen1 - val i4 = i3.lookup("one").map(_._1).getOrElse(i3.insert("one", 1)) - assertEquals(i4.gen0, Map("one" -> 1, "three" -> 3)) - assertEquals(i4.gen1, Map("two" -> 2)) - assertEquals(i4.evicted.toList, Nil) - - // Four doesn't exist; gen0 is full so push to gen1 - // insert four to gen0 and evict gen1 - val i5 = i4.lookup("four").map(_._1).getOrElse(i4.insert("four", 4)) - assertEquals(i5.gen0, Map("four" -> 4)) - assertEquals(i5.gen1, Map("one" -> 1, "three" -> 3)) - assertEquals(i5.evicted.toList, List(2)) - } - - test("insert on empty cache results in eviction") { - val cache = SemispaceCache.empty(0, true).insert("one", 1) - assertEquals(cache.values, Nil) - assert(!cache.containsKey("one")) - assertEquals(cache.clearEvicted._2, List(1)) - } - - test("insert on full cache results in eviction") { - val cache = SemispaceCache.empty(1, true).insert("one", 1) - assertEquals(cache.values, List(1)) - assertEquals(cache.lookup("one").map(_._2), Some(1)) - assertEquals(cache.clearEvicted._2, List.empty) - - // We now have two items (the cache stores up to 2*max entries) - val updated = cache.insert("two", 2) - assert(updated.containsKey("one")) // gen1 - assert(updated.containsKey("two")) // gen0 - assertEquals(updated.clearEvicted._2, List.empty) - - val third = updated.insert("one", 1) - assert(third.containsKey("one")) // gen1 - assert(third.containsKey("two")) // gen0 - assertEquals(third.clearEvicted._2, List.empty) - } - - test("max is never negative") { - forAll(genEmpty) { c => - assert(c.max >= 0) - } - } - - test("insert should allow lookup") { - forAll(genEmpty) { c => - val c使 = c.insert(1, "x") - assertEquals(c使.lookup(1), if (c.max == 0) None else Some((c使, "x"))) - } - } - - test("overflow") { - forAll(genEmpty) { c => - val max = c.max - - // Load up the cache such that one more insert will cause it to overflow - val c使 = (0 until max).foldLeft(c) { case (c, n) => c.insert(n, "x") } - assertEquals(c使.gen0.size, max) - assertEquals(c使.gen1.size, 0) - - // Overflow the cache - val c使使 = c使.insert(max, "x") - assertEquals(c使使.gen0.size, 1 min max) - assertEquals(c使使.gen1.size, max) - - } - } - - test("promotion") { - forAll(genEmpty) { c => - val max = c.max - - // Load up the cache such that it overflows by 1 - val c使 = (0 to max).foldLeft(c) { case (c, n) => c.insert(n, n.toString) } - assertEquals(c使.gen0.size, 1 min max) - assertEquals(c使.gen1.size, max) - - // Look up something that was demoted. - c使.lookup(0) match { - case None => assertEquals(max, 0) - case Some((c使使, _)) => - assertEquals(c使使.gen0.size, 2 min max) - // When we promote 0 to gen0, we remove it from gen1 - assertEquals(c使使.gen1.size, max-1 max 1) - assertEquals(c使使.evicted.toList, Nil) - } - - } - } - -} diff --git a/modules/tests/shared/src/test/scala/data/SortedBiMapTest.scala b/modules/tests/shared/src/test/scala/data/SortedBiMapTest.scala new file mode 100644 index 00000000..65714af2 --- /dev/null +++ b/modules/tests/shared/src/test/scala/data/SortedBiMapTest.scala @@ -0,0 +1,38 @@ +// Copyright (c) 2018-2024 by Rob Norris and Contributors +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package skunk.data + +import munit.ScalaCheckSuite +import org.scalacheck.Prop + +class SortedBiMapTest extends ScalaCheckSuite { + + test("put handles overwrites") { + val m = SortedBiMap.empty[Int, Int].put(1, 2) + assertEquals(m.size, 1) + assertEquals(m.get(1), Some(2)) + assertEquals(m.inverseGet(2), Some(1)) + + val m2 = m.put(3, 2) + assertEquals(m2.size, 1) + assertEquals(m2.get(3), Some(2)) + assertEquals(m2.inverseGet(2), Some(3)) + assertEquals(m2.get(1), None) + + val m3 = m2.put(3, 4) + assertEquals(m3.size, 1) + assertEquals(m3.get(3), Some(4)) + assertEquals(m3.inverseGet(4), Some(3)) + assertEquals(m3.inverseGet(2), None) + } + + test("entries are sorted") { + Prop.forAll { (s: Set[Int]) => + val m = s.foldLeft(SortedBiMap.empty[Int, String])((acc, i) => acc.put(i, i.toString)) + assertEquals(m.size, s.size) + assertEquals(m.entries.keySet.toList, s.toList.sorted) + } + } +}