Skip to content

Commit e77a8ca

Browse files
committed
Big tests refacto
1 parent df6e098 commit e77a8ca

File tree

11 files changed

+424
-416
lines changed

11 files changed

+424
-416
lines changed

lib/transfer.py

Lines changed: 158 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import anyio
2-
from typing import AsyncIterator, Optional
2+
from typing import AsyncIterator, Optional, Tuple
33
from fastapi import WebSocketDisconnect
44

55
from lib.store import Store
@@ -114,6 +114,123 @@ async def _wait_for_reconnection(self, peer_type: str) -> bool:
114114
self.warning(f"◆ {peer_type.capitalize()} did not reconnect in time")
115115
return False
116116

117+
async def _get_next_chunk(self, last_chunk_id: str, is_range_request: bool) -> Optional[Tuple[str, bytes]]:
118+
"""Get next chunk from stream. Returns None if no more data available."""
119+
if is_range_request:
120+
result = await self.store.get_chunk_by_range(last_chunk_id)
121+
if not result:
122+
if not await self._should_wait_for_sender():
123+
return None
124+
return ('wait', None)
125+
return result
126+
else:
127+
return await self.store.get_next_chunk(self.STREAM_TIMEOUT, last_id=last_chunk_id)
128+
129+
async def _should_wait_for_sender(self) -> bool:
130+
"""Check if we should wait for sender to reconnect or give up."""
131+
sender_state = await self.store.get_sender_state()
132+
if sender_state == ClientState.COMPLETE:
133+
return False
134+
elif sender_state == ClientState.DISCONNECTED:
135+
if not await self._wait_for_reconnection("sender"):
136+
await self.store.set_receiver_state(ClientState.ERROR)
137+
return False
138+
return True
139+
140+
def _adjust_chunk_for_range(self, chunk_data: bytes, stream_position: int,
141+
start_byte: int, bytes_sent: int, bytes_to_send: int) -> Tuple[Optional[bytes], int]:
142+
"""Adjust chunk data for byte range. Returns (data_to_send, new_stream_position)."""
143+
new_position = stream_position
144+
145+
# Skip bytes before start_byte
146+
if stream_position < start_byte:
147+
skip = min(len(chunk_data), start_byte - stream_position)
148+
chunk_data = chunk_data[skip:]
149+
new_position += skip
150+
151+
# Still haven't reached start? Skip entire chunk
152+
if new_position < start_byte:
153+
new_position += len(chunk_data)
154+
return None, new_position
155+
156+
# Trim to remaining bytes needed
157+
if chunk_data and bytes_sent + len(chunk_data) > bytes_to_send:
158+
chunk_data = chunk_data[:bytes_to_send - bytes_sent]
159+
160+
return chunk_data if chunk_data else None, new_position
161+
162+
async def _save_progress_if_needed(self, stream_position: int, last_chunk_id: str, force: bool = False):
163+
"""Save download progress periodically or when forced."""
164+
if force or stream_position % (64 * 1024) == 0:
165+
await self.store.save_download_progress(
166+
bytes_downloaded=stream_position,
167+
last_read_id=last_chunk_id
168+
)
169+
if force:
170+
self.debug(f"▼ Progress saved: {stream_position} bytes")
171+
172+
async def _initialize_download_state(self, start_byte: int, is_range_request: bool) -> Tuple[int, str]:
173+
"""Initialize download state and return (stream_position, last_chunk_id)."""
174+
stream_position = 0
175+
last_chunk_id = '0'
176+
177+
if start_byte > 0:
178+
self.info(f"▼ Starting download from byte {start_byte}")
179+
if not is_range_request:
180+
progress = await self.store.get_download_progress()
181+
if progress and progress.bytes_downloaded >= start_byte:
182+
last_chunk_id = progress.last_read_id
183+
stream_position = progress.bytes_downloaded
184+
185+
return stream_position, last_chunk_id
186+
187+
async def _finalize_download_status(self, bytes_sent: int, stream_position: int,
188+
start_byte: int, end_byte: Optional[int],
189+
last_chunk_id: str):
190+
"""Update final download status based on what was transferred."""
191+
if end_byte is not None:
192+
self.info(f"▼ Range download complete ({bytes_sent} bytes from {start_byte}-{end_byte})")
193+
return
194+
195+
total_downloaded = start_byte + bytes_sent
196+
if total_downloaded >= self.file.size:
197+
self.info("▼ Full download complete")
198+
await self.store.set_receiver_state(ClientState.COMPLETE)
199+
else:
200+
self.info(f"▼ Download incomplete ({total_downloaded}/{self.file.size} bytes)")
201+
await self._save_progress_if_needed(stream_position, last_chunk_id, force=True)
202+
203+
async def _handle_download_disconnect(self, error: Exception, stream_position: int, last_chunk_id: str):
204+
"""Handle download disconnection errors."""
205+
self.warning(f"▼ Download disconnected: {error}")
206+
await self.store.save_download_progress(
207+
bytes_downloaded=stream_position,
208+
last_read_id=last_chunk_id
209+
)
210+
await self.store.set_receiver_state(ClientState.DISCONNECTED)
211+
212+
if not await self._wait_for_reconnection("receiver"):
213+
await self.store.set_receiver_state(ClientState.ERROR)
214+
await self.set_interrupted()
215+
216+
async def _handle_download_timeout(self, stream_position: int, last_chunk_id: str):
217+
"""Handle download timeout by checking sender state."""
218+
self.info("▼ Timeout waiting for data")
219+
sender_state = await self.store.get_sender_state()
220+
if sender_state == ClientState.DISCONNECTED:
221+
if not await self._wait_for_reconnection("sender"):
222+
await self.store.set_receiver_state(ClientState.ERROR)
223+
return False
224+
else:
225+
raise TimeoutError("Download timeout")
226+
return True
227+
228+
async def _handle_download_fatal_error(self, error: Exception):
229+
"""Handle unexpected download errors."""
230+
self.error(f"▼ Unexpected download error: {error}", exc_info=True)
231+
await self.store.set_receiver_state(ClientState.ERROR)
232+
await self.set_interrupted()
233+
117234
async def collect_upload(self, stream: AsyncIterator[bytes], resume_from: int = 0) -> None:
118235
"""Collect file data from sender and store in Redis stream."""
119236
bytes_uploaded = resume_from
@@ -195,135 +312,61 @@ async def collect_upload(self, stream: AsyncIterator[bytes], resume_from: int =
195312

196313
async def supply_download(self, start_byte: int = 0, end_byte: Optional[int] = None) -> AsyncIterator[bytes]:
197314
"""Stream file data to the receiver."""
198-
stream_position = 0 # Current position in the stream we've read to
199-
bytes_sent = 0 # Bytes sent to client
315+
bytes_sent = 0
200316
bytes_to_send = (end_byte - start_byte + 1) if end_byte else (self.file.size - start_byte)
201-
last_chunk_id = '0'
202317
is_range_request = end_byte is not None
203318

319+
stream_position, last_chunk_id = await self._initialize_download_state(start_byte, is_range_request)
204320
await self.store.set_receiver_state(ClientState.ACTIVE)
205321

206-
if start_byte > 0:
207-
self.info(f"▼ Starting download from byte {start_byte}")
208-
if not is_range_request:
209-
# For live streams starting mid-file, check if we have previous progress
210-
progress = await self.store.get_download_progress()
211-
if progress and progress.bytes_downloaded >= start_byte:
212-
last_chunk_id = progress.last_read_id
213-
stream_position = progress.bytes_downloaded
214-
215322
self.debug(f"▼ Range request: {start_byte}-{end_byte or 'end'}, to_send: {bytes_to_send}")
216323

217324
try:
218325
while bytes_sent < bytes_to_send:
219-
try:
220-
if is_range_request:
221-
# For range requests, use non-blocking reads from existing stream data
222-
result = await self.store.get_chunk_by_range(last_chunk_id)
223-
if not result:
224-
# Check if sender is still uploading
225-
sender_state = await self.store.get_sender_state()
226-
if sender_state == ClientState.COMPLETE:
227-
# Upload is complete but no more chunks - we're done
228-
break
229-
elif sender_state == ClientState.DISCONNECTED:
230-
if not await self._wait_for_reconnection("sender"):
231-
await self.store.set_receiver_state(ClientState.ERROR)
232-
return
233-
await anyio.sleep(0.1)
234-
continue
235-
chunk_id, chunk_data = result
236-
else:
237-
# For live streams, use blocking reads
238-
chunk_id, chunk_data = await self.store.get_next_chunk(
239-
timeout=self.STREAM_TIMEOUT,
240-
last_id=last_chunk_id
241-
)
242-
243-
last_chunk_id = chunk_id
244-
245-
if chunk_data == self.DONE_FLAG:
246-
self.debug("▼ Done marker received")
247-
await self.store.set_receiver_state(ClientState.COMPLETE)
248-
break
249-
elif chunk_data == self.DEAD_FLAG:
250-
self.warning("▼ Dead marker received")
251-
await self.store.set_receiver_state(ClientState.ERROR)
252-
return
326+
# Get next chunk
327+
result = await self._get_next_chunk(last_chunk_id, is_range_request)
328+
if result is None:
329+
break
330+
if result[0] == 'wait':
331+
await anyio.sleep(0.1)
332+
continue
253333

254-
# Skip bytes until we reach start_byte
255-
if stream_position < start_byte:
256-
bytes_in_chunk = len(chunk_data)
257-
skip = min(bytes_in_chunk, start_byte - stream_position)
258-
chunk_data = chunk_data[skip:]
259-
stream_position += skip
260-
261-
# If we still haven't reached start_byte, move to next chunk
262-
if stream_position < start_byte:
263-
stream_position += len(chunk_data)
264-
continue
265-
266-
# Send only the bytes we need for this range
267-
if len(chunk_data) > 0:
268-
remaining = bytes_to_send - bytes_sent
269-
if len(chunk_data) > remaining:
270-
chunk_data = chunk_data[:remaining]
271-
272-
yield chunk_data
273-
bytes_sent += len(chunk_data)
274-
stream_position += len(chunk_data)
275-
276-
# Save progress periodically for resumption
277-
if stream_position % (64 * 1024) == 0:
278-
await self.store.save_download_progress(
279-
bytes_downloaded=stream_position,
280-
last_read_id=last_chunk_id
281-
)
282-
283-
except TimeoutError:
284-
self.info("▼ Timeout waiting for data")
285-
sender_state = await self.store.get_sender_state()
286-
if sender_state == ClientState.DISCONNECTED:
287-
if not await self._wait_for_reconnection("sender"):
288-
await self.store.set_receiver_state(ClientState.ERROR)
289-
return
290-
else:
291-
raise
334+
chunk_id, chunk_data = result
335+
last_chunk_id = chunk_id
292336

293-
# Determine completion status
294-
if is_range_request:
295-
# For range requests, just log completion but don't mark transfer as complete
296-
# Multiple ranges may be downloading different parts of the same file
297-
self.info(f"▼ Range download complete ({bytes_sent} bytes from {start_byte}-{end_byte or 'end'})")
298-
else:
299-
# For full downloads, check if entire file was downloaded
300-
total_downloaded = start_byte + bytes_sent
301-
if total_downloaded >= self.file.size:
302-
self.info("▼ Full download complete")
337+
# Check for control flags
338+
if chunk_data == self.DONE_FLAG:
339+
self.debug("▼ Done marker received")
303340
await self.store.set_receiver_state(ClientState.COMPLETE)
304-
else:
305-
self.info(f"▼ Download incomplete ({total_downloaded}/{self.file.size} bytes)")
306-
await self.store.save_download_progress(
307-
bytes_downloaded=stream_position,
308-
last_read_id=last_chunk_id
309-
)
310-
311-
except (ConnectionError, WebSocketDisconnect) as e:
312-
self.warning(f"▼ Download disconnected: {e}")
313-
await self.store.save_download_progress(
314-
bytes_downloaded=stream_position,
315-
last_read_id=last_chunk_id
341+
break
342+
elif chunk_data == self.DEAD_FLAG:
343+
self.warning("▼ Dead marker received")
344+
await self.store.set_receiver_state(ClientState.ERROR)
345+
return
346+
347+
# Process chunk for byte range
348+
chunk_to_send, stream_position = self._adjust_chunk_for_range(
349+
chunk_data, stream_position, start_byte, bytes_sent, bytes_to_send
350+
)
351+
352+
# Yield data if we have any
353+
if chunk_to_send:
354+
yield chunk_to_send
355+
bytes_sent += len(chunk_to_send)
356+
await self._save_progress_if_needed(stream_position, last_chunk_id)
357+
358+
# Handle completion
359+
await self._finalize_download_status(
360+
bytes_sent, stream_position, start_byte, end_byte, last_chunk_id
316361
)
317-
await self.store.set_receiver_state(ClientState.DISCONNECTED)
318-
319-
if not await self._wait_for_reconnection("receiver"):
320-
await self.store.set_receiver_state(ClientState.ERROR)
321-
await self.set_interrupted()
322362

363+
except TimeoutError:
364+
if not await self._handle_download_timeout(stream_position, last_chunk_id):
365+
return
366+
except (ConnectionError, WebSocketDisconnect) as e:
367+
await self._handle_download_disconnect(e, stream_position, last_chunk_id)
323368
except Exception as e:
324-
self.error(f"▼ Unexpected download error: {e}", exc_info=True)
325-
await self.store.set_receiver_state(ClientState.ERROR)
326-
await self.set_interrupted()
369+
await self._handle_download_fatal_error(e)
327370

328371
async def finalize_download(self):
329372
"""Finalize download and potentially clean up."""

tests/conftest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import AsyncIterator
1010

1111
from tests.ws_client import WebSocketTestClient
12+
from tests.http_client import HTTPTestClient
1213
from lib.logging import get_logger
1314
log = get_logger('setup-tests')
1415

@@ -123,9 +124,9 @@ def live_server():
123124

124125

125126
@pytest.fixture
126-
async def test_client(live_server: str) -> AsyncIterator[httpx.AsyncClient]:
127-
"""HTTP client for testing."""
128-
async with httpx.AsyncClient(base_url=f'http://{live_server}') as client:
127+
async def test_client(live_server: str) -> AsyncIterator[HTTPTestClient]:
128+
"""HTTP client for testing with helper methods."""
129+
async with HTTPTestClient(base_url=f'http://{live_server}') as client:
129130
print()
130131
yield client
131132

@@ -139,7 +140,7 @@ async def websocket_client(live_server: str):
139140

140141

141142
@pytest.mark.anyio
142-
async def test_mocks(test_client: httpx.AsyncClient) -> None:
143+
async def test_mocks(test_client: HTTPTestClient) -> None:
143144
response = await test_client.get("/nonexistent-endpoint")
144145
assert response.status_code == 404, "Expected 404 for nonexistent endpoint"
145146

tests/helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import anyio
2+
import httpx
23
from string import ascii_letters
34
from itertools import islice, repeat, chain
4-
from typing import Tuple, Iterable, AsyncIterator
5+
from typing import Tuple, Iterable, AsyncIterator, Dict, Any, Optional
56
from annotated_types import T
67
import anyio.lowlevel
78

@@ -26,3 +27,8 @@ async def chunks(data: bytes, chunk_size: int = 1024) -> AsyncIterator[bytes]:
2627
for i in range(0, len(data), chunk_size):
2728
yield data[i:i + chunk_size]
2829
await anyio.lowlevel.checkpoint()
30+
31+
32+
# All WebSocket and HTTP helper functions have been moved to the respective client classes:
33+
# - WebSocket helpers are now methods in WebSocketWrapper (tests/ws_client.py)
34+
# - HTTP helpers are now methods in HTTPTestClient (tests/http_client.py)

0 commit comments

Comments
 (0)