|  | 
| 1 | 1 | import anyio | 
| 2 |  | -from typing import AsyncIterator, Optional | 
|  | 2 | +from typing import AsyncIterator, Optional, Tuple | 
| 3 | 3 | from fastapi import WebSocketDisconnect | 
| 4 | 4 | 
 | 
| 5 | 5 | from lib.store import Store | 
| @@ -114,6 +114,123 @@ async def _wait_for_reconnection(self, peer_type: str) -> bool: | 
| 114 | 114 |             self.warning(f"◆ {peer_type.capitalize()} did not reconnect in time") | 
| 115 | 115 |             return False | 
| 116 | 116 | 
 | 
|  | 117 | +    async def _get_next_chunk(self, last_chunk_id: str, is_range_request: bool) -> Optional[Tuple[str, bytes]]: | 
|  | 118 | +        """Get next chunk from stream. Returns None if no more data available.""" | 
|  | 119 | +        if is_range_request: | 
|  | 120 | +            result = await self.store.get_chunk_by_range(last_chunk_id) | 
|  | 121 | +            if not result: | 
|  | 122 | +                if not await self._should_wait_for_sender(): | 
|  | 123 | +                    return None | 
|  | 124 | +                return ('wait', None) | 
|  | 125 | +            return result | 
|  | 126 | +        else: | 
|  | 127 | +            return await self.store.get_next_chunk(self.STREAM_TIMEOUT, last_id=last_chunk_id) | 
|  | 128 | + | 
|  | 129 | +    async def _should_wait_for_sender(self) -> bool: | 
|  | 130 | +        """Check if we should wait for sender to reconnect or give up.""" | 
|  | 131 | +        sender_state = await self.store.get_sender_state() | 
|  | 132 | +        if sender_state == ClientState.COMPLETE: | 
|  | 133 | +            return False | 
|  | 134 | +        elif sender_state == ClientState.DISCONNECTED: | 
|  | 135 | +            if not await self._wait_for_reconnection("sender"): | 
|  | 136 | +                await self.store.set_receiver_state(ClientState.ERROR) | 
|  | 137 | +                return False | 
|  | 138 | +        return True | 
|  | 139 | + | 
|  | 140 | +    def _adjust_chunk_for_range(self, chunk_data: bytes, stream_position: int, | 
|  | 141 | +                                start_byte: int, bytes_sent: int, bytes_to_send: int) -> Tuple[Optional[bytes], int]: | 
|  | 142 | +        """Adjust chunk data for byte range. Returns (data_to_send, new_stream_position).""" | 
|  | 143 | +        new_position = stream_position | 
|  | 144 | + | 
|  | 145 | +        # Skip bytes before start_byte | 
|  | 146 | +        if stream_position < start_byte: | 
|  | 147 | +            skip = min(len(chunk_data), start_byte - stream_position) | 
|  | 148 | +            chunk_data = chunk_data[skip:] | 
|  | 149 | +            new_position += skip | 
|  | 150 | + | 
|  | 151 | +            # Still haven't reached start? Skip entire chunk | 
|  | 152 | +            if new_position < start_byte: | 
|  | 153 | +                new_position += len(chunk_data) | 
|  | 154 | +                return None, new_position | 
|  | 155 | + | 
|  | 156 | +        # Trim to remaining bytes needed | 
|  | 157 | +        if chunk_data and bytes_sent + len(chunk_data) > bytes_to_send: | 
|  | 158 | +            chunk_data = chunk_data[:bytes_to_send - bytes_sent] | 
|  | 159 | + | 
|  | 160 | +        return chunk_data if chunk_data else None, new_position | 
|  | 161 | + | 
|  | 162 | +    async def _save_progress_if_needed(self, stream_position: int, last_chunk_id: str, force: bool = False): | 
|  | 163 | +        """Save download progress periodically or when forced.""" | 
|  | 164 | +        if force or stream_position % (64 * 1024) == 0: | 
|  | 165 | +            await self.store.save_download_progress( | 
|  | 166 | +                bytes_downloaded=stream_position, | 
|  | 167 | +                last_read_id=last_chunk_id | 
|  | 168 | +            ) | 
|  | 169 | +            if force: | 
|  | 170 | +                self.debug(f"▼ Progress saved: {stream_position} bytes") | 
|  | 171 | + | 
|  | 172 | +    async def _initialize_download_state(self, start_byte: int, is_range_request: bool) -> Tuple[int, str]: | 
|  | 173 | +        """Initialize download state and return (stream_position, last_chunk_id).""" | 
|  | 174 | +        stream_position = 0 | 
|  | 175 | +        last_chunk_id = '0' | 
|  | 176 | + | 
|  | 177 | +        if start_byte > 0: | 
|  | 178 | +            self.info(f"▼ Starting download from byte {start_byte}") | 
|  | 179 | +            if not is_range_request: | 
|  | 180 | +                progress = await self.store.get_download_progress() | 
|  | 181 | +                if progress and progress.bytes_downloaded >= start_byte: | 
|  | 182 | +                    last_chunk_id = progress.last_read_id | 
|  | 183 | +                    stream_position = progress.bytes_downloaded | 
|  | 184 | + | 
|  | 185 | +        return stream_position, last_chunk_id | 
|  | 186 | + | 
|  | 187 | +    async def _finalize_download_status(self, bytes_sent: int, stream_position: int, | 
|  | 188 | +                                       start_byte: int, end_byte: Optional[int], | 
|  | 189 | +                                       last_chunk_id: str): | 
|  | 190 | +        """Update final download status based on what was transferred.""" | 
|  | 191 | +        if end_byte is not None: | 
|  | 192 | +            self.info(f"▼ Range download complete ({bytes_sent} bytes from {start_byte}-{end_byte})") | 
|  | 193 | +            return | 
|  | 194 | + | 
|  | 195 | +        total_downloaded = start_byte + bytes_sent | 
|  | 196 | +        if total_downloaded >= self.file.size: | 
|  | 197 | +            self.info("▼ Full download complete") | 
|  | 198 | +            await self.store.set_receiver_state(ClientState.COMPLETE) | 
|  | 199 | +        else: | 
|  | 200 | +            self.info(f"▼ Download incomplete ({total_downloaded}/{self.file.size} bytes)") | 
|  | 201 | +            await self._save_progress_if_needed(stream_position, last_chunk_id, force=True) | 
|  | 202 | + | 
|  | 203 | +    async def _handle_download_disconnect(self, error: Exception, stream_position: int, last_chunk_id: str): | 
|  | 204 | +        """Handle download disconnection errors.""" | 
|  | 205 | +        self.warning(f"▼ Download disconnected: {error}") | 
|  | 206 | +        await self.store.save_download_progress( | 
|  | 207 | +            bytes_downloaded=stream_position, | 
|  | 208 | +            last_read_id=last_chunk_id | 
|  | 209 | +        ) | 
|  | 210 | +        await self.store.set_receiver_state(ClientState.DISCONNECTED) | 
|  | 211 | + | 
|  | 212 | +        if not await self._wait_for_reconnection("receiver"): | 
|  | 213 | +            await self.store.set_receiver_state(ClientState.ERROR) | 
|  | 214 | +            await self.set_interrupted() | 
|  | 215 | + | 
|  | 216 | +    async def _handle_download_timeout(self, stream_position: int, last_chunk_id: str): | 
|  | 217 | +        """Handle download timeout by checking sender state.""" | 
|  | 218 | +        self.info("▼ Timeout waiting for data") | 
|  | 219 | +        sender_state = await self.store.get_sender_state() | 
|  | 220 | +        if sender_state == ClientState.DISCONNECTED: | 
|  | 221 | +            if not await self._wait_for_reconnection("sender"): | 
|  | 222 | +                await self.store.set_receiver_state(ClientState.ERROR) | 
|  | 223 | +                return False | 
|  | 224 | +        else: | 
|  | 225 | +            raise TimeoutError("Download timeout") | 
|  | 226 | +        return True | 
|  | 227 | + | 
|  | 228 | +    async def _handle_download_fatal_error(self, error: Exception): | 
|  | 229 | +        """Handle unexpected download errors.""" | 
|  | 230 | +        self.error(f"▼ Unexpected download error: {error}", exc_info=True) | 
|  | 231 | +        await self.store.set_receiver_state(ClientState.ERROR) | 
|  | 232 | +        await self.set_interrupted() | 
|  | 233 | + | 
| 117 | 234 |     async def collect_upload(self, stream: AsyncIterator[bytes], resume_from: int = 0) -> None: | 
| 118 | 235 |         """Collect file data from sender and store in Redis stream.""" | 
| 119 | 236 |         bytes_uploaded = resume_from | 
| @@ -195,135 +312,61 @@ async def collect_upload(self, stream: AsyncIterator[bytes], resume_from: int = | 
| 195 | 312 | 
 | 
| 196 | 313 |     async def supply_download(self, start_byte: int = 0, end_byte: Optional[int] = None) -> AsyncIterator[bytes]: | 
| 197 | 314 |         """Stream file data to the receiver.""" | 
| 198 |  | -        stream_position = 0  # Current position in the stream we've read to | 
| 199 |  | -        bytes_sent = 0       # Bytes sent to client | 
|  | 315 | +        bytes_sent = 0 | 
| 200 | 316 |         bytes_to_send = (end_byte - start_byte + 1) if end_byte else (self.file.size - start_byte) | 
| 201 |  | -        last_chunk_id = '0' | 
| 202 | 317 |         is_range_request = end_byte is not None | 
| 203 | 318 | 
 | 
|  | 319 | +        stream_position, last_chunk_id = await self._initialize_download_state(start_byte, is_range_request) | 
| 204 | 320 |         await self.store.set_receiver_state(ClientState.ACTIVE) | 
| 205 | 321 | 
 | 
| 206 |  | -        if start_byte > 0: | 
| 207 |  | -            self.info(f"▼ Starting download from byte {start_byte}") | 
| 208 |  | -            if not is_range_request: | 
| 209 |  | -                # For live streams starting mid-file, check if we have previous progress | 
| 210 |  | -                progress = await self.store.get_download_progress() | 
| 211 |  | -                if progress and progress.bytes_downloaded >= start_byte: | 
| 212 |  | -                    last_chunk_id = progress.last_read_id | 
| 213 |  | -                    stream_position = progress.bytes_downloaded | 
| 214 |  | - | 
| 215 | 322 |         self.debug(f"▼ Range request: {start_byte}-{end_byte or 'end'}, to_send: {bytes_to_send}") | 
| 216 | 323 | 
 | 
| 217 | 324 |         try: | 
| 218 | 325 |             while bytes_sent < bytes_to_send: | 
| 219 |  | -                try: | 
| 220 |  | -                    if is_range_request: | 
| 221 |  | -                        # For range requests, use non-blocking reads from existing stream data | 
| 222 |  | -                        result = await self.store.get_chunk_by_range(last_chunk_id) | 
| 223 |  | -                        if not result: | 
| 224 |  | -                            # Check if sender is still uploading | 
| 225 |  | -                            sender_state = await self.store.get_sender_state() | 
| 226 |  | -                            if sender_state == ClientState.COMPLETE: | 
| 227 |  | -                                # Upload is complete but no more chunks - we're done | 
| 228 |  | -                                break | 
| 229 |  | -                            elif sender_state == ClientState.DISCONNECTED: | 
| 230 |  | -                                if not await self._wait_for_reconnection("sender"): | 
| 231 |  | -                                    await self.store.set_receiver_state(ClientState.ERROR) | 
| 232 |  | -                                    return | 
| 233 |  | -                            await anyio.sleep(0.1) | 
| 234 |  | -                            continue | 
| 235 |  | -                        chunk_id, chunk_data = result | 
| 236 |  | -                    else: | 
| 237 |  | -                        # For live streams, use blocking reads | 
| 238 |  | -                        chunk_id, chunk_data = await self.store.get_next_chunk( | 
| 239 |  | -                            timeout=self.STREAM_TIMEOUT, | 
| 240 |  | -                            last_id=last_chunk_id | 
| 241 |  | -                        ) | 
| 242 |  | - | 
| 243 |  | -                    last_chunk_id = chunk_id | 
| 244 |  | - | 
| 245 |  | -                    if chunk_data == self.DONE_FLAG: | 
| 246 |  | -                        self.debug("▼ Done marker received") | 
| 247 |  | -                        await self.store.set_receiver_state(ClientState.COMPLETE) | 
| 248 |  | -                        break | 
| 249 |  | -                    elif chunk_data == self.DEAD_FLAG: | 
| 250 |  | -                        self.warning("▼ Dead marker received") | 
| 251 |  | -                        await self.store.set_receiver_state(ClientState.ERROR) | 
| 252 |  | -                        return | 
|  | 326 | +                # Get next chunk | 
|  | 327 | +                result = await self._get_next_chunk(last_chunk_id, is_range_request) | 
|  | 328 | +                if result is None: | 
|  | 329 | +                    break | 
|  | 330 | +                if result[0] == 'wait': | 
|  | 331 | +                    await anyio.sleep(0.1) | 
|  | 332 | +                    continue | 
| 253 | 333 | 
 | 
| 254 |  | -                    # Skip bytes until we reach start_byte | 
| 255 |  | -                    if stream_position < start_byte: | 
| 256 |  | -                        bytes_in_chunk = len(chunk_data) | 
| 257 |  | -                        skip = min(bytes_in_chunk, start_byte - stream_position) | 
| 258 |  | -                        chunk_data = chunk_data[skip:] | 
| 259 |  | -                        stream_position += skip | 
| 260 |  | - | 
| 261 |  | -                        # If we still haven't reached start_byte, move to next chunk | 
| 262 |  | -                        if stream_position < start_byte: | 
| 263 |  | -                            stream_position += len(chunk_data) | 
| 264 |  | -                            continue | 
| 265 |  | - | 
| 266 |  | -                    # Send only the bytes we need for this range | 
| 267 |  | -                    if len(chunk_data) > 0: | 
| 268 |  | -                        remaining = bytes_to_send - bytes_sent | 
| 269 |  | -                        if len(chunk_data) > remaining: | 
| 270 |  | -                            chunk_data = chunk_data[:remaining] | 
| 271 |  | - | 
| 272 |  | -                        yield chunk_data | 
| 273 |  | -                        bytes_sent += len(chunk_data) | 
| 274 |  | -                        stream_position += len(chunk_data) | 
| 275 |  | - | 
| 276 |  | -                        # Save progress periodically for resumption | 
| 277 |  | -                        if stream_position % (64 * 1024) == 0: | 
| 278 |  | -                            await self.store.save_download_progress( | 
| 279 |  | -                                bytes_downloaded=stream_position, | 
| 280 |  | -                                last_read_id=last_chunk_id | 
| 281 |  | -                            ) | 
| 282 |  | - | 
| 283 |  | -                except TimeoutError: | 
| 284 |  | -                    self.info("▼ Timeout waiting for data") | 
| 285 |  | -                    sender_state = await self.store.get_sender_state() | 
| 286 |  | -                    if sender_state == ClientState.DISCONNECTED: | 
| 287 |  | -                        if not await self._wait_for_reconnection("sender"): | 
| 288 |  | -                            await self.store.set_receiver_state(ClientState.ERROR) | 
| 289 |  | -                            return | 
| 290 |  | -                    else: | 
| 291 |  | -                        raise | 
|  | 334 | +                chunk_id, chunk_data = result | 
|  | 335 | +                last_chunk_id = chunk_id | 
| 292 | 336 | 
 | 
| 293 |  | -            # Determine completion status | 
| 294 |  | -            if is_range_request: | 
| 295 |  | -                # For range requests, just log completion but don't mark transfer as complete | 
| 296 |  | -                # Multiple ranges may be downloading different parts of the same file | 
| 297 |  | -                self.info(f"▼ Range download complete ({bytes_sent} bytes from {start_byte}-{end_byte or 'end'})") | 
| 298 |  | -            else: | 
| 299 |  | -                # For full downloads, check if entire file was downloaded | 
| 300 |  | -                total_downloaded = start_byte + bytes_sent | 
| 301 |  | -                if total_downloaded >= self.file.size: | 
| 302 |  | -                    self.info("▼ Full download complete") | 
|  | 337 | +                # Check for control flags | 
|  | 338 | +                if chunk_data == self.DONE_FLAG: | 
|  | 339 | +                    self.debug("▼ Done marker received") | 
| 303 | 340 |                     await self.store.set_receiver_state(ClientState.COMPLETE) | 
| 304 |  | -                else: | 
| 305 |  | -                    self.info(f"▼ Download incomplete ({total_downloaded}/{self.file.size} bytes)") | 
| 306 |  | -                    await self.store.save_download_progress( | 
| 307 |  | -                        bytes_downloaded=stream_position, | 
| 308 |  | -                        last_read_id=last_chunk_id | 
| 309 |  | -                    ) | 
| 310 |  | - | 
| 311 |  | -        except (ConnectionError, WebSocketDisconnect) as e: | 
| 312 |  | -            self.warning(f"▼ Download disconnected: {e}") | 
| 313 |  | -            await self.store.save_download_progress( | 
| 314 |  | -                bytes_downloaded=stream_position, | 
| 315 |  | -                last_read_id=last_chunk_id | 
|  | 341 | +                    break | 
|  | 342 | +                elif chunk_data == self.DEAD_FLAG: | 
|  | 343 | +                    self.warning("▼ Dead marker received") | 
|  | 344 | +                    await self.store.set_receiver_state(ClientState.ERROR) | 
|  | 345 | +                    return | 
|  | 346 | + | 
|  | 347 | +                # Process chunk for byte range | 
|  | 348 | +                chunk_to_send, stream_position = self._adjust_chunk_for_range( | 
|  | 349 | +                    chunk_data, stream_position, start_byte, bytes_sent, bytes_to_send | 
|  | 350 | +                ) | 
|  | 351 | + | 
|  | 352 | +                # Yield data if we have any | 
|  | 353 | +                if chunk_to_send: | 
|  | 354 | +                    yield chunk_to_send | 
|  | 355 | +                    bytes_sent += len(chunk_to_send) | 
|  | 356 | +                    await self._save_progress_if_needed(stream_position, last_chunk_id) | 
|  | 357 | + | 
|  | 358 | +            # Handle completion | 
|  | 359 | +            await self._finalize_download_status( | 
|  | 360 | +                bytes_sent, stream_position, start_byte, end_byte, last_chunk_id | 
| 316 | 361 |             ) | 
| 317 |  | -            await self.store.set_receiver_state(ClientState.DISCONNECTED) | 
| 318 |  | - | 
| 319 |  | -            if not await self._wait_for_reconnection("receiver"): | 
| 320 |  | -                await self.store.set_receiver_state(ClientState.ERROR) | 
| 321 |  | -                await self.set_interrupted() | 
| 322 | 362 | 
 | 
|  | 363 | +        except TimeoutError: | 
|  | 364 | +            if not await self._handle_download_timeout(stream_position, last_chunk_id): | 
|  | 365 | +                return | 
|  | 366 | +        except (ConnectionError, WebSocketDisconnect) as e: | 
|  | 367 | +            await self._handle_download_disconnect(e, stream_position, last_chunk_id) | 
| 323 | 368 |         except Exception as e: | 
| 324 |  | -            self.error(f"▼ Unexpected download error: {e}", exc_info=True) | 
| 325 |  | -            await self.store.set_receiver_state(ClientState.ERROR) | 
| 326 |  | -            await self.set_interrupted() | 
|  | 369 | +            await self._handle_download_fatal_error(e) | 
| 327 | 370 | 
 | 
| 328 | 371 |     async def finalize_download(self): | 
| 329 | 372 |         """Finalize download and potentially clean up.""" | 
|  | 
0 commit comments