diff --git a/opentelemetry-sdk/tests/logs/test_handler.py b/opentelemetry-sdk/tests/logs/test_handler.py index 0f96361712..146e4b95b0 100644 --- a/opentelemetry-sdk/tests/logs/test_handler.py +++ b/opentelemetry-sdk/tests/logs/test_handler.py @@ -19,50 +19,42 @@ from opentelemetry._logs import get_logger as APIGetLogger from opentelemetry.attributes import BoundedAttributes from opentelemetry.sdk import trace -from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler +from opentelemetry.sdk._logs import ( + LogData, + LoggerProvider, + LoggingHandler, + LogRecordProcessor, +) from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import INVALID_SPAN_CONTEXT -def get_logger(level=logging.NOTSET, logger_provider=None): - logger = logging.getLogger(__name__) - handler = LoggingHandler(level=level, logger_provider=logger_provider) - logger.addHandler(handler) - return logger - - class TestLoggingHandler(unittest.TestCase): def test_handler_default_log_level(self): - emitter_provider_mock = Mock(spec=LoggerProvider) - emitter_mock = APIGetLogger( - __name__, logger_provider=emitter_provider_mock - ) - logger = get_logger(logger_provider=emitter_provider_mock) + processor, logger = set_up_test_logging(logging.NOTSET) + # Make sure debug messages are ignored by default logger.debug("Debug message") - self.assertEqual(emitter_mock.emit.call_count, 0) + assert processor.emit_count() == 0 + # Assert emit gets called for warning message with self.assertLogs(level=logging.WARNING): logger.warning("Warning message") - self.assertEqual(emitter_mock.emit.call_count, 1) + self.assertEqual(processor.emit_count(), 1) def test_handler_custom_log_level(self): - emitter_provider_mock = Mock(spec=LoggerProvider) - emitter_mock = APIGetLogger( - __name__, logger_provider=emitter_provider_mock - ) - logger = get_logger( - level=logging.ERROR, logger_provider=emitter_provider_mock - ) + processor, logger = set_up_test_logging(logging.ERROR) + with self.assertLogs(level=logging.WARNING): logger.warning("Warning message test custom log level") # Make sure any log with level < ERROR is ignored - self.assertEqual(emitter_mock.emit.call_count, 0) + assert processor.emit_count() == 0 + with self.assertLogs(level=logging.ERROR): logger.error("Mumbai, we have a major problem") with self.assertLogs(level=logging.CRITICAL): logger.critical("No Time For Caution") - self.assertEqual(emitter_mock.emit.call_count, 2) + self.assertEqual(processor.emit_count(), 2) # pylint: disable=protected-access def test_log_record_emit_noop(self): @@ -77,14 +69,16 @@ def test_log_record_emit_noop(self): logger.addHandler(handler_mock) with self.assertLogs(level=logging.WARNING): logger.warning("Warning message") - handler_mock._translate.assert_not_called() def test_log_flush_noop(self): - no_op_logger_provider = NoOpLoggerProvider() no_op_logger_provider.force_flush = Mock() - logger = get_logger(logger_provider=no_op_logger_provider) + logger = logging.getLogger("foo") + handler = LoggingHandler( + level=logging.NOTSET, logger_provider=no_op_logger_provider + ) + logger.addHandler(handler) with self.assertLogs(level=logging.WARNING): logger.warning("Warning message") @@ -93,16 +87,13 @@ def test_log_flush_noop(self): no_op_logger_provider.force_flush.assert_not_called() def test_log_record_no_span_context(self): - emitter_provider_mock = Mock(spec=LoggerProvider) - emitter_mock = APIGetLogger( - __name__, logger_provider=emitter_provider_mock - ) - logger = get_logger(logger_provider=emitter_provider_mock) + processor, logger = set_up_test_logging(logging.WARNING) + # Assert emit gets called for warning message with self.assertLogs(level=logging.WARNING): logger.warning("Warning message") - args, _ = emitter_mock.emit.call_args_list[0] - log_record = args[0] + + log_record = processor.get_log_record(0) self.assertIsNotNone(log_record) self.assertEqual(log_record.trace_id, INVALID_SPAN_CONTEXT.trace_id) @@ -112,31 +103,23 @@ def test_log_record_no_span_context(self): ) def test_log_record_observed_timestamp(self): - emitter_provider_mock = Mock(spec=LoggerProvider) - emitter_mock = APIGetLogger( - __name__, logger_provider=emitter_provider_mock - ) - logger = get_logger(logger_provider=emitter_provider_mock) - # Assert emit gets called for warning message + processor, logger = set_up_test_logging(logging.WARNING) + with self.assertLogs(level=logging.WARNING): logger.warning("Warning message") - args, _ = emitter_mock.emit.call_args_list[0] - log_record = args[0] + log_record = processor.get_log_record(0) self.assertIsNotNone(log_record.observed_timestamp) def test_log_record_user_attributes(self): """Attributes can be injected into logs by adding them to the LogRecord""" - emitter_provider_mock = Mock(spec=LoggerProvider) - emitter_mock = APIGetLogger( - __name__, logger_provider=emitter_provider_mock - ) - logger = get_logger(logger_provider=emitter_provider_mock) + processor, logger = set_up_test_logging(logging.WARNING) + # Assert emit gets called for warning message with self.assertLogs(level=logging.WARNING): logger.warning("Warning message", extra={"http.status_code": 200}) - args, _ = emitter_mock.emit.call_args_list[0] - log_record = args[0] + + log_record = processor.get_log_record(0) self.assertIsNotNone(log_record) self.assertEqual(len(log_record.attributes), 4) @@ -157,18 +140,15 @@ def test_log_record_user_attributes(self): def test_log_record_exception(self): """Exception information will be included in attributes""" - emitter_provider_mock = Mock(spec=LoggerProvider) - emitter_mock = APIGetLogger( - __name__, logger_provider=emitter_provider_mock - ) - logger = get_logger(logger_provider=emitter_provider_mock) + processor, logger = set_up_test_logging(logging.ERROR) + try: raise ZeroDivisionError("division by zero") except ZeroDivisionError: with self.assertLogs(level=logging.ERROR): logger.exception("Zero Division Error") - args, _ = emitter_mock.emit.call_args_list[0] - log_record = args[0] + + log_record = processor.get_log_record(0) self.assertIsNotNone(log_record) self.assertEqual(log_record.body, "Zero Division Error") @@ -191,18 +171,15 @@ def test_log_record_exception(self): def test_log_exc_info_false(self): """Exception information will be included in attributes""" - emitter_provider_mock = Mock(spec=LoggerProvider) - emitter_mock = APIGetLogger( - __name__, logger_provider=emitter_provider_mock - ) - logger = get_logger(logger_provider=emitter_provider_mock) + processor, logger = set_up_test_logging(logging.NOTSET) + try: raise ZeroDivisionError("division by zero") except ZeroDivisionError: with self.assertLogs(level=logging.ERROR): logger.error("Zero Division Error", exc_info=False) - args, _ = emitter_mock.emit.call_args_list[0] - log_record = args[0] + + log_record = processor.get_log_record(0) self.assertIsNotNone(log_record) self.assertEqual(log_record.body, "Zero Division Error") @@ -215,19 +192,15 @@ def test_log_exc_info_false(self): ) def test_log_record_trace_correlation(self): - emitter_provider_mock = Mock(spec=LoggerProvider) - emitter_mock = APIGetLogger( - __name__, logger_provider=emitter_provider_mock - ) - logger = get_logger(logger_provider=emitter_provider_mock) + processor, logger = set_up_test_logging(logging.WARNING) tracer = trace.TracerProvider().get_tracer(__name__) with tracer.start_as_current_span("test") as span: with self.assertLogs(level=logging.CRITICAL): logger.critical("Critical message within span") - args, _ = emitter_mock.emit.call_args_list[0] - log_record = args[0] + log_record = processor.get_log_record(0) + self.assertEqual(log_record.body, "Critical message within span") self.assertEqual(log_record.severity_text, "CRITICAL") self.assertEqual(log_record.severity_number, SeverityNumber.FATAL) @@ -235,3 +208,33 @@ def test_log_record_trace_correlation(self): self.assertEqual(log_record.trace_id, span_context.trace_id) self.assertEqual(log_record.span_id, span_context.span_id) self.assertEqual(log_record.trace_flags, span_context.trace_flags) + + +def set_up_test_logging(level): + logger_provider = LoggerProvider() + processor = FakeProcessor() + logger_provider.add_log_record_processor(processor) + logger = logging.getLogger("foo") + handler = LoggingHandler(level=level, logger_provider=logger_provider) + logger.addHandler(handler) + return processor, logger + + +class FakeProcessor(LogRecordProcessor): + def __init__(self): + self.log_data_emitted = [] + + def emit(self, log_data: LogData): + self.log_data_emitted.append(log_data) + + def shutdown(self): + pass + + def force_flush(self, timeout_millis: int = 30000): + pass + + def emit_count(self): + return len(self.log_data_emitted) + + def get_log_record(self, i): + return self.log_data_emitted[i].log_record