Skip to content

Commit 3414c66

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
chore: apply formatting to custom files
1 parent 40aa598 commit 3414c66

File tree

3 files changed

+90
-28
lines changed

3 files changed

+90
-28
lines changed

src/openlayer/lib/core/base_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ class OpenlayerModel(abc.ABC):
4242
def run_from_cli(self) -> None:
4343
"""Run the model from the command line."""
4444
parser = argparse.ArgumentParser(description="Run data through a model.")
45-
parser.add_argument("--dataset-path", type=str, required=True, help="Path to the dataset")
45+
parser.add_argument(
46+
"--dataset-path", type=str, required=True, help="Path to the dataset"
47+
)
4648
parser.add_argument(
4749
"--output-dir",
4850
type=str,
@@ -83,7 +85,9 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
8385
# Filter row_dict to only include keys that are valid parameters
8486
# for the 'run' method
8587
row_dict = row.to_dict()
86-
filtered_kwargs = {k: v for k, v in row_dict.items() if k in run_signature.parameters}
88+
filtered_kwargs = {
89+
k: v for k, v in row_dict.items() if k in run_signature.parameters
90+
}
8791

8892
# Call the run method with filtered kwargs
8993
output = self.run(**filtered_kwargs)

src/openlayer/lib/integrations/langchain_callback.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def __init__(self, **kwargs: Any) -> None:
3535
self.metatada: Dict[str, Any] = kwargs or {}
3636

3737
# noqa arg002
38-
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
38+
def on_llm_start(
39+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
40+
) -> Any:
3941
"""Run when LLM starts running."""
4042
pass
4143

@@ -79,32 +81,45 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
7981
"""Run on new LLM token. Only available when streaming is enabled."""
8082
pass
8183

82-
def on_llm_end(self, response: langchain_schema.LLMResult, **kwargs: Any) -> Any: # noqa: ARG002, E501
84+
def on_llm_end(
85+
self, response: langchain_schema.LLMResult, **kwargs: Any # noqa: ARG002, E501
86+
) -> Any:
8387
"""Run when LLM ends running."""
8488
self.end_time = time.time()
8589
self.latency = (self.end_time - self.start_time) * 1000
8690

8791
if response.llm_output and "token_usage" in response.llm_output:
88-
self.prompt_tokens = response.llm_output["token_usage"].get("prompt_tokens", 0)
89-
self.completion_tokens = response.llm_output["token_usage"].get("completion_tokens", 0)
92+
self.prompt_tokens = response.llm_output["token_usage"].get(
93+
"prompt_tokens", 0
94+
)
95+
self.completion_tokens = response.llm_output["token_usage"].get(
96+
"completion_tokens", 0
97+
)
9098
self.cost = self._get_cost_estimate(
9199
num_input_tokens=self.prompt_tokens,
92100
num_output_tokens=self.completion_tokens,
93101
)
94-
self.total_tokens = response.llm_output["token_usage"].get("total_tokens", 0)
102+
self.total_tokens = response.llm_output["token_usage"].get(
103+
"total_tokens", 0
104+
)
95105

96106
for generations in response.generations:
97107
for generation in generations:
98108
self.output += generation.text.replace("\n", " ")
99109

100110
self._add_to_trace()
101111

102-
def _get_cost_estimate(self, num_input_tokens: int, num_output_tokens: int) -> float:
112+
def _get_cost_estimate(
113+
self, num_input_tokens: int, num_output_tokens: int
114+
) -> float:
103115
"""Returns the cost estimate for a given model and number of tokens."""
104116
if self.model not in constants.OPENAI_COST_PER_TOKEN:
105117
return None
106118
cost_per_token = constants.OPENAI_COST_PER_TOKEN[self.model]
107-
return cost_per_token["input"] * num_input_tokens + cost_per_token["output"] * num_output_tokens
119+
return (
120+
cost_per_token["input"] * num_input_tokens
121+
+ cost_per_token["output"] * num_output_tokens
122+
)
108123

109124
def _add_to_trace(self) -> None:
110125
"""Adds to the trace."""
@@ -126,42 +141,56 @@ def _add_to_trace(self) -> None:
126141
metadata=self.metatada,
127142
)
128143

129-
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
144+
def on_llm_error(
145+
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
146+
) -> Any:
130147
"""Run when LLM errors."""
131148
pass
132149

133-
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
150+
def on_chain_start(
151+
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
152+
) -> Any:
134153
"""Run when chain starts running."""
135154
pass
136155

137156
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
138157
"""Run when chain ends running."""
139158
pass
140159

141-
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
160+
def on_chain_error(
161+
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
162+
) -> Any:
142163
"""Run when chain errors."""
143164
pass
144165

145-
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
166+
def on_tool_start(
167+
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
168+
) -> Any:
146169
"""Run when tool starts running."""
147170
pass
148171

149172
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
150173
"""Run when tool ends running."""
151174
pass
152175

153-
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
176+
def on_tool_error(
177+
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
178+
) -> Any:
154179
"""Run when tool errors."""
155180
pass
156181

157182
def on_text(self, text: str, **kwargs: Any) -> Any:
158183
"""Run on arbitrary text."""
159184
pass
160185

161-
def on_agent_action(self, action: langchain_schema.AgentAction, **kwargs: Any) -> Any:
186+
def on_agent_action(
187+
self, action: langchain_schema.AgentAction, **kwargs: Any
188+
) -> Any:
162189
"""Run on agent action."""
163190
pass
164191

