Skip to content

Commit fe7b388

Browse files
authored
Add streaming support for LLM and Chat Model calls with multiple prompts or completions (#1760)
* Add streaming support for LLM and Chat Model calls with multiple prompts or completions - multiple prompts are batch calls to generate() - multiple completions are calls with n>1 * Undo unrelated change * Update docstring * Rename * Lint
1 parent 46627f3 commit fe7b388

File tree

11 files changed

+151
-39
lines changed

11 files changed

+151
-39
lines changed

langchain/src/callbacks/base.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ export interface BaseCallbackHandlerInput {
2222
ignoreAgent?: boolean;
2323
}
2424

25+
export interface NewTokenIndices {
26+
prompt: number;
27+
completion: number;
28+
}
29+
2530
abstract class BaseCallbackHandlerMethodsClass {
2631
/**
2732
* Called at the start of an LLM or Chat Model run, with the prompt(s)
@@ -41,6 +46,13 @@ abstract class BaseCallbackHandlerMethodsClass {
4146
*/
4247
handleLLMNewToken?(
4348
token: string,
49+
/**
50+
* idx.prompt is the index of the prompt that produced the token
51+
* (if there are multiple prompts)
52+
* idx.completion is the index of the completion that produced the token
53+
* (if multiple completions per prompt are requested)
54+
*/
55+
idx: NewTokenIndices,
4456
runId: string,
4557
parentRunId?: string
4658
): Promise<void> | void;

langchain/src/callbacks/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ export {
22
BaseCallbackHandler,
33
CallbackHandlerMethods,
44
BaseCallbackHandlerInput,
5+
NewTokenIndices,
56
} from "./base.js";
67

78
export { Run, RunType, BaseTracer } from "./handlers/tracer.js";

langchain/src/callbacks/manager.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ import {
66
ChainValues,
77
LLMResult,
88
} from "../schema/index.js";
9-
import { BaseCallbackHandler, CallbackHandlerMethods } from "./base.js";
9+
import {
10+
BaseCallbackHandler,
11+
CallbackHandlerMethods,
12+
NewTokenIndices,
13+
} from "./base.js";
1014
import { ConsoleCallbackHandler } from "./handlers/console.js";
1115
import {
1216
getTracingCallbackHandler,
@@ -79,14 +83,18 @@ export class CallbackManagerForLLMRun
7983
extends BaseRunManager
8084
implements BaseCallbackManagerMethods
8185
{
82-
async handleLLMNewToken(token: string): Promise<void> {
86+
async handleLLMNewToken(
87+
token: string,
88+
idx: NewTokenIndices = { prompt: 0, completion: 0 }
89+
): Promise<void> {
8390
await Promise.all(
8491
this.handlers.map((handler) =>
8592
consumeCallback(async () => {
8693
if (!handler.ignoreLLM) {
8794
try {
8895
await handler.handleLLMNewToken?.(
8996
token,
97+
idx,
9098
this.runId,
9199
this._parentRunId
92100
);

langchain/src/chat_models/base.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ export abstract class BaseChatModel extends BaseLanguageModel {
9090
// generate results
9191
const results = await Promise.allSettled(
9292
messages.map((messageList, i) =>
93-
this._generate(messageList, parsedOptions, runManagers?.[i])
93+
this._generate(
94+
messageList,
95+
{ ...parsedOptions, promptIndex: i },
96+
runManagers?.[i]
97+
)
9498
)
9599
);
96100
// handle results

langchain/src/chat_models/openai.ts

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ export interface ChatOpenAICallOptions extends OpenAICallOptions {
8686
function_call?: CreateChatCompletionRequestFunctionCall;
8787
functions?: ChatCompletionFunctions[];
8888
tools?: StructuredTool[];
89+
promptIndex?: number;
8990
}
9091

9192
/**
@@ -113,7 +114,15 @@ export class ChatOpenAI
113114
declare CallOptions: ChatOpenAICallOptions;
114115

115116
get callKeys(): (keyof ChatOpenAICallOptions)[] {
116-
return ["stop", "signal", "timeout", "options", "functions", "tools"];
117+
return [
118+
"stop",
119+
"signal",
120+
"timeout",
121+
"options",
122+
"functions",
123+
"tools",
124+
"promptIndex",
125+
];
117126
}
118127

119128
lc_serializable = true;
@@ -223,10 +232,6 @@ export class ChatOpenAI
223232

224233
this.streaming = fields?.streaming ?? false;
225234

226-
if (this.streaming && this.n > 1) {
227-
throw new Error("Cannot stream results when n > 1");
228-
}
229-
230235
if (this.azureOpenAIApiKey) {
231236
if (!this.azureOpenAIApiInstanceName) {
232237
throw new Error("Azure OpenAI API instance name not found");
@@ -408,11 +413,13 @@ export class ChatOpenAI
408413
choice.message.function_call.arguments +=
409414
part.delta?.function_call?.arguments ?? "";
410415
}
411-
// TODO this should pass part.index to the callback
412-
// when that's supported there
413416
// eslint-disable-next-line no-void
414417
void runManager?.handleLLMNewToken(
415-
part.delta?.content ?? ""
418+
part.delta?.content ?? "",
419+
{
420+
prompt: options.promptIndex ?? 0,
421+
completion: part.index,
422+
}
416423
);
417424
// TODO we don't currently have a callback method for
418425
// sending the function call arguments

langchain/src/chat_models/tests/chatopenai.int.test.ts

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import {
1414
SystemMessagePromptTemplate,
1515
} from "../../prompts/index.js";
1616
import { CallbackManager } from "../../callbacks/index.js";
17+
import { NewTokenIndices } from "../../callbacks/base.js";
1718

1819
test("Test ChatOpenAI", async () => {
1920
const chat = new ChatOpenAI({ modelName: "gpt-3.5-turbo", maxTokens: 10 });
@@ -129,11 +130,43 @@ test("Test ChatOpenAI in streaming mode", async () => {
129130
],
130131
});
131132
const message = new HumanChatMessage("Hello!");
132-
const res = await model.call([message]);
133-
console.log({ res });
133+
const result = await model.call([message]);
134+
console.log(result);
135+
136+
expect(nrNewTokens > 0).toBe(true);
137+
expect(result.text).toBe(streamedCompletion);
138+
});
139+
140+
test("Test ChatOpenAI in streaming mode with n > 1 and multiple prompts", async () => {
141+
let nrNewTokens = 0;
142+
const streamedCompletions = [
143+
["", ""],
144+
["", ""],
145+
];
146+
147+
const model = new ChatOpenAI({
148+
modelName: "gpt-3.5-turbo",
149+
streaming: true,
150+
maxTokens: 10,
151+
n: 2,
152+
callbacks: [
153+
{
154+
async handleLLMNewToken(token: string, idx: NewTokenIndices) {
155+
nrNewTokens += 1;
156+
streamedCompletions[idx.prompt][idx.completion] += token;
157+
},
158+
},
159+
],
160+
});
161+
const message1 = new HumanChatMessage("Hello!");
162+
const message2 = new HumanChatMessage("Bye!");
163+
const result = await model.generate([[message1], [message2]]);
164+
console.log(result.generations);
134165

135166
expect(nrNewTokens > 0).toBe(true);
136-
expect(res.text).toBe(streamedCompletion);
167+
expect(result.generations.map((g) => g.map((gg) => gg.text))).toEqual(
168+
streamedCompletions
169+
);
137170
});
138171

139172
test("Test ChatOpenAI prompt value", async () => {

langchain/src/llms/base.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,10 @@ export abstract class LLM extends BaseLLM {
332332
runManager?: CallbackManagerForLLMRun
333333
): Promise<LLMResult> {
334334
const generations: Generation[][] = await Promise.all(
335-
prompts.map((prompt) =>
336-
this._call(prompt, options, runManager).then((text) => [{ text }])
335+
prompts.map((prompt, promptIndex) =>
336+
this._call(prompt, { ...options, promptIndex }, runManager).then(
337+
(text) => [{ text }]
338+
)
337339
)
338340
);
339341
return { generations };

langchain/src/llms/openai-chat.ts

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ import { CallbackManagerForLLMRun } from "../callbacks/manager.js";
2020
import { Generation, LLMResult } from "../schema/index.js";
2121
import { promptLayerTrackRequest } from "../util/prompt-layer.js";
2222

23-
export { OpenAICallOptions, OpenAIChatInput, AzureOpenAIInput };
23+
export { OpenAIChatInput, AzureOpenAIInput };
24+
25+
export interface OpenAIChatCallOptions extends OpenAICallOptions {
26+
promptIndex?: number;
27+
}
2428

2529
/**
2630
* Wrapper around OpenAI large language models that use the Chat endpoint.
@@ -48,10 +52,10 @@ export class OpenAIChat
4852
extends LLM
4953
implements OpenAIChatInput, AzureOpenAIInput
5054
{
51-
declare CallOptions: OpenAICallOptions;
55+
declare CallOptions: OpenAIChatCallOptions;
5256

53-
get callKeys(): (keyof OpenAICallOptions)[] {
54-
return ["stop", "signal", "timeout", "options"];
57+
get callKeys(): (keyof OpenAIChatCallOptions)[] {
58+
return ["stop", "signal", "timeout", "options", "promptIndex"];
5559
}
5660

5761
lc_serializable = true;
@@ -166,8 +170,10 @@ export class OpenAIChat
166170

167171
this.streaming = fields?.streaming ?? false;
168172

169-
if (this.streaming && this.n > 1) {
170-
throw new Error("Cannot stream results when n > 1");
173+
if (this.n > 1) {
174+
throw new Error(
175+
"Cannot use n > 1 in OpenAIChat LLM. Use ChatOpenAI Chat Model instead."
176+
);
171177
}
172178

173179
if (this.azureOpenAIApiKey) {
@@ -329,7 +335,11 @@ export class OpenAIChat
329335
choice.message.content += part.delta?.content ?? "";
330336
// eslint-disable-next-line no-void
331337
void runManager?.handleLLMNewToken(
332-
part.delta?.content ?? ""
338+
part.delta?.content ?? "",
339+
{
340+
prompt: options.promptIndex ?? 0,
341+
completion: part.index,
342+
}
333343
);
334344
}
335345
}

langchain/src/llms/openai.ts

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
8888

8989
n = 1;
9090

91-
bestOf = 1;
91+
bestOf?: number;
9292

9393
logitBias?: Record<string, number>;
9494

@@ -179,11 +179,7 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
179179

180180
this.streaming = fields?.streaming ?? false;
181181

182-
if (this.streaming && this.n > 1) {
183-
throw new Error("Cannot stream results when n > 1");
184-
}
185-
186-
if (this.streaming && this.bestOf > 1) {
182+
if (this.streaming && this.bestOf && this.bestOf > 1) {
187183
throw new Error("Cannot stream results when bestOf > 1");
188184
}
189185

@@ -345,10 +341,11 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
345341
choice.text = (choice.text ?? "") + (part.text ?? "");
346342
choice.finish_reason = part.finish_reason;
347343
choice.logprobs = part.logprobs;
348-
// TODO this should pass part.index to the callback
349-
// when that's supported there
350344
// eslint-disable-next-line no-void
351-
void runManager?.handleLLMNewToken(part.text ?? "");
345+
void runManager?.handleLLMNewToken(part.text ?? "", {
346+
prompt: Math.floor(part.index / this.n),
347+
completion: part.index % this.n,
348+
});
352349
}
353350
}
354351

langchain/src/llms/tests/openai.int.test.ts

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { OpenAIChat } from "../openai-chat.js";
44
import { OpenAI } from "../openai.js";
55
import { StringPromptValue } from "../../prompts/index.js";
66
import { CallbackManager } from "../../callbacks/index.js";
7+
import { NewTokenIndices } from "../../callbacks/base.js";
78

89
test("Test OpenAI", async () => {
910
const model = new OpenAI({ maxTokens: 5, modelName: "text-ada-001" });
@@ -144,26 +145,63 @@ test("Test OpenAI in streaming mode", async () => {
144145

145146
test("Test OpenAI in streaming mode with multiple prompts", async () => {
146147
let nrNewTokens = 0;
148+
const completions = [
149+
["", ""],
150+
["", ""],
151+
];
147152

148153
const model = new OpenAI({
149154
maxTokens: 5,
150155
modelName: "text-ada-001",
151156
streaming: true,
157+
n: 2,
152158
callbacks: CallbackManager.fromHandlers({
153-
async handleLLMNewToken(_token: string) {
159+
async handleLLMNewToken(token: string, idx: NewTokenIndices) {
154160
nrNewTokens += 1;
161+
completions[idx.prompt][idx.completion] += token;
155162
},
156163
}),
157164
});
158165
const res = await model.generate(["Print hello world", "print hello sea"]);
159-
console.log({ res });
166+
console.log(
167+
res.generations,
168+
res.generations.map((g) => g[0].generationInfo)
169+
);
160170

161171
expect(nrNewTokens > 0).toBe(true);
162172
expect(res.generations.length).toBe(2);
163-
expect(res.generations.map((g) => typeof g[0].text === "string")).toEqual([
164-
true,
165-
true,
166-
]);
173+
expect(res.generations.map((g) => g.map((gg) => gg.text))).toEqual(
174+
completions
175+
);
176+
});
177+
178+
test("Test OpenAIChat in streaming mode with multiple prompts", async () => {
179+
let nrNewTokens = 0;
180+
const completions = [[""], [""]];
181+
182+
const model = new OpenAI({
183+
maxTokens: 5,
184+
modelName: "gpt-3.5-turbo",
185+
streaming: true,
186+
n: 1,
187+
callbacks: CallbackManager.fromHandlers({
188+
async handleLLMNewToken(token: string, idx: NewTokenIndices) {
189+
nrNewTokens += 1;
190+
completions[idx.prompt][idx.completion] += token;
191+
},
192+
}),
193+
});
194+
const res = await model.generate(["Print hello world", "print hello sea"]);
195+
console.log(
196+
res.generations,
197+
res.generations.map((g) => g[0].generationInfo)
198+
);
199+
200+
expect(nrNewTokens > 0).toBe(true);
201+
expect(res.generations.length).toBe(2);
202+
expect(res.generations.map((g) => g.map((gg) => gg.text))).toEqual(
203+
completions
204+
);
167205
});
168206

169207
test("Test OpenAI prompt value", async () => {

0 commit comments

Comments
 (0)