Skip to content

Commit 8212a09

Browse files
committed
fix(watsonx): refactor event emission
1 parent c6de57e commit 8212a09

File tree

8 files changed

+215
-173
lines changed

8 files changed

+215
-173
lines changed

packages/opentelemetry-instrumentation-watsonx/opentelemetry/instrumentation/watsonx/__init__.py

Lines changed: 139 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,27 @@
44
import os
55
import time
66
import types
7-
from typing import Collection, Optional
7+
from typing import Collection, Optional, Union
88

99
from opentelemetry import context as context_api
10-
from opentelemetry._events import get_event_logger
10+
from opentelemetry._events import EventLogger, get_event_logger
1111
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
1212
from opentelemetry.instrumentation.utils import (
1313
_SUPPRESS_INSTRUMENTATION_KEY,
1414
unwrap,
1515
)
1616
from opentelemetry.instrumentation.watsonx.config import Config
17-
from opentelemetry.instrumentation.watsonx.event_handler import (
18-
ChoiceEvent,
19-
MessageEvent,
17+
from opentelemetry.instrumentation.watsonx.event_emitter import (
2018
emit_event,
2119
)
22-
from opentelemetry.instrumentation.watsonx.utils import dont_throw
20+
from opentelemetry.instrumentation.watsonx.event_models import ChoiceEvent, MessageEvent
21+
from opentelemetry.instrumentation.watsonx.utils import (
22+
dont_throw,
23+
should_emit_events,
24+
should_send_prompts,
25+
)
2326
from opentelemetry.instrumentation.watsonx.version import __version__
2427
from opentelemetry.metrics import Counter, Histogram, get_meter
25-
from opentelemetry.semconv._incubating.attributes import (
26-
gen_ai_attributes as GenAIAttributes,
27-
)
2828
from opentelemetry.semconv_ai import (
2929
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY,
3030
LLMRequestTypeValues,
@@ -107,6 +107,8 @@ def _set_span_attribute(span, name, value):
107107

108108

109109
def _set_api_attributes(span):
110+
if not span.is_recording():
111+
return
110112
_set_span_attribute(
111113
span,
112114
WatsonxSpanAttributes.WATSONX_API_BASE,
@@ -115,20 +117,15 @@ def _set_api_attributes(span):
115117
_set_span_attribute(span, WatsonxSpanAttributes.WATSONX_API_TYPE, "watsonx.ai")
116118
_set_span_attribute(span, WatsonxSpanAttributes.WATSONX_API_VERSION, "1.0")
117119

118-
return
119-
120-
121-
def should_send_prompts():
122-
return (
123-
os.getenv("TRACELOOP_TRACE_CONTENT") or "true"
124-
).lower() == "true" or context_api.get_value("override_enable_content_tracing")
125-
126120

127121
def is_metrics_enabled() -> bool:
128122
return (os.getenv("TRACELOOP_METRICS_ENABLED") or "true").lower() == "true"
129123

130124

131125
def _set_input_attributes(span, instance, kwargs):
126+
if not span.is_recording():
127+
return
128+
132129
if should_send_prompts() and kwargs is not None and len(kwargs) > 0:
133130
prompt = kwargs.get("prompt")
134131
if isinstance(prompt, list):
@@ -145,6 +142,11 @@ def _set_input_attributes(span, instance, kwargs):
145142
prompt,
146143
)
147144

145+
146+
def set_model_input_attributes(span, instance):
147+
if not span.is_recording():
148+
return
149+
148150
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, instance.model_id)
149151
# Set other attributes
150152
modelParameters = instance.params
@@ -186,10 +188,20 @@ def _set_input_attributes(span, instance, kwargs):
186188
span, SpanAttributes.LLM_REQUEST_TOP_P, modelParameters.get("top_p", None)
187189
)
188190

189-
return
190-
191191

192192
def _set_stream_response_attributes(span, stream_response):
193+
if not span.is_recording():
194+
return
195+
_set_span_attribute(
196+
span,
197+
f"{SpanAttributes.LLM_COMPLETIONS}.0.content",
198+
stream_response.get("generated_text"),
199+
)
200+
201+
202+
def _set_model_stream_response_attributes(span, stream_response):
203+
if not span.is_recording():
204+
return
193205
_set_span_attribute(
194206
span, SpanAttributes.LLM_RESPONSE_MODEL, stream_response.get("model_id")
195207
)
@@ -211,11 +223,6 @@ def _set_stream_response_attributes(span, stream_response):
211223
SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
212224
total_token,
213225
)
214-
_set_span_attribute(
215-
span,
216-
f"{SpanAttributes.LLM_COMPLETIONS}.0.content",
217-
stream_response.get("generated_text"),
218-
)
219226

