diff --git a/src/aleph/web/controllers/p2p.py b/src/aleph/web/controllers/p2p.py index a51ab7383..109e56672 100644 --- a/src/aleph/web/controllers/p2p.py +++ b/src/aleph/web/controllers/p2p.py @@ -27,7 +27,7 @@ get_mq_channel_from_request, ) from aleph.web.controllers.utils import mq_make_aleph_message_topic_queue, processing_status_to_http_status, \ - mq_read_one_message + mq_read_one_message, validate_message_dict LOGGER = logging.getLogger(__name__) @@ -46,13 +46,6 @@ def from_failures(cls, failed_publications: List[Protocol]): return cls(status=status, failed=failed_publications) -def _validate_message_dict(message_dict: Mapping[str, Any]) -> BasePendingMessage: - try: - return parse_message(message_dict) - except InvalidMessageException as e: - raise web.HTTPUnprocessableEntity(body=str(e)) - - def _validate_request_data(config: Config, request_data: Dict) -> None: """ Validates the content of a JSON pubsub message depending on the channel @@ -84,7 +77,7 @@ def _validate_request_data(config: Config, request_data: Dict) -> None: reason="'data': must be deserializable as JSON." ) - _validate_message_dict(message_dict) + validate_message_dict(message_dict) async def _pub_on_p2p_topics( @@ -160,7 +153,7 @@ async def pub_message(request: web.Request): except ValidationError as e: raise web.HTTPUnprocessableEntity(body=e.json(indent=4)) - pending_message = _validate_message_dict(request_data.message_dict) + pending_message = validate_message_dict(request_data.message_dict) # In sync mode, wait for a message processing event. We need to create the queue # before publishing the message on P2P topics in order to guarantee that the event diff --git a/src/aleph/web/controllers/storage.py b/src/aleph/web/controllers/storage.py index dd8c11924..49037ef61 100644 --- a/src/aleph/web/controllers/storage.py +++ b/src/aleph/web/controllers/storage.py @@ -1,3 +1,4 @@ +import ast import base64 import datetime as dt import functools @@ -6,9 +7,11 @@ from typing import Union, Tuple import aio_pika +from aiohttp.web_response import Response from eth_account import Account from eth_account.messages import encode_defunct from mypy.dmypy_server import MiB +from sqlalchemy.orm import Session from aleph.chains.chain_service import ChainService from aleph.chains.common import get_verification_buffer @@ -22,8 +25,8 @@ from aleph.db.models import PendingMessageDb from aleph.exceptions import AlephStorageException, UnknownHashError from aleph.toolkit import json -from aleph.types.db_session import DbSession -from aleph.types.message_status import MessageProcessingStatus, InvalidSignature +from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.message_status import MessageProcessingStatus, InvalidSignature, InvalidMessageException from aleph.utils import run_in_executor, item_type_from_hash from aleph.web.controllers.app_state_getters import ( get_session_factory_from_request, @@ -37,9 +40,9 @@ multidict_proxy_to_io, mq_make_aleph_message_topic_queue, processing_status_to_http_status, - mq_read_one_message, + mq_read_one_message, validate_message_dict, ) -from aleph.schemas.pending_messages import BasePendingMessage +from aleph.schemas.pending_messages import BasePendingMessage, PendingStoreMessage logger = logging.getLogger(__name__) @@ -83,7 +86,7 @@ async def add_storage_json_controller(request: web.Request): async def get_message_content( - post_data: MultiDictProxy[Union[str, bytes, FileField]] + post_data: MultiDictProxy[Union[str, bytes, FileField]] ) -> Tuple[dict, int]: message_bytearray = post_data.get("message", b"") file_size = post_data.get("file_size") or 0 @@ -98,109 +101,105 @@ async def get_message_content( return message_dict, int(str(file_size)) -async def verify_and_handle_request( - pending_message_db, - file_io, - message, - size, - session_factory, - chain_service: ChainService, -): +async def _verify_message_signature(pending_message_db: PendingMessageDb, chain_service: ChainService) -> None: + try: + await chain_service.verify_signature(pending_message_db) + except InvalidSignature: + raise web.HTTPForbidden() + + +async def _verify_user_balance(pending_message_db: PendingMessageDb, session: DbSession, size: int) -> None: + current_balance = get_total_balance( + session=session, address=pending_message_db.sender + ) + required_balance = (size / MiB) / 3 + # Need to merge to get this functions + # current_cost_for_user = get_total_cost_for_address( + # session=session, address=pending_message_db.sender + # ) + if current_balance < required_balance: + raise web.HTTPPaymentRequired + +async def _verify_user_file(message : dict, size : int, file_io) -> None: + file_io.seek(0) content = file_io.read(size) item_content = json.loads(message["item_content"]) actual_item_hash = sha256(content).hexdigest() - c_item_hash = item_content["item_hash"] + client_item_hash = item_content["item_hash"] + if len(content) > (1000 * MiB): + raise web.HTTPRequestEntityTooLarge(actual_size=len(content), max_size=(1000 * MiB)) + elif actual_item_hash != client_item_hash: + raise web.HTTPUnprocessableEntity() + + +async def storage_add_file_with_message(request: web.Request, session: DbSession, chain_service, post, file_io, sync = False): + config = get_config_from_request(request) + message, size = await get_message_content(post) + mq_queue = None try: - await chain_service.verify_signature(pending_message_db) - except InvalidSignature: - output = {"status": "Forbidden"} - return web.json_response(output, status=403) + valid_message: BasePendingMessage = validate_message_dict(message) + except InvalidMessageException: + output = {"status": "rejected"} + return web.json_response(output, status=422) - with session_factory() as session: - current_balance = get_total_balance( - session=session, address=pending_message_db.sender + pending_store_message = PendingStoreMessage.parse_obj(valid_message) + pending_message_db = PendingMessageDb.from_obj( + obj=pending_store_message, reception_time=dt.datetime.now(), fetched=True + ) + await _verify_message_signature( + pending_message_db=pending_message_db, chain_service=chain_service + ) + await _verify_user_balance(session=session, pending_message_db=pending_message_db, size=size) + await _verify_user_file(message=message, size=size, file_io=file_io) + + if sync: + mq_channel = await get_mq_channel_from_request(request, logger=logger) + mq_queue = await mq_make_aleph_message_topic_queue( + channel=mq_channel, + config=config, + routing_key=f"*.{pending_message_db.item_hash}", ) - if current_balance < len(content): - output = {"status": "Payment Required"} - return web.json_response(output, status=402) - elif len(content) > (1000 * MiB): - output = {"status": "Payload too large"} - return web.json_response(output, status=413) - elif actual_item_hash != c_item_hash: - output = {"status": "Unprocessable Content"} - return web.json_response(output, status=422) - elif len(content) > 25 * MiB and not message: - output = {"status": "Unauthorized"} - return web.json_response(output, status=401) - else: - return None + session.add(pending_message_db) + session.commit() + if sync: + mq_message = await mq_read_one_message(mq_queue, 30) + if mq_message is None: + raise web.HTTPAccepted() + if mq_message.routing_key is not None: + status_str, _item_hash = mq_message.routing_key.split(".") + processing_status = MessageProcessingStatus(status_str) + status_code = processing_status_to_http_status(processing_status) + return web.json_response(status=status_code, text=_item_hash) -async def storage_add_file_with_message(request: web.Request): + +async def storage_add_file(request: web.Request): storage_service = get_storage_service_from_request(request) session_factory = get_session_factory_from_request(request) # TODO : Add chainservice to ccn_api_client to be able to call get_chainservice_from_request chain_service: ChainService = ChainService( session_factory=session_factory, storage_service=storage_service ) - config = get_config_from_request(request) - post = await request.post() file_io = multidict_proxy_to_io(post) - message, size = await get_message_content(post) - pending_message_db = PendingMessageDb.from_message_dict( - message_dict=message, reception_time=dt.datetime.now(), fetched=True - ) - mq_channel = await get_mq_channel_from_request(request, logger=logger) - mq_queue = await mq_make_aleph_message_topic_queue( - channel=mq_channel, - config=config, - routing_key=f"*.{pending_message_db.item_hash}", - ) - - is_valid_message = await verify_and_handle_request( - pending_message_db, - file_io, - message, - size, - session_factory, - chain_service, - ) - if is_valid_message is not None: - return is_valid_message - - with session_factory() as session: - file_hash = await storage_service.add_file( - session=session, fileobject=file_io, engine=ItemType.storage - ) - session.add(pending_message_db) - session.commit() - mq_message = await mq_read_one_message(mq_queue, 30) - - if mq_message is None: - output = {"status": "accepted"} - return web.json_response(output, status=202) - if mq_message.routing_key is not None: - status_str, _item_hash = mq_message.routing_key.split(".") - processing_status = MessageProcessingStatus(status_str) - status_code = processing_status_to_http_status(processing_status) - return web.json_response(status=status_code, text=file_hash) + post = await request.post() + sync = False + if post.get("message", b"") is None: + raise web.HTTPUnauthorized() -async def storage_add_file(request: web.Request): - post = await request.post() - if post.get("message", b"") is not None and post.get("file_size") is not None: - return await storage_add_file_with_message(request) + sync_value = post.get("sync") + if sync_value is not None: + sync = ast.literal_eval(sync_value) - storage_service = get_storage_service_from_request(request) - session_factory = get_session_factory_from_request(request) - file_io = multidict_proxy_to_io(post) with session_factory() as session: file_hash = await storage_service.add_file( session=session, fileobject=file_io, engine=ItemType.storage ) + if post.get("message", b"") is not None and post.get("file_size") is not None: + await storage_add_file_with_message(request, session, chain_service, post, file_io, sync) session.commit() output = {"status": "success", "hash": file_hash} return web.json_response(output) diff --git a/src/aleph/web/controllers/utils.py b/src/aleph/web/controllers/utils.py index 1d638e1cf..b03744f49 100644 --- a/src/aleph/web/controllers/utils.py +++ b/src/aleph/web/controllers/utils.py @@ -2,7 +2,7 @@ import json from io import BytesIO, StringIO from math import ceil -from typing import Optional, Union, IO +from typing import Optional, Union, IO, Mapping, Any import aio_pika import aiohttp_jinja2 @@ -11,7 +11,8 @@ from configmanager import Config from multidict import MultiDictProxy -from aleph.types.message_status import MessageProcessingStatus +from aleph.schemas.pending_messages import BasePendingMessage, parse_message +from aleph.types.message_status import MessageProcessingStatus, InvalidMessageException DEFAULT_MESSAGES_PER_PAGE = 20 DEFAULT_PAGE = 1 @@ -194,4 +195,11 @@ async def _process_message(message: aio_pika.abc.AbstractMessage): except asyncio.TimeoutError: return None finally: - await mq_queue.cancel(consumer_tag) \ No newline at end of file + await mq_queue.cancel(consumer_tag) + + +def validate_message_dict(message_dict: Mapping[str, Any]) -> BasePendingMessage: + try: + return parse_message(message_dict) + except InvalidMessageException as e: + raise web.HTTPUnprocessableEntity(body=str(e)) \ No newline at end of file diff --git a/tests/api/test_storage.py b/tests/api/test_storage.py index f13a3b836..425e7d1c9 100644 --- a/tests/api/test_storage.py +++ b/tests/api/test_storage.py @@ -174,7 +174,8 @@ async def add_file_with_message( form_data.add_field("file", file_content) form_data.add_field("message", json_data, content_type="application/json") - form_data.add_field("size", size) + form_data.add_field("file_size", size) + response = await api_client.post(uri, data=form_data) assert response.status == error_code, await response.text() @@ -216,7 +217,8 @@ async def add_file_with_message_202( form_data.add_field("file", file_content) form_data.add_field("message", json_data, content_type="application/json") - form_data.add_field("size", size) + form_data.add_field("file_size", size) + form_data.add_field("sync", "True") response = await api_client.post(uri, data=form_data) assert response.status == error_code, await response.text()