Skip to content

Commit 11d937b

Browse files
committed
typecheck fixes
1 parent c19e4e6 commit 11d937b

File tree

2 files changed

+115
-86
lines changed

2 files changed

+115
-86
lines changed

util/opentelemetry-util-genai/src/opentelemetry/util/genai/data.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,23 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import Optional, Type
16+
from typing import List, Literal, Optional, Type, TypedDict
17+
18+
19+
class TextPart(TypedDict):
20+
type: Literal["text"]
21+
content: str
22+
23+
24+
# Keep room for future parts without changing the return type
25+
# addition of tools can use Part = Union[TextPart, ToolPart]
26+
Part = TextPart
27+
28+
29+
class OtelMessage(TypedDict):
30+
role: str
31+
# role: Literal["user", "assistant", "system", "tool", "tool_message"] # TODO: check semconvs for allowed roles
32+
parts: List[Part]
1733

1834

1935
@dataclass
@@ -22,7 +38,7 @@ class Message:
2238
type: str
2339
name: str
2440

25-
def _to_part_dict(self):
41+
def _to_part_dict(self) -> OtelMessage:
2642
"""Convert the message to a dictionary suitable for OpenTelemetry semconvs.
2743
2844
Ref: https://github.com/open-telemetry/semantic-conventions/blob/main/docs/registry/attributes/gen-ai.md#gen-ai-input-messages

util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters.py

Lines changed: 97 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@
3232
3333
"""
3434

35+
import json
3536
from dataclasses import dataclass, field
36-
from typing import Any, Dict, List, Optional, cast
37+
from typing import Any, Dict, List, Mapping, Optional, cast
3738
from uuid import UUID
3839

3940
from opentelemetry import trace
4041
from opentelemetry._logs import Logger, LogRecord
4142
from opentelemetry.context import Context, get_current
42-
from opentelemetry.metrics import Meter
43+
from opentelemetry.metrics import Meter, get_meter
4344
from opentelemetry.semconv._incubating.attributes import (
4445
gen_ai_attributes as GenAI,
4546
)
@@ -54,9 +55,9 @@
5455
use_span,
5556
)
5657
from opentelemetry.trace.status import Status, StatusCode
57-
from opentelemetry.util.types import Attributes
58+
from opentelemetry.util.types import AttributeValue
5859

59-
from .data import ChatGeneration, Error, Message
60+
from .data import ChatGeneration, Error, Message, OtelMessage
6061
from .instruments import Instruments
6162
from .types import LLMInvocation
6263

@@ -72,8 +73,9 @@ class _SpanState:
7273

7374

7475
def _get_property_value(obj: Any, property_name: str) -> Any:
75-
if isinstance(obj, dict):
76-
return cast(Any, obj.get(property_name, None))
76+
if isinstance(obj, Mapping):
77+
m = cast(Mapping[str, Any], obj)
78+
return m.get(property_name, None)
7779

7880
return cast(Any, getattr(obj, property_name, None))
7981

