Skip to content

Commit 7a2e2b5

Browse files
committed
changes to litellm
1 parent e2eaaac commit 7a2e2b5

File tree

1 file changed

+155
-153
lines changed

1 file changed

+155
-153
lines changed

sentry_sdk/integrations/litellm.py

Lines changed: 155 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import wraps
12
from typing import TYPE_CHECKING
23

34
import sentry_sdk
@@ -11,79 +12,40 @@
1112
from sentry_sdk.consts import SPANDATA
1213
from sentry_sdk.integrations import DidNotEnable, Integration
1314
from sentry_sdk.scope import should_send_default_pii
14-
from sentry_sdk.utils import event_from_exception
15+
from sentry_sdk.utils import capture_internal_exceptions, event_from_exception
1516

1617
if TYPE_CHECKING:
1718
from typing import Any, Dict
1819
from datetime import datetime
1920

2021
try:
2122
import litellm # type: ignore[import-not-found]
22-
from litellm import input_callback, success_callback, failure_callback
2323
except ImportError:
2424
raise DidNotEnable("LiteLLM not installed")
2525

2626

27-
def _get_metadata_dict(kwargs: "Dict[str, Any]") -> "Dict[str, Any]":
28-
"""Get the metadata dictionary from the kwargs."""
29-
litellm_params = kwargs.setdefault("litellm_params", {})
30-
31-
# we need this weird little dance, as metadata might be set but may be None initially
32-
metadata = litellm_params.get("metadata")
33-
if metadata is None:
34-
metadata = {}
35-
litellm_params["metadata"] = metadata
36-
return metadata
37-
38-
39-
def _input_callback(kwargs: "Dict[str, Any]") -> None:
40-
"""Handle the start of a request."""
41-
integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration)
42-
43-
if integration is None:
44-
return
45-
46-
# Get key parameters
47-
full_model = kwargs.get("model", "")
27+
def _get_provider_and_model(full_model: str) -> "tuple[str, str]":
28+
"""Extract provider and model name from full model string."""
4829
try:
4930
model, provider, _, _ = litellm.get_llm_provider(full_model)
31+
return provider, model
5032
except Exception:
51-
model = full_model
52-
provider = "unknown"
53-
54-
call_type = kwargs.get("call_type", None)
55-
if call_type == "embedding":
56-
operation = "embeddings"
57-
else:
58-
operation = "chat"
59-
60-
# Start a new span/transaction
61-
span = get_start_span_function()(
62-
op=(
63-
consts.OP.GEN_AI_CHAT
64-
if operation == "chat"
65-
else consts.OP.GEN_AI_EMBEDDINGS
66-
),
67-
name=f"{operation} {model}",
68-
origin=LiteLLMIntegration.origin,
69-
)
70-
span.__enter__()
33+
return "unknown", full_model
7134

72-
# Store span for later
73-
_get_metadata_dict(kwargs)["_sentry_span"] = span
74-
75-
# Set basic data
76-
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, provider)
77-
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, operation)
7835

36+
def _set_input_data(
37+
span: "Any",
38+
kwargs: "Dict[str, Any]",
39+
operation: str,
40+
integration: "LiteLLMIntegration",
41+
) -> None:
42+
"""Set input data on the span."""
7943
# Record input/messages if allowed
8044
if should_send_default_pii() and integration.include_prompts:
8145
if operation == "embeddings":
82-
# For embeddings, look for the 'input' parameter
8346
embedding_input = kwargs.get("input")
8447
if embedding_input:
8548
scope = sentry_sdk.get_current_scope()
86-
# Normalize to list format
8749
input_list = (
8850
embedding_input
8951
if isinstance(embedding_input, list)
@@ -98,7 +60,6 @@ def _input_callback(kwargs: "Dict[str, Any]") -> None:
9860
unpack=False,
9961
)
10062
else:
101-
# For chat, look for the 'messages' parameter
10263
messages = kwargs.get("messages", [])
10364
if messages:
10465
scope = sentry_sdk.get_current_scope()
@@ -111,7 +72,7 @@ def _input_callback(kwargs: "Dict[str, Any]") -> None:
11172
unpack=False,
11273
)
11374

