Skip to content

Commit

Permalink
feat: add verify_signature parameter to fetch functions
Browse files Browse the repository at this point in the history
Co-authored-by: Laurent Peuch <cortex@worlddomination.be>
Co-authored-by: Hugo Herter <git@hugoherter.com>
  • Loading branch information
3 people committed Jun 27, 2024
1 parent e362ea2 commit 8225b02
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
26 changes: 26 additions & 0 deletions src/aleph/sdk/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ async def get_posts(
post_filter: Optional[PostFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> PostsResponse:
"""
Fetch a list of posts from the network.
Expand All @@ -83,25 +84,35 @@ async def get_posts(
:param post_filter: Filter to apply to the posts (Default: None)
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
:param verify_signatures: Verify the signatures of the messages (Default: False)
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

async def get_posts_iterator(
self,
post_filter: Optional[PostFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> AsyncIterable[PostMessage]:
"""
Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates
but will always return all posts.
:param post_filter: Filter to apply to the posts (Default: None)
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
:param verify_signatures: Verify the signatures of the messages (Default: False)
"""
page = 1
resp = None
while resp is None or len(resp.posts) > 0:
resp = await self.get_posts(
page=page,
post_filter=post_filter,
ignore_invalid_messages=ignore_invalid_messages,
invalid_messages_log_level=invalid_messages_log_level,
verify_signatures=verify_signatures,
)
page += 1
for post in resp.posts:
Expand Down Expand Up @@ -178,6 +189,7 @@ async def get_messages(
message_filter: Optional[MessageFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> MessagesResponse:
"""
Fetch a list of messages from the network.
Expand All @@ -187,25 +199,35 @@ async def get_messages(
:param message_filter: Filter to apply to the messages
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
:param verify_signatures: Verify the signatures of the messages (Default: False)
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

async def get_messages_iterator(
self,
message_filter: Optional[MessageFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> AsyncIterable[AlephMessage]:
"""
Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates
but will always return all messages.
:param message_filter: Filter to apply to the messages
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
:param verify_signatures: Whether to verify the signatures of the messages (Default: False)
"""
page = 1
resp = None
while resp is None or len(resp.messages) > 0:
resp = await self.get_messages(
page=page,
message_filter=message_filter,
ignore_invalid_messages=ignore_invalid_messages,
invalid_messages_log_level=invalid_messages_log_level,
verify_signatures=verify_signatures,
)
page += 1
for message in resp.messages:
Expand All @@ -216,24 +238,28 @@ async def get_message(
self,
item_hash: str,
message_type: Optional[Type[GenericMessage]] = None,
verify_signature: bool = False,
) -> GenericMessage:
"""
Get a single message from its `item_hash` and perform some basic validation.
:param item_hash: Hash of the message to fetch
:param message_type: Type of message to fetch
:param verify_signature: Whether to verify the signature of the message (Default: False)
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

@abstractmethod
def watch_messages(
self,
message_filter: Optional[MessageFilter] = None,
verify_signatures: bool = False,
) -> AsyncIterable[AlephMessage]:
"""
Iterate over current and future matching messages asynchronously.
:param message_filter: Filter to apply to the messages
:param verify_signatures: Whether to verify the signatures of the messages (Default: False)
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

Expand Down
19 changes: 17 additions & 2 deletions src/aleph/sdk/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..exceptions import FileTooLarge, ForgottenMessageError, MessageNotFoundError
from ..query.filters import MessageFilter, PostFilter
from ..query.responses import MessagesResponse, Post, PostsResponse
from ..security import verify_message_signature
from ..types import GenericMessage
from ..utils import (
Writable,
Expand Down Expand Up @@ -117,6 +118,7 @@ async def get_posts(
post_filter: Optional[PostFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> PostsResponse:
ignore_invalid_messages = (
True if ignore_invalid_messages is None else ignore_invalid_messages
Expand Down Expand Up @@ -145,12 +147,15 @@ async def get_posts(
posts: List[Post] = []
for post_raw in posts_raw:
try:
posts.append(Post.parse_obj(post_raw))
post = Post.parse_obj(post_raw)
posts.append(post)
except ValidationError as e:
if not ignore_invalid_messages:
raise e
if invalid_messages_log_level:
logger.log(level=invalid_messages_log_level, msg=e)
if verify_signatures:
verify_message_signature(post)
return PostsResponse(
posts=posts,
pagination_page=response_json["pagination_page"],
Expand Down Expand Up @@ -266,6 +271,7 @@ async def get_messages(
message_filter: Optional[MessageFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> MessagesResponse:
ignore_invalid_messages = (
True if ignore_invalid_messages is None else ignore_invalid_messages
Expand Down Expand Up @@ -312,6 +318,8 @@ async def get_messages(
raise e
if invalid_messages_log_level:
logger.log(level=invalid_messages_log_level, msg=e)
if verify_signatures:
verify_message_signature(message)

return MessagesResponse(
messages=messages,
Expand All @@ -325,6 +333,7 @@ async def get_message(
self,
item_hash: str,
message_type: Optional[Type[GenericMessage]] = None,
verify_signature: bool = False,
) -> GenericMessage:
async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp:
try:
Expand All @@ -339,6 +348,8 @@ async def get_message(
f"The requested message {message_raw['item_hash']} has been forgotten by {', '.join(message_raw['forgotten_by'])}"
)
message = parse_message(message_raw["message"])
if verify_signature:
verify_message_signature(message)
if message_type:
expected_type = get_message_type_value(message_type)
if message.type != expected_type:
Expand Down Expand Up @@ -374,6 +385,7 @@ async def get_message_error(
async def watch_messages(
self,
message_filter: Optional[MessageFilter] = None,
verify_signatures: bool = False,
) -> AsyncIterable[AlephMessage]:
message_filter = message_filter or MessageFilter()
params = message_filter.as_http_params()
Expand All @@ -389,6 +401,9 @@ async def watch_messages(
break
else:
data = json.loads(msg.data)
yield parse_message(data)
message = parse_message(data)
if verify_signatures:
verify_message_signature(message)
yield message
elif msg.type == aiohttp.WSMsgType.ERROR:
break

0 comments on commit 8225b02

Please sign in to comment.