@@ -35,7 +35,9 @@ def __init__(self, **kwargs: Any) -> None:
35
35
self .metatada : Dict [str , Any ] = kwargs or {}
36
36
37
37
# noqa arg002
38
- def on_llm_start (self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any ) -> Any :
38
+ def on_llm_start (
39
+ self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any
40
+ ) -> Any :
39
41
"""Run when LLM starts running."""
40
42
pass
41
43
@@ -79,32 +81,45 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
79
81
"""Run on new LLM token. Only available when streaming is enabled."""
80
82
pass
81
83
82
- def on_llm_end (self , response : langchain_schema .LLMResult , ** kwargs : Any ) -> Any : # noqa: ARG002, E501
84
+ def on_llm_end (
85
+ self , response : langchain_schema .LLMResult , ** kwargs : Any # noqa: ARG002, E501
86
+ ) -> Any :
83
87
"""Run when LLM ends running."""
84
88
self .end_time = time .time ()
85
89
self .latency = (self .end_time - self .start_time ) * 1000
86
90
87
91
if response .llm_output and "token_usage" in response .llm_output :
88
- self .prompt_tokens = response .llm_output ["token_usage" ].get ("prompt_tokens" , 0 )
89
- self .completion_tokens = response .llm_output ["token_usage" ].get ("completion_tokens" , 0 )
92
+ self .prompt_tokens = response .llm_output ["token_usage" ].get (
93
+ "prompt_tokens" , 0
94
+ )
95
+ self .completion_tokens = response .llm_output ["token_usage" ].get (
96
+ "completion_tokens" , 0
97
+ )
90
98
self .cost = self ._get_cost_estimate (
91
99
num_input_tokens = self .prompt_tokens ,
92
100
num_output_tokens = self .completion_tokens ,
93
101
)
94
- self .total_tokens = response .llm_output ["token_usage" ].get ("total_tokens" , 0 )
102
+ self .total_tokens = response .llm_output ["token_usage" ].get (
103
+ "total_tokens" , 0
104
+ )
95
105
96
106
for generations in response .generations :
97
107
for generation in generations :
98
108
self .output += generation .text .replace ("\n " , " " )
99
109
100
110
self ._add_to_trace ()
101
111
102
- def _get_cost_estimate (self , num_input_tokens : int , num_output_tokens : int ) -> float :
112
+ def _get_cost_estimate (
113
+ self , num_input_tokens : int , num_output_tokens : int
114
+ ) -> float :
103
115
"""Returns the cost estimate for a given model and number of tokens."""
104
116
if self .model not in constants .OPENAI_COST_PER_TOKEN :
105
117
return None
106
118
cost_per_token = constants .OPENAI_COST_PER_TOKEN [self .model ]
107
- return cost_per_token ["input" ] * num_input_tokens + cost_per_token ["output" ] * num_output_tokens
119
+ return (
120
+ cost_per_token ["input" ] * num_input_tokens
121
+ + cost_per_token ["output" ] * num_output_tokens
122
+ )
108
123
109
124
def _add_to_trace (self ) -> None :
110
125
"""Adds to the trace."""
@@ -126,42 +141,56 @@ def _add_to_trace(self) -> None:
126
141
metadata = self .metatada ,
127
142
)
128
143
129
- def on_llm_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
144
+ def on_llm_error (
145
+ self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
146
+ ) -> Any :
130
147
"""Run when LLM errors."""
131
148
pass
132
149
133
- def on_chain_start (self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
150
+ def on_chain_start (
151
+ self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any
152
+ ) -> Any :
134
153
"""Run when chain starts running."""
135
154
pass
136
155
137
156
def on_chain_end (self , outputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
138
157
"""Run when chain ends running."""
139
158
pass
140
159
141
- def on_chain_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
160
+ def on_chain_error (
161
+ self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
162
+ ) -> Any :
142
163
"""Run when chain errors."""
143
164
pass
144
165
145
- def on_tool_start (self , serialized : Dict [str , Any ], input_str : str , ** kwargs : Any ) -> Any :
166
+ def on_tool_start (
167
+ self , serialized : Dict [str , Any ], input_str : str , ** kwargs : Any
168
+ ) -> Any :
146
169
"""Run when tool starts running."""
147
170
pass
148
171
149
172
def on_tool_end (self , output : str , ** kwargs : Any ) -> Any :
150
173
"""Run when tool ends running."""
151
174
pass
152
175
153
- def on_tool_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
176
+ def on_tool_error (
177
+ self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
178
+ ) -> Any :
154
179
"""Run when tool errors."""
155
180
pass
156
181
157
182
def on_text (self , text : str , ** kwargs : Any ) -> Any :
158
183
"""Run on arbitrary text."""
159
184
pass
160
185
161
- def on_agent_action (self , action : langchain_schema .AgentAction , ** kwargs : Any ) -> Any :
186
+ def on_agent_action (
187
+ self , action : langchain_schema .AgentAction , ** kwargs : Any
188
+ ) -> Any :
162
189
"""Run on agent action."""
163
190
pass
164
191
165
- def on_agent_finish (self , finish : langchain_schema .AgentFinish , ** kwargs : Any ) -> Any :
192
+ def on_agent_finish (
193
+ self , finish : langchain_schema .AgentFinish , ** kwargs : Any
194
+ ) -> Any :
166
195
"""Run on agent end."""
167
196
pass
0 commit comments