From b0da53d28b041338e04982f3a9715926887b5cad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Neum=C3=BCller?= Date: Wed, 9 Oct 2019 10:43:41 +0200 Subject: [PATCH] Fix some "errors" found by mypy. (#204) Fix some errors found by mypy (split from #201). --- .../src/opentelemetry/context/__init__.py | 7 +-- .../src/opentelemetry/trace/__init__.py | 2 +- .../sdk/context/propagation/b3_format.py | 21 +++++---- .../src/opentelemetry/sdk/metrics/__init__.py | 7 +-- .../src/opentelemetry/sdk/trace/__init__.py | 47 ++++++++++--------- .../sdk/trace/export/__init__.py | 13 +++-- .../src/opentelemetry/sdk/util.py | 6 +-- 7 files changed, 55 insertions(+), 48 deletions(-) diff --git a/opentelemetry-api/src/opentelemetry/context/__init__.py b/opentelemetry-api/src/opentelemetry/context/__init__.py index cf6c72dd8d..43a7722f88 100644 --- a/opentelemetry-api/src/opentelemetry/context/__init__.py +++ b/opentelemetry-api/src/opentelemetry/context/__init__.py @@ -138,19 +138,14 @@ async def main(): asyncio.run(main()) """ -import typing - from .base_context import BaseRuntimeContext __all__ = ["Context"] - -Context = None # type: typing.Optional[BaseRuntimeContext] - try: from .async_context import AsyncRuntimeContext - Context = AsyncRuntimeContext() + Context = AsyncRuntimeContext() # type: BaseRuntimeContext except ImportError: from .thread_local_context import ThreadLocalRuntimeContext diff --git a/opentelemetry-api/src/opentelemetry/trace/__init__.py b/opentelemetry-api/src/opentelemetry/trace/__init__.py index 094255aa1e..18eced4504 100644 --- a/opentelemetry-api/src/opentelemetry/trace/__init__.py +++ b/opentelemetry-api/src/opentelemetry/trace/__init__.py @@ -142,7 +142,7 @@ class SpanKind(enum.Enum): class Span: """A span represents a single operation within a trace.""" - def start(self, start_time: int = None) -> None: + def start(self, start_time: typing.Optional[int] = None) -> None: """Sets the current time as the span's start time. Each span represents a single operation. The span's start time is the diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py index 9e8f7d3f19..2eca8afaa1 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py @@ -35,7 +35,7 @@ class B3Format(HTTPTextFormat): def extract(cls, get_from_carrier, carrier): trace_id = format_trace_id(trace.INVALID_TRACE_ID) span_id = format_span_id(trace.INVALID_SPAN_ID) - sampled = 0 + sampled = "0" flags = None single_header = _extract_first_element( @@ -95,8 +95,8 @@ def extract(cls, get_from_carrier, carrier): # trace an span ids are encoded in hex, so must be converted trace_id=int(trace_id, 16), span_id=int(span_id, 16), - trace_options=options, - trace_state={}, + trace_options=trace.TraceOptions(options), + trace_state=trace.TraceState(), ) @classmethod @@ -111,17 +111,20 @@ def inject(cls, context, set_in_carrier, carrier): set_in_carrier(carrier, cls.SAMPLED_KEY, "1" if sampled else "0") -def format_trace_id(trace_id: int): +def format_trace_id(trace_id: int) -> str: """Format the trace id according to b3 specification.""" return format(trace_id, "032x") -def format_span_id(span_id: int): +def format_span_id(span_id: int) -> str: """Format the span id according to b3 specification.""" return format(span_id, "016x") -def _extract_first_element(list_object: list) -> typing.Optional[object]: - if list_object: - return list_object[0] - return None +_T = typing.TypeVar("_T") + + +def _extract_first_element(items: typing.Iterable[_T]) -> typing.Optional[_T]: + if items is None: + return None + return next(iter(items), None) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/metrics/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/metrics/__init__.py index f80a72c770..041d0e5dcd 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/metrics/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/metrics/__init__.py @@ -27,12 +27,12 @@ def __init__( enabled: bool, monotonic: bool, ): - self.data = 0 + self.data = value_type() self.value_type = value_type self.enabled = enabled self.monotonic = monotonic - def _validate_update(self, value: metrics_api.ValueT): + def _validate_update(self, value: metrics_api.ValueT) -> bool: if not self.enabled: return False if not isinstance(value, self.value_type): @@ -232,7 +232,8 @@ def create_metric( monotonic: bool = False, ) -> metrics_api.MetricT: """See `opentelemetry.metrics.Meter.create_metric`.""" - return metric_type( + # Ignore type b/c of mypy bug in addition to missing annotations + return metric_type( # type: ignore name, description, unit, diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py index 1cee1933e2..eb754fadb8 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py @@ -73,10 +73,10 @@ class MultiSpanProcessor(SpanProcessor): def __init__(self): # use a tuple to avoid race conditions when adding a new span and # iterating through it on "on_start" and "on_end". - self._span_processors = () + self._span_processors = () # type: typing.Tuple[SpanProcessor, ...] self._lock = threading.Lock() - def add_span_processor(self, span_processor: SpanProcessor): + def add_span_processor(self, span_processor: SpanProcessor) -> None: """Adds a SpanProcessor to the list handled by this instance.""" with self._lock: self._span_processors = self._span_processors + (span_processor,) @@ -122,11 +122,11 @@ class Span(trace_api.Span): def __init__( self, name: str, - context: "trace_api.SpanContext", + context: trace_api.SpanContext, parent: trace_api.ParentSpan = None, - sampler=None, # TODO - trace_config=None, # TODO - resource=None, # TODO + sampler: None = None, # TODO + trace_config: None = None, # TODO + resource: None = None, # TODO attributes: types.Attributes = None, # TODO events: typing.Sequence[trace_api.Event] = None, # TODO links: typing.Sequence[trace_api.Link] = None, # TODO @@ -140,9 +140,6 @@ def __init__( self.sampler = sampler self.trace_config = trace_config self.resource = resource - self.attributes = attributes - self.events = events - self.links = links self.kind = kind self.span_processor = span_processor @@ -165,8 +162,8 @@ def __init__( else: self.links = BoundedList.from_seq(MAX_NUM_LINKS, links) - self.end_time = None - self.start_time = None + self.end_time = None # type: typing.Optional[int] + self.start_time = None # type: typing.Optional[int] def __repr__(self): return '{}(name="{}", context={})'.format( @@ -203,9 +200,13 @@ def set_attribute(self, key: str, value: types.AttributeValue) -> None: def add_event( self, name: str, attributes: types.Attributes = None ) -> None: - if attributes is None: - attributes = Span.empty_attributes - self.add_lazy_event(trace_api.Event(name, time_ns(), attributes)) + self.add_lazy_event( + trace_api.Event( + name, + time_ns(), + Span.empty_attributes if attributes is None else attributes, + ) + ) def add_lazy_event(self, event: trace_api.Event) -> None: with self._lock: @@ -226,7 +227,9 @@ def add_link( attributes: types.Attributes = None, ) -> None: if attributes is None: - attributes = Span.empty_attributes + attributes = ( + Span.empty_attributes + ) # TODO: empty_attributes is not a Dict. Use Mapping? self.add_lazy_link(trace_api.Link(link_target_context, attributes)) def add_lazy_link(self, link: "trace_api.Link") -> None: @@ -242,7 +245,7 @@ def add_lazy_link(self, link: "trace_api.Link") -> None: return self.links.append(link) - def start(self, start_time: int = None): + def start(self, start_time: typing.Optional[int] = None) -> None: with self._lock: if not self.is_recording_events(): return @@ -256,7 +259,7 @@ def start(self, start_time: int = None): return self.span_processor.on_start(self) - def end(self, end_time: int = None): + def end(self, end_time: int = None) -> None: with self._lock: if not self.is_recording_events(): return @@ -283,7 +286,7 @@ def is_recording_events(self) -> bool: return True -def generate_span_id(): +def generate_span_id() -> int: """Get a new random span ID. Returns: @@ -292,7 +295,7 @@ def generate_span_id(): return random.getrandbits(64) -def generate_trace_id(): +def generate_trace_id() -> int: """Get a new random trace ID. Returns: @@ -325,7 +328,7 @@ def start_span( name: str, parent: trace_api.ParentSpan = trace_api.Tracer.CURRENT_SPAN, kind: trace_api.SpanKind = trace_api.SpanKind.INTERNAL, - ) -> typing.Iterator["Span"]: + ) -> typing.Iterator[trace_api.Span]: """See `opentelemetry.trace.Tracer.start_span`.""" span = self.create_span(name, parent, kind) @@ -368,8 +371,8 @@ def create_span( @contextmanager def use_span( - self, span: Span, end_on_exit: bool = False - ) -> typing.Iterator[Span]: + self, span: trace_api.Span, end_on_exit: bool = False + ) -> typing.Iterator[trace_api.Span]: """See `opentelemetry.trace.Tracer.use_span`.""" try: span_snapshot = self._current_span_slot.get() diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py index ce362813ec..a76a658b3a 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py @@ -118,7 +118,9 @@ def __init__( ) self.span_exporter = span_exporter - self.queue = collections.deque([], max_queue_size) + self.queue = collections.deque( + [], max_queue_size + ) # type: typing.Deque[Span] self.worker_thread = threading.Thread(target=self.worker, daemon=True) self.condition = threading.Condition(threading.Lock()) self.schedule_delay_millis = schedule_delay_millis @@ -128,7 +130,9 @@ def __init__( # flag that indicates that spans are being dropped self._spans_dropped = False # precallocated list to send spans to exporter - self.spans_list = [None] * self.max_export_batch_size + self.spans_list = [ + None + ] * self.max_export_batch_size # type: typing.List[typing.Optional[Span]] self.worker_thread.start() def on_start(self, span: Span) -> None: @@ -172,7 +176,7 @@ def worker(self): # be sure that all spans are sent self._flush() - def export(self) -> bool: + def export(self) -> None: """Exports at most max_export_batch_size spans.""" idx = 0 @@ -184,7 +188,8 @@ def export(self) -> bool: suppress_instrumentation = Context.suppress_instrumentation try: Context.suppress_instrumentation = True - self.span_exporter.export(self.spans_list[:idx]) + # Ignore type b/c the Optional[None]+slicing is too "clever" for mypy + self.span_exporter.export(self.spans_list[:idx]) # type: ignore # pylint: disable=broad-except except Exception: logger.exception("Exception while exporting data.") diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/util.py b/opentelemetry-sdk/src/opentelemetry/sdk/util.py index da6ada90c3..2265c29460 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/util.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/util.py @@ -41,7 +41,7 @@ class BoundedList(Sequence): def __init__(self, maxlen): self.dropped = 0 - self._dq = deque(maxlen=maxlen) + self._dq = deque(maxlen=maxlen) # type: deque self._lock = threading.Lock() def __repr__(self): @@ -97,8 +97,8 @@ def __init__(self, maxlen): raise ValueError self.maxlen = maxlen self.dropped = 0 - self._dict = OrderedDict() - self._lock = threading.Lock() + self._dict = OrderedDict() # type: OrderedDict + self._lock = threading.Lock() # type: threading.Lock def __repr__(self): return "{}({}, maxlen={})".format(