Skip to content

Commit cacd89f

Browse files
committed
feat: python async function tracing in dev mode, closing OPEN-6157
1 parent d0e7934 commit cacd89f

File tree

6 files changed

+126
-64
lines changed

6 files changed

+126
-64
lines changed

src/openlayer/lib/core/base_model.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Base class for an Openlayer model."""
22

3-
import os
43
import abc
4+
import argparse
5+
import inspect
56
import json
7+
import os
68
import time
7-
import inspect
8-
import argparse
9+
from dataclasses import dataclass, field
910
from typing import Any, Dict, Tuple
10-
from dataclasses import field, dataclass
1111

1212
import pandas as pd
1313

@@ -42,9 +42,7 @@ class OpenlayerModel(abc.ABC):
4242
def run_from_cli(self) -> None:
4343
"""Run the model from the command line."""
4444
parser = argparse.ArgumentParser(description="Run data through a model.")
45-
parser.add_argument(
46-
"--dataset-path", type=str, required=True, help="Path to the dataset"
47-
)
45+
parser.add_argument("--dataset-path", type=str, required=True, help="Path to the dataset")
4846
parser.add_argument(
4947
"--output-dir",
5048
type=str,
@@ -85,9 +83,7 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
8583
# Filter row_dict to only include keys that are valid parameters
8684
# for the 'run' method
8785
row_dict = row.to_dict()
88-
filtered_kwargs = {
89-
k: v for k, v in row_dict.items() if k in run_signature.parameters
90-
}
86+
filtered_kwargs = {k: v for k, v in row_dict.items() if k in run_signature.parameters}
9187

9288
# Call the run method with filtered kwargs
9389
output = self.run(**filtered_kwargs)

src/openlayer/lib/integrations/langchain_callback.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pylint: disable=unused-argument
44
import time
5-
from typing import Any, Dict, List, Union, Optional
5+
from typing import Any, Dict, List, Optional, Union
66

77
from langchain import schema as langchain_schema
88
from langchain.callbacks.base import BaseCallbackHandler
@@ -35,9 +35,7 @@ def __init__(self, **kwargs: Any) -> None:
3535
self.metatada: Dict[str, Any] = kwargs or {}
3636

3737
# noqa arg002
38-
def on_llm_start(
39-
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
40-
) -> Any:
38+
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
4139
"""Run when LLM starts running."""
4240
pass
4341

@@ -81,45 +79,32 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
8179
"""Run on new LLM token. Only available when streaming is enabled."""
8280
pass
8381

84-
def on_llm_end(
85-
self, response: langchain_schema.LLMResult, **kwargs: Any # noqa: ARG002, E501
86-
) -> Any:
82+
def on_llm_end(self, response: langchain_schema.LLMResult, **kwargs: Any) -> Any: # noqa: ARG002, E501
8783
"""Run when LLM ends running."""
8884
self.end_time = time.time()
8985
self.latency = (self.end_time - self.start_time) * 1000
9086

9187
if response.llm_output and "token_usage" in response.llm_output:
92-
self.prompt_tokens = response.llm_output["token_usage"].get(
93-
"prompt_tokens", 0
94-
)
95-
self.completion_tokens = response.llm_output["token_usage"].get(
96-
"completion_tokens", 0
97-
)
88+
self.prompt_tokens = response.llm_output["token_usage"].get("prompt_tokens", 0)
89+
self.completion_tokens = response.llm_output["token_usage"].get("completion_tokens", 0)
9890
self.cost = self._get_cost_estimate(
9991
num_input_tokens=self.prompt_tokens,
10092
num_output_tokens=self.completion_tokens,
10193
)
102-
self.total_tokens = response.llm_output["token_usage"].get(
103-
"total_tokens", 0
104-
)
94+
self.total_tokens = response.llm_output["token_usage"].get("total_tokens", 0)
10595

10696
for generations in response.generations:
10797
for generation in generations:
10898
self.output += generation.text.replace("\n", " ")
10999

110100
self._add_to_trace()
111101

112-
def _get_cost_estimate(
113-
self, num_input_tokens: int, num_output_tokens: int
114-
) -> float:
102+
def _get_cost_estimate(self, num_input_tokens: int, num_output_tokens: int) -> float:
115103
"""Returns the cost estimate for a given model and number of tokens."""
116104
if self.model not in constants.OPENAI_COST_PER_TOKEN:
117105
return None
118106
cost_per_token = constants.OPENAI_COST_PER_TOKEN[self.model]
119-
return (
120-
cost_per_token["input"] * num_input_tokens
121-
+ cost_per_token["output"] * num_output_tokens
122-
)
107+
return cost_per_token["input"] * num_input_tokens + cost_per_token["output"] * num_output_tokens
123108

124109
def _add_to_trace(self) -> None:
125110
"""Adds to the trace."""
@@ -141,56 +126,42 @@ def _add_to_trace(self) -> None:
141126
metadata=self.metatada,
142127
)
143128

144-
def on_llm_error(
145-
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
146-
) -> Any:
129+
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
147130
"""Run when LLM errors."""
148131
pass
149132

150-
def on_chain_start(
151-
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
152-
) -> Any:
133+
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
153134
"""Run when chain starts running."""
154135
pass
155136

156137
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
157138
"""Run when chain ends running."""
158139
pass
159140

160-
def on_chain_error(
161-
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
162-
) -> Any:
141+
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
163142
"""Run when chain errors."""
164143
pass
165144

166-
def on_tool_start(
167-
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
168-
) -> Any:
145+
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
169146
"""Run when tool starts running."""
170147
pass
171148

172149
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
173150
"""Run when tool ends running."""
174151
pass
175152

176-
def on_tool_error(
177-
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
178-
) -> Any:
153+
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
179154
"""Run when tool errors."""
180155
pass
181156

182157
def on_text(self, text: str, **kwargs: Any) -> Any:
183158
"""Run on arbitrary text."""
184159
pass
185160

186-
def on_agent_action(
187-
self, action: langchain_schema.AgentAction, **kwargs: Any
188-
) -> Any:
161+
def on_agent_action(self, action: langchain_schema.AgentAction, **kwargs: Any) -> Any:
189162
"""Run on agent action."""
190163
pass
191164

192-
def on_agent_finish(
193-
self, finish: langchain_schema.AgentFinish, **kwargs: Any
194-
) -> Any:
165+
def on_agent_finish(self, finish: langchain_schema.AgentFinish, **kwargs: Any) -> Any:
195166
"""Run on agent end."""
196167
pass

src/openlayer/lib/integrations/openai_tracer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Module with methods used to trace OpenAI / Azure OpenAI LLMs."""
22

