Skip to content

Commit

Permalink
Refactor SeenCache method names
Browse files Browse the repository at this point in the history
  • Loading branch information
Nashatyrev committed Aug 16, 2023
1 parent b2aebac commit b512f27
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 33 deletions.
2 changes: 1 addition & 1 deletion libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ abstract class AbstractRouter(
val validationResult = seenMessages[subscribedMessage]
if (validationResult != null) {
// Message has been seen
notifySeenMessage(peer, seenMessages.getSeenMessage(subscribedMessage), validationResult)
notifySeenMessage(peer, seenMessages.getSeenMessageCached(subscribedMessage), validationResult)
false
} else {
// Message is unseen
Expand Down
23 changes: 12 additions & 11 deletions libp2p/src/main/kotlin/io/libp2p/pubsub/SeenCache.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@ import java.util.LinkedList
interface SeenCache<TValue> {
val size: Int

fun put(msg: PubsubMessage, value: TValue)
fun get(msg: PubsubMessage): TValue?
fun isSeen(msg: PubsubMessage): Boolean
fun isSeen(messageId: MessageId): Boolean
fun remove(messageId: MessageId)

/**
* Returns the 'matching' message if it exists in the cache or falls back to returning the argument if not
* The returned instance may have some data prepared and cached (e.g. `messageId`) which may
* have positive performance effect
*/
fun getSeenMessage(msg: PubsubMessage): PubsubMessage
fun getValue(msg: PubsubMessage): TValue?
fun isSeen(msg: PubsubMessage): Boolean
fun isSeen(messageId: MessageId): Boolean
fun put(msg: PubsubMessage, value: TValue)
fun remove(messageId: MessageId)
fun getSeenMessageCached(msg: PubsubMessage): PubsubMessage
}

operator fun <TValue> SeenCache<TValue>.get(msg: PubsubMessage) = getValue(msg)
operator fun <TValue> SeenCache<TValue>.get(msg: PubsubMessage) = get(msg)
operator fun <TValue> SeenCache<TValue>.set(msg: PubsubMessage, value: TValue) = put(msg, value)
operator fun <TValue> SeenCache<TValue>.contains(msg: PubsubMessage) = isSeen(msg)
operator fun <TValue> SeenCache<TValue>.minusAssign(messageId: MessageId) = remove(messageId)
Expand All @@ -42,8 +43,8 @@ class SimpleSeenCache<TValue> : SeenCache<TValue> {
override val size: Int
get() = map.size

override fun getSeenMessage(msg: PubsubMessage) = msg
override fun getValue(msg: PubsubMessage) = map[msg.messageId]
override fun getSeenMessageCached(msg: PubsubMessage) = msg
override fun get(msg: PubsubMessage) = map[msg.messageId]
override fun isSeen(msg: PubsubMessage) = msg.messageId in map
override fun isSeen(messageId: MessageId) = messageId in map

Expand Down Expand Up @@ -110,7 +111,7 @@ class FastIdSeenCache<TValue>(private val fastIdFunction: (PubsubMessage) -> Any
override val size: Int
get() = slowIdMap.size

override fun getSeenMessage(msg: PubsubMessage): PubsubMessage {
override fun getSeenMessageCached(msg: PubsubMessage): PubsubMessage {
val slowId = fastIdMap[fastIdFunction(msg)]
return when {
slowId == null -> msg
Expand All @@ -119,7 +120,7 @@ class FastIdSeenCache<TValue>(private val fastIdFunction: (PubsubMessage) -> Any
}
}

override fun getValue(msg: PubsubMessage): TValue? {
override fun get(msg: PubsubMessage): TValue? {
val slowId = fastIdMap[fastIdFunction(msg)] ?: msg.messageId
return slowIdMap[slowId]
}
Expand Down
42 changes: 21 additions & 21 deletions libp2p/src/test/kotlin/io/libp2p/pubsub/SeenCacheTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fun createPubsubMessage(number: Int, fastId: Int) =
fun assertContainsEntry(cache: SeenCache<String>, fakeMsg: Int) {
assertThat(cache.isSeen(createPubsubMessage(fakeMsg))).isTrue()
assertThat(cache.isSeen(createPubsubMessage(fakeMsg).messageId)).isTrue()
assertThat(cache.getValue(createPubsubMessage(fakeMsg))).isEqualTo(fakeMsg.toString())
assertThat(cache.get(createPubsubMessage(fakeMsg))).isEqualTo(fakeMsg.toString())
}
fun assertContainsEntries(cache: SeenCache<String>, vararg fakeMsgs: Int) {
fakeMsgs.forEach {
Expand All @@ -39,7 +39,7 @@ fun assertContainsEntries(cache: SeenCache<String>, vararg fakeMsgs: Int) {
fun assertDoesntContainEntry(cache: SeenCache<String>, fakeMsg: Int) {
assertThat(cache.isSeen(createPubsubMessage(fakeMsg))).isFalse()
assertThat(cache.isSeen(createPubsubMessage(fakeMsg).messageId)).isFalse()
assertThat(cache.getValue(createPubsubMessage(fakeMsg))).isNull()
assertThat(cache.get(createPubsubMessage(fakeMsg))).isNull()
}
fun assertDoesntContainEntries(cache: SeenCache<String>, vararg fakeMsgs: Int) {
fakeMsgs.forEach {
Expand Down Expand Up @@ -81,41 +81,41 @@ fun genericSanityTest(cache: SeenCache<String>) {
assertThat(cache.size).isEqualTo(1)
assertThat(cache.isSeen(createPubsubMessage(1))).isTrue()
assertThat(cache.isSeen(createPubsubMessage(2))).isFalse()
assertThat(cache.getValue(createPubsubMessage(1))).isEqualTo("1")
assertThat(cache.getValue(createPubsubMessage(2))).isNull()
assertThat(cache.getSeenMessage(createPubsubMessage(1))).isEqualTo(createPubsubMessage(1))
assertThat(cache.get(createPubsubMessage(1))).isEqualTo("1")
assertThat(cache.get(createPubsubMessage(2))).isNull()
assertThat(cache.getSeenMessageCached(createPubsubMessage(1))).isEqualTo(createPubsubMessage(1))

cache[createPubsubMessage(1)] = "1-1"

assertThat(cache.size).isEqualTo(1)
assertThat(cache.isSeen(createPubsubMessage(1))).isTrue()
assertThat(cache.isSeen(createPubsubMessage(2))).isFalse()
assertThat(cache.getValue(createPubsubMessage(1))).isEqualTo("1-1")
assertThat(cache.getValue(createPubsubMessage(2))).isNull()
assertThat(cache.get(createPubsubMessage(1))).isEqualTo("1-1")
assertThat(cache.get(createPubsubMessage(2))).isNull()

cache[createPubsubMessage(2)] = "2"

assertThat(cache.size).isEqualTo(2)
assertThat(cache.isSeen(createPubsubMessage(1))).isTrue()
assertThat(cache.isSeen(createPubsubMessage(2))).isTrue()
assertThat(cache.getValue(createPubsubMessage(1))).isEqualTo("1-1")
assertThat(cache.getValue(createPubsubMessage(2))).isEqualTo("2")
assertThat(cache.get(createPubsubMessage(1))).isEqualTo("1-1")
assertThat(cache.get(createPubsubMessage(2))).isEqualTo("2")

cache -= createPubsubMessage(1)

assertThat(cache.size).isEqualTo(1)
assertThat(cache.isSeen(createPubsubMessage(1))).isFalse()
assertThat(cache.isSeen(createPubsubMessage(2))).isTrue()
assertThat(cache.getValue(createPubsubMessage(1))).isNull()
assertThat(cache.getValue(createPubsubMessage(2))).isEqualTo("2")
assertThat(cache.get(createPubsubMessage(1))).isNull()
assertThat(cache.get(createPubsubMessage(2))).isEqualTo("2")

cache -= createPubsubMessage(2)

assertThat(cache.size).isEqualTo(0)
assertThat(cache.isSeen(createPubsubMessage(1))).isFalse()
assertThat(cache.isSeen(createPubsubMessage(2))).isFalse()
assertThat(cache.getValue(createPubsubMessage(1))).isNull()
assertThat(cache.getValue(createPubsubMessage(2))).isNull()
assertThat(cache.get(createPubsubMessage(1))).isNull()
assertThat(cache.get(createPubsubMessage(2))).isNull()
}

class LRUSeenCacheTest {
Expand Down Expand Up @@ -365,14 +365,14 @@ class FastIdSeenCacheTest {
assertThat(m1_1.canonicalId).isNotNull()
assertThat(m1_2.canonicalId).isNull()

val m1_3 = cache.getSeenMessage(m1_2)
val m1_3 = cache.getSeenMessageCached(m1_2)
assertThat(m1_3.messageId).isEqualTo(m1_1.canonicalId)
assertThat(m1_2.canonicalId).isNull()

assertThat(m1_2 in cache).isTrue()
assertThat(m1_2.canonicalId).isNull()

assertThat(cache.getValue(m1_2)).isEqualTo("1-1")
assertThat(cache.get(m1_2)).isEqualTo("1-1")
assertThat(m1_2.canonicalId).isNull()
}

Expand All @@ -385,14 +385,14 @@ class FastIdSeenCacheTest {
cache[m1_1] = "1-1"
assertThat(m1_1 in cache).isTrue()
assertThat(m1_2 in cache).isTrue()
assertThat(cache.getValue(m1_1)).isEqualTo("1-1")
assertThat(cache.getValue(m1_2)).isEqualTo("1-1")
assertThat(cache.get(m1_1)).isEqualTo("1-1")
assertThat(cache.get(m1_2)).isEqualTo("1-1")

cache[m1_2] = "1-2"
assertThat(m1_1 in cache).isTrue()
assertThat(m1_2 in cache).isTrue()
assertThat(cache.getValue(m1_1)).isEqualTo("1-2")
assertThat(cache.getValue(m1_2)).isEqualTo("1-2")
assertThat(cache.get(m1_1)).isEqualTo("1-2")
assertThat(cache.get(m1_2)).isEqualTo("1-2")

val m1_1_1 = createPubsubMessage(1, 1)
val m1_2_1 = createPubsubMessage(1, 2)
Expand All @@ -404,8 +404,8 @@ class FastIdSeenCacheTest {
cache -= m1_1
assertThat(m1_1 in cache).isFalse()
assertThat(m1_2 in cache).isFalse()
assertThat(cache.getValue(m1_1)).isNull()
assertThat(cache.getValue(m1_2)).isNull()
assertThat(cache.get(m1_1)).isNull()
assertThat(cache.get(m1_2)).isNull()

assertThat(cache.fastIdMap.isEmpty()).isTrue()
assertThat(cache.slowIdMap).isEmpty()
Expand Down

0 comments on commit b512f27

Please sign in to comment.