-
-
Notifications
You must be signed in to change notification settings - Fork 18k
/
Copy pathLLMChain.ts
300 lines (283 loc) · 12.7 KB
/
LLMChain.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import { BaseLanguageModel, BaseLanguageModelCallOptions } from '@langchain/core/language_models/base'
import { BaseLLMOutputParser, BaseOutputParser } from '@langchain/core/output_parsers'
import { HumanMessage } from '@langchain/core/messages'
import { ChatPromptTemplate, FewShotPromptTemplate, HumanMessagePromptTemplate, PromptTemplate } from '@langchain/core/prompts'
import { OutputFixingParser } from 'langchain/output_parsers'
import { LLMChain } from 'langchain/chains'
import {
IVisionChatModal,
ICommonObject,
INode,
INodeData,
INodeOutputsValue,
INodeParams,
IServerSideEventStreamer
} from '../../../src/Interface'
import { additionalCallbacks, ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler'
import { getBaseClasses, handleEscapeCharacters } from '../../../src/utils'
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
import { formatResponse, injectOutputParser } from '../../outputparsers/OutputParserHelpers'
import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils'
class LLMChain_Chains implements INode {
label: string
name: string
version: number
type: string
icon: string
category: string
baseClasses: string[]
description: string
inputs: INodeParams[]
outputs: INodeOutputsValue[]
outputParser: BaseOutputParser
constructor() {
this.label = 'LLM Chain'
this.name = 'llmChain'
this.version = 3.0
this.type = 'LLMChain'
this.icon = 'LLM_Chain.svg'
this.category = 'Chains'
this.description = 'Chain to run queries against LLMs'
this.baseClasses = [this.type, ...getBaseClasses(LLMChain)]
this.inputs = [
{
label: 'Language Model',
name: 'model',
type: 'BaseLanguageModel'
},
{
label: 'Prompt',
name: 'prompt',
type: 'BasePromptTemplate'
},
{
label: 'Output Parser',
name: 'outputParser',
type: 'BaseLLMOutputParser',
optional: true
},
{
label: 'Input Moderation',
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
name: 'inputModeration',
type: 'Moderation',
optional: true,
list: true
},
{
label: 'Chain Name',
name: 'chainName',
type: 'string',
placeholder: 'Name Your Chain',
optional: true
}
]
this.outputs = [
{
label: 'LLM Chain',
name: 'llmChain',
baseClasses: [this.type, ...getBaseClasses(LLMChain)]
},
{
label: 'Output Prediction',
name: 'outputPrediction',
baseClasses: ['string', 'json']
}
]
}
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
const model = nodeData.inputs?.model as BaseLanguageModel
const prompt = nodeData.inputs?.prompt
const output = nodeData.outputs?.output as string
let promptValues: ICommonObject | undefined = nodeData.inputs?.prompt.promptValues as ICommonObject
const llmOutputParser = nodeData.inputs?.outputParser as BaseOutputParser
this.outputParser = llmOutputParser
if (llmOutputParser) {
let autoFix = (llmOutputParser as any).autoFix
if (autoFix === true) {
this.outputParser = OutputFixingParser.fromLLM(model, llmOutputParser)
}
}
if (output === this.name) {
const chain = new LLMChain({
llm: model,
outputParser: this.outputParser as BaseLLMOutputParser<string | object>,
prompt,
verbose: process.env.DEBUG === 'true'
})
return chain
} else if (output === 'outputPrediction') {
const chain = new LLMChain({
llm: model,
outputParser: this.outputParser as BaseLLMOutputParser<string | object>,
prompt,
verbose: process.env.DEBUG === 'true'
})
const inputVariables = chain.prompt.inputVariables as string[] // ["product"]
promptValues = injectOutputParser(this.outputParser, chain, promptValues)
// Disable streaming because its not final chain
const disableStreaming = true
const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData, disableStreaming)
// eslint-disable-next-line no-console
console.log('\x1b[92m\x1b[1m\n*****OUTPUT PREDICTION*****\n\x1b[0m\x1b[0m')
// eslint-disable-next-line no-console
console.log(res)
let finalRes = res
if (this.outputParser && typeof res === 'object' && Object.prototype.hasOwnProperty.call(res, 'json')) {
finalRes = (res as ICommonObject).json
}
/**
* Apply string transformation to convert special chars:
* FROM: hello i am ben\n\n\thow are you?
* TO: hello i am benFLOWISE_NEWLINEFLOWISE_NEWLINEFLOWISE_TABhow are you?
*/
return handleEscapeCharacters(finalRes, false)
}
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
const inputVariables = nodeData.instance.prompt.inputVariables as string[] // ["product"]
const chain = nodeData.instance as LLMChain
let promptValues: ICommonObject | undefined = nodeData.inputs?.prompt.promptValues as ICommonObject
const outputParser = nodeData.inputs?.outputParser as BaseOutputParser
if (!this.outputParser && outputParser) {
this.outputParser = outputParser
}
promptValues = injectOutputParser(this.outputParser, chain, promptValues)
const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData)
// eslint-disable-next-line no-console
console.log('\x1b[93m\x1b[1m\n*****FINAL RESULT*****\n\x1b[0m\x1b[0m')
// eslint-disable-next-line no-console
console.log(res)
return res
}
}
const runPrediction = async (
inputVariables: string[],
chain: LLMChain<string | object | BaseLanguageModel<any, BaseLanguageModelCallOptions>>,
input: string,
promptValuesRaw: ICommonObject | undefined,
options: ICommonObject,
nodeData: INodeData,
disableStreaming?: boolean
) => {
const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbacks = await additionalCallbacks(nodeData, options)
const moderations = nodeData.inputs?.inputModeration as Moderation[]
// this is true if the prediction is external and the client has requested streaming='true'
const shouldStreamResponse = !disableStreaming && options.shouldStreamResponse
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
const chatId = options.chatId
if (moderations && moderations.length > 0) {
try {
// Use the output of the moderation chain as input for the LLM chain
input = await checkInputs(moderations, input)
} catch (e) {
await new Promise((resolve) => setTimeout(resolve, 500))
if (shouldStreamResponse) {
streamResponse(sseStreamer, chatId, e.message)
}
return formatResponse(e.message)
}
}
/**
* Apply string transformation to reverse converted special chars:
* FROM: { "value": "hello i am benFLOWISE_NEWLINEFLOWISE_NEWLINEFLOWISE_TABhow are you?" }
* TO: { "value": "hello i am ben\n\n\thow are you?" }
*/
const promptValues = handleEscapeCharacters(promptValuesRaw, true)
if (llmSupportsVision(chain.llm)) {
const visionChatModel = chain.llm as IVisionChatModal
const messageContent = await addImagesToMessages(nodeData, options, visionChatModel.multiModalOption)
if (messageContent?.length) {
// Change model to gpt-4-vision && max token to higher when using gpt-4-vision
visionChatModel.setVisionModel()
// Add image to the message
if (chain.prompt instanceof PromptTemplate) {
const existingPromptTemplate = chain.prompt.template as string
const msg = HumanMessagePromptTemplate.fromTemplate([
...messageContent,
{
text: existingPromptTemplate
}
])
msg.inputVariables = chain.prompt.inputVariables
chain.prompt = ChatPromptTemplate.fromMessages([msg])
} else if (chain.prompt instanceof ChatPromptTemplate) {
if (chain.prompt.promptMessages.at(-1) instanceof HumanMessagePromptTemplate) {
const lastMessage = chain.prompt.promptMessages.pop() as HumanMessagePromptTemplate
const template = (lastMessage.prompt as PromptTemplate).template as string
const msg = HumanMessagePromptTemplate.fromTemplate([
...messageContent,
{
text: template
}
])
msg.inputVariables = lastMessage.inputVariables
chain.prompt.promptMessages.push(msg)
} else {
chain.prompt.promptMessages.push(new HumanMessage({ content: messageContent }))
}
} else if (chain.prompt instanceof FewShotPromptTemplate) {
let existingFewShotPromptTemplate = chain.prompt.examplePrompt.template as string
let newFewShotPromptTemplate = ChatPromptTemplate.fromMessages([
HumanMessagePromptTemplate.fromTemplate(existingFewShotPromptTemplate)
])
newFewShotPromptTemplate.promptMessages.push(new HumanMessage({ content: messageContent }))
// @ts-ignore
chain.prompt.examplePrompt = newFewShotPromptTemplate
}
} else {
// revert to previous values if image upload is empty
visionChatModel.revertToOriginalModel()
}
}
if (promptValues && inputVariables.length > 0) {
let seen: string[] = []
for (const variable of inputVariables) {
seen.push(variable)
if (promptValues[variable] != null) {
seen.pop()
}
}
if (seen.length === 0) {
// All inputVariables have fixed values specified
const options = { ...promptValues }
if (shouldStreamResponse) {
const handler = new CustomChainHandler(sseStreamer, chatId)
const res = await chain.call(options, [loggerHandler, handler, ...callbacks])
return formatResponse(res?.text)
} else {
const res = await chain.call(options, [loggerHandler, ...callbacks])
return formatResponse(res?.text)
}
} else if (seen.length === 1) {
// If one inputVariable is not specify, use input (user's question) as value
const lastValue = seen.pop()
if (!lastValue) throw new Error('Please provide Prompt Values')
const options = {
...promptValues,
[lastValue]: input
}
if (shouldStreamResponse) {
const handler = new CustomChainHandler(sseStreamer, chatId)
const res = await chain.call(options, [loggerHandler, handler, ...callbacks])
return formatResponse(res?.text)
} else {
const res = await chain.call(options, [loggerHandler, ...callbacks])
return formatResponse(res?.text)
}
} else {
throw new Error(`Please provide Prompt Values for: ${seen.join(', ')}`)
}
} else {
if (shouldStreamResponse) {
const handler = new CustomChainHandler(sseStreamer, chatId)
const res = await chain.run(input, [loggerHandler, handler, ...callbacks])
return formatResponse(res)
} else {
const res = await chain.run(input, [loggerHandler, ...callbacks])
return formatResponse(res)
}
}
}
module.exports = { nodeClass: LLMChain_Chains }