1
1
# coding=utf-8
2
- import warnings
3
- from typing import List , Dict , Optional , Any , Iterator , cast , Type , Union
2
+ from typing import Dict , Optional , Any , Iterator , cast , Union , Sequence , Callable , Mapping
4
3
5
- import openai
6
- from langchain_core .callbacks import CallbackManagerForLLMRun
7
4
from langchain_core .language_models import LanguageModelInput
8
- from langchain_core .messages import BaseMessage , get_buffer_string , BaseMessageChunk , AIMessageChunk
9
- from langchain_core .outputs import ChatGenerationChunk , ChatGeneration
5
+ from langchain_core .messages import BaseMessage , get_buffer_string , BaseMessageChunk , HumanMessageChunk , AIMessageChunk , \
6
+ SystemMessageChunk , FunctionMessageChunk , ChatMessageChunk
7
+ from langchain_core .messages .ai import UsageMetadata
8
+ from langchain_core .messages .tool import tool_call_chunk , ToolMessageChunk
9
+ from langchain_core .outputs import ChatGenerationChunk
10
10
from langchain_core .runnables import RunnableConfig , ensure_config
11
- from langchain_core .utils . pydantic import is_basemodel_subclass
11
+ from langchain_core .tools import BaseTool
12
12
from langchain_openai import ChatOpenAI
13
+ from langchain_openai .chat_models .base import _create_usage_metadata
13
14
14
15
from common .config .tokenizer_manage_config import TokenizerManage
15
16
@@ -19,14 +20,78 @@ def custom_get_token_ids(text: str):
19
20
return tokenizer .encode (text )
20
21
21
22
23
+ def _convert_delta_to_message_chunk (
24
+ _dict : Mapping [str , Any ], default_class : type [BaseMessageChunk ]
25
+ ) -> BaseMessageChunk :
26
+ id_ = _dict .get ("id" )
27
+ reasoning_content = cast (str , _dict .get ("reasoning_content" ) or "" )
28
+ role = cast (str , _dict .get ("role" ))
29
+ content = cast (str , _dict .get ("content" ) or "" )
30
+ additional_kwargs : dict = {'reasoning_content' : reasoning_content }
31
+ if _dict .get ("function_call" ):
32
+ function_call = dict (_dict ["function_call" ])
33
+ if "name" in function_call and function_call ["name" ] is None :
34
+ function_call ["name" ] = ""
35
+ additional_kwargs ["function_call" ] = function_call
36
+ tool_call_chunks = []
37
+ if raw_tool_calls := _dict .get ("tool_calls" ):
38
+ additional_kwargs ["tool_calls" ] = raw_tool_calls
39
+ try :
40
+ tool_call_chunks = [
41
+ tool_call_chunk (
42
+ name = rtc ["function" ].get ("name" ),
43
+ args = rtc ["function" ].get ("arguments" ),
44
+ id = rtc .get ("id" ),
45
+ index = rtc ["index" ],
46
+ )
47
+ for rtc in raw_tool_calls
48
+ ]
49
+ except KeyError :
50
+ pass
51
+
52
+ if role == "user" or default_class == HumanMessageChunk :
53
+ return HumanMessageChunk (content = content , id = id_ )
54
+ elif role == "assistant" or default_class == AIMessageChunk :
55
+ return AIMessageChunk (
56
+ content = content ,
57
+ additional_kwargs = additional_kwargs ,
58
+ id = id_ ,
59
+ tool_call_chunks = tool_call_chunks , # type: ignore[arg-type]
60
+ )
61
+ elif role in ("system" , "developer" ) or default_class == SystemMessageChunk :
62
+ if role == "developer" :
63
+ additional_kwargs = {"__openai_role__" : "developer" }
64
+ else :
65
+ additional_kwargs = {}
66
+ return SystemMessageChunk (
67
+ content = content , id = id_ , additional_kwargs = additional_kwargs
68
+ )
69
+ elif role == "function" or default_class == FunctionMessageChunk :
70
+ return FunctionMessageChunk (content = content , name = _dict ["name" ], id = id_ )
71
+ elif role == "tool" or default_class == ToolMessageChunk :
72
+ return ToolMessageChunk (
73
+ content = content , tool_call_id = _dict ["tool_call_id" ], id = id_
74
+ )
75
+ elif role or default_class == ChatMessageChunk :
76
+ return ChatMessageChunk (content = content , role = role , id = id_ )
77
+ else :
78
+ return default_class (content = content , id = id_ ) # type: ignore
79
+
80
+
22
81
class BaseChatOpenAI (ChatOpenAI ):
23
82
usage_metadata : dict = {}
24
83
custom_get_token_ids = custom_get_token_ids
25
84
26
85
def get_last_generation_info (self ) -> Optional [Dict [str , Any ]]:
27
86
return self .usage_metadata
28
87
29
- def get_num_tokens_from_messages (self , messages : List [BaseMessage ]) -> int :
88
+ def get_num_tokens_from_messages (
89
+ self ,
90
+ messages : list [BaseMessage ],
91
+ tools : Optional [
92
+ Sequence [Union [dict [str , Any ], type , Callable , BaseTool ]]
93
+ ] = None ,
94
+ ) -> int :
30
95
if self .usage_metadata is None or self .usage_metadata == {}:
31
96
try :
32
97
return super ().get_num_tokens_from_messages (messages )
@@ -44,114 +109,77 @@ def get_num_tokens(self, text: str) -> int:
44
109
return len (tokenizer .encode (text ))
45
110
return self .get_last_generation_info ().get ('output_tokens' , 0 )
46
111
47
- def _stream (
112
+ def _stream (self , * args : Any , ** kwargs : Any ) -> Iterator [ChatGenerationChunk ]:
113
+ kwargs ['stream_usage' ] = True
114
+ for chunk in super ()._stream (* args , ** kwargs ):
115
+ if chunk .message .usage_metadata is not None :
116
+ self .usage_metadata = chunk .message .usage_metadata
117
+ yield chunk
118
+
119
+ def _convert_chunk_to_generation_chunk (
48
120
self ,
49
- messages : List [BaseMessage ],
50
- stop : Optional [List [str ]] = None ,
51
- run_manager : Optional [CallbackManagerForLLMRun ] = None ,
52
- ** kwargs : Any ,
53
- ) -> Iterator [ChatGenerationChunk ]:
54
- kwargs ["stream" ] = True
55
- kwargs ["stream_options" ] = {"include_usage" : True }
56
- """Set default stream_options."""
57
- stream_usage = self ._should_stream_usage (kwargs .get ('stream_usage' ), ** kwargs )
58
- # Note: stream_options is not a valid parameter for Azure OpenAI.
59
- # To support users proxying Azure through ChatOpenAI, here we only specify
60
- # stream_options if include_usage is set to True.
61
- # See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
62
- # for release notes.
63
- if stream_usage :
64
- kwargs ["stream_options" ] = {"include_usage" : stream_usage }
65
-
66
- payload = self ._get_request_payload (messages , stop = stop , ** kwargs )
67
- default_chunk_class : Type [BaseMessageChunk ] = AIMessageChunk
68
- base_generation_info = {}
69
-
70
- if "response_format" in payload and is_basemodel_subclass (
71
- payload ["response_format" ]
72
- ):
73
- # TODO: Add support for streaming with Pydantic response_format.
74
- warnings .warn ("Streaming with Pydantic response_format not yet supported." )
75
- chat_result = self ._generate (
76
- messages , stop , run_manager = run_manager , ** kwargs
77
- )
78
- msg = chat_result .generations [0 ].message
79
- yield ChatGenerationChunk (
80
- message = AIMessageChunk (
81
- ** msg .dict (exclude = {"type" , "additional_kwargs" }),
82
- # preserve the "parsed" Pydantic object without converting to dict
83
- additional_kwargs = msg .additional_kwargs ,
84
- ),
85
- generation_info = chat_result .generations [0 ].generation_info ,
121
+ chunk : dict ,
122
+ default_chunk_class : type ,
123
+ base_generation_info : Optional [dict ],
124
+ ) -> Optional [ChatGenerationChunk ]:
125
+ if chunk .get ("type" ) == "content.delta" : # from beta.chat.completions.stream
126
+ return None
127
+ token_usage = chunk .get ("usage" )
128
+ choices = (
129
+ chunk .get ("choices" , [])
130
+ # from beta.chat.completions.stream
131
+ or chunk .get ("chunk" , {}).get ("choices" , [])
132
+ )
133
+
134
+ usage_metadata : Optional [UsageMetadata ] = (
135
+ _create_usage_metadata (token_usage ) if token_usage else None
136
+ )
137
+ if len (choices ) == 0 :
138
+ # logprobs is implicitly None
139
+ generation_chunk = ChatGenerationChunk (
140
+ message = default_chunk_class (content = "" , usage_metadata = usage_metadata )
86
141
)
87
- return
88
- if self .include_response_headers :
89
- raw_response = self .client .with_raw_response .create (** payload )
90
- response = raw_response .parse ()
91
- base_generation_info = {"headers" : dict (raw_response .headers )}
92
- else :
93
- response = self .client .create (** payload )
94
- with response :
95
- is_first_chunk = True
96
- for chunk in response :
97
- if not isinstance (chunk , dict ):
98
- chunk = chunk .model_dump ()
99
-
100
- generation_chunk = super ()._convert_chunk_to_generation_chunk (
101
- chunk ,
102
- default_chunk_class ,
103
- base_generation_info if is_first_chunk else {},
104
- )
105
- if generation_chunk is None :
106
- continue
107
-
108
- # custom code
109
- if len (chunk ['choices' ]) > 0 and 'reasoning_content' in chunk ['choices' ][0 ]['delta' ]:
110
- generation_chunk .message .additional_kwargs ["reasoning_content" ] = chunk ['choices' ][0 ]['delta' ][
111
- 'reasoning_content' ]
112
-
113
- default_chunk_class = generation_chunk .message .__class__
114
- logprobs = (generation_chunk .generation_info or {}).get ("logprobs" )
115
- if run_manager :
116
- run_manager .on_llm_new_token (
117
- generation_chunk .text , chunk = generation_chunk , logprobs = logprobs
118
- )
119
- is_first_chunk = False
120
- # custom code
121
- if generation_chunk .message .usage_metadata is not None :
122
- self .usage_metadata = generation_chunk .message .usage_metadata
123
- yield generation_chunk
124
-
125
- def _create_chat_result (self ,
126
- response : Union [dict , openai .BaseModel ],
127
- generation_info : Optional [Dict ] = None ):
128
- result = super ()._create_chat_result (response , generation_info )
129
- try :
130
- reasoning_content = ''
131
- reasoning_content_enable = False
132
- for res in response .choices :
133
- if 'reasoning_content' in res .message .model_extra :
134
- reasoning_content_enable = True
135
- _reasoning_content = res .message .model_extra .get ('reasoning_content' )
136
- if _reasoning_content is not None :
137
- reasoning_content += _reasoning_content
138
- if reasoning_content_enable :
139
- result .llm_output ['reasoning_content' ] = reasoning_content
140
- except Exception as e :
141
- pass
142
- return result
142
+ return generation_chunk
143
+
144
+ choice = choices [0 ]
145
+ if choice ["delta" ] is None :
146
+ return None
147
+
148
+ message_chunk = _convert_delta_to_message_chunk (
149
+ choice ["delta" ], default_chunk_class
150
+ )
151
+ generation_info = {** base_generation_info } if base_generation_info else {}
152
+
153
+ if finish_reason := choice .get ("finish_reason" ):
154
+ generation_info ["finish_reason" ] = finish_reason
155
+ if model_name := chunk .get ("model" ):
156
+ generation_info ["model_name" ] = model_name
157
+ if system_fingerprint := chunk .get ("system_fingerprint" ):
158
+ generation_info ["system_fingerprint" ] = system_fingerprint
159
+
160
+ logprobs = choice .get ("logprobs" )
161
+ if logprobs :
162
+ generation_info ["logprobs" ] = logprobs
163
+
164
+ if usage_metadata and isinstance (message_chunk , AIMessageChunk ):
165
+ message_chunk .usage_metadata = usage_metadata
166
+
167
+ generation_chunk = ChatGenerationChunk (
168
+ message = message_chunk , generation_info = generation_info or None
169
+ )
170
+ return generation_chunk
143
171
144
172
def invoke (
145
173
self ,
146
174
input : LanguageModelInput ,
147
175
config : Optional [RunnableConfig ] = None ,
148
176
* ,
149
- stop : Optional [List [str ]] = None ,
177
+ stop : Optional [list [str ]] = None ,
150
178
** kwargs : Any ,
151
179
) -> BaseMessage :
152
180
config = ensure_config (config )
153
181
chat_result = cast (
154
- ChatGeneration ,
182
+ " ChatGeneration" ,
155
183
self .generate_prompt (
156
184
[self ._convert_input (input )],
157
185
stop = stop ,
@@ -162,7 +190,9 @@ def invoke(
162
190
run_id = config .pop ("run_id" , None ),
163
191
** kwargs ,
164
192
).generations [0 ][0 ],
193
+
165
194
).message
195
+
166
196
self .usage_metadata = chat_result .response_metadata [
167
197
'token_usage' ] if 'token_usage' in chat_result .response_metadata else chat_result .usage_metadata
168
198
return chat_result
0 commit comments