Skip to content

Commit

Permalink
feat(groq): add llm adapter (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
abughali authored Sep 12, 2024
1 parent 8a93496 commit 5abb614
Show file tree
Hide file tree
Showing 9 changed files with 384 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ CODE_INTERPRETER_URL=http://127.0.0.1:50051

# For OpenAI LLM Adapter
# OPENAI_API_KEY=

# For Groq LLM Adapter
# GROQ_API_KEY=
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ To run this example, be sure that you have installed [ollama](https://ollama.com
| `OpenAI` | LLM + ChatLLM support ([example](./examples/llms/providers/openai.ts)) |
| `LangChain` | Use any LLM that LangChain supports ([example](./examples/llms/providers/langchain.ts)) |
| `WatsonX` | LLM + ChatLLM support ([example](./examples/llms/providers/watsonx.ts)) |
| `Groq` | ChatLLM support ([example](./examples/llms/providers/groq.ts)) |
| `BAM (Internal)` | LLM + ChatLLM support ([example](./examples/llms/providers/bam.ts)) |
|[Request](https://github.com/i-am-bee/bee-agent-framework/discussions) | |

Expand Down
21 changes: 21 additions & 0 deletions examples/llms/providers/groq.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import "dotenv/config";
import { BaseMessage } from "bee-agent-framework/llms/primitives/message";
import { GroqChatLLM } from "bee-agent-framework/adapters/groq/chat";

const llm = new GroqChatLLM({
modelId: "gemma2-9b-it",
parameters: {
temperature: 0.7,
max_tokens: 1024,
top_p: 1,
},
});

console.info("Meta", await llm.meta());
const response = await llm.generate([
BaseMessage.of({
role: "user",
text: "Hello world!",
}),
]);
console.info(response.getTextContent());
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
"@langchain/community": "~0.2.28",
"@langchain/core": "~0.2.27",
"@langchain/langgraph": "~0.0.34",
"groq-sdk": "^0.7.0",
"ollama": "^0.5.8",
"openai": "^4.56.0",
"openai-chat-tokens": "^0.2.8"
Expand Down Expand Up @@ -139,6 +140,7 @@
"eslint-config-prettier": "^9.1.0",
"eslint-plugin-unused-imports": "^4.1.3",
"glob": "^11.0.0",
"groq-sdk": "^0.7.0",
"husky": "^9.1.5",
"langchain": "~0.2.16",
"lint-staged": "^15.2.9",
Expand Down
37 changes: 37 additions & 0 deletions src/adapters/groq/chat.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { verifyDeserialization } from "@tests/e2e/utils.js";
import { GroqChatLLM } from "@/adapters/groq/chat.js";
import { Groq } from "groq-sdk";

describe("Groq ChatLLM", () => {
const getInstance = () => {
return new GroqChatLLM({
modelId: "gemma2-9b-it",
client: new Groq({
apiKey: "123",
}),
});
};

it("Serializes", async () => {
const instance = getInstance();
const serialized = instance.serialize();
const deserialized = GroqChatLLM.fromSerialized(serialized);
verifyDeserialization(instance, deserialized);
});
});
229 changes: 229 additions & 0 deletions src/adapters/groq/chat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import {
AsyncStream,
BaseLLMTokenizeOutput,
ExecutionOptions,
GenerateCallbacks,
GenerateOptions,
LLMMeta,
StreamGenerateOptions,
} from "@/llms/base.js";
import { shallowCopy } from "@/serializer/utils.js";
import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage, RoleType } from "@/llms/primitives/message.js";
import { Emitter } from "@/emitter/emitter.js";
import { ClientOptions, Groq as Client } from "groq-sdk";
import { GetRunContext } from "@/context.js";
import { Serializer } from "@/serializer/serializer.js";
import { getPropStrict } from "@/internals/helpers/object.js";
import { ChatCompletionCreateParams } from "groq-sdk/resources/chat/completions";

type Parameters = Omit<ChatCompletionCreateParams, "stream" | "messages" | "model">;
type Response = Omit<Client.Chat.ChatCompletionChunk, "object">;

export class ChatGroqOutput extends ChatLLMOutput {
public readonly responses: Response[];

constructor(response: Response) {
super();
this.responses = [response];
}

static {
this.register();
}

get messages() {
return this.responses
.flatMap((response) => response.choices)
.flatMap((choice) =>
BaseMessage.of({
role: choice.delta.role as RoleType,
text: choice.delta.content!,
}),
);
}

getTextContent(): string {
return this.messages.map((msg) => msg.text).join("\n");
}

merge(other: ChatGroqOutput): void {
this.responses.push(...other.responses);
}

toString(): string {
return this.getTextContent();
}

createSnapshot() {
return {
responses: shallowCopy(this.responses),
};
}

loadSnapshot(snapshot: ReturnType<typeof this.createSnapshot>): void {
Object.assign(this, snapshot);
}
}

interface Input {
modelId?: string;
client?: Client;
parameters?: Parameters;
executionOptions?: ExecutionOptions;
}

export class GroqChatLLM extends ChatLLM<ChatGroqOutput> {
public readonly emitter = Emitter.root.child<GenerateCallbacks>({
namespace: ["groq", "chat_llm"],
creator: this,
});

public readonly client: Client;
public readonly parameters: Partial<Parameters>;

constructor({
client,
modelId = "llama3-8b-8192",
parameters,
executionOptions = {},
}: Input = {}) {
super(modelId, executionOptions);
this.client = client ?? new Client();
this.parameters = parameters ?? {};
}

static {
this.register();
Serializer.register(Client, {
toPlain: (value) => ({
options: getPropStrict(value, "_options") as ClientOptions,
}),
fromPlain: (value) => new Client(value.options),
});
}

async meta(): Promise<LLMMeta> {
if (
this.modelId.includes("gemma") ||
this.modelId.includes("llama3") ||
this.modelId.includes("llama-guard")
) {
return { tokenLimit: 8 * 1024 };
} else if (this.modelId.includes("llava-v1.5")) {
return { tokenLimit: 4 * 1024 };
} else if (this.modelId.includes("llama-3.1-70b") || this.modelId.includes("llama-3.1-8b")) {
return { tokenLimit: 128 * 1024 };
} else if (this.modelId.includes("mixtral-8x7b")) {
return { tokenLimit: 32 * 1024 };
}

return {
tokenLimit: Infinity,
};
}

async tokenize(input: BaseMessage[]): Promise<BaseLLMTokenizeOutput> {
const contentLength = input.reduce((acc, msg) => acc + msg.text.length, 0);

return {
tokensCount: Math.ceil(contentLength / 4),
};
}

protected _prepareRequest(
input: BaseMessage[],
options: GenerateOptions,
): ChatCompletionCreateParams {
return {
...this.parameters,
model: this.modelId,
stream: false,
messages: input.map(
(message) =>
({
role: message.role,
content: message.text,
}) as Client.Chat.ChatCompletionMessageParam,
),
...(options?.guided?.json && {
response_format: {
type: "json_object",
},
}),
};
}

protected async _generate(
input: BaseMessage[],
options: GenerateOptions,
run: GetRunContext<typeof this>,
): Promise<ChatGroqOutput> {
const response = await this.client.chat.completions.create(
{
...this._prepareRequest(input, options),
stream: false,
},
{
signal: run.signal,
},
);
return new ChatGroqOutput({
id: response.id,
model: response.model,
created: response.created,
system_fingerprint: response.system_fingerprint,
choices: response.choices.map(
(choice) =>
({
delta: choice.message,
index: choice.index,
logprobs: choice.logprobs,
finish_reason: choice.finish_reason,
}) as Client.Chat.ChatCompletionChunk.Choice,
),
});
}

protected async *_stream(
input: BaseMessage[],
options: StreamGenerateOptions,
run: GetRunContext<typeof this>,
): AsyncStream<ChatGroqOutput> {
for await (const chunk of await this.client.chat.completions.create(
{
...this._prepareRequest(input, options),
stream: true,
},
{
signal: run.signal,
},
)) {
yield new ChatGroqOutput(chunk);
}
}

createSnapshot() {
return {
...super.createSnapshot(),
parameters: shallowCopy(this.parameters),
client: this.client,
};
}
}
72 changes: 72 additions & 0 deletions tests/e2e/adapters/groq/chat.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { GroqChatLLM } from "@/adapters/groq/chat.js";

const apiKey = process.env.GROQ_API_KEY;

describe.runIf(Boolean(apiKey))("Adapter Groq Chat LLM", () => {
const createChatLLM = () => {
const model = new GroqChatLLM({
modelId: "llama3-8b-8192",
parameters: {
temperature: 0,
max_tokens: 1024,
top_p: 1,
},
});
return new GroqChatLLM(model);
};

it("Generates", async () => {
const conversation = [
BaseMessage.of({
role: Role.SYSTEM,
text: `You are a helpful and respectful and honest assistant. Your answer should be short and concise.`,
}),
];
const llm = createChatLLM();

for (const { question, answer } of [
{
question: `What is the coldest continent? Response must be a single word without any punctuation.`,
answer: "Antarctica",
},
{
question:
"What is the most common typical animal that lives there? Response must be a single word without any punctuation.",
answer: "Penguin",
},
]) {
conversation.push(
BaseMessage.of({
role: Role.USER,
text: question,
}),
);
const response = await llm.generate(conversation);
expect(response.messages.length).toBeGreaterThan(0);
expect(response.getTextContent()).toBe(answer);
conversation.push(
BaseMessage.of({
role: Role.ASSISTANT,
text: response.getTextContent(),
}),
);
}
});
});
Loading

0 comments on commit 5abb614

Please sign in to comment.