|
| 1 | +from logging import getLogger |
1 | 2 | from typing import Any, Union |
2 | 3 |
|
3 | 4 | from ..exceptions import ConnectionError, InvalidResponse, ResponseError |
|
9 | 10 | class _RESP3Parser(_RESPBase): |
10 | 11 | """RESP3 protocol implementation""" |
11 | 12 |
|
12 | | - def read_response(self, disable_decoding=False): |
| 13 | + def __init__(self, socket_read_size): |
| 14 | + super().__init__(socket_read_size) |
| 15 | + self.push_handler_func = self.handle_push_response |
| 16 | + |
| 17 | + def handle_push_response(self, response): |
| 18 | + logger = getLogger("push_response") |
| 19 | + logger.info("Push response: " + str(response)) |
| 20 | + return response |
| 21 | + |
| 22 | + def read_response(self, disable_decoding=False, push_request=False): |
13 | 23 | pos = self._buffer.get_pos() |
14 | 24 | try: |
15 | | - result = self._read_response(disable_decoding=disable_decoding) |
| 25 | + result = self._read_response( |
| 26 | + disable_decoding=disable_decoding, push_request=push_request |
| 27 | + ) |
16 | 28 | except BaseException: |
17 | 29 | self._buffer.rewind(pos) |
18 | 30 | raise |
19 | 31 | else: |
20 | 32 | self._buffer.purge() |
21 | 33 | return result |
22 | 34 |
|
23 | | - def _read_response(self, disable_decoding=False): |
| 35 | + def _read_response(self, disable_decoding=False, push_request=False): |
24 | 36 | raw = self._buffer.readline() |
25 | 37 | if not raw: |
26 | 38 | raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) |
@@ -77,31 +89,64 @@ def _read_response(self, disable_decoding=False): |
77 | 89 | response = { |
78 | 90 | self._read_response( |
79 | 91 | disable_decoding=disable_decoding |
80 | | - ): self._read_response(disable_decoding=disable_decoding) |
| 92 | + ): self._read_response( |
| 93 | + disable_decoding=disable_decoding, push_request=push_request |
| 94 | + ) |
81 | 95 | for _ in range(int(response)) |
82 | 96 | } |
| 97 | + # push response |
| 98 | + elif byte == b">": |
| 99 | + response = [ |
| 100 | + self._read_response( |
| 101 | + disable_decoding=disable_decoding, push_request=push_request |
| 102 | + ) |
| 103 | + for _ in range(int(response)) |
| 104 | + ] |
| 105 | + res = self.push_handler_func(response) |
| 106 | + if not push_request: |
| 107 | + return self._read_response( |
| 108 | + disable_decoding=disable_decoding, push_request=push_request |
| 109 | + ) |
| 110 | + else: |
| 111 | + return res |
83 | 112 | else: |
84 | 113 | raise InvalidResponse(f"Protocol Error: {raw!r}") |
85 | 114 |
|
86 | 115 | if isinstance(response, bytes) and disable_decoding is False: |
87 | 116 | response = self.encoder.decode(response) |
88 | 117 | return response |
89 | 118 |
|
| 119 | + def set_push_handler(self, push_handler_func): |
| 120 | + self.push_handler_func = push_handler_func |
| 121 | + |
90 | 122 |
|
91 | 123 | class _AsyncRESP3Parser(_AsyncRESPBase): |
92 | | - async def read_response(self, disable_decoding: bool = False): |
| 124 | + def __init__(self, socket_read_size): |
| 125 | + super().__init__(socket_read_size) |
| 126 | + self.push_handler_func = self.handle_push_response |
| 127 | + |
| 128 | + def handle_push_response(self, response): |
| 129 | + logger = getLogger("push_response") |
| 130 | + logger.info("Push response: " + str(response)) |
| 131 | + return response |
| 132 | + |
| 133 | + async def read_response( |
| 134 | + self, disable_decoding: bool = False, push_request: bool = False |
| 135 | + ): |
93 | 136 | if self._chunks: |
94 | 137 | # augment parsing buffer with previously read data |
95 | 138 | self._buffer += b"".join(self._chunks) |
96 | 139 | self._chunks.clear() |
97 | 140 | self._pos = 0 |
98 | | - response = await self._read_response(disable_decoding=disable_decoding) |
| 141 | + response = await self._read_response( |
| 142 | + disable_decoding=disable_decoding, push_request=push_request |
| 143 | + ) |
99 | 144 | # Successfully parsing a response allows us to clear our parsing buffer |
100 | 145 | self._clear() |
101 | 146 | return response |
102 | 147 |
|
103 | 148 | async def _read_response( |
104 | | - self, disable_decoding: bool = False |
| 149 | + self, disable_decoding: bool = False, push_request: bool = False |
105 | 150 | ) -> Union[EncodableT, ResponseError, None]: |
106 | 151 | if not self._stream or not self.encoder: |
107 | 152 | raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) |
@@ -166,9 +211,31 @@ async def _read_response( |
166 | 211 | ) |
167 | 212 | for _ in range(int(response)) |
168 | 213 | } |
| 214 | + # push response |
| 215 | + elif byte == b">": |
| 216 | + response = [ |
| 217 | + ( |
| 218 | + await self._read_response( |
| 219 | + disable_decoding=disable_decoding, push_request=push_request |
| 220 | + ) |
| 221 | + ) |
| 222 | + for _ in range(int(response)) |
| 223 | + ] |
| 224 | + res = self.push_handler_func(response) |
| 225 | + if not push_request: |
| 226 | + return await ( |
| 227 | + self._read_response( |
| 228 | + disable_decoding=disable_decoding, push_request=push_request |
| 229 | + ) |
| 230 | + ) |
| 231 | + else: |
| 232 | + return res |
169 | 233 | else: |
170 | 234 | raise InvalidResponse(f"Protocol Error: {raw!r}") |
171 | 235 |
|
172 | 236 | if isinstance(response, bytes) and disable_decoding is False: |
173 | 237 | response = self.encoder.decode(response) |
174 | 238 | return response |
| 239 | + |
| 240 | + def set_push_handler(self, push_handler_func): |
| 241 | + self.push_handler_func = push_handler_func |
0 commit comments