114-
# Record other parameters
75+
# Record standard parameters
11576
params = {
11677
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
11778
"stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
@@ -126,107 +87,157 @@ def _input_callback(kwargs: "Dict[str, Any]") -> None:
12687
if value is not None:
12788
set_data_normalized(span, attribute, value)
12889

129-
# Record LiteLLM-specific parameters
130-
litellm_params = {
131-
"api_base": kwargs.get("api_base"),
132-
"api_version": kwargs.get("api_version"),
133-
"custom_llm_provider": kwargs.get("custom_llm_provider"),
134-
}
135-
for key, value in litellm_params.items():
136-
if value is not None:
137-
set_data_normalized(span, f"gen_ai.litellm.{key}", value)
138-
13990

140-
def _success_callback(
141-
kwargs: "Dict[str, Any]",
142-
completion_response: "Any",
143-
start_time: "datetime",
144-
end_time: "datetime",
91+
def _set_output_data(
92+
span: "Any",
93+
response: "Any",
94+
integration: "LiteLLMIntegration",
14595
) -> None:
146-
"""Handle successful completion."""
96+
"""Set output data on the span."""
97+
# Record model information
98+
if hasattr(response, "model"):
99+
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model)
100+
101+
# Record response content if allowed
102+
if should_send_default_pii() and integration.include_prompts:
103+
if hasattr(response, "choices"):
104+
response_messages = []
105+
for choice in response.choices:
106+
if hasattr(choice, "message"):
107+
if hasattr(choice.message, "model_dump"):
108+
response_messages.append(choice.message.model_dump())
109+
elif hasattr(choice.message, "dict"):
110+
response_messages.append(choice.message.dict())
111+
else:
112+
msg = {}
113+
if hasattr(choice.message, "role"):
114+
msg["role"] = choice.message.role
115+
if hasattr(choice.message, "content"):
116+
msg["content"] = choice.message.content
117+
if hasattr(choice.message, "tool_calls"):
118+
msg["tool_calls"] = choice.message.tool_calls
119+
response_messages.append(msg)
120+
121+
if response_messages:
122+
set_data_normalized(
123+
span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_messages
124+
)
147125

148-
span = _get_metadata_dict(kwargs).get("_sentry_span")
149-
if span is None:
150-
return
126+
# Record token usage
127+
if hasattr(response, "usage"):
128+
usage = response.usage
129+
130+
# Extract cached tokens from prompt_tokens_details (OpenAI format used by LiteLLM)
131+
cached_tokens = None
132+
prompt_tokens_details = getattr(usage, "prompt_tokens_details", None)
133+
if prompt_tokens_details is not None:
134+
cached_tokens = getattr(prompt_tokens_details, "cached_tokens", None)
135+
136+
# Extract cache write tokens (Anthropic only)
137+
cache_creation_tokens = getattr(usage, "cache_creation_input_tokens", None)
138+
139+
record_token_usage(
140+
span,
141+
input_tokens=getattr(usage, "prompt_tokens", None),
142+
input_tokens_cached=cached_tokens,
143+
input_tokens_cache_write=cache_creation_tokens,
144+
output_tokens=getattr(usage, "completion_tokens", None),
145+
total_tokens=getattr(usage, "total_tokens", None),
146+
)
151147

152-
integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration)
153-
if integration is None:
154-
return
155148

156-
try:
157-
# Record model information
158-
if hasattr(completion_response, "model"):
159-
set_data_normalized(
160-
span, SPANDATA.GEN_AI_RESPONSE_MODEL, completion_response.model
161-
)
149+
def _wrap_completion(original_func: "Any") -> "Any":
150+
"""Wrap litellm.completion to add instrumentation."""
162151

163-
# Record response content if allowed
164-
if should_send_default_pii() and integration.include_prompts:
165-
if hasattr(completion_response, "choices"):
166-
response_messages = []
167-
for choice in completion_response.choices:
168-
if hasattr(choice, "message"):
169-
if hasattr(choice.message, "model_dump"):
170-
response_messages.append(choice.message.model_dump())
171-
elif hasattr(choice.message, "dict"):
172-
response_messages.append(choice.message.dict())
173-
else:
174-
# Fallback for basic message objects
175-
msg = {}
176-
if hasattr(choice.message, "role"):
177-
msg["role"] = choice.message.role
178-
if hasattr(choice.message, "content"):
179-
msg["content"] = choice.message.content
180-
if hasattr(choice.message, "tool_calls"):
181-
msg["tool_calls"] = choice.message.tool_calls
182-
response_messages.append(msg)
183-
184-
if response_messages:
185-
set_data_normalized(
186-
span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_messages
187-
)
152+
@wraps(original_func)
153+
def wrapper(*args: "Any", **kwargs: "Any") -> "Any":
154+
integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration)
155+
if integration is None:
156+
return original_func(*args, **kwargs)
188157

