Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion logging_utilities/context/thread_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,54 @@ class ThreadMappingContext(BaseContext):

def __init__(self):
self.__local = threading.local()
self.__local.data = {}
self.ensure_data()

def ensure_data(self):
"""Ensure the current thread has a `data` attribute in its local storage.

The `threading.local()` object provides each thread with its own independent attribute
namespace. Attributes created in one thread are not visible to other threads. This means
that even if `data` was initialized in the thread where this object was constructed,
new threads will not automatically have a `data` attribute since the constructor is not
run again.

Calling this method guarantees that `self.__local.data` exists in the *current* thread,
creating an empty dictionary if needed. It must be invoked on every access path
(e.g., __getitem__, __iter__).
"""
if not hasattr(self.__local, 'data'):
self.__local.data = {}

def __str__(self):
self.ensure_data()
return str(self.__local.data)

def __getitem__(self, __key):
self.ensure_data()
return self.__local.data[__key]

def __setitem__(self, __key, __value):
self.ensure_data()
self.__local.data[__key] = __value

def __delitem__(self, __key):
self.ensure_data()
del self.__local.data[__key]

def __len__(self):
self.ensure_data()
return len(self.__local.data)

def __iter__(self):
self.ensure_data()
return self.__local.data.__iter__()

def __contains__(self, __o):
self.ensure_data()
return self.__local.data.__contains__(__o)

def init(self, data=None):
self.ensure_data()
if data is None:
self.__local.data = {}
else:
Expand All @@ -46,18 +70,23 @@ def init(self, data=None):
self.__local.data = data

def get(self, key, default=None):
self.ensure_data()
return self.__local.data.get(key, default)

def pop(self, key, default=__marker):
self.ensure_data()
if default == self.__marker:
return self.__local.data.pop(key)
return self.__local.data.pop(key, default)

def set(self, key, value):
self.ensure_data()
self.__local.data[key] = value

def delete(self, key):
self.ensure_data()
del self.__local.data[key]

def clear(self):
self.ensure_data()
self.__local.data = {}
16 changes: 16 additions & 0 deletions tests/test_logging_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed
from threading import Thread

from logging_utilities.context import get_logging_context
from logging_utilities.context import remove_logging_context
Expand Down Expand Up @@ -125,6 +126,21 @@ def test_thread_context_str(self):
ctx.init({'a': 1, 'b': 2, 'c': 'my string'})
self.assertEqual(str(ctx), "{'a': 1, 'b': 2, 'c': 'my string'}")

def test_thread_context_local_data(self):
ctx = ThreadMappingContext()
ctx['thread'] = 'main'
results = {}

def worker():
assert 'thread' not in ctx
ctx['thread'] = 'worker'

t = Thread(target=worker)
t.start()
t.join()

assert ctx['thread'] == 'main'


class LoggingContextTest(unittest.TestCase):

Expand Down
Loading