Skip to content

Commit

Permalink
Fix logger not "picklable" causing errors with multiprocessing
Browse files Browse the repository at this point in the history
This required to refactor handling of sinks to remove closure functions
used outside of the method they were declared.
  • Loading branch information
Delgan committed Oct 20, 2019
1 parent 14848a4 commit ffac433
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 122 deletions.
19 changes: 14 additions & 5 deletions loguru/_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ class Handler:
def __init__(
self,
*,
sink_wrapper,
name,
writer,
stopper,
levelno,
formatter,
is_formatter_dynamic,
Expand All @@ -33,8 +32,9 @@ def __init__(
levels_ansi_codes
):
self._name = name
self._writer = writer
self._stopper = stopper
self._sink_wrapper = sink_wrapper
self._writer = sink_wrapper.write
self._stopper = sink_wrapper.stop
self._levelno = levelno
self._formatter = formatter
self._is_formatter_dynamic = is_formatter_dynamic
Expand All @@ -45,7 +45,7 @@ def __init__(
self._enqueue = enqueue
self._exception_formatter = exception_formatter
self._id = id_
self._levels_ansi_codes = levels_ansi_codes
self._levels_ansi_codes = levels_ansi_codes # Warning, reference shared among handlers

self._static_format = None
self._decolorized_format = None
Expand Down Expand Up @@ -278,3 +278,12 @@ def _handle_error(self, record=None):
pass
finally:
del ex_type, ex, tb

def __getstate__(self):
state = self.__dict__.copy()
del state["_lock"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._lock = threading.Lock()
133 changes: 49 additions & 84 deletions loguru/_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import itertools
import logging
import pickle
import re
import sys
import threading
Expand All @@ -14,6 +15,7 @@

from . import _colorama
from . import _defaults
from . import _sink_wrappers
from ._ansimarkup import AnsiMarkup
from ._better_exceptions import ExceptionFormatter
from ._datetime import aware_now
Expand All @@ -32,6 +34,20 @@ def parse_ansi(color):
return AnsiMarkup(strip=False).feed(color.strip(), strict=False)


class FiltererModule:
def __init__(self, module):
self._parent = module + "."
self._length = len(self._parent)

def filter(self, record):
return (record["name"] + ".")[: self._length] == self._parent


class FiltererNull:
def filter(self, record):
return record["name"] is not None


Level = namedtuple("Level", ["no", "color", "icon"])

start_time = aware_now()
Expand Down Expand Up @@ -86,6 +102,19 @@ def __init__(self):

self.lock = threading.Lock()

def __getstate__(self):
state = self.__dict__.copy()
del state["lock"]
try:
return pickle.dumps(state)
except Exception as e:
raise ValueError("The logger can't be pickled") from e

def __setstate__(self, state):
unpickled = pickle.loads(state)
self.__dict__.update(unpickled)
self.lock = threading.Lock()


class Logger:
"""An object to dispatch logging messages to configured handlers.
Expand Down Expand Up @@ -646,6 +675,18 @@ def add(
if colorize is None and serialize:
colorize = False

if isinstance(format, str):
formatter = format + "\n{exception}"
is_formatter_dynamic = False
elif callable(format):
formatter = format
is_formatter_dynamic = True
else:
raise ValueError(
"Invalid format, it should be a string or a function, not: '%s'"
% type(format).__name__
)

if isclass(sink):
sink = sink(**kwargs)
return self.add(
Expand Down Expand Up @@ -686,93 +727,30 @@ def add(
else:
stream = sink

stream_write = stream.write
if kwargs:

def write(m):
return stream_write(m, **kwargs)

else:
write = stream_write

if hasattr(stream, "flush") and callable(stream.flush):
stream_flush = stream.flush

def writer(m):
write(m)
stream_flush()

else:
writer = write

if hasattr(stream, "stop") and callable(stream.stop):
stopper = stream.stop
else:

def stopper():
return None

sink_wrapper = _sink_wrappers.StreamSinkWrapper(stream, kwargs)
elif isinstance(sink, logging.Handler):
name = repr(sink)

def writer(m):
message = str(m)
r = m.record
exc = r["exception"]
if not is_formatter_dynamic:
message = message[:-1]
record = logging.root.makeRecord(
r["name"],
r["level"].no,
r["file"].path,
r["line"],
message,
(),
(exc.type, exc.value, exc.traceback) if exc else None,
r["function"],
r["extra"],
**kwargs
)
if exc:
record.exc_text = "\n"
sink.handle(record)

stopper = sink.close
if colorize is None:
colorize = False

sink_wrapper = _sink_wrappers.StandardSinkWrapper(sink, kwargs, is_formatter_dynamic)
elif callable(sink):
name = getattr(sink, "__name__", repr(sink))

if kwargs:

def writer(m):
return sink(m, **kwargs)

else:
writer = sink

def stopper():
return None

if colorize is None:
colorize = False

sink_wrapper = _sink_wrappers.CallableSinkWrapper(sink, kwargs)
else:
raise ValueError("Cannot log to objects of type '%s'." % type(sink).__name__)

if filter is None:
filter_func = None
elif filter == "":

def filter_func(record):
return record["name"] is not None

filter_func = FiltererNull().filter
elif isinstance(filter, str):
parent = filter + "."
length = len(parent)

def filter_func(record):
return (record["name"] + ".")[:length] == parent

filter_func = FiltererModule(filter).filter
elif callable(filter):
filter_func = filter
else:
Expand All @@ -796,18 +774,6 @@ def filter_func(record):
"Invalid level value, it should be a positive integer, not: %d" % levelno
)

if isinstance(format, str):
formatter = format + "\n{exception}"
is_formatter_dynamic = False
elif callable(format):
formatter = format
is_formatter_dynamic = True
else:
raise ValueError(
"Invalid format, it should be a string or a function, not: '%s'"
% type(format).__name__
)

try:
encoding = sink.encoding
except AttributeError:
Expand All @@ -829,8 +795,7 @@ def filter_func(record):

handler = Handler(
name=name,
writer=writer,
stopper=stopper,
sink_wrapper=sink_wrapper,
levelno=levelno,
formatter=formatter,
is_formatter_dynamic=is_formatter_dynamic,
Expand Down
62 changes: 62 additions & 0 deletions loguru/_sink_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import logging


class StreamSinkWrapper:
def __init__(self, stream, kwargs):
self._stream = stream
self._kwargs = kwargs
self._flushable = hasattr(stream, "flush") and callable(stream.flush)
self._stoppable = hasattr(stream, "stop") and callable(stream.stop)

def write(self, message):
self._stream.write(message, **self._kwargs)
if self._flushable:
self._stream.flush()

def stop(self):
if self._stoppable:
self._stream.stop()


class StandardSinkWrapper:
def __init__(self, handler, kwargs, is_formatter_dynamic):
self._handler = handler
self._kwargs = kwargs
self._is_formatter_dynamic = is_formatter_dynamic

def write(self, message):
record = message.record
message = str(message)
exc = record["exception"]
if not self._is_formatter_dynamic:
message = message[:-1]
record = logging.root.makeRecord(
record["name"],
record["level"].no,
record["file"].path,
record["line"],
message,
(),
(exc.type, exc.value, exc.traceback) if exc else None,
record["function"],
record["extra"],
**self._kwargs
)
if exc:
record.exc_text = "\n"
self._handler.handle(record)

def stop(self):
self._handler.close()


class CallableSinkWrapper:
def __init__(self, function, kwargs):
self._function = function
self._kwargs = kwargs

def write(self, message):
self._function(message, **self._kwargs)

def stop(self):
pass
2 changes: 1 addition & 1 deletion tests/test_filesink_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def creation_time(filepath):
j = logger.add(str(tmpdir.join("test.log")), compression="tar.gz")
logger.debug("test")

filesink = next(iter(logger._core.handlers.values()))._writer.__self__
filesink = next(iter(logger._core.handlers.values()))._sink_wrapper._stream
monkeypatch.setattr(filesink, "_get_creation_time", creation_time)

logger.remove(j)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_filesink_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def creation_time(filepath):
i = logger.add(str(tmpdir.join("test.log")), rotation=10, format="{message}")
logger.debug("X")

filesink = next(iter(logger._core.handlers.values()))._writer.__self__
filesink = next(iter(logger._core.handlers.values()))._sink_wrapper._stream
monkeypatch.setattr(filesink, "_get_creation_time", creation_time)

logger.debug("Y" * 20)
Expand Down
Loading

0 comments on commit ffac433

Please sign in to comment.