33
import json
4-
import time
54
import logging
6-
from typing import Any, Dict, List, Union, Iterator, Optional
5+
import time
76
from functools import wraps
7+
from typing import Any, Dict, Iterator, List, Optional, Union
88

99
import openai
1010

src/openlayer/lib/tracing/steps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import uuid
55
from typing import Any, Dict, Optional
66

7-
from . import enums
87
from .. import utils
8+
from . import enums
99

1010

1111
class Step:

src/openlayer/lib/tracing/tracer.py

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
"""Module with the logic to create and manage traces and steps."""
22

3-
import time
3+
import asyncio
4+
import contextvars
45
import inspect
56
import logging
6-
import contextvars
7-
from typing import Any, Dict, List, Tuple, Optional, Generator
8-
from functools import wraps
7+
import time
98
from contextlib import contextmanager
9+
from functools import wraps
10+
from typing import Any, Awaitable, Dict, Generator, List, Optional, Tuple
1011

11-
from . import enums, steps, traces
12-
from .. import utils
1312
from ..._client import Openlayer
1413
from ...types.inference_pipelines.data_stream_params import ConfigLlmData
14+
from .. import utils
15+
from . import enums, steps, traces
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -195,6 +196,100 @@ def wrapper(*func_args, **func_kwargs):
195196
return decorator
196197

