Skip to content

Commit

Permalink
Refactor: storage_add
Browse files Browse the repository at this point in the history
  • Loading branch information
1yam committed Aug 25, 2023
1 parent 678360c commit 6507bd7
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 98 deletions.
13 changes: 3 additions & 10 deletions src/aleph/web/controllers/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
get_mq_channel_from_request,
)
from aleph.web.controllers.utils import mq_make_aleph_message_topic_queue, processing_status_to_http_status, \
mq_read_one_message
mq_read_one_message, validate_message_dict

LOGGER = logging.getLogger(__name__)

Expand All @@ -46,13 +46,6 @@ def from_failures(cls, failed_publications: List[Protocol]):
return cls(status=status, failed=failed_publications)


def _validate_message_dict(message_dict: Mapping[str, Any]) -> BasePendingMessage:
try:
return parse_message(message_dict)
except InvalidMessageException as e:
raise web.HTTPUnprocessableEntity(body=str(e))


def _validate_request_data(config: Config, request_data: Dict) -> None:
"""
Validates the content of a JSON pubsub message depending on the channel
Expand Down Expand Up @@ -84,7 +77,7 @@ def _validate_request_data(config: Config, request_data: Dict) -> None:
reason="'data': must be deserializable as JSON."
)

_validate_message_dict(message_dict)
validate_message_dict(message_dict)


async def _pub_on_p2p_topics(
Expand Down Expand Up @@ -160,7 +153,7 @@ async def pub_message(request: web.Request):
except ValidationError as e:
raise web.HTTPUnprocessableEntity(body=e.json(indent=4))

pending_message = _validate_message_dict(request_data.message_dict)
pending_message = validate_message_dict(request_data.message_dict)

# In sync mode, wait for a message processing event. We need to create the queue
# before publishing the message on P2P topics in order to guarantee that the event
Expand Down
165 changes: 82 additions & 83 deletions src/aleph/web/controllers/storage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import base64
import datetime as dt
import functools
Expand All @@ -6,9 +7,11 @@
from typing import Union, Tuple

import aio_pika
from aiohttp.web_response import Response
from eth_account import Account
from eth_account.messages import encode_defunct
from mypy.dmypy_server import MiB
from sqlalchemy.orm import Session

from aleph.chains.chain_service import ChainService
from aleph.chains.common import get_verification_buffer
Expand All @@ -22,8 +25,8 @@
from aleph.db.models import PendingMessageDb
from aleph.exceptions import AlephStorageException, UnknownHashError
from aleph.toolkit import json
from aleph.types.db_session import DbSession
from aleph.types.message_status import MessageProcessingStatus, InvalidSignature
from aleph.types.db_session import DbSession, DbSessionFactory
from aleph.types.message_status import MessageProcessingStatus, InvalidSignature, InvalidMessageException
from aleph.utils import run_in_executor, item_type_from_hash
from aleph.web.controllers.app_state_getters import (
get_session_factory_from_request,
Expand All @@ -37,9 +40,9 @@
multidict_proxy_to_io,
mq_make_aleph_message_topic_queue,
processing_status_to_http_status,
mq_read_one_message,
mq_read_one_message, validate_message_dict,
)
from aleph.schemas.pending_messages import BasePendingMessage
from aleph.schemas.pending_messages import BasePendingMessage, PendingStoreMessage

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,7 +86,7 @@ async def add_storage_json_controller(request: web.Request):


async def get_message_content(
post_data: MultiDictProxy[Union[str, bytes, FileField]]
post_data: MultiDictProxy[Union[str, bytes, FileField]]
) -> Tuple[dict, int]:
message_bytearray = post_data.get("message", b"")
file_size = post_data.get("file_size") or 0
Expand All @@ -98,109 +101,105 @@ async def get_message_content(
return message_dict, int(str(file_size))


async def verify_and_handle_request(
pending_message_db,
file_io,
message,
size,
session_factory,
chain_service: ChainService,
):
async def _verify_message_signature(pending_message_db: PendingMessageDb, chain_service: ChainService) -> None:
try:
await chain_service.verify_signature(pending_message_db)
except InvalidSignature:
raise web.HTTPForbidden()


async def _verify_user_balance(pending_message_db: PendingMessageDb, session: DbSession, size: int) -> None:
current_balance = get_total_balance(
session=session, address=pending_message_db.sender
)
required_balance = (size / MiB) / 3
# Need to merge to get this functions
# current_cost_for_user = get_total_cost_for_address(
# session=session, address=pending_message_db.sender
# )
if current_balance < required_balance:
raise web.HTTPPaymentRequired

async def _verify_user_file(message : dict, size : int, file_io) -> None:
file_io.seek(0)
content = file_io.read(size)
item_content = json.loads(message["item_content"])
actual_item_hash = sha256(content).hexdigest()
c_item_hash = item_content["item_hash"]
client_item_hash = item_content["item_hash"]
if len(content) > (1000 * MiB):
raise web.HTTPRequestEntityTooLarge(actual_size=len(content), max_size=(1000 * MiB))
elif actual_item_hash != client_item_hash:
raise web.HTTPUnprocessableEntity()


async def storage_add_file_with_message(request: web.Request, session: DbSession, chain_service, post, file_io, sync = False):
config = get_config_from_request(request)
message, size = await get_message_content(post)
mq_queue = None

try:
await chain_service.verify_signature(pending_message_db)
except InvalidSignature:
output = {"status": "Forbidden"}
return web.json_response(output, status=403)
valid_message: BasePendingMessage = validate_message_dict(message)
except InvalidMessageException:
output = {"status": "rejected"}
return web.json_response(output, status=422)

with session_factory() as session:
current_balance = get_total_balance(
session=session, address=pending_message_db.sender
pending_store_message = PendingStoreMessage.parse_obj(valid_message)
pending_message_db = PendingMessageDb.from_obj(
obj=pending_store_message, reception_time=dt.datetime.now(), fetched=True
)
await _verify_message_signature(
pending_message_db=pending_message_db, chain_service=chain_service
)
await _verify_user_balance(session=session, pending_message_db=pending_message_db, size=size)
await _verify_user_file(message=message, size=size, file_io=file_io)

if sync:
mq_channel = await get_mq_channel_from_request(request, logger=logger)
mq_queue = await mq_make_aleph_message_topic_queue(
channel=mq_channel,
config=config,
routing_key=f"*.{pending_message_db.item_hash}",
)
if current_balance < len(content):
output = {"status": "Payment Required"}
return web.json_response(output, status=402)
elif len(content) > (1000 * MiB):
output = {"status": "Payload too large"}
return web.json_response(output, status=413)
elif actual_item_hash != c_item_hash:
output = {"status": "Unprocessable Content"}
return web.json_response(output, status=422)

elif len(content) > 25 * MiB and not message:
output = {"status": "Unauthorized"}
return web.json_response(output, status=401)
else:
return None
session.add(pending_message_db)
session.commit()
if sync:
mq_message = await mq_read_one_message(mq_queue, 30)

if mq_message is None:
raise web.HTTPAccepted()
if mq_message.routing_key is not None:
status_str, _item_hash = mq_message.routing_key.split(".")
processing_status = MessageProcessingStatus(status_str)
status_code = processing_status_to_http_status(processing_status)
return web.json_response(status=status_code, text=_item_hash)

async def storage_add_file_with_message(request: web.Request):

async def storage_add_file(request: web.Request):
storage_service = get_storage_service_from_request(request)
session_factory = get_session_factory_from_request(request)
# TODO : Add chainservice to ccn_api_client to be able to call get_chainservice_from_request
chain_service: ChainService = ChainService(
session_factory=session_factory, storage_service=storage_service
)
config = get_config_from_request(request)

post = await request.post()
file_io = multidict_proxy_to_io(post)
message, size = await get_message_content(post)
pending_message_db = PendingMessageDb.from_message_dict(
message_dict=message, reception_time=dt.datetime.now(), fetched=True
)
mq_channel = await get_mq_channel_from_request(request, logger=logger)
mq_queue = await mq_make_aleph_message_topic_queue(
channel=mq_channel,
config=config,
routing_key=f"*.{pending_message_db.item_hash}",
)

is_valid_message = await verify_and_handle_request(
pending_message_db,
file_io,
message,
size,
session_factory,
chain_service,
)
if is_valid_message is not None:
return is_valid_message

with session_factory() as session:
file_hash = await storage_service.add_file(
session=session, fileobject=file_io, engine=ItemType.storage
)
session.add(pending_message_db)
session.commit()
mq_message = await mq_read_one_message(mq_queue, 30)

if mq_message is None:
output = {"status": "accepted"}
return web.json_response(output, status=202)
if mq_message.routing_key is not None:
status_str, _item_hash = mq_message.routing_key.split(".")
processing_status = MessageProcessingStatus(status_str)
status_code = processing_status_to_http_status(processing_status)
return web.json_response(status=status_code, text=file_hash)
post = await request.post()
sync = False

if post.get("message", b"") is None:
raise web.HTTPUnauthorized()

async def storage_add_file(request: web.Request):
post = await request.post()
if post.get("message", b"") is not None and post.get("file_size") is not None:
return await storage_add_file_with_message(request)
sync_value = post.get("sync")
if sync_value is not None:
sync = ast.literal_eval(sync_value)

storage_service = get_storage_service_from_request(request)
session_factory = get_session_factory_from_request(request)
file_io = multidict_proxy_to_io(post)
with session_factory() as session:
file_hash = await storage_service.add_file(
session=session, fileobject=file_io, engine=ItemType.storage
)
if post.get("message", b"") is not None and post.get("file_size") is not None:
await storage_add_file_with_message(request, session, chain_service, post, file_io, sync)
session.commit()
output = {"status": "success", "hash": file_hash}
return web.json_response(output)
Expand Down
14 changes: 11 additions & 3 deletions src/aleph/web/controllers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from io import BytesIO, StringIO
from math import ceil
from typing import Optional, Union, IO
from typing import Optional, Union, IO, Mapping, Any

import aio_pika
import aiohttp_jinja2
Expand All @@ -11,7 +11,8 @@
from configmanager import Config
from multidict import MultiDictProxy

from aleph.types.message_status import MessageProcessingStatus
from aleph.schemas.pending_messages import BasePendingMessage, parse_message
from aleph.types.message_status import MessageProcessingStatus, InvalidMessageException

DEFAULT_MESSAGES_PER_PAGE = 20
DEFAULT_PAGE = 1
Expand Down Expand Up @@ -194,4 +195,11 @@ async def _process_message(message: aio_pika.abc.AbstractMessage):
except asyncio.TimeoutError:
return None
finally:
await mq_queue.cancel(consumer_tag)
await mq_queue.cancel(consumer_tag)


def validate_message_dict(message_dict: Mapping[str, Any]) -> BasePendingMessage:
try:
return parse_message(message_dict)
except InvalidMessageException as e:
raise web.HTTPUnprocessableEntity(body=str(e))
6 changes: 4 additions & 2 deletions tests/api/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ async def add_file_with_message(

form_data.add_field("file", file_content)
form_data.add_field("message", json_data, content_type="application/json")
form_data.add_field("size", size)
form_data.add_field("file_size", size)

response = await api_client.post(uri, data=form_data)
assert response.status == error_code, await response.text()

Expand Down Expand Up @@ -216,7 +217,8 @@ async def add_file_with_message_202(

form_data.add_field("file", file_content)
form_data.add_field("message", json_data, content_type="application/json")
form_data.add_field("size", size)
form_data.add_field("file_size", size)
form_data.add_field("sync", "True")
response = await api_client.post(uri, data=form_data)
assert response.status == error_code, await response.text()

Expand Down

0 comments on commit 6507bd7

Please sign in to comment.