@@ -79,7 +79,13 @@ class TeeingReceiveChannel(ReceiveChannel[T_co], typing.Generic[T_co]):
79
79
80
80
def __init__ (self , source : trio .abc .ReceiveChannel [T_co ], buffer_size : int = 0 , * ,
81
81
_shared : typing .Optional [_TeeingChannelShared [T_co ]] = None ):
82
- super ().__init__ ()
82
+ # Try to copy extra attributes from source channel
83
+ super ().__init__ (
84
+ count = getattr (source , "count" , None ),
85
+ atime = getattr (source , "atime" , None ),
86
+ mtime = getattr (source , "mtime" , None ),
87
+ btime = getattr (source , "btime" , None ),
88
+ )
83
89
84
90
if _shared is not None :
85
91
self ._shared = _shared
@@ -89,11 +95,6 @@ def __init__(self, source: trio.abc.ReceiveChannel[T_co], buffer_size: int = 0,
89
95
self ._shared .bufsize = buffer_size
90
96
self ._shared .source = source
91
97
92
- # Try to copy extra attributes from source channel
93
- self .count = getattr (source , "count" , None )
94
- self .atime = getattr (source , "atime" , None )
95
- self .mtime = getattr (source , "mtime" , None )
96
- self .btime = getattr (source , "btime" , None )
97
98
98
99
self ._closed = False
99
100
@@ -276,8 +277,9 @@ class _WrapingIterReceiveChannelBase(ReceiveChannel[T_co], typing.Generic[T_co,
276
277
_shared : _WrapingChannelShared [U_co ]
277
278
278
279
def __init__ (self , source : typing .Optional [U_co ], * ,
280
+ count : typing .Optional [int ] = None ,
279
281
_shared : typing .Optional [_WrapingChannelShared [U_co ]] = None ):
280
- super ().__init__ ()
282
+ super ().__init__ (count = count )
281
283
282
284
assert source is not None or _shared is not None
283
285
@@ -366,9 +368,8 @@ class _WrapingAsyncIterReceiveChannel(
366
368
_WrapingIterReceiveChannelBase [T_co , typing .AsyncIterator [T_co ]],
367
369
typing .Generic [T_co ]
368
370
):
369
- def __init__ (self , source : typing .Optional [typing .AsyncIterable [T_co ]],
370
- ** kwargs : typing .Any ):
371
- super ().__init__ (source .__aiter__ () if source is not None else None , ** kwargs )
371
+ def __init__ (self , source : typing .Optional [typing .AsyncIterable [T_co ]]):
372
+ super ().__init__ (source .__aiter__ () if source is not None else None )
372
373
373
374
374
375
async def _receive (self ) -> T_co :
@@ -397,11 +398,9 @@ class _WrapingSyncIterReceiveChannel(
397
398
_WrapingIterReceiveChannelBase [T_co , typing .Iterator [T_co ]],
398
399
typing .Generic [T_co ]
399
400
):
400
- def __init__ (self , source : typing .Optional [typing .Iterable [T_co ]],
401
- count_hint : typing .Optional [int ] = None , ** kwargs : typing .Any ):
402
- super ().__init__ (iter (source ) if source is not None else None , ** kwargs )
403
-
404
- self .count = count_hint
401
+ def __init__ (self , source : typing .Optional [typing .Iterable [T_co ]], * ,
402
+ count : typing .Optional [int ] = None ):
403
+ super ().__init__ (iter (source ) if source is not None else None , count = count )
405
404
406
405
407
406
async def _receive (self ) -> T_co :
@@ -465,7 +464,7 @@ async def await_iter_wrapper(channel: typing.Awaitable[T_co]) \
465
464
if isinstance (source2 , collections .abc .Sequence ):
466
465
count = len (source2 )
467
466
468
- return _WrapingSyncIterReceiveChannel (source2 , count )
467
+ return _WrapingSyncIterReceiveChannel (source2 , count = count )
469
468
470
469
assert False , "Unreachable code"
471
470
@@ -510,17 +509,19 @@ class TeeingReceiveStream(ReceiveStream):
510
509
_source : typing .Optional [trio .abc .ReceiveStream ]
511
510
512
511
def __init__ (self , source : trio .abc .ReceiveStream , buffer_size : int = 0 ):
513
- super ().__init__ ()
512
+ # Try to copy extra attributes from source stream
513
+ super ().__init__ (
514
+ size = getattr (source , "size" , None ),
515
+ atime = getattr (source , "atime" , None ),
516
+ mtime = getattr (source , "mtime" , None ),
517
+ btime = getattr (source , "btime" , None ),
518
+ )
514
519
515
520
self ._bufsize = buffer_size
516
521
self ._source = source
517
522
self ._closed = False
518
523
519
- # Try to copy extra attributes from source stream
520
- self .size = getattr (source , "size" , None )
521
- self .atime = getattr (source , "atime" , None )
522
- self .mtime = getattr (source , "mtime" , None )
523
- self .btime = getattr (source , "btime" , None )
524
+
524
525
525
526
# Create nursery without using a `async with`-statement
526
527
# (Only works because the `__aenter__`-call does not actually block on anything.)
@@ -649,8 +650,8 @@ class _WrapingIterReceiveStreamBase(ReceiveStream, typing.Generic[T_co]):
649
650
650
651
_source : typing .Optional [T_co ]
651
652
652
- def __init__ (self , source : T_co ):
653
- super ().__init__ ()
653
+ def __init__ (self , source : T_co , * , size : typing . Optional [ int ] = None ):
654
+ super ().__init__ (size = size )
654
655
655
656
self ._source = source
656
657
@@ -756,10 +757,8 @@ async def _close_source(self) -> None:
756
757
757
758
758
759
class _WrapingSyncIterReceiveStream (_WrapingIterReceiveStreamBase [typing .Iterator [bytes ]]):
759
- def __init__ (self , source : typing .Iterable [bytes ], size_hint : typing .Optional [int ]):
760
- super ().__init__ (iter (source ))
761
-
762
- self .size = size_hint
760
+ def __init__ (self , source : typing .Iterable [bytes ], * , size : typing .Optional [int ] = None ):
761
+ super ().__init__ (iter (source ), size = size )
763
762
764
763
765
764
async def _receive (self , _ : typing .Optional [int ]) -> bytes :
@@ -833,6 +832,6 @@ async def await_iter_wrapper(stream: typing.Awaitable[bytes]) \
833
832
size = source2 .tell () - pos
834
833
source2 .seek (pos , io .SEEK_SET )
835
834
836
- return _WrapingSyncIterReceiveStream (source2 , size )
835
+ return _WrapingSyncIterReceiveStream (source2 , size = size )
837
836
838
837
assert False , "Unreachable code"
0 commit comments