197198

199+
def trace_async(*step_args, **step_kwargs):
200+
"""Decorator to trace a function.
201+
202+
Examples
203+
--------
204+
205+
To trace a function, simply decorate it with the ``@trace()`` decorator. By doing so,
206+
the functions inputs, outputs, and metadata will be automatically logged to your
207+
Openlayer project.
208+
209+
>>> import os
210+
>>> from openlayer.tracing import tracer
211+
>>>
212+
>>> # Set the environment variables
213+
>>> os.environ["OPENLAYER_API_KEY"] = "YOUR_OPENLAYER_API_KEY_HERE"
214+
>>> os.environ["OPENLAYER_PROJECT_NAME"] = "YOUR_OPENLAYER_PROJECT_NAME_HERE"
215+
>>>
216+
>>> # Decorate all the functions you want to trace
217+
>>> @tracer.trace_async()
218+
>>> async def main(user_query: str) -> str:
219+
>>> context = retrieve_context(user_query)
220+
>>> answer = generate_answer(user_query, context)
221+
>>> return answer
222+
>>>
223+
>>> @tracer.trace_async()
224+
>>> def retrieve_context(user_query: str) -> str:
225+
>>> return "Some context"
226+
>>>
227+
>>> @tracer.trace_async()
228+
>>> def generate_answer(user_query: str, context: str) -> str:
229+
>>> return "Some answer"
230+
>>>
231+
>>> # Every time the main function is called, the data is automatically
232+
>>> # streamed to your Openlayer project. E.g.:
233+
>>> tracer.run_async_func(main("What is the meaning of life?"))
234+
"""
235+
236+
def decorator(func):
237+
func_signature = inspect.signature(func)
238+
239+
@wraps(func)
240+
async def wrapper(*func_args, **func_kwargs):
241+
if step_kwargs.get("name") is None:
242+
step_kwargs["name"] = func.__name__
243+
with create_step(*step_args, **step_kwargs) as step:
244+
output = exception = None
245+
try:
246+
output = await func(*func_args, **func_kwargs)
247+
# pylint: disable=broad-except
248+
except Exception as exc:
249+
step.log(metadata={"Exceptions": str(exc)})
250+
exception = exc
251+
end_time = time.time()
252+
latency = (end_time - step.start_time) * 1000 # in ms
253+
254+
bound = func_signature.bind(*func_args, **func_kwargs)
255+
bound.apply_defaults()
256+
inputs = dict(bound.arguments)
257+
inputs.pop("self", None)
258+
inputs.pop("cls", None)
259+
260+
step.log(
261+
inputs=inputs,
262+
output=output,
263+
end_time=end_time,
264+
latency=latency,
265+
)
266+
267+
if exception is not None:
268+
raise exception
269+
return output
270+
271+
return wrapper
272+
273+
return decorator
274+
275+
276+
async def _invoke_with_context(coroutine: Awaitable[Any]) -> Tuple[contextvars.Context, Any]:
277+
"""Runs a coroutine and preserves the context variables set within it."""
278+
result = await coroutine
279+
context = contextvars.copy_context()
280+
return context, result
281+
282+
283+
def run_async_func(coroutine: Awaitable[Any]) -> Any:
284+
"""Runs an async function while preserving the context. This is needed
285+
for tracing async functions.
286+
"""
287+
context, result = asyncio.run(_invoke_with_context(coroutine))
288+
for key, value in context.items():
289+
key.set(value)
290+
return result
291+
292+
198293
# --------------------- Helper post-processing functions --------------------- #
199294
def post_process_trace(
200295
trace_obj: traces.Trace,

src/openlayer/lib/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
Openlayer SDK.
33
"""
44

5-
import os
65
import json
6+
import os
77
from typing import Optional
88

99

0 commit comments

Comments
 (0)