From 0533e4a582cff3262fa03bd736410ea00ce8b65b Mon Sep 17 00:00:00 2001 From: mhh Date: Thu, 27 Jun 2024 18:50:40 +0200 Subject: [PATCH] feat: add verify_signature parameter to fetch functions Co-authored-by: Laurent Peuch Co-authored-by: Hugo Herter --- src/aleph/sdk/client/abstract.py | 26 ++++++++++++++++++++++++++ src/aleph/sdk/client/http.py | 19 +++++++++++++++++-- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/aleph/sdk/client/abstract.py b/src/aleph/sdk/client/abstract.py index 9fce5469..301388ec 100644 --- a/src/aleph/sdk/client/abstract.py +++ b/src/aleph/sdk/client/abstract.py @@ -74,6 +74,7 @@ async def get_posts( post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> PostsResponse: """ Fetch a list of posts from the network. @@ -83,18 +84,25 @@ async def get_posts( :param post_filter: Filter to apply to the posts (Default: None) :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + :param verify_signatures: Verify the signatures of the messages (Default: False) """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") async def get_posts_iterator( self, post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> AsyncIterable[PostMessage]: """ Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates but will always return all posts. :param post_filter: Filter to apply to the posts (Default: None) + :param ignore_invalid_messages: Ignore invalid messages (Default: True) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + :param verify_signatures: Verify the signatures of the messages (Default: False) """ page = 1 resp = None @@ -102,6 +110,9 @@ async def get_posts_iterator( resp = await self.get_posts( page=page, post_filter=post_filter, + ignore_invalid_messages=ignore_invalid_messages, + invalid_messages_log_level=invalid_messages_log_level, + verify_signatures=verify_signatures, ) page += 1 for post in resp.posts: @@ -178,6 +189,7 @@ async def get_messages( message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> MessagesResponse: """ Fetch a list of messages from the network. @@ -187,18 +199,25 @@ async def get_messages( :param message_filter: Filter to apply to the messages :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + :param verify_signatures: Verify the signatures of the messages (Default: False) """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") async def get_messages_iterator( self, message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> AsyncIterable[AlephMessage]: """ Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates but will always return all messages. :param message_filter: Filter to apply to the messages + :param ignore_invalid_messages: Ignore invalid messages (Default: True) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + :param verify_signatures: Whether to verify the signatures of the messages (Default: False) """ page = 1 resp = None @@ -206,6 +225,9 @@ async def get_messages_iterator( resp = await self.get_messages( page=page, message_filter=message_filter, + ignore_invalid_messages=ignore_invalid_messages, + invalid_messages_log_level=invalid_messages_log_level, + verify_signatures=verify_signatures, ) page += 1 for message in resp.messages: @@ -216,12 +238,14 @@ async def get_message( self, item_hash: str, message_type: Optional[Type[GenericMessage]] = None, + verify_signature: bool = False, ) -> GenericMessage: """ Get a single message from its `item_hash` and perform some basic validation. :param item_hash: Hash of the message to fetch :param message_type: Type of message to fetch + :param verify_signature: Whether to verify the signature of the message (Default: False) """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") @@ -229,11 +253,13 @@ async def get_message( def watch_messages( self, message_filter: Optional[MessageFilter] = None, + verify_signatures: bool = False, ) -> AsyncIterable[AlephMessage]: """ Iterate over current and future matching messages asynchronously. :param message_filter: Filter to apply to the messages + :param verify_signatures: Whether to verify the signatures of the messages (Default: False) """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index ae98b0d1..9ded2bcd 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -15,6 +15,7 @@ from ..exceptions import FileTooLarge, ForgottenMessageError, MessageNotFoundError from ..query.filters import MessageFilter, PostFilter from ..query.responses import MessagesResponse, Post, PostsResponse +from ..security import verify_message_signature from ..types import GenericMessage from ..utils import ( Writable, @@ -117,6 +118,7 @@ async def get_posts( post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> PostsResponse: ignore_invalid_messages = ( True if ignore_invalid_messages is None else ignore_invalid_messages @@ -145,12 +147,15 @@ async def get_posts( posts: List[Post] = [] for post_raw in posts_raw: try: - posts.append(Post.parse_obj(post_raw)) + post = Post.parse_obj(post_raw) + posts.append(post) except ValidationError as e: if not ignore_invalid_messages: raise e if invalid_messages_log_level: logger.log(level=invalid_messages_log_level, msg=e) + if verify_signatures: + verify_message_signature(post) return PostsResponse( posts=posts, pagination_page=response_json["pagination_page"], @@ -266,6 +271,7 @@ async def get_messages( message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> MessagesResponse: ignore_invalid_messages = ( True if ignore_invalid_messages is None else ignore_invalid_messages @@ -312,6 +318,8 @@ async def get_messages( raise e if invalid_messages_log_level: logger.log(level=invalid_messages_log_level, msg=e) + if verify_signatures: + verify_message_signature(message) return MessagesResponse( messages=messages, @@ -325,6 +333,7 @@ async def get_message( self, item_hash: str, message_type: Optional[Type[GenericMessage]] = None, + verify_signature: bool = False, ) -> GenericMessage: async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp: try: @@ -339,6 +348,8 @@ async def get_message( f"The requested message {message_raw['item_hash']} has been forgotten by {', '.join(message_raw['forgotten_by'])}" ) message = parse_message(message_raw["message"]) + if verify_signature: + verify_message_signature(message) if message_type: expected_type = get_message_type_value(message_type) if message.type != expected_type: @@ -374,6 +385,7 @@ async def get_message_error( async def watch_messages( self, message_filter: Optional[MessageFilter] = None, + verify_signatures: bool = False, ) -> AsyncIterable[AlephMessage]: message_filter = message_filter or MessageFilter() params = message_filter.as_http_params() @@ -389,6 +401,9 @@ async def watch_messages( break else: data = json.loads(msg.data) - yield parse_message(data) + message = parse_message(data) + if verify_signatures: + verify_message_signature(message) + yield message elif msg.type == aiohttp.WSMsgType.ERROR: break