Skip to content

Commit 2493813

Browse files
authored
perf: cache message count queries (#880)
Counting matching messages slows down queries /messages by several seconds because the messages table is too large. This PR mitigates the problem by adding a cache in front of DB queries in order to reduce the average load on CCNs while keeping exact message counts. The default cache TTL is 5 minutes but can be configured through the new perf.message_count_cache_ttl config value.
1 parent eff0504 commit 2493813

File tree

15 files changed

+296
-248
lines changed

15 files changed

+296
-248
lines changed

src/aleph/api_entrypoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ async def configure_aiohttp_app(
4242
session_factory = make_session_factory(engine)
4343

4444
node_cache = NodeCache(
45-
redis_host=config.redis.host.value, redis_port=config.redis.port.value
45+
redis_host=config.redis.host.value,
46+
redis_port=config.redis.port.value,
47+
message_count_cache_ttl=config.perf.message_count_cache_ttl.value,
4648
)
4749
# TODO: find a way to close the node cache when exiting the API process, not closing it causes
4850
# a warning.

src/aleph/commands.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def run_db_migrations(config: Config):
6666

6767
async def init_node_cache(config: Config) -> NodeCache:
6868
node_cache = NodeCache(
69-
redis_host=config.redis.host.value, redis_port=config.redis.port.value
69+
redis_host=config.redis.host.value,
70+
redis_port=config.redis.port.value,
71+
message_count_cache_ttl=config.perf.message_count_cache_ttl.value,
7072
)
7173
return node_cache
7274

src/aleph/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def get_defaults():
248248
# Sentry trace sample rate.
249249
"traces_sample_rate": None,
250250
},
251+
"perf": {
252+
# TTL of the cache in front of DB count queries on the messages table.
253+
"message_count_cache_ttl": 300,
254+
},
251255
}
252256

253257

src/aleph/jobs/fetch_pending_messages.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ async def fetch_messages_task(config: Config):
171171

172172
async with (
173173
NodeCache(
174-
redis_host=config.redis.host.value, redis_port=config.redis.port.value
174+
redis_host=config.redis.host.value,
175+
redis_port=config.redis.port.value,
176+
message_count_cache_ttl=config.perf.message_count_cache_ttl.value,
175177
) as node_cache,
176178
IpfsService.new(config) as ipfs_service,
177179
):

src/aleph/jobs/process_pending_messages.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ async def fetch_and_process_messages_task(config: Config):
159159

160160
async with (
161161
NodeCache(
162-
redis_host=config.redis.host.value, redis_port=config.redis.port.value
162+
redis_host=config.redis.host.value,
163+
redis_port=config.redis.port.value,
164+
message_count_cache_ttl=config.perf.message_count_cache_ttl.value,
163165
) as node_cache,
164166
IpfsService.new(config) as ipfs_service,
165167
):

src/aleph/jobs/process_pending_txs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ async def handle_txs_task(config: Config):
133133

134134
async with (
135135
NodeCache(
136-
redis_host=config.redis.host.value, redis_port=config.redis.port.value
136+
redis_host=config.redis.host.value,
137+
redis_port=config.redis.port.value,
138+
message_count_cache_ttl=config.perf.message_count_cache_ttl.value,
137139
) as node_cache,
138140
IpfsService.new(config) as ipfs_service,
139141
):

src/aleph/schemas/api/accounts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from aleph_message.models import Chain
66
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, field_validator
77

8+
from aleph.schemas.messages_query_params import DEFAULT_PAGE, LIST_FIELD_SEPARATOR
89
from aleph.types.files import FileType
910
from aleph.types.sort_order import SortOrder
10-
from aleph.web.controllers.utils import DEFAULT_PAGE, LIST_FIELD_SEPARATOR
1111

1212