165-
def on_agent_finish(self, finish: langchain_schema.AgentFinish, **kwargs: Any) -> Any:
192+
def on_agent_finish(
193+
self, finish: langchain_schema.AgentFinish, **kwargs: Any
194+
) -> Any:
166195
"""Run on agent end."""
167196
pass

src/openlayer/lib/integrations/openai_tracer.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,16 @@ def stream_chunks(
139139
if delta.function_call.name:
140140
collected_function_call["name"] += delta.function_call.name
141141
if delta.function_call.arguments:
142-
collected_function_call["arguments"] += delta.function_call.arguments
142+
collected_function_call[
143+
"arguments"
144+
] += delta.function_call.arguments
143145
elif delta.tool_calls:
144146
if delta.tool_calls[0].function.name:
145147
collected_function_call["name"] += delta.tool_calls[0].function.name
146148
if delta.tool_calls[0].function.arguments:
147-
collected_function_call["arguments"] += delta.tool_calls[0].function.arguments
149+
collected_function_call["arguments"] += delta.tool_calls[
150+
0
151+
].function.arguments
148152

149153
yield chunk
150154
end_time = time.time()
@@ -155,16 +159,22 @@ def stream_chunks(
155159
finally:
156160
# Try to add step to the trace
157161
try:
158-
collected_output_data = [message for message in collected_output_data if message is not None]
162+
collected_output_data = [
163+
message for message in collected_output_data if message is not None
164+
]
159165
if collected_output_data:
160166
output_data = "".join(collected_output_data)
161167
else:
162-
collected_function_call["arguments"] = json.loads(collected_function_call["arguments"])
168+
collected_function_call["arguments"] = json.loads(
169+
collected_function_call["arguments"]
170+
)
163171
output_data = collected_function_call
164172
completion_cost = estimate_cost(
165173
model=kwargs.get("model"),
166174
prompt_tokens=0,
167-
completion_tokens=(num_of_completion_tokens if num_of_completion_tokens else 0),
175+
completion_tokens=(
176+
num_of_completion_tokens if num_of_completion_tokens else 0
177+
),
168178
is_azure_openai=is_azure_openai,
169179
)
170180

@@ -181,7 +191,13 @@ def stream_chunks(
181191
model_parameters=get_model_parameters(kwargs),
182192
raw_output=raw_outputs,
183193
id=inference_id,
184-
metadata={"timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None)},
194+
metadata={
195+
"timeToFirstToken": (
196+
(first_token_time - start_time) * 1000
197+
if first_token_time
198+
else None
199+
)
200+
},
185201
)
186202
add_to_trace(
187203
**trace_args,
@@ -207,7 +223,10 @@ def estimate_cost(
207223
cost_per_token = constants.AZURE_OPENAI_COST_PER_TOKEN[model]
208224
elif model in constants.OPENAI_COST_PER_TOKEN:
209225
cost_per_token = constants.OPENAI_COST_PER_TOKEN[model]
210-
return cost_per_token["input"] * prompt_tokens + cost_per_token["output"] * completion_tokens
226+
return (
227+
cost_per_token["input"] * prompt_tokens
228+
+ cost_per_token["output"] * completion_tokens
229+
)
211230
return None
212231

213232

@@ -266,8 +285,12 @@ def create_trace_args(
266285
def add_to_trace(is_azure_openai: bool = False, **kwargs) -> None:
267286
"""Add a chat completion step to the trace."""
268287
if is_azure_openai:
269-
tracer.add_chat_completion_step_to_trace(**kwargs, name="Azure OpenAI Chat Completion", provider="Azure")
270-
tracer.add_chat_completion_step_to_trace(**kwargs, name="OpenAI Chat Completion", provider="OpenAI")
288+
tracer.add_chat_completion_step_to_trace(
289+
**kwargs, name="Azure OpenAI Chat Completion", provider="Azure"
290+
)
291+
tracer.add_chat_completion_step_to_trace(
292+
**kwargs, name="OpenAI Chat Completion", provider="OpenAI"
293+
)
271294

272295

273296
def handle_non_streaming_create(
@@ -327,7 +350,9 @@ def handle_non_streaming_create(
327350
)
328351
# pylint: disable=broad-except
329352
except Exception as e:
330-
logger.error("Failed to trace the create chat completion request with Openlayer. %s", e)
353+
logger.error(
354+
"Failed to trace the create chat completion request with Openlayer. %s", e
355+
)
331356

332357
return response
333358

@@ -369,7 +394,9 @@ def parse_non_streaming_output_data(
369394

370395

371396
# --------------------------- OpenAI Assistants API -------------------------- #
372-
def trace_openai_assistant_thread_run(client: openai.OpenAI, run: "openai.types.beta.threads.run.Run") -> None:
397+
def trace_openai_assistant_thread_run(
398+
client: openai.OpenAI, run: "openai.types.beta.threads.run.Run"
399+
) -> None:
373400
"""Trace a run from an OpenAI assistant.
374401
375402
Once the run is completed, the thread data is published to Openlayer,
@@ -386,7 +413,9 @@ def trace_openai_assistant_thread_run(client: openai.OpenAI, run: "openai.types.
386413
metadata = _extract_run_metadata(run)
387414

388415
# Convert thread to prompt
389-
messages = client.beta.threads.messages.list(thread_id=run.thread_id, order="asc")
416+
messages = client.beta.threads.messages.list(
417+
thread_id=run.thread_id, order="asc"
418+
)
390419
prompt = _thread_messages_to_prompt(messages)
391420

392421
# Add step to the trace

0 commit comments

Comments
 (0)