189-
# Record token usage
190-
if hasattr(completion_response, "usage"):
191-
usage = completion_response.usage
192-
record_token_usage(
193-
span,
194-
input_tokens=getattr(usage, "prompt_tokens", None),
195-
input_tokens_cached=getattr(usage, "cache_read_input_tokens", None),
196-
input_tokens_cache_write=getattr(
197-
usage, "cache_write_input_tokens", None
198-
),
199-
output_tokens=getattr(usage, "completion_tokens", None),
200-
total_tokens=getattr(usage, "total_tokens", None),
158+
# Get model and provider
159+
full_model = kwargs.get("model", args[0] if args else "")
160+
provider, model = _get_provider_and_model(full_model)
161+
162+
# Create span
163+
span = get_start_span_function()(
164+
op=consts.OP.GEN_AI_CHAT,
165+
name=f"chat {model}",
166+
origin=LiteLLMIntegration.origin,
167+
)
168+
span.__enter__()
169+
170+
# Set basic data
171+
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, provider)
172+
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
173+
174+
with capture_internal_exceptions():
175+
_set_input_data(span, kwargs, "chat", integration)
176+
177+
try:
178+
response = original_func(*args, **kwargs)
179+
with capture_internal_exceptions():
180+
_set_output_data(span, response, integration)
181+
return response
182+
except Exception as exc:
183+
event, hint = event_from_exception(
184+
exc,
185+
client_options=sentry_sdk.get_client().options,
186+
mechanism={"type": "litellm", "handled": False},
201187
)
188+
sentry_sdk.capture_event(event, hint=hint)
189+
raise
190+
finally:
191+
span.__exit__(None, None, None)
202192

203-
finally:
204-
# Always finish the span and clean up
205-
span.__exit__(None, None, None)
193+
return wrapper
206194

207195

208-
def _failure_callback(
209-
kwargs: "Dict[str, Any]",
210-
exception: Exception,
211-
start_time: "datetime",
212-
end_time: "datetime",
213-
) -> None:
214-
"""Handle request failure."""
215-
span = _get_metadata_dict(kwargs).get("_sentry_span")
216-
if span is None:
217-
return
196+
def _wrap_acompletion(original_func: "Any") -> "Any":
197+
"""Wrap litellm.acompletion to add instrumentation."""
218198

219-
try:
220-
# Capture the exception
221-
event, hint = event_from_exception(
222-
exception,
223-
client_options=sentry_sdk.get_client().options,
224-
mechanism={"type": "litellm", "handled": False},
199+
@wraps(original_func)
200+
async def wrapper(*args: "Any", **kwargs: "Any") -> "Any":
201+
integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration)
202+
if integration is None:
203+
return await original_func(*args, **kwargs)
204+
205+
# Get model and provider
206+
full_model = kwargs.get("model", args[0] if args else "")
207+
provider, model = _get_provider_and_model(full_model)
208+
209+
# Create span
210+
span = get_start_span_function()(
211+
op=consts.OP.GEN_AI_CHAT,
212+
name=f"chat {model}",
213+
origin=LiteLLMIntegration.origin,
225214
)
226-
sentry_sdk.capture_event(event, hint=hint)
227-
finally:
228-
# Always finish the span and clean up
229-
span.__exit__(type(exception), exception, None)
215+
span.__enter__()
216+
217+
# Set basic data
218+
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, provider)
219+
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
220+
221+
with capture_internal_exceptions():
222+
_set_input_data(span, kwargs, "chat", integration)
223+
224+
try:
225+
response = await original_func(*args, **kwargs)
226+
with capture_internal_exceptions():
227+
_set_output_data(span, response, integration)
228+
return response
229+
except Exception as exc:
230+
event, hint = event_from_exception(
231+
exc,
232+
client_options=sentry_sdk.get_client().options,
233+
mechanism={"type": "litellm", "handled": False},
234+
)
235+
sentry_sdk.capture_event(event, hint=hint)
236+
raise
237+
finally:
238+
span.__exit__(None, None, None)
239+
240+
return wrapper
230241

231242

232243
class LiteLLMIntegration(Integration):
@@ -282,15 +293,6 @@ def __init__(self: "LiteLLMIntegration", include_prompts: bool = True) -> None:
282293

283294
@staticmethod
284295
def setup_once() -> None:
285-
"""Set up LiteLLM callbacks for monitoring."""
286-
litellm.input_callback = input_callback or []
287-
if _input_callback not in litellm.input_callback:
288-
litellm.input_callback.append(_input_callback)
289-
290-
litellm.success_callback = success_callback or []
291-
if _success_callback not in litellm.success_callback:
292-
litellm.success_callback.append(_success_callback)
293-
294-
litellm.failure_callback = failure_callback or []
295-
if _failure_callback not in litellm.failure_callback:
296-
litellm.failure_callback.append(_failure_callback)
296+
"""Set up LiteLLM instrumentation by wrapping completion functions."""
297+
litellm.completion = _wrap_completion(litellm.completion)
298+
litellm.acompletion = _wrap_acompletion(litellm.acompletion)

0 commit comments

Comments
 (0)