55import socket
66import threading
77import weakref
8+ from io import SEEK_END
89from itertools import chain
910from queue import Empty , Full , LifoQueue
1011from time import time
11- from typing import Optional
12+ from typing import Optional , Union
1213from urllib .parse import parse_qs , unquote , urlparse
1314
1415from redis .backoff import NoBackoff
@@ -163,39 +164,47 @@ def parse_error(self, response):
163164
164165
165166class SocketBuffer :
166- def __init__ (self , socket , socket_read_size , socket_timeout ):
167+ def __init__ (
168+ self , socket : socket .socket , socket_read_size : int , socket_timeout : float
169+ ):
167170 self ._sock = socket
168171 self .socket_read_size = socket_read_size
169172 self .socket_timeout = socket_timeout
170173 self ._buffer = io .BytesIO ()
171- # number of bytes written to the buffer from the socket
172- self .bytes_written = 0
173- # number of bytes read from the buffer
174- self .bytes_read = 0
175174
176- @property
177- def length (self ):
178- return self .bytes_written - self .bytes_read
175+ def unread_bytes (self ) -> int :
176+ """
177+ Remaining unread length of buffer
178+ """
179+ pos = self ._buffer .tell ()
180+ end = self ._buffer .seek (0 , SEEK_END )
181+ self ._buffer .seek (pos )
182+ return end - pos
179183
180- def _read_from_socket (self , length = None , timeout = SENTINEL , raise_on_timeout = True ):
184+ def _read_from_socket (
185+ self ,
186+ length : Optional [int ] = None ,
187+ timeout : Union [float , object ] = SENTINEL ,
188+ raise_on_timeout : Optional [bool ] = True ,
189+ ) -> bool :
181190 sock = self ._sock
182191 socket_read_size = self .socket_read_size
183- buf = self ._buffer
184- buf .seek (self .bytes_written )
185192 marker = 0
186193 custom_timeout = timeout is not SENTINEL
187194
195+ buf = self ._buffer
196+ current_pos = buf .tell ()
197+ buf .seek (0 , SEEK_END )
198+ if custom_timeout :
199+ sock .settimeout (timeout )
188200 try :
189- if custom_timeout :
190- sock .settimeout (timeout )
191201 while True :
192202 data = self ._sock .recv (socket_read_size )
193203 # an empty string indicates the server shutdown the socket
194204 if isinstance (data , bytes ) and len (data ) == 0 :
195205 raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
196206 buf .write (data )
197207 data_length = len (data )
198- self .bytes_written += data_length
199208 marker += data_length
200209
201210 if length is not None and length > marker :
@@ -215,55 +224,53 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True
215224 return False
216225 raise ConnectionError (f"Error while reading from socket: { ex .args } " )
217226 finally :
227+ buf .seek (current_pos )
218228 if custom_timeout :
219229 sock .settimeout (self .socket_timeout )
220230
221- def can_read (self , timeout ) :
222- return bool (self .length ) or self ._read_from_socket (
231+ def can_read (self , timeout : float ) -> bool :
232+ return bool (self .unread_bytes () ) or self ._read_from_socket (
223233 timeout = timeout , raise_on_timeout = False
224234 )
225235
226- def read (self , length ) :
236+ def read (self , length : int ) -> bytes :
227237 length = length + 2 # make sure to read the \r\n terminator
228- # make sure we've read enough data from the socket
229- if length > self .length :
230- self ._read_from_socket (length - self .length )
231-
232- self ._buffer .seek (self .bytes_read )
238+ # BufferIO will return less than requested if buffer is short
233239 data = self ._buffer .read (length )
234- self .bytes_read += len (data )
240+ missing = length - len (data )
241+ if missing :
242+ # fill up the buffer and read the remainder
243+ self ._read_from_socket (missing )
244+ data += self ._buffer .read (missing )
235245 return data [:- 2 ]
236246
237- def readline (self ):
247+ def readline (self ) -> bytes :
238248 buf = self ._buffer
239- buf .seek (self .bytes_read )
240249 data = buf .readline ()
241250 while not data .endswith (SYM_CRLF ):
242251 # there's more data in the socket that we need
243252 self ._read_from_socket ()
244- buf .seek (self .bytes_read )
245- data = buf .readline ()
253+ data += buf .readline ()
246254
247- self .bytes_read += len (data )
248255 return data [:- 2 ]
249256
250- def get_pos (self ):
257+ def get_pos (self ) -> int :
251258 """
252259 Get current read position
253260 """
254- return self .bytes_read
261+ return self ._buffer . tell ()
255262
256- def rewind (self , pos ) :
263+ def rewind (self , pos : int ) -> None :
257264 """
258265 Rewind the buffer to a specific position, to re-start reading
259266 """
260- self .bytes_read = pos
267+ self ._buffer . seek ( pos )
261268
262- def purge (self ):
269+ def purge (self ) -> None :
263270 """
264271 After a successful read, purge the read part of buffer
265272 """
266- unread = self .bytes_written - self . bytes_read
273+ unread = self .unread_bytes ()
267274
268275 # Only if we have read all of the buffer do we truncate, to
269276 # reduce the amount of memory thrashing. This heuristic
@@ -276,13 +283,10 @@ def purge(self):
276283 view = self ._buffer .getbuffer ()
277284 view [:unread ] = view [- unread :]
278285 self ._buffer .truncate (unread )
279- self .bytes_written = unread
280- self .bytes_read = 0
281286 self ._buffer .seek (0 )
282287
283- def close (self ):
288+ def close (self ) -> None :
284289 try :
285- self .bytes_written = self .bytes_read = 0
286290 self ._buffer .close ()
287291 except Exception :
288292 # issue #633 suggests the purge/close somehow raised a
@@ -498,6 +502,7 @@ def read_response(self, disable_decoding=False):
498502 return response
499503
500504
505+ DefaultParser : BaseParser
501506if HIREDIS_AVAILABLE :
502507 DefaultParser = HiredisParser
503508else :
0 commit comments