Skip to content

Commit

Permalink
feat(chains): Return ChatMessage when LLMChain used with ChatModel
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Aug 5, 2023
1 parent 659783a commit bb5f4d2
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 40 deletions.
4 changes: 2 additions & 2 deletions packages/langchain/lib/src/chains/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ import 'models/models.dart';
/// string. This method can only be used for a subset of chains and cannot
/// return as rich of an output as [call].
/// {@endtemplate}
abstract class BaseChain {
abstract class BaseChain<MemoryType extends BaseMemory> {
/// {@macro base_chain}
const BaseChain({
this.memory,
});

/// Memory to use for this chain.
final BaseMemory? memory;
final MemoryType? memory;

/// Return the string type key uniquely identifying this class of chain.
String get chainType;
Expand Down
45 changes: 24 additions & 21 deletions packages/langchain/lib/src/chains/llm_chain.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import '../memory/base.dart';
import '../model_io/model_io.dart';
import 'base.dart';
import 'models/models.dart';
Expand All @@ -15,44 +16,47 @@ import 'models/models.dart';
/// final res = await chain.run('bad');
/// ```
/// {@endtemplate}
class LLMChain<LLMInput extends Object, LLMOptions extends LanguageModelOptions,
LLMOutput extends Object, ParserOutput extends Object> extends BaseChain {
class LLMChain<
LLMType extends BaseLanguageModel,
LLMOptions extends LanguageModelOptions,
OutputParserType extends BaseLLMOutputParser,
MemoryType extends BaseMemory> extends BaseChain<MemoryType> {
/// {@macro llm_chain}
const LLMChain({
required this.llm,
this.llmOptions,
required this.prompt,
this.outputKey = defaultOutputKey,
super.memory,
this.outputParser,
this.outputKey = defaultOutputKey,
this.returnFinalOnly = true,
this.llmOptions,
super.memory,
});

/// Language model to call.
final BaseLanguageModel<LLMInput, LLMOptions, LLMOutput> llm;
final LLMType llm;

/// Options to pass to the language model.
final LLMOptions? llmOptions;

/// Prompt object to use.
final BasePromptTemplate prompt;

/// Key to use for output.
final String outputKey;

/// OutputParser to use.
///
/// Defaults to one that takes the most likely string but does not change it
/// otherwise.
final BaseLLMOutputParser<LLMOutput, ParserOutput>? outputParser;
final OutputParserType? outputParser;

/// Key to use for output.
final String outputKey;

/// Whether to return only the final parsed result.
/// If false, it will return a bunch of extra information about the
/// generation.
final bool returnFinalOnly;

/// Options to pass to the language model.
final LLMOptions? llmOptions;

/// Default output key.
static const defaultOutputKey = 'text';
static const defaultOutputKey = 'output';

/// Output key to use for returning the full generation.
static const fullGenerationOutputKey = 'full_generation';
Expand All @@ -76,13 +80,12 @@ class LLMChain<LLMInput extends Object, LLMOptions extends LanguageModelOptions,

final response = await llm.generatePrompt(promptValue, options: llmOptions);

final res = switch (outputParser) {
null => response.firstOutputAsString,
_ => await outputParser!.parseResultWithPrompt(
response.generations,
promptValue,
),
};
final res = outputParser == null
? response.generations.firstOrNull?.output
: await outputParser!.parseResultWithPrompt(
response.generations,
promptValue,
);

return {
outputKey: res,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import 'output_parser.dart';
/// {@template base_output_functions_parser}
/// A parser that converts the output of a function call into a specified type.
/// {@endtemplate}
abstract class BaseOutputFunctionsParser<O>
abstract class BaseOutputFunctionsParser<O extends Object>
extends BaseLLMOutputParser<ChatMessage, O> {
/// {@macro base_output_functions_parser}
const BaseOutputFunctionsParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ interface class FormatInstructionsOptions {}
/// {@template base_llm_output_parser}
/// Class to parse the output of an LLM call.
/// {@endtemplate}
abstract class BaseLLMOutputParser<LLMOutput, ParserOutput> {
abstract class BaseLLMOutputParser<LLMOutput extends Object,
ParserOutput extends Object> {
/// {@macro base_llm_output_parser}
const BaseLLMOutputParser();

Expand All @@ -28,7 +29,8 @@ abstract class BaseLLMOutputParser<LLMOutput, ParserOutput> {
/// {@template base_output_parser}
/// Class to parse the output of an LLM call.
/// {@endtemplate}
abstract class BaseOutputParser<LLMOutput, ParserOutput>
abstract class BaseOutputParser<LLMOutput extends Object,
ParserOutput extends Object>
extends BaseLLMOutputParser<LLMOutput, ParserOutput> {
/// {@macro base_output_parser}
const BaseOutputParser();
Expand Down
15 changes: 12 additions & 3 deletions packages/langchain/test/chains/llm_chain_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ void main() {
final prompt = PromptTemplate.fromTemplate('Print {foo}');
final chain = LLMChain(prompt: prompt, llm: model);
final res = await chain.call({'foo': 'Hello world!'});
expect(res['text'], 'Hello world!');
expect(res[LLMChain.defaultOutputKey], 'Hello world!');
expect(res['foo'], 'Hello world!');
});

Expand All @@ -17,7 +17,7 @@ void main() {
final prompt = PromptTemplate.fromTemplate('Print {foo}');
final chain = LLMChain(prompt: prompt, llm: model);
final res = await chain.call('Hello world!');
expect(res['text'], 'Hello world!');
expect(res[LLMChain.defaultOutputKey], 'Hello world!');
expect(res['foo'], 'Hello world!');
});

Expand All @@ -30,7 +30,7 @@ void main() {
returnOnlyOutputs: true,
);
expect(res.length, 1);
expect(res['text'], 'Hello world! again!');
expect(res[LLMChain.defaultOutputKey], 'Hello world! again!');
});

test('Test LLMChain outputKey', () async {
Expand Down Expand Up @@ -80,5 +80,14 @@ void main() {
throwsArgumentError,
);
});

test('Test LLMChain with chat model', () async {
final model = FakeChatModel(responses: ['Hello world!']);
final prompt = PromptTemplate.fromTemplate('Print {foo}');
final chain = LLMChain(prompt: prompt, llm: model);
final res = await chain.call({'foo': 'Hello world!'});
expect(res[LLMChain.defaultOutputKey], ChatMessage.ai('Hello world!'));
expect(res['foo'], 'Hello world!');
});
});
}
10 changes: 5 additions & 5 deletions packages/langchain/test/chains/sequential_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ void main() {
});

test('Test memory in one of the internal chains', () async {
final memory = ConversationBufferMemory(memoryKey: 'bla')
..saveContext(
inputValues: {'input': 'yo'},
outputValues: {'output': 'ya'},
);
final memory = ConversationBufferMemory(memoryKey: 'bla');
await memory.saveContext(
inputValues: {'input': 'yo'},
outputValues: {'output': 'ya'},
);

final chain1 = _FakeChain(
inputVariables: {'foo', 'bla'},
Expand Down
13 changes: 7 additions & 6 deletions packages/langchain_openai/lib/src/chains/qa_with_structure.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@ import '../chat_models/models/models.dart';
/// a specific structure (e.g. the answer and the sources used to answer the
/// question).
/// {@endtemplate}
class OpenAIQAWithStructureChain<S extends Object>
extends LLMChain<List<ChatMessage>, ChatOpenAIOptions, ChatMessage, S> {
class OpenAIQAWithStructureChain<S extends Object> extends LLMChain<
BaseChatOpenAI,
ChatOpenAIOptions,
BaseOutputFunctionsParser<S>,
BaseChatMemory> {
OpenAIQAWithStructureChain({
required final BaseChatOpenAI llm,
required super.llm,
required final ChatFunction function,
required final BaseOutputFunctionsParser<S> outputParser,
required BaseOutputFunctionsParser<S> super.outputParser,
final BasePromptTemplate? prompt,
}) : super(
prompt: prompt ?? _getPrompt(),
llm: llm,
outputParser: outputParser,
llmOptions: ChatOpenAIOptions(
functions: [function],
),
Expand Down

0 comments on commit bb5f4d2

Please sign in to comment.