@@ -149,11 +151,11 @@ def _get_metric_attributes(
149151
operation_name: Optional[str],
150152
system: Optional[str],
151153
framework: Optional[str],
152-
) -> Dict:
153-
attributes = {
154-
# TODO: add below to opentelemetry.semconv._incubating.attributes.gen_ai_attributes
155-
"gen_ai.framework": framework,
156-
}
154+
) -> Dict[str, AttributeValue]:
155+
attributes: Dict[str, AttributeValue] = {}
156+
# TODO: add below to opentelemetry.semconv._incubating.attributes.gen_ai_attributes
157+
if framework is not None:
158+
attributes["gen_ai.framework"] = framework
157159
if system:
158160
attributes["gen_ai.provider.name"] = system
159161
if operation_name:
@@ -171,13 +173,13 @@ class BaseEmitter:
171173
Abstract base for emitters mapping GenAI types -> OpenTelemetry.
172174
"""
173175

174-
def init(self, invocation: LLMInvocation):
176+
def init(self, invocation: LLMInvocation) -> None:
175177
raise NotImplementedError
176178

177-
def emit(self, invocation: LLMInvocation):
179+
def emit(self, invocation: LLMInvocation) -> None:
178180
raise NotImplementedError
179181

180-
def error(self, error: Error, invocation: LLMInvocation):
182+
def error(self, error: Error, invocation: LLMInvocation) -> None:
181183
raise NotImplementedError
182184

183185

@@ -188,16 +190,17 @@ class SpanMetricEventEmitter(BaseEmitter):
188190

189191
def __init__(
190192
self,
191-
logger: Logger = None,
192-
tracer: Tracer = None,
193-
meter: Meter = None,
193+
logger: Optional[Logger] = None,
194+
tracer: Optional[Tracer] = None,
195+
meter: Optional[Meter] = None,
194196
capture_content: bool = False,
195197
):
196-
self._tracer = tracer or trace.get_tracer(__name__)
197-
instruments = Instruments(meter)
198+
self._tracer: Tracer = tracer or trace.get_tracer(__name__)
199+
_meter: Meter = meter or get_meter(__name__)
200+
instruments = Instruments(_meter)
198201
self._duration_histogram = instruments.operation_duration_histogram
199202
self._token_histogram = instruments.token_usage_histogram
200-
self._logger = logger
203+
self._logger: Optional[Logger] = logger
201204
self._capture_content = capture_content
202205

203206
# Map from run_id -> _SpanState, to keep track of spans and parent/child relationships
@@ -289,7 +292,7 @@ def emit(self, invocation: LLMInvocation):
289292
# TODO: add below to opentelemetry.semconv._incubating.attributes.gen_ai_attributes
290293
span.set_attribute("gen_ai.provider.name", system)
291294

292-
finish_reasons = []
295+
finish_reasons: List[str] = []
293296
for index, chat_generation in enumerate(
294297
invocation.chat_generations
295298
):
@@ -302,9 +305,10 @@ def emit(self, invocation: LLMInvocation):
302305
)
303306
if log and self._logger:
304307
self._logger.emit(log)
305-
finish_reasons.append(chat_generation.finish_reason)
308+
if chat_generation.finish_reason is not None:
309+
finish_reasons.append(chat_generation.finish_reason)
306310

307-
if finish_reasons is not None and len(finish_reasons) > 0:
311+
if finish_reasons:
308312
span.set_attribute(
309313
GenAI.GEN_AI_RESPONSE_FINISH_REASONS, finish_reasons
310314
)
@@ -319,13 +323,13 @@ def emit(self, invocation: LLMInvocation):
319323

320324
# usage
321325
prompt_tokens = invocation.attributes.get("input_tokens")
322-
if prompt_tokens is not None:
326+
if isinstance(prompt_tokens, (int, float)):
323327
span.set_attribute(
324328
GenAI.GEN_AI_USAGE_INPUT_TOKENS, prompt_tokens
325329
)
326330

327331
completion_tokens = invocation.attributes.get("output_tokens")
328-
if completion_tokens is not None:
332+
if isinstance(completion_tokens, (int, float)):
329333
span.set_attribute(
330334
GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, completion_tokens
331335
)
@@ -339,30 +343,33 @@ def emit(self, invocation: LLMInvocation):
339343
)
340344

341345
# Record token usage metrics
342-
prompt_tokens_attributes = {
346+
prompt_tokens_attributes: Dict[str, AttributeValue] = {
343347
GenAI.GEN_AI_TOKEN_TYPE: GenAI.GenAiTokenTypeValues.INPUT.value,
344348
}
345349
prompt_tokens_attributes.update(metric_attributes)
346-
self._token_histogram.record(
347-
prompt_tokens, attributes=prompt_tokens_attributes
348-
)
350+
if isinstance(prompt_tokens, (int, float)):
351+
self._token_histogram.record(
352+
prompt_tokens, attributes=prompt_tokens_attributes
353+
)
349354

350-
completion_tokens_attributes = {
355+
completion_tokens_attributes: Dict[str, AttributeValue] = {
351356
GenAI.GEN_AI_TOKEN_TYPE: GenAI.GenAiTokenTypeValues.COMPLETION.value
352357
}
353358
completion_tokens_attributes.update(metric_attributes)
354-
self._token_histogram.record(
355-
completion_tokens, attributes=completion_tokens_attributes
356-
)
359+
if isinstance(completion_tokens, (int, float)):
360+
self._token_histogram.record(
361+
completion_tokens, attributes=completion_tokens_attributes
362+
)
357363

358364
# End the LLM span
359365
self._end_span(invocation.run_id)
360366

361367
# Record overall duration metric
362-
elapsed = invocation.end_time - invocation.start_time
363-
self._duration_histogram.record(
364-
elapsed, attributes=metric_attributes
365-
)
368+
if invocation.end_time is not None:
369+
elapsed: float = invocation.end_time - invocation.start_time
370+
self._duration_histogram.record(
371+
elapsed, attributes=metric_attributes
372+
)
366373

367374
def error(self, error: Error, invocation: LLMInvocation):
368375
system = invocation.attributes.get("system")
@@ -408,10 +415,11 @@ def error(self, error: Error, invocation: LLMInvocation):
408415
)
409416

410417
# Record overall duration metric
411-
elapsed = invocation.end_time - invocation.start_time
412-
self._duration_histogram.record(
413-
elapsed, attributes=metric_attributes
414-
)
418+
if invocation.end_time is not None:
419+
elapsed: float = invocation.end_time - invocation.start_time
420+
self._duration_histogram.record(
421+
elapsed, attributes=metric_attributes
422+
)
415423

416424

417425
class SpanMetricEmitter(BaseEmitter):
@@ -421,12 +429,13 @@ class SpanMetricEmitter(BaseEmitter):
421429

422430
def __init__(
423431
self,
424-
tracer: Tracer = None,
425-
meter: Meter = None,
432+
tracer: Optional[Tracer] = None,
433+
meter: Optional[Meter] = None,
426434
capture_content: bool = False,
427435
):
428-
self._tracer = tracer or trace.get_tracer(__name__)
429-
instruments = Instruments(meter)
436+
self._tracer: Tracer = tracer or trace.get_tracer(__name__)
437+
_meter: Meter = meter or get_meter(__name__)
438+
instruments = Instruments(_meter)
430439
self._duration_histogram = instruments.operation_duration_histogram
431440
self._token_histogram = instruments.token_usage_histogram
432441
self._capture_content = capture_content
@@ -454,10 +463,9 @@ def _end_span(self, run_id: UUID):
454463
state = self.spans[run_id]
455464
for child_id in state.children:
456465
child_state = self.spans.get(child_id)
457-
if child_state and child_state.span._end_time is None:
466+
if child_state:
458467
child_state.span.end()
459-
if state.span._end_time is None:
460-
state.span.end()
468+
state.span.end()
461469

462470
def init(self, invocation: LLMInvocation):
463471
if (
@@ -502,17 +510,19 @@ def emit(self, invocation: LLMInvocation):
502510
framework = invocation.attributes.get("framework")
503511
if framework is not None:
504512
span.set_attribute("gen_ai.framework", framework)
505-
span.set_attribute(
506-
GenAI.GEN_AI_SYSTEM, system
507-
) # Deprecated: use "gen_ai.provider.name"
508-
# TODO: add below to opentelemetry.semconv._incubating.attributes.gen_ai_attributes
509-
span.set_attribute("gen_ai.provider.name", system)
513+
if system is not None:
514+
span.set_attribute(
515+
GenAI.GEN_AI_SYSTEM, system
516+
) # Deprecated: use "gen_ai.provider.name"
517+
# TODO: add below to opentelemetry.semconv._incubating.attributes.gen_ai_attributes
518+
span.set_attribute("gen_ai.provider.name", system)
510519

511-
finish_reasons: list[str] = []
520+
finish_reasons: List[str] = []
512521
for index, chat_generation in enumerate(
513522
invocation.chat_generations
514523
):
515-
finish_reasons.append(chat_generation.finish_reason)
524+
if chat_generation.finish_reason is not None:
525+
finish_reasons.append(chat_generation.finish_reason)
516526
if finish_reasons and len(finish_reasons) > 0:
517527
span.set_attribute(
518528
GenAI.GEN_AI_RESPONSE_FINISH_REASONS, finish_reasons
@@ -528,29 +538,28 @@ def emit(self, invocation: LLMInvocation):
528538

529539
# usage
530540
prompt_tokens = invocation.attributes.get("input_tokens")
531-
if prompt_tokens is not None:
541+
if isinstance(prompt_tokens, (int, float)):
532542
span.set_attribute(
533543
GenAI.GEN_AI_USAGE_INPUT_TOKENS, prompt_tokens
534544
)
535545

536546
completion_tokens = invocation.attributes.get("output_tokens")
537-
if completion_tokens is not None:
547+
if isinstance(completion_tokens, (int, float)):
538548
span.set_attribute(
539549
GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, completion_tokens
540550
)
541551

542-
message_parts: List[Attributes] = []
543-
for index, message in enumerate(invocation.messages):
544-
message_parts.append(message._to_part_dict())
545-
546-
if len(message_parts) > 0:
547-
span.set_attribute("gen_ai.input.messages", message_parts)
552+
if self._capture_content:
553+
message_parts: List[OtelMessage] = []
554+
for index, message in enumerate(invocation.messages):
555+
# ref: https://github.com/open-telemetry/semantic-conventions/blob/main/docs/registry/attributes/gen-ai.md#gen-ai-input-messages
556+
# when recording prompt messages, use a json encoded string if structured form is not available.
557+
message_parts.append(message._to_part_dict())
548558

549-
# for index, message in enumerate(invocation.messages):
550-
# content = message.content
551-
# # Set these attributes to upcoming semconv: https://github.com/open-telemetry/semantic-conventions/pull/2179
552-
# span.set_attribute(f"gen_ai.input.messages.{index}.content", [content._to_part_dict()])
553-
# span.set_attribute(f"gen_ai.input.messages.{index}.role", message.type)
559+
if len(message_parts) > 0:
560+
span.set_attribute(
561+
"gen_ai.input.messages", json.dumps(message_parts)
562+
)
554563

555564
for index, chat_generation in enumerate(
556565
invocation.chat_generations
@@ -573,30 +582,33 @@ def emit(self, invocation: LLMInvocation):
573582
)
574583

575584
# Record token usage metrics
576-
prompt_tokens_attributes = {
585+
prompt_tokens_attributes: Dict[str, AttributeValue] = {
577586
GenAI.GEN_AI_TOKEN_TYPE: GenAI.GenAiTokenTypeValues.INPUT.value
578587
}
579588
prompt_tokens_attributes.update(metric_attributes)
580-
self._token_histogram.record(
581-
prompt_tokens, attributes=prompt_tokens_attributes
582-
)
589+
if isinstance(prompt_tokens, (int, float)):
590+
self._token_histogram.record(
591+
prompt_tokens, attributes=prompt_tokens_attributes
592+
)
583593

584-
completion_tokens_attributes = {
594+
completion_tokens_attributes: Dict[str, AttributeValue] = {
585595
GenAI.GEN_AI_TOKEN_TYPE: GenAI.GenAiTokenTypeValues.COMPLETION.value
586596
}
587597
completion_tokens_attributes.update(metric_attributes)
588-
self._token_histogram.record(
589-
completion_tokens, attributes=completion_tokens_attributes
590-
)
598+
if isinstance(completion_tokens, (int, float)):
599+
self._token_histogram.record(
600+
completion_tokens, attributes=completion_tokens_attributes
601+
)
591602

592603
# End the LLM span
593604
self._end_span(invocation.run_id)
594605

595606
# Record overall duration metric
596-
elapsed = invocation.end_time - invocation.start_time
597-
self._duration_histogram.record(
598-
elapsed, attributes=metric_attributes
599-
)
607+
if invocation.end_time is not None:
608+
elapsed: float = invocation.end_time - invocation.start_time
609+
self._duration_histogram.record(
610+
elapsed, attributes=metric_attributes
611+
)
600612

601613
def error(self, error: Error, invocation: LLMInvocation):
602614
system = invocation.attributes.get("system")
@@ -642,7 +654,8 @@ def error(self, error: Error, invocation: LLMInvocation):
642654
)
643655

644656
# Record overall duration metric
645-
elapsed = invocation.end_time - invocation.start_time
646-
self._duration_histogram.record(
647-
elapsed, attributes=metric_attributes
648-
)
657+
if invocation.end_time is not None:
658+
elapsed: float = invocation.end_time - invocation.start_time
659+
self._duration_histogram.record(
660+
elapsed, attributes=metric_attributes
661+
)

0 commit comments

Comments
 (0)