55
66import asyncio
77import atexit
8+ import contextvars
89import io
910import os
1011import sys
1112import threading
1213import traceback
1314import warnings
1415from binascii import b2a_hex
15- from collections import deque
16+ from collections import defaultdict , deque
1617from io import StringIO , TextIOBase
1718from threading import local
1819from typing import Any , Callable , Deque , Dict , Optional
@@ -412,7 +413,7 @@ def __init__(
412413 name : str {'stderr', 'stdout'}
413414 the name of the standard stream to replace
414415 pipe : object
415- the pip object
416+ the pipe object
416417 echo : bool
417418 whether to echo output
418419 watchfd : bool (default, True)
@@ -446,13 +447,19 @@ def __init__(
446447 self .pub_thread = pub_thread
447448 self .name = name
448449 self .topic = b"stream." + name .encode ()
449- self .parent_header = {}
450+ self ._parent_header : contextvars .ContextVar [Dict [str , Any ]] = contextvars .ContextVar (
451+ "parent_header"
452+ )
453+ self ._parent_header .set ({})
454+ self ._thread_to_parent = {}
455+ self ._thread_to_parent_header = {}
456+ self ._parent_header_global = {}
450457 self ._master_pid = os .getpid ()
451458 self ._flush_pending = False
452459 self ._subprocess_flush_pending = False
453460 self ._io_loop = pub_thread .io_loop
454461 self ._buffer_lock = threading .RLock ()
455- self ._buffer = StringIO ( )
462+ self ._buffers = defaultdict ( StringIO )
456463 self .echo = None
457464 self ._isatty = bool (isatty )
458465 self ._should_watch = False
@@ -495,6 +502,30 @@ def __init__(
495502 msg = "echo argument must be a file-like object"
496503 raise ValueError (msg )
497504
505+ @property
506+ def parent_header (self ):
507+ try :
508+ # asyncio-specific
509+ return self ._parent_header .get ()
510+ except LookupError :
511+ try :
512+ # thread-specific
513+ identity = threading .current_thread ().ident
514+ # retrieve the outermost (oldest ancestor,
515+ # discounting the kernel thread) thread identity
516+ while identity in self ._thread_to_parent :
517+ identity = self ._thread_to_parent [identity ]
518+ # use the header of the oldest ancestor
519+ return self ._thread_to_parent_header [identity ]
520+ except KeyError :
521+ # global (fallback)
522+ return self ._parent_header_global
523+
524+ @parent_header .setter
525+ def parent_header (self , value ):
526+ self ._parent_header_global = value
527+ return self ._parent_header .set (value )
528+
498529 def isatty (self ):
499530 """Return a bool indicating whether this is an 'interactive' stream.
500531
@@ -598,28 +629,28 @@ def _flush(self):
598629 if self .echo is not sys .__stderr__ :
599630 print (f"Flush failed: { e } " , file = sys .__stderr__ )
600631
601- data = self ._flush_buffer ()
602- if data :
603- # FIXME: this disables Session's fork-safe check,
604- # since pub_thread is itself fork-safe.
605- # There should be a better way to do this.
606- self .session .pid = os .getpid ()
607- content = {"name" : self .name , "text" : data }
608- msg = self .session .msg ("stream" , content , parent = self . parent_header )
609-
610- # Each transform either returns a new
611- # message or None. If None is returned,
612- # the message has been 'used' and we return.
613- for hook in self ._hooks :
614- msg = hook (msg )
615- if msg is None :
616- return
617-
618- self .session .send (
619- self .pub_thread ,
620- msg ,
621- ident = self .topic ,
622- )
632+ for parent , data in self ._flush_buffers ():
633+ if data :
634+ # FIXME: this disables Session's fork-safe check,
635+ # since pub_thread is itself fork-safe.
636+ # There should be a better way to do this.
637+ self .session .pid = os .getpid ()
638+ content = {"name" : self .name , "text" : data }
639+ msg = self .session .msg ("stream" , content , parent = parent )
640+
641+ # Each transform either returns a new
642+ # message or None. If None is returned,
643+ # the message has been 'used' and we return.
644+ for hook in self ._hooks :
645+ msg = hook (msg )
646+ if msg is None :
647+ return
648+
649+ self .session .send (
650+ self .pub_thread ,
651+ msg ,
652+ ident = self .topic ,
653+ )
623654
624655 def write (self , string : str ) -> Optional [int ]: # type:ignore[override]
625656 """Write to current stream after encoding if necessary
@@ -630,6 +661,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
630661 number of items from input parameter written to stream.
631662
632663 """
664+ parent = self .parent_header
633665
634666 if not isinstance (string , str ):
635667 msg = f"write() argument must be str, not { type (string )} " # type:ignore[unreachable]
@@ -649,7 +681,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
649681 is_child = not self ._is_master_process ()
650682 # only touch the buffer in the IO thread to avoid races
651683 with self ._buffer_lock :
652- self ._buffer .write (string )
684+ self ._buffers [ frozenset ( parent . items ())] .write (string )
653685 if is_child :
654686 # mp.Pool cannot be trusted to flush promptly (or ever),
655687 # and this helps.
@@ -675,19 +707,20 @@ def writable(self):
675707 """Test whether the stream is writable."""
676708 return True
677709
678- def _flush_buffer (self ):
710+ def _flush_buffers (self ):
679711 """clear the current buffer and return the current buffer data."""
680- buf = self ._rotate_buffer ()
681- data = buf .getvalue ()
682- buf .close ()
683- return data
712+ buffers = self ._rotate_buffers ()
713+ for frozen_parent , buffer in buffers .items ():
714+ data = buffer .getvalue ()
715+ buffer .close ()
716+ yield dict (frozen_parent ), data
684717
685- def _rotate_buffer (self ):
718+ def _rotate_buffers (self ):
686719 """Returns the current buffer and replaces it with an empty buffer."""
687720 with self ._buffer_lock :
688- old_buffer = self ._buffer
689- self ._buffer = StringIO ( )
690- return old_buffer
721+ old_buffers = self ._buffers
722+ self ._buffers = defaultdict ( StringIO )
723+ return old_buffers
691724
692725 @property
693726 def _hooks (self ):
0 commit comments