From 5abb61483727f9f98cd2a84059b293ce089e91f5 Mon Sep 17 00:00:00 2001 From: Mahmoud Abughali Date: Thu, 12 Sep 2024 16:43:30 -0400 Subject: [PATCH] feat(groq): add llm adapter (#20) --- .env.template | 3 + README.md | 1 + examples/llms/providers/groq.ts | 21 +++ package.json | 2 + src/adapters/groq/chat.test.ts | 37 +++++ src/adapters/groq/chat.ts | 229 +++++++++++++++++++++++++++ tests/e2e/adapters/groq/chat.test.ts | 72 +++++++++ tests/e2e/utils.ts | 3 +- yarn.lock | 17 ++ 9 files changed, 384 insertions(+), 1 deletion(-) create mode 100644 examples/llms/providers/groq.ts create mode 100644 src/adapters/groq/chat.test.ts create mode 100644 src/adapters/groq/chat.ts create mode 100644 tests/e2e/adapters/groq/chat.test.ts diff --git a/.env.template b/.env.template index a584e16..ddab454 100644 --- a/.env.template +++ b/.env.template @@ -11,3 +11,6 @@ CODE_INTERPRETER_URL=http://127.0.0.1:50051 # For OpenAI LLM Adapter # OPENAI_API_KEY= + +# For Groq LLM Adapter +# GROQ_API_KEY= \ No newline at end of file diff --git a/README.md b/README.md index bc31e51..a6462d3 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,7 @@ To run this example, be sure that you have installed [ollama](https://ollama.com | `OpenAI` | LLM + ChatLLM support ([example](./examples/llms/providers/openai.ts)) | | `LangChain` | Use any LLM that LangChain supports ([example](./examples/llms/providers/langchain.ts)) | | `WatsonX` | LLM + ChatLLM support ([example](./examples/llms/providers/watsonx.ts)) | +| `Groq` | ChatLLM support ([example](./examples/llms/providers/groq.ts)) | | `BAM (Internal)` | LLM + ChatLLM support ([example](./examples/llms/providers/bam.ts)) | | ➕ [Request](https://github.com/i-am-bee/bee-agent-framework/discussions) | | diff --git a/examples/llms/providers/groq.ts b/examples/llms/providers/groq.ts new file mode 100644 index 0000000..5edec49 --- /dev/null +++ b/examples/llms/providers/groq.ts @@ -0,0 +1,21 @@ +import "dotenv/config"; +import { BaseMessage } from "bee-agent-framework/llms/primitives/message"; +import { GroqChatLLM } from "bee-agent-framework/adapters/groq/chat"; + +const llm = new GroqChatLLM({ + modelId: "gemma2-9b-it", + parameters: { + temperature: 0.7, + max_tokens: 1024, + top_p: 1, + }, +}); + +console.info("Meta", await llm.meta()); +const response = await llm.generate([ + BaseMessage.of({ + role: "user", + text: "Hello world!", + }), +]); +console.info(response.getTextContent()); diff --git a/package.json b/package.json index bc46ea7..2bc57c5 100644 --- a/package.json +++ b/package.json @@ -110,6 +110,7 @@ "@langchain/community": "~0.2.28", "@langchain/core": "~0.2.27", "@langchain/langgraph": "~0.0.34", + "groq-sdk": "^0.7.0", "ollama": "^0.5.8", "openai": "^4.56.0", "openai-chat-tokens": "^0.2.8" @@ -139,6 +140,7 @@ "eslint-config-prettier": "^9.1.0", "eslint-plugin-unused-imports": "^4.1.3", "glob": "^11.0.0", + "groq-sdk": "^0.7.0", "husky": "^9.1.5", "langchain": "~0.2.16", "lint-staged": "^15.2.9", diff --git a/src/adapters/groq/chat.test.ts b/src/adapters/groq/chat.test.ts new file mode 100644 index 0000000..c684880 --- /dev/null +++ b/src/adapters/groq/chat.test.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { verifyDeserialization } from "@tests/e2e/utils.js"; +import { GroqChatLLM } from "@/adapters/groq/chat.js"; +import { Groq } from "groq-sdk"; + +describe("Groq ChatLLM", () => { + const getInstance = () => { + return new GroqChatLLM({ + modelId: "gemma2-9b-it", + client: new Groq({ + apiKey: "123", + }), + }); + }; + + it("Serializes", async () => { + const instance = getInstance(); + const serialized = instance.serialize(); + const deserialized = GroqChatLLM.fromSerialized(serialized); + verifyDeserialization(instance, deserialized); + }); +}); diff --git a/src/adapters/groq/chat.ts b/src/adapters/groq/chat.ts new file mode 100644 index 0000000..d2fe14c --- /dev/null +++ b/src/adapters/groq/chat.ts @@ -0,0 +1,229 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + AsyncStream, + BaseLLMTokenizeOutput, + ExecutionOptions, + GenerateCallbacks, + GenerateOptions, + LLMMeta, + StreamGenerateOptions, +} from "@/llms/base.js"; +import { shallowCopy } from "@/serializer/utils.js"; +import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js"; +import { BaseMessage, RoleType } from "@/llms/primitives/message.js"; +import { Emitter } from "@/emitter/emitter.js"; +import { ClientOptions, Groq as Client } from "groq-sdk"; +import { GetRunContext } from "@/context.js"; +import { Serializer } from "@/serializer/serializer.js"; +import { getPropStrict } from "@/internals/helpers/object.js"; +import { ChatCompletionCreateParams } from "groq-sdk/resources/chat/completions"; + +type Parameters = Omit; +type Response = Omit; + +export class ChatGroqOutput extends ChatLLMOutput { + public readonly responses: Response[]; + + constructor(response: Response) { + super(); + this.responses = [response]; + } + + static { + this.register(); + } + + get messages() { + return this.responses + .flatMap((response) => response.choices) + .flatMap((choice) => + BaseMessage.of({ + role: choice.delta.role as RoleType, + text: choice.delta.content!, + }), + ); + } + + getTextContent(): string { + return this.messages.map((msg) => msg.text).join("\n"); + } + + merge(other: ChatGroqOutput): void { + this.responses.push(...other.responses); + } + + toString(): string { + return this.getTextContent(); + } + + createSnapshot() { + return { + responses: shallowCopy(this.responses), + }; + } + + loadSnapshot(snapshot: ReturnType): void { + Object.assign(this, snapshot); + } +} + +interface Input { + modelId?: string; + client?: Client; + parameters?: Parameters; + executionOptions?: ExecutionOptions; +} + +export class GroqChatLLM extends ChatLLM { + public readonly emitter = Emitter.root.child({ + namespace: ["groq", "chat_llm"], + creator: this, + }); + + public readonly client: Client; + public readonly parameters: Partial; + + constructor({ + client, + modelId = "llama3-8b-8192", + parameters, + executionOptions = {}, + }: Input = {}) { + super(modelId, executionOptions); + this.client = client ?? new Client(); + this.parameters = parameters ?? {}; + } + + static { + this.register(); + Serializer.register(Client, { + toPlain: (value) => ({ + options: getPropStrict(value, "_options") as ClientOptions, + }), + fromPlain: (value) => new Client(value.options), + }); + } + + async meta(): Promise { + if ( + this.modelId.includes("gemma") || + this.modelId.includes("llama3") || + this.modelId.includes("llama-guard") + ) { + return { tokenLimit: 8 * 1024 }; + } else if (this.modelId.includes("llava-v1.5")) { + return { tokenLimit: 4 * 1024 }; + } else if (this.modelId.includes("llama-3.1-70b") || this.modelId.includes("llama-3.1-8b")) { + return { tokenLimit: 128 * 1024 }; + } else if (this.modelId.includes("mixtral-8x7b")) { + return { tokenLimit: 32 * 1024 }; + } + + return { + tokenLimit: Infinity, + }; + } + + async tokenize(input: BaseMessage[]): Promise { + const contentLength = input.reduce((acc, msg) => acc + msg.text.length, 0); + + return { + tokensCount: Math.ceil(contentLength / 4), + }; + } + + protected _prepareRequest( + input: BaseMessage[], + options: GenerateOptions, + ): ChatCompletionCreateParams { + return { + ...this.parameters, + model: this.modelId, + stream: false, + messages: input.map( + (message) => + ({ + role: message.role, + content: message.text, + }) as Client.Chat.ChatCompletionMessageParam, + ), + ...(options?.guided?.json && { + response_format: { + type: "json_object", + }, + }), + }; + } + + protected async _generate( + input: BaseMessage[], + options: GenerateOptions, + run: GetRunContext, + ): Promise { + const response = await this.client.chat.completions.create( + { + ...this._prepareRequest(input, options), + stream: false, + }, + { + signal: run.signal, + }, + ); + return new ChatGroqOutput({ + id: response.id, + model: response.model, + created: response.created, + system_fingerprint: response.system_fingerprint, + choices: response.choices.map( + (choice) => + ({ + delta: choice.message, + index: choice.index, + logprobs: choice.logprobs, + finish_reason: choice.finish_reason, + }) as Client.Chat.ChatCompletionChunk.Choice, + ), + }); + } + + protected async *_stream( + input: BaseMessage[], + options: StreamGenerateOptions, + run: GetRunContext, + ): AsyncStream { + for await (const chunk of await this.client.chat.completions.create( + { + ...this._prepareRequest(input, options), + stream: true, + }, + { + signal: run.signal, + }, + )) { + yield new ChatGroqOutput(chunk); + } + } + + createSnapshot() { + return { + ...super.createSnapshot(), + parameters: shallowCopy(this.parameters), + client: this.client, + }; + } +} diff --git a/tests/e2e/adapters/groq/chat.test.ts b/tests/e2e/adapters/groq/chat.test.ts new file mode 100644 index 0000000..350b8be --- /dev/null +++ b/tests/e2e/adapters/groq/chat.test.ts @@ -0,0 +1,72 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { BaseMessage, Role } from "@/llms/primitives/message.js"; +import { GroqChatLLM } from "@/adapters/groq/chat.js"; + +const apiKey = process.env.GROQ_API_KEY; + +describe.runIf(Boolean(apiKey))("Adapter Groq Chat LLM", () => { + const createChatLLM = () => { + const model = new GroqChatLLM({ + modelId: "llama3-8b-8192", + parameters: { + temperature: 0, + max_tokens: 1024, + top_p: 1, + }, + }); + return new GroqChatLLM(model); + }; + + it("Generates", async () => { + const conversation = [ + BaseMessage.of({ + role: Role.SYSTEM, + text: `You are a helpful and respectful and honest assistant. Your answer should be short and concise.`, + }), + ]; + const llm = createChatLLM(); + + for (const { question, answer } of [ + { + question: `What is the coldest continent? Response must be a single word without any punctuation.`, + answer: "Antarctica", + }, + { + question: + "What is the most common typical animal that lives there? Response must be a single word without any punctuation.", + answer: "Penguin", + }, + ]) { + conversation.push( + BaseMessage.of({ + role: Role.USER, + text: question, + }), + ); + const response = await llm.generate(conversation); + expect(response.messages.length).toBeGreaterThan(0); + expect(response.getTextContent()).toBe(answer); + conversation.push( + BaseMessage.of({ + role: Role.ASSISTANT, + text: response.getTextContent(), + }), + ); + } + }); +}); diff --git a/tests/e2e/utils.ts b/tests/e2e/utils.ts index cb145fb..f72521f 100644 --- a/tests/e2e/utils.ts +++ b/tests/e2e/utils.ts @@ -28,6 +28,7 @@ import { RunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; import { toJsonSchema } from "@/internals/helpers/schema.js"; import { OpenAI } from "openai"; +import { Groq } from "groq-sdk"; interface CallbackOptions { required?: boolean; @@ -121,7 +122,7 @@ verifyDeserialization.isIgnored = (value: unknown, parent?: any) => { return true; } - if (parent && parent instanceof OpenAI) { + if (parent && (parent instanceof OpenAI || parent instanceof Groq)) { try { Serializer.findFactory(value); return false; diff --git a/yarn.lock b/yarn.lock index f99c258..11d1359 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2543,6 +2543,7 @@ __metadata: eslint-plugin-unused-imports: "npm:^4.1.3" fast-xml-parser: "npm:^4.4.1" glob: "npm:^11.0.0" + groq-sdk: "npm:^0.7.0" header-generator: "npm:^2.1.54" husky: "npm:^9.1.5" joplin-turndown-plugin-gfm: "npm:^1.0.12" @@ -2589,6 +2590,7 @@ __metadata: "@langchain/community": ~0.2.28 "@langchain/core": ~0.2.27 "@langchain/langgraph": ~0.0.34 + groq-sdk: ^0.7.0 ollama: ^0.5.8 openai: ^4.56.0 openai-chat-tokens: ^0.2.8 @@ -4747,6 +4749,21 @@ __metadata: languageName: node linkType: hard +"groq-sdk@npm:^0.7.0": + version: 0.7.0 + resolution: "groq-sdk@npm:0.7.0" + dependencies: + "@types/node": "npm:^18.11.18" + "@types/node-fetch": "npm:^2.6.4" + abort-controller: "npm:^3.0.0" + agentkeepalive: "npm:^4.2.1" + form-data-encoder: "npm:1.7.2" + formdata-node: "npm:^4.3.2" + node-fetch: "npm:^2.6.7" + checksum: 10c0/53fad6f5492b4682cdcc75d994e7bc6c29be910f727da35639a729519003f433813ac47a0ca6a9acb27dd7d938a0a6b6677c22bd68067f2aba7a5cdbb5bb20c2 + languageName: node + linkType: hard + "handlebars@npm:^4.7.7": version: 4.7.8 resolution: "handlebars@npm:4.7.8"