Skip to content

Commit 8de5779

Browse files
committed
refactor: improve protocol availability
1 parent b9010b0 commit 8de5779

File tree

3 files changed

+31
-17
lines changed

3 files changed

+31
-17
lines changed

app/logic/data.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,34 @@
77
COLDWIRE_LEN_OFFSET,
88
COLDWIRE_DATA_SEP
99
)
10+
import secrets
11+
import base64
1012

1113

1214
redis_client = get_redis()
1315

1416

15-
def get_redis_list(client, key: str) -> list:
16-
data = b""
17-
while True:
18-
raw = client.lpop(key)
19-
if raw is None:
20-
break
17+
def b64u_decode(data: str) -> bytes:
18+
padding = 4 - (len(data) % 4)
19+
if padding != 4:
20+
data += "=" * padding
21+
return base64.urlsafe_b64decode(data)
2122

22-
data += raw
2323

24-
return data
24+
def delete_data(user_id: str, acks: list[str]) -> None:
25+
byte_acks = [b64u_decode(p) for p in acks]
26+
27+
values = redis_client.lrange(user_id, 0, -1)
28+
for v in values:
29+
if any(v.startswith(pref) for pref in byte_acks):
30+
res = redis_client.lrem(user_id, 0, v)
2531

2632
def check_new_data(user_id: str) -> bytes:
27-
return get_redis_list(redis_client, user_id)
33+
data = redis_client.lrange(user_id, 0, -1)
34+
if not data:
35+
return b""
36+
37+
return b"".join(data)
2838

2939
def data_processor(user_id: str, recipient: str, blob: bytes) -> None:
3040
if recipient.isdigit():
@@ -41,10 +51,10 @@ def data_processor(user_id: str, recipient: str, blob: bytes) -> None:
4151
if COLDWIRE_DATA_SEP in user_id:
4252
raise ValueError("User ID cannot have null byte!")
4353

44-
payload = user_id + COLDWIRE_DATA_SEP + blob
54+
payload = user_id + COLDWIRE_DATA_SEP + blob
4555
length_prefix = len(payload).to_bytes(COLDWIRE_LEN_OFFSET, "big")
4656

47-
payload = length_prefix + payload
57+
payload = secrets.token_bytes(32) + length_prefix + payload
4858

4959
redis_client.rpush(recipient, payload)
5060

app/logic/federation_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from base64 import b64encode, b64decode
2222
from datetime import datetime, timezone, timedelta
2323
import json
24+
import secrets
2425

2526
redis_client = get_redis()
2627

@@ -117,7 +118,7 @@ def federation_processor(url: str, sender: str, recipient: str, blob: bytes) ->
117118
payload = sender_with_url + COLDWIRE_DATA_SEP + blob
118119
length_prefix = len(payload).to_bytes(COLDWIRE_LEN_OFFSET, "big")
119120

120-
payload = length_prefix + payload
121+
payload = secrets.token_bytes(32) + length_prefix + payload
121122

122123
redis_client.rpush(recipient, payload)
123124

app/routes/data.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1-
from fastapi import APIRouter, Request, HTTPException, Response, Depends, Form, UploadFile, File
2-
from app.logic.data import check_new_data, data_processor
1+
from fastapi import APIRouter, Request, HTTPException, Response, Depends, Form, UploadFile, File, Query
2+
from app.logic.data import check_new_data, delete_data, data_processor
33
from app.utils.jwt import verify_jwt_token
44
from app.core.constants import LONGPOLL_MAX
5+
from typing import Optional
56
import asyncio
67
import json
78

89
router = APIRouter()
910

1011

1112
@router.get("/data/longpoll")
12-
async def get_data_longpoll(request: Request, response: Response, user=Depends(verify_jwt_token)):
13+
async def get_data_longpoll(request: Request, response: Response, acks: Optional[list[str]] = Query(None), user=Depends(verify_jwt_token)):
14+
if acks:
15+
await asyncio.to_thread(delete_data, user["id"], acks)
16+
1317
for _ in range(LONGPOLL_MAX):
1418
if await request.is_disconnected():
15-
# Don't attempt to check for new data if client disconnects before LONGPOLL_MAX seconds
16-
# This is crucial to perserve data as they get deleted after being read
19+
# Don't bother checking for new data if client disconnects before LONGPOLL_MAX seconds
1720
return Response(content=b'', media_type="application/octet-stream")
1821

1922
data = await asyncio.to_thread(check_new_data, user["id"])

0 commit comments

Comments
 (0)