Skip to content

Commit e17a06c

Browse files
authored
[Messages] Use Pydantic message models everywhere (#301)
Propagated the use of Pydantic message models to all the functions related to message processing. Extended the Pydantic message models by introducing a new message class, `ValidatedBaseMessage`. This schema represents fully-loaded messages, including their content field. Split the message validation function, check_message, in two parts: * `parse_message` checks that the message is semantically valid, i.e. that all the required fields are present, are of the correct type and have sensible values. * `verify_signature` checks the signature of the message using the public key of the sender. We now call parse_message as early as possible in the process. This allows to simplify hypotheses about the presence/absence of specific fields and allows to simplify the codebase. Improved the typing of the message models by using generics to represent the message and item type fields. Split the message validation function, check_message, in two parts: * `parse_message` checks that the message is semantically valid, i.e. that all the required fields are present, are of the correct type and have sensible values. * `verify_signature` checks the signature of the message using the public key of the sender. We now call parse_message as early as possible in the process. This allows to simplify hypotheses about the presence/absence of specific fields and allows to simplify the codebase. Modified tests to use the message classes. Removed a few tests that tested corner cases linked to incomplete message dictionaries.
1 parent ecca54b commit e17a06c

30 files changed

+865
-651
lines changed

src/aleph/chains/common.py

Lines changed: 59 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from enum import IntEnum
66
from typing import Dict, Optional, Tuple, List
77

8-
from aleph_message.models import MessageType, ItemType
8+
from aleph_message.models import MessageConfirmation
99
from bson import ObjectId
1010
from pymongo import UpdateOne
1111

@@ -23,12 +23,20 @@
2323
from aleph.model.filepin import PermanentPin
2424
from aleph.model.messages import CappedMessage, Message
2525
from aleph.model.pending import PendingMessage, PendingTX
26-
from aleph.network import check_message as check_message_fn
26+
from aleph.network import verify_signature
2727
from aleph.permissions import check_sender_authorization
2828
from aleph.storage import get_json, pin_hash, add_json, get_message_content
2929
from .tx_context import TxContext
30-
from ..schemas.pending_messages import BasePendingMessage
31-
from ..utils import item_type_from_hash
30+
from aleph.schemas.pending_messages import (
31+
BasePendingMessage,
32+
)
33+
from aleph.schemas.validated_message import (
34+
validate_pending_message,
35+
ValidatedStoreMessage,
36+
ValidatedForgetMessage,
37+
make_confirmation_update_query,
38+
make_message_upsert_query,
39+
)
3240

3341
LOGGER = logging.getLogger("chains.common")
3442

@@ -54,12 +62,17 @@ async def mark_confirmed_data(chain_name, tx_hash, height):
5462
}
5563

5664