220227

221228
def _set_completion_content_attributes(
@@ -263,7 +270,7 @@ def _token_usage_count(responses):
263270
def _set_response_attributes(
264271
span, responses, token_histogram, response_counter, duration_histogram, duration
265272
):
266-
if not isinstance(responses, (list, dict)):
273+
if not isinstance(responses, (list, dict)) or not span.is_recording():
267274
return
268275

269276
if isinstance(responses, list):
@@ -283,6 +290,32 @@ def _set_response_attributes(
283290
return
284291
_set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, model_id)
285292

293+
shared_attributes = _metric_shared_attributes(response_model=model_id)
294+
295+
prompt_token, completion_token = _token_usage_count(responses)
296+
297+
if token_histogram:
298+
attributes_with_token_type = {
299+
**shared_attributes,
300+
SpanAttributes.LLM_TOKEN_TYPE: "output",
301+
}
302+
token_histogram.record(completion_token, attributes=attributes_with_token_type)
303+
attributes_with_token_type = {
304+
**shared_attributes,
305+
SpanAttributes.LLM_TOKEN_TYPE: "input",
306+
}
307+
token_histogram.record(prompt_token, attributes=attributes_with_token_type)
308+
309+
if duration and isinstance(duration, (float, int)) and duration_histogram:
310+
duration_histogram.record(duration, attributes=shared_attributes)
311+
312+
313+
def set_model_response_attributes(
314+
span, responses, token_histogram, duration_histogram, duration
315+
):
316+
if not span.is_recording():
317+
return
318+
286319
prompt_token, completion_token = _token_usage_count(responses)
287320
if (prompt_token + completion_token) != 0:
288321
_set_span_attribute(
@@ -301,35 +334,16 @@ def _set_response_attributes(
301334
prompt_token + completion_token,
302335
)
303336

304-
shared_attributes = _metric_shared_attributes(response_model=model_id)
305337

306-
if token_histogram:
307-
attributes_with_token_type = {
308-
**shared_attributes,
309-
SpanAttributes.LLM_TOKEN_TYPE: "output",
310-
}
311-
token_histogram.record(
312-
completion_token, attributes=attributes_with_token_type
313-
)
314-
attributes_with_token_type = {
315-
**shared_attributes,
316-
SpanAttributes.LLM_TOKEN_TYPE: "input",
317-
}
318-
token_histogram.record(prompt_token, attributes=attributes_with_token_type)
319-
320-
if duration and isinstance(duration, (float, int)) and duration_histogram:
321-
duration_histogram.record(duration, attributes=shared_attributes)
322-
323-
324-
def _emit_input_events(args, kwargs):
338+
def _emit_input_events(args, kwargs, event_logger):
325339
prompt = kwargs.get("prompt") or args[0]
326340

327341
if isinstance(prompt, list):
328342
for message in prompt:
329-
emit_event(MessageEvent(content=message, role="user"))
343+
emit_event(MessageEvent(content=message, role="user"), event_logger)
330344

331345
elif isinstance(prompt, str):
332-
emit_event(MessageEvent(content=prompt, role="user"))
346+
emit_event(MessageEvent(content=prompt, role="user"), event_logger)
333347

334348

335349
def _emit_response_events(response: dict):
@@ -345,6 +359,7 @@ def _emit_response_events(response: dict):
345359

346360
def _build_and_set_stream_response(
347361
span,
362+
event_logger,
348363
response,
349364
raw_flag,
350365
token_histogram,
@@ -378,7 +393,9 @@ def _build_and_set_stream_response(
378393
"generated_token_count": stream_generated_token_count,
379394
"input_token_count": stream_input_token_count,
380395
}
381-
_set_stream_response_attributes(span, stream_response)
396+
_handle_stream_response(
397+
span, event_logger, stream_response, stream_generated_text, stream_stop_reason
398+
)
382399
# response counter
383400
if response_counter:
384401
attributes_with_reason = {
@@ -412,16 +429,6 @@ def _build_and_set_stream_response(
412429
if duration and isinstance(duration, (float, int)) and duration_histogram:
413430
duration_histogram.record(duration, attributes=shared_attributes)
414431

415-
_emit_response_events(
416-
{
417-
"results": [
418-
{
419-
"stop_reason": stream_stop_reason,
420-
"generated_text": stream_generated_text,
421-
}
422-
]
423-
},
424-
)
425432
span.set_status(Status(StatusCode.OK))
426433
span.end()
427434

@@ -444,6 +451,7 @@ def _with_tracer(
444451
response_counter,
445452
duration_histogram,
446453
exception_counter,
454+
event_logger,
447455
):
448456
def wrapper(wrapped, instance, args, kwargs):
449457
return func(
@@ -453,6 +461,7 @@ def wrapper(wrapped, instance, args, kwargs):
453461
response_counter,
454462
duration_histogram,
455463
exception_counter,
464+
event_logger,
456465
wrapped,
457466
instance,
458467
args,
@@ -464,6 +473,67 @@ def wrapper(wrapped, instance, args, kwargs):
464473
return _with_tracer
465474

466475

476+
@dont_throw
477+
def _handle_input(span, event_logger, name, instance, response_counter, args, kwargs):
478+
_set_api_attributes(span)
479+
480+
if "generate" in name:
481+
set_model_input_attributes(span, instance)
482+
483+
if should_emit_events() and event_logger:
484+
_emit_input_events(args, kwargs, event_logger)
485+
elif "generate" in name:
486+
_set_input_attributes(span, instance, kwargs)
487+
488+
489+
@dont_throw
490+
def _handle_response(
491+
span,
492+
event_logger,
493+
responses,
494+
response_counter,
495+
token_histogram,
496+
duration_histogram,
497+
duration,
498+
):
499+
set_model_response_attributes(
500+
span, responses, token_histogram, duration_histogram, duration
501+
)
502+
503+
if should_emit_events() and event_logger:
504+
_emit_response_events(responses, event_logger)
505+
else:
506+
_set_response_attributes(
507+
span,
508+
responses,
509+
token_histogram,
510+
response_counter,
511+
duration_histogram,
512+
duration,
513+
)
514+
515+
516+
@dont_throw
517+
def _handle_stream_response(
518+
span, event_logger, stream_response, stream_generated_text, stream_stop_reason
519+
):
520+
_set_model_stream_response_attributes(span, stream_response)
521+
522+
if should_emit_events() and event_logger:
523+
_emit_response_events(
524+
{
525+
"results": [
526+
{
527+
"stop_reason": stream_stop_reason,
528+
"generated_text": stream_generated_text,
529+
}
530+
]
531+
},
532+
)
533+
else:
534+
_set_stream_response_attributes(span, stream_response)
535+
536+
467537
@_with_tracer_wrapper
468538
def _wrap(
469539
tracer,
@@ -472,6 +542,7 @@ def _wrap(
472542
response_counter: Counter,
473543
duration_histogram: Histogram,
474544
exception_counter: Counter,
545+
event_logger: Union[EventLogger, None],
475546
wrapped,
476547
instance,
477548
args,
@@ -494,17 +565,15 @@ def _wrap(
494565
},
495566
)
496567

497-
_set_api_attributes(span)
568+
_handle_input(span, event_logger, name, instance, args, kwargs)
569+
498570
if "generate" in name:
499-
_set_input_attributes(span, instance, kwargs)
500571
if to_wrap.get("method") == "generate_text_stream":
501572
if (raw_flag := kwargs.get("raw_response", None)) is None:
502573
kwargs = {**kwargs, "raw_response": True}
503574
elif raw_flag is False:
504575
kwargs["raw_response"] = True
505576

506-
_emit_input_events(args, kwargs)
507-
508577
try:
509578
start_time = time.time()
510579
response = wrapped(*args, **kwargs)
@@ -528,6 +597,7 @@ def _wrap(
528597
if isinstance(response, types.GeneratorType):
529598
return _build_and_set_stream_response(
530599
span,
600+
event_logger,
531601
response,
532602
raw_flag,
533603
token_histogram,
@@ -537,17 +607,15 @@ def _wrap(
537607
)
538608
else:
539609
duration = end_time - start_time
540-
_set_response_attributes(
610+
_handle_response(
541611
span,
612+
event_logger,
542613
response,
543-
token_histogram,
544614
response_counter,
615+
token_histogram,
545616
duration_histogram,
546617
duration,
547618
)
548-
549-
_emit_response_events(response)
550-
551619
span.end()
552620
return response
553621

@@ -613,9 +681,11 @@ def _instrument(self, **kwargs):
613681
None,
614682
)
615683

684+
event_logger = None
685+
616686
if not Config.use_legacy_attributes:
617687
event_logger_provider = kwargs.get("event_logger_provider")
618-
Config.event_logger = get_event_logger(
688+
event_logger = get_event_logger(
619689
__name__, __version__, event_logger_provider=event_logger_provider
620690
)
621691

@@ -634,6 +704,7 @@ def _instrument(self, **kwargs):
634704
response_counter,
635705
duration_histogram,
636706
exception_counter,
707+
event_logger,
637708
),
638709
)
639710

0 commit comments

Comments
 (0)