1313
class GetAccountQueryParams(BaseModel):
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
from typing import List, Optional
2+
3+
from aleph_message.models import Chain, ItemHash, MessageType
4+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
5+
6+
from aleph.types.message_status import MessageStatus
7+
from aleph.types.sort_order import SortBy, SortOrder
8+
9+
DEFAULT_WS_HISTORY = 10
10+
DEFAULT_MESSAGES_PER_PAGE = 20
11+
DEFAULT_PAGE = 1
12+
LIST_FIELD_SEPARATOR = ","
13+
14+
15+
class BaseMessageQueryParams(BaseModel):
16+
sort_by: SortBy = Field(
17+
default=SortBy.TIME,
18+
alias="sortBy",
19+
description="Key to use to sort the messages. "
20+
"'time' uses the message time field. "
21+
"'tx-time' uses the first on-chain confirmation time.",
22+
)
23+
sort_order: SortOrder = Field(
24+
default=SortOrder.DESCENDING,
25+
alias="sortOrder",
26+
description="Order in which messages should be listed: "
27+
"-1 means most recent messages first, 1 means older messages first.",
28+
)
29+
message_type: Optional[MessageType] = Field(
30+
default=None,
31+
alias="msgType",
32+
description="Message type. Deprecated: use msgTypes instead",
33+
)
34+
message_types: Optional[List[MessageType]] = Field(
35+
default=None, alias="msgTypes", description="Accepted message types."
36+
)
37+
message_statuses: Optional[List[MessageStatus]] = Field(
38+
default=[MessageStatus.PROCESSED, MessageStatus.REMOVING],
39+
alias="msgStatuses",
40+
description="Accepted values for the 'status' field.",
41+
)
42+
addresses: Optional[List[str]] = Field(
43+
default=None, description="Accepted values for the 'sender' field."
44+
)
45+
refs: Optional[List[str]] = Field(
46+
default=None, description="Accepted values for the 'content.ref' field."
47+
)
48+
content_hashes: Optional[List[ItemHash]] = Field(
49+
default=None,
50+
alias="contentHashes",
51+
description="Accepted values for the 'content.item_hash' field.",
52+
)
53+
content_keys: Optional[List[ItemHash]] = Field(
54+
default=None,
55+
alias="contentKeys",
56+
description="Accepted values for the 'content.keys' field.",
57+
)
58+
content_types: Optional[List[str]] = Field(
59+
default=None,
60+
alias="contentTypes",
61+
description="Accepted values for the 'content.type' field.",
62+
)
63+
chains: Optional[List[Chain]] = Field(
64+
default=None, description="Accepted values for the 'chain' field."
65+
)
66+
channels: Optional[List[str]] = Field(
67+
default=None, description="Accepted values for the 'channel' field."
68+
)
69+
tags: Optional[List[str]] = Field(
70+
default=None, description="Accepted values for the 'content.content.tag' field."
71+
)
72+
hashes: Optional[List[ItemHash]] = Field(
73+
default=None, description="Accepted values for the 'item_hash' field."
74+
)
75+
76+
start_date: float = Field(
77+
default=0,
78+
ge=0,
79+
alias="startDate",
80+
description="Start date timestamp. If specified, only messages with "
81+
"a time field greater or equal to this value will be returned.",
82+
)
83+
end_date: float = Field(
84+
default=0,
85+
ge=0,
86+
alias="endDate",
87+
description="End date timestamp. If specified, only messages with "
88+
"a time field lower than this value will be returned.",
89+
)
90+
91+
start_block: int = Field(
92+
default=0,
93+
ge=0,
94+
alias="startBlock",
95+
description="Start block number. If specified, only messages with "
96+
"a block number greater or equal to this value will be returned.",
97+
)
98+
end_block: int = Field(
99+
default=0,
100+
ge=0,
101+
alias="endBlock",
102+
description="End block number. If specified, only messages with "
103+
"a block number lower than this value will be returned.",
104+
)
105+
106+
@model_validator(mode="after")
107+
def validate_field_dependencies(self):
108+
start_date = self.start_date
109+
end_date = self.end_date
110+
if start_date and end_date and (end_date < start_date):
111+
raise ValueError("end date cannot be lower than start date.")
112+
start_block = self.start_block
113+
end_block = self.end_block
114+
if start_block and end_block and (end_block < start_block):
115+
raise ValueError("end block cannot be lower than start block.")
116+
117+
return self
118+
119+
@field_validator(
120+
"hashes",
121+
"addresses",
122+
"refs",
123+
"content_hashes",
124+
"content_keys",
125+
"content_types",
126+
"chains",
127+
"channels",
128+
"message_types",
129+
"message_statuses",
130+
"tags",
131+
mode="before",
132+
)
133+
def split_str(cls, v):
134+
if isinstance(v, str):
135+
return v.split(LIST_FIELD_SEPARATOR)
136+
return v
137+
138+
model_config = ConfigDict(populate_by_name=True)
139+
140+
141+
class MessageQueryParams(BaseMessageQueryParams):
142+
pagination: int = Field(
143+
default=DEFAULT_MESSAGES_PER_PAGE,
144+
ge=0,
145+
description="Maximum number of messages to return. Specifying 0 removes this limit.",
146+
)
147+
page: int = Field(
148+
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
149+
)
150+
151+
152+
class WsMessageQueryParams(BaseMessageQueryParams):
153+
history: Optional[int] = Field(
154+
DEFAULT_WS_HISTORY,
155+
ge=0,
156+
lt=200,
157+
description="Historical elements to send through the websocket.",
158+
)
159+
160+
161+
class MessageHashesQueryParams(BaseModel):
162+
status: Optional[MessageStatus] = Field(
163+
default=None,
164+
description="Message status.",
165+
)
166+
page: int = Field(
167+
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
168+
)
169+
pagination: int = Field(
170+
default=DEFAULT_MESSAGES_PER_PAGE,
171+
ge=0,
172+
description="Maximum number of messages to return. Specifying 0 removes this limit.",
173+
)
174+
start_date: float = Field(
175+
default=0,
176+
ge=0,
177+
alias="startDate",
178+
description="Start date timestamp. If specified, only messages with "
179+
"a time field greater or equal to this value will be returned.",
180+
)
181+
end_date: float = Field(
182+
default=0,
183+
ge=0,
184+
alias="endDate",
185+
description="End date timestamp. If specified, only messages with "
186+
"a time field lower than this value will be returned.",
187+
)
188+
sort_order: SortOrder = Field(
189+
default=SortOrder.DESCENDING,
190+
alias="sortOrder",
191+
description="Order in which messages should be listed: "
192+
"-1 means most recent messages first, 1 means older messages first.",
193+
)
194+
hash_only: bool = Field(
195+
default=True,
196+
description="By default, only hashes are returned. "
197+
"Set this to false to include metadata alongside the hashes in the response.",
198+
)
199+
200+
@model_validator(mode="after")
201+
def validate_field_dependencies(self):
202+
start_date = self.start_date
203+
end_date = self.end_date
204+
if start_date and end_date and (end_date < start_date):
205+
raise ValueError("end date cannot be lower than start date.")
206+
return self
207+
208+
model_config = ConfigDict(populate_by_name=True)

src/aleph/services/cache/node_cache.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1-
from typing import Any, List, Optional, Set
1+
from hashlib import sha256
2+
from typing import Any, Dict, List, Optional, Set
23

34
import redis.asyncio as redis_asyncio
45

6+
import aleph.toolkit.json as aleph_json
7+
from aleph.db.accessors.messages import count_matching_messages
8+
from aleph.schemas.messages_query_params import MessageQueryParams
9+
from aleph.types.db_session import DbSession
10+
511
CacheKey = Any
612
CacheValue = bytes
713

@@ -10,9 +16,10 @@ class NodeCache:
1016
API_SERVERS_KEY = "api_servers"
1117
PUBLIC_ADDRESSES_KEY = "public_addresses"
1218

13-
def __init__(self, redis_host: str, redis_port: int):
19+
def __init__(self, redis_host: str, redis_port: int, message_count_cache_ttl):
1420
self.redis_host = redis_host
1521
self.redis_port = redis_port
22+
self.message_cache_count_ttl = message_count_cache_ttl
1623

1724
self._redis_client: Optional[redis_asyncio.Redis] = None
1825

@@ -52,8 +59,8 @@ async def reset(self):
5259
async def get(self, key: CacheKey) -> Optional[CacheValue]:
5360
return await self.redis_client.get(key)
5461

55-
async def set(self, key: CacheKey, value: Any):
56-
await self.redis_client.set(key, value)
62+
async def set(self, key: CacheKey, value: Any, expiration: Optional[int] = None):
63+
await self.redis_client.set(key, value, ex=expiration)
5764

5865
async def incr(self, key: CacheKey):
5966
await self.redis_client.incr(key)
@@ -82,3 +89,25 @@ async def add_public_address(self, public_address: str) -> None:
8289
async def get_public_addresses(self) -> List[str]:
8390
addresses = await self.redis_client.smembers(self.PUBLIC_ADDRESSES_KEY)
8491
return [addr.decode() for addr in addresses]
92+
93+
@staticmethod
94+
def _message_filter_id(filters: Dict[str, Any]):
95+
filters_json = aleph_json.dumps(filters, sort_keys=True)
96+
return sha256(filters_json).hexdigest()
97+
98+
async def count_messages(
99+
self, session: DbSession, query_params: MessageQueryParams
100+
) -> int:
101+
filters = query_params.model_dump(exclude_none=True)
102+
cache_key = f"message_count:{self._message_filter_id(filters)}"
103+
104+
cached_result = await self.get(cache_key)
105+
if cached_result is not None:
106+
return int(cached_result.decode())
107+
108+
# Slow, can take a few seconds
109+
n_matches = count_matching_messages(session, **filters)
110+
111+
await self.set(cache_key, n_matches, expiration=self.message_cache_count_ttl)
112+
113+
return n_matches

src/aleph/toolkit/json.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
# serializer changes easier.
1919
SerializedJsonInput = Union[bytes, str]
2020

21-
2221
# Note: JSONDecodeError is a subclass of ValueError, but the JSON module sometimes throws
2322
# raw value errors, including on NaN because of our custom parse_constant.
2423
DecodeError = orjson.JSONDecodeError
@@ -55,8 +54,11 @@ def extended_json_encoder(obj: Any) -> Any:
5554
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
5655

5756

58-
def dumps(obj: Any) -> bytes:
57+
def dumps(obj: Any, sort_keys: bool = True) -> bytes:
5958
try:
60-
return orjson.dumps(obj)
59+
opts = orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS if sort_keys else 0
60+
return orjson.dumps(obj, option=opts)
6161
except TypeError:
62-
return json.dumps(obj, default=extended_json_encoder).encode()
62+
return json.dumps(
63+
obj, default=extended_json_encoder, sort_keys=sort_keys
64+
).encode()

0 commit comments

Comments
 (0)