@@ -4,6 +4,7 @@ import { OpenAIChat } from "../openai-chat.js";
4
4
import { OpenAI } from "../openai.js" ;
5
5
import { StringPromptValue } from "../../prompts/index.js" ;
6
6
import { CallbackManager } from "../../callbacks/index.js" ;
7
+ import { NewTokenIndices } from "../../callbacks/base.js" ;
7
8
8
9
test ( "Test OpenAI" , async ( ) => {
9
10
const model = new OpenAI ( { maxTokens : 5 , modelName : "text-ada-001" } ) ;
@@ -144,26 +145,63 @@ test("Test OpenAI in streaming mode", async () => {
144
145
145
146
test ( "Test OpenAI in streaming mode with multiple prompts" , async ( ) => {
146
147
let nrNewTokens = 0 ;
148
+ const completions = [
149
+ [ "" , "" ] ,
150
+ [ "" , "" ] ,
151
+ ] ;
147
152
148
153
const model = new OpenAI ( {
149
154
maxTokens : 5 ,
150
155
modelName : "text-ada-001" ,
151
156
streaming : true ,
157
+ n : 2 ,
152
158
callbacks : CallbackManager . fromHandlers ( {
153
- async handleLLMNewToken ( _token : string ) {
159
+ async handleLLMNewToken ( token : string , idx : NewTokenIndices ) {
154
160
nrNewTokens += 1 ;
161
+ completions [ idx . prompt ] [ idx . completion ] += token ;
155
162
} ,
156
163
} ) ,
157
164
} ) ;
158
165
const res = await model . generate ( [ "Print hello world" , "print hello sea" ] ) ;
159
- console . log ( { res } ) ;
166
+ console . log (
167
+ res . generations ,
168
+ res . generations . map ( ( g ) => g [ 0 ] . generationInfo )
169
+ ) ;
160
170
161
171
expect ( nrNewTokens > 0 ) . toBe ( true ) ;
162
172
expect ( res . generations . length ) . toBe ( 2 ) ;
163
- expect ( res . generations . map ( ( g ) => typeof g [ 0 ] . text === "string" ) ) . toEqual ( [
164
- true ,
165
- true ,
166
- ] ) ;
173
+ expect ( res . generations . map ( ( g ) => g . map ( ( gg ) => gg . text ) ) ) . toEqual (
174
+ completions
175
+ ) ;
176
+ } ) ;
177
+
178
+ test ( "Test OpenAIChat in streaming mode with multiple prompts" , async ( ) => {
179
+ let nrNewTokens = 0 ;
180
+ const completions = [ [ "" ] , [ "" ] ] ;
181
+
182
+ const model = new OpenAI ( {
183
+ maxTokens : 5 ,
184
+ modelName : "gpt-3.5-turbo" ,
185
+ streaming : true ,
186
+ n : 1 ,
187
+ callbacks : CallbackManager . fromHandlers ( {
188
+ async handleLLMNewToken ( token : string , idx : NewTokenIndices ) {
189
+ nrNewTokens += 1 ;
190
+ completions [ idx . prompt ] [ idx . completion ] += token ;
191
+ } ,
192
+ } ) ,
193
+ } ) ;
194
+ const res = await model . generate ( [ "Print hello world" , "print hello sea" ] ) ;
195
+ console . log (
196
+ res . generations ,
197
+ res . generations . map ( ( g ) => g [ 0 ] . generationInfo )
198
+ ) ;
199
+
200
+ expect ( nrNewTokens > 0 ) . toBe ( true ) ;
201
+ expect ( res . generations . length ) . toBe ( 2 ) ;
202
+ expect ( res . generations . map ( ( g ) => g . map ( ( gg ) => gg . text ) ) ) . toEqual (
203
+ completions
204
+ ) ;
167
205
} ) ;
168
206
169
207
test ( "Test OpenAI prompt value" , async ( ) => {
0 commit comments