57-
async def delayed_incoming(message, chain_name=None, tx_hash=None, height=None):
65+
async def delayed_incoming(
66+
message: BasePendingMessage,
67+
chain_name: Optional[str] = None,
68+
tx_hash: Optional[str] = None,
69+
height: Optional[int] = None,
70+
):
5871
if message is None:
5972
return
6073
await PendingMessage.collection.insert_one(
6174
{
62-
"message": message,
75+
"message": message.dict(exclude={"content"}),
6376
"source": dict(
6477
chain_name=chain_name,
6578
tx_hash=tx_hash,
@@ -77,18 +90,20 @@ class IncomingStatus(IntEnum):
7790

7891

7992
async def mark_message_for_retry(
80-
message: Dict,
93+
message: BasePendingMessage,
8194
chain_name: Optional[str],
8295
tx_hash: Optional[str],
8396
height: Optional[int],
8497
check_message: bool,
8598
retrying: bool,
8699
existing_id,
87100
):
101+
message_dict = message.dict(exclude={"content"})
102+
88103
if not retrying:
89104
await PendingMessage.collection.insert_one(
90105
{
91-
"message": message,
106+
"message": message_dict,
92107
"source": dict(
93108
chain_name=chain_name,
94109
tx_hash=tx_hash,
@@ -105,25 +120,8 @@ async def mark_message_for_retry(
105120
LOGGER.debug(f"Update result {result}")
106121

107122

108-
def update_message_item_type(message_dict: Dict) -> Dict:
109-
"""
110-
Ensures that the item_type field of a message is present.
111-
Sets it to the default value if the field is not specified.
112-
"""
113-
if "item_type" in message_dict:
114-
return message_dict
115-
116-
if "item_content" in message_dict:
117-
item_type = ItemType.inline
118-
else:
119-
item_type = item_type_from_hash(message_dict["item_hash"])
120-
121-
message_dict["item_type"] = item_type
122-
return message_dict
123-
124-
125123
async def incoming(
126-
message: Dict,
124+
pending_message: BasePendingMessage,
127125
chain_name: Optional[str] = None,
128126
tx_hash: Optional[str] = None,
129127
height: Optional[int] = None,
@@ -138,77 +136,47 @@ async def incoming(
138136
if existing in database, created if not.
139137
"""
140138

141-
# TODO: this is a temporary fix to set the item_type of the message to the correct
142-
# value. This should be replaced by a full use of Pydantic models.
143-
message = update_message_item_type(message)
144-
145-
item_hash = message["item_hash"]
146-
sender = message["sender"]
139+
item_hash = pending_message.item_hash
140+
sender = pending_message.sender
141+
confirmations = []
147142
ids_key = (item_hash, sender, chain_name)
148143

149-
if chain_name and tx_hash and height and seen_ids is not None:
150-
if ids_key in seen_ids.keys():
151-
if height > seen_ids[ids_key]:
152-
return IncomingStatus.MESSAGE_HANDLED, []
144+
if chain_name and tx_hash and height:
145+
if seen_ids is not None:
146+
if ids_key in seen_ids.keys():
147+
if height > seen_ids[ids_key]:
148+
return IncomingStatus.MESSAGE_HANDLED, []
149+
150+
confirmations.append(
151+
MessageConfirmation(chain=chain_name, hash=tx_hash, height=height)
152+
)
153153

154154
filters = {
155155
"item_hash": item_hash,
156-
"chain": message["chain"],
157-
"sender": message["sender"],
158-
"type": message["type"],
156+
"chain": pending_message.chain,
157+
"sender": pending_message.sender,
158+
"type": pending_message.type,
159159
}
160160
existing = await Message.collection.find_one(
161161
filters,
162162
projection={"confirmed": 1, "confirmations": 1, "time": 1, "signature": 1},
163163
)
164164

165165
if check_message:
166-
if existing is None or (existing["signature"] != message["signature"]):
166+
if existing is None or (existing["signature"] != pending_message.signature):
167167
# check/sanitize the message if needed
168168
try:
169-
message = await check_message_fn(
170-
message, from_chain=(chain_name is not None)
171-
)
169+
await verify_signature(pending_message)
172170
except InvalidMessageError:
173171
return IncomingStatus.FAILED_PERMANENTLY, []
174172

175-
if message is None:
176-
return IncomingStatus.MESSAGE_HANDLED, []
177-
178173
if retrying:
179174
LOGGER.debug("(Re)trying %s." % item_hash)
180175
else:
181176
LOGGER.info("Incoming %s." % item_hash)
182177

183-
# we set the incoming chain as default for signature
184-
message["chain"] = message.get("chain", chain_name)
185-
186-
# if existing is None:
187-
# # TODO: verify if search key is ok. do we need an unique key for messages?
188-
# existing = await Message.collection.find_one(
189-
# filters, projection={'confirmed': 1, 'confirmations': 1, 'time': 1})
190-
191-
if chain_name and tx_hash and height:
192-
# We are getting a confirmation here
193-
new_values = await mark_confirmed_data(chain_name, tx_hash, height)
194-
195-
updates = {
196-
"$set": {
197-
"confirmed": True,
198-
},
199-
"$min": {"time": message["time"]},
200-
"$addToSet": {"confirmations": new_values["confirmations"][0]},
201-
}
202-
else:
203-
updates = {
204-
"$max": {
205-
"confirmed": False,
206-
},
207-
"$min": {"time": message["time"]},
208-
}
178+
updates: Dict[str, Dict] = {}
209179

210-
# new_values = {'confirmed': False} # this should be our default.
211-
should_commit = False
212180
if existing:
213181
if seen_ids is not None and height is not None:
214182
if ids_key in seen_ids.keys():
@@ -219,25 +187,14 @@ async def incoming(
219187
else:
220188
seen_ids[ids_key] = height
221189

222-
# THIS CODE SHOULD BE HERE...
223-
# But, if a race condition appeared, we might have the message twice.
224-
# if (existing['confirmed'] and
225-
# chain_name in [c['chain'] for c in existing['confirmations']]):
226-
# return
227-
228190
LOGGER.debug("Updating %s." % item_hash)
229191

230-
if chain_name and tx_hash and height:
231-
# we need to update messages adding the confirmation
232-
# await Message.collection.update_many(filters, updates)
233-
should_commit = True
192+
if confirmations:
193+
updates = make_confirmation_update_query(confirmations)
234194

235195
else:
236-
# if not (chain_name and tx_hash and height):
237-
# new_values = {'confirmed': False} # this should be our default.
238-
239196
try:
240-
content = await get_message_content(message)
197+
content = await get_message_content(pending_message)
241198

242199
except InvalidContent:
243200
LOGGER.warning("Can't get content of object %r, won't retry." % item_hash)
@@ -247,7 +204,7 @@ async def incoming(
247204
if not isinstance(e, ContentCurrentlyUnavailable):
248205
LOGGER.exception("Can't get content of object %r" % item_hash)
249206
await mark_message_for_retry(
250-
message=message,
207+
message=pending_message,
251208
chain_name=chain_name,
252209
tx_hash=tx_hash,
253210
height=height,
@@ -257,26 +214,23 @@ async def incoming(
257214
)
258215
return IncomingStatus.RETRYING_LATER, []
259216

260-
json_content = content.value
261-
if json_content.get("address", None) is None:
262-
json_content["address"] = message["sender"]
263-
264-
if json_content.get("time", None) is None:
265-
json_content["time"] = message["time"]
217+
validated_message = validate_pending_message(
218+
pending_message=pending_message, content=content, confirmations=confirmations
219+
)
266220

267221
# warning: those handlers can modify message and content in place
268222
# and return a status. None has to be retried, -1 is discarded, True is
269223
# handled and kept.
270224
# TODO: change this, it's messy.
271225
try:
272-
if message["type"] == MessageType.store:
273-
handling_result = await handle_new_storage(message, json_content)
274-
elif message["type"] == MessageType.forget:
226+
if isinstance(validated_message, ValidatedStoreMessage):
227+
handling_result = await handle_new_storage(validated_message)
228+
elif isinstance(validated_message, ValidatedForgetMessage):
275229
# Handling it here means that there we ensure that the message
276230
# has been forgotten before it is saved on the node.
277231
# We may want the opposite instead: ensure that the message has
278232
# been saved before it is forgotten.
279-
handling_result = await handle_forget_message(message, json_content)
233+
handling_result = await handle_forget_message(validated_message)
280234
else:
281235
handling_result = True
282236
except UnknownHashError:
@@ -289,7 +243,7 @@ async def incoming(
289243
if handling_result is None:
290244
LOGGER.debug("Message type handler has failed, retrying later.")
291245
await mark_message_for_retry(
292-
message=message,
246+
message=pending_message,
293247
chain_name=chain_name,
294248
tx_hash=tx_hash,
295249
height=height,
@@ -306,7 +260,7 @@ async def incoming(
306260
)
307261
return IncomingStatus.FAILED_PERMANENTLY, []
308262

309-
if not await check_sender_authorization(message, json_content):
263+
if not await check_sender_authorization(validated_message):
310264
LOGGER.warning("Invalid sender for %s" % item_hash)
311265
return IncomingStatus.MESSAGE_HANDLED, []
312266

@@ -320,19 +274,10 @@ async def incoming(
320274
seen_ids[ids_key] = height
321275

322276
LOGGER.debug("New message to store for %s." % item_hash)
323-
# message.update(new_values)
324-
updates["$set"] = {
325-
"content": json_content,
326-
"size": len(content.raw_value),
327-
"item_content": message.get("item_content"),
328-
"item_type": message.get("item_type"),
329-
"channel": message.get("channel"),
330-
"signature": message.get("signature"),
331-
**updates.get("$set", {}),
332-
}
333-
should_commit = True
334277

335-
if should_commit:
278+
updates = make_message_upsert_query(validated_message)
279+
280+
if updates:
336281
update_op = UpdateOne(filters, updates, upsert=True)
337282
bulk_ops = [DbBulkOperation(Message, update_op)]
338283

@@ -346,7 +291,7 @@ async def incoming(
346291
return IncomingStatus.MESSAGE_HANDLED, []
347292

348293

349-
async def process_one_message(message: Dict, *args, **kwargs):
294+
async def process_one_message(message: BasePendingMessage, *args, **kwargs):
350295
"""
351296
Helper function to process a message on the spot.
352297
"""

src/aleph/handlers/forget.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66

77
from aioipfs.api import RepoAPI
88
from aioipfs.exceptions import NotPinnedError
9-
from aleph_message.models import ForgetMessage, MessageType
10-
from aleph_message.models import ItemType
11-
from pydantic import ValidationError
9+
from aleph_message.models import ItemType, MessageType
1210

1311
from aleph.model.filepin import PermanentPin
1412
from aleph.model.hashes import delete_value
1513
from aleph.model.messages import Message
14+
from aleph.schemas.validated_message import ValidatedForgetMessage
1615
from aleph.services.ipfs.common import get_ipfs_api
1716
from aleph.utils import item_type_from_hash
1817

@@ -118,7 +117,7 @@ async def garbage_collect(storage_hash: str, storage_type: ItemType):
118117

119118

120119
async def is_allowed_to_forget(
121-
target_info: TargetMessageInfo, by: ForgetMessage
120+
target_info: TargetMessageInfo, by: ValidatedForgetMessage
122121
) -> bool:
123122
"""Check if a forget message is allowed to 'forget' the target message given its hash."""
124123
# Both senders are identical:
@@ -136,7 +135,7 @@ async def is_allowed_to_forget(
136135

137136

138137
async def forget_if_allowed(
139-
target_info: TargetMessageInfo, forget_message: ForgetMessage
138+
target_info: TargetMessageInfo, forget_message: ValidatedForgetMessage
140139
) -> None:
141140
"""Forget a message.
142141
@@ -210,17 +209,8 @@ async def get_target_message_info(target_hash: str) -> Optional[TargetMessageInf
210209
return TargetMessageInfo.from_db_object(message_dict)
211210

212211

213-
async def handle_forget_message(message: Dict, content: Dict) -> bool:
212+
async def handle_forget_message(forget_message: ValidatedForgetMessage) -> bool:
214213
# Parsing and validation
215-
# TODO: this is a temporary fix to release faster, finish od-message-models-in-pipeline
216-
message["content"] = content
217-
218-
try:
219-
forget_message = ForgetMessage(**message)
220-
except ValidationError as e:
221-
logger.error("Invalid forget message: %s", e)
222-
return False
223-
224214
logger.debug(f"Handling forget message {forget_message.item_hash}")
225215
hashes_to_forget = forget_message.content.hashes
226216

0 commit comments

Comments
 (0)