Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community[minor]: Addded ChatTogetherAI integration #4215

Merged
merged 3 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/core_docs/docs/integrations/chat/togetherai.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
---
sidebar_label: TogetherAI
---

import CodeBlock from "@theme/CodeBlock";

# ChatTogetherAI

## Setup

1. Create a TogetherAI account and get your API key [here](http://api.together.ai/).
2. Export or set your API key inline. The ChatTogetherAI class defaults to `process.env.TOGETHER_AI_API_KEY`.

```bash
export TOGETHER_AI_API_KEY=your-api-key
```

You can use models provided by TogetherAI as follows:

import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx";

<IntegrationInstallTooltip></IntegrationInstallTooltip>

```bash npm2yarn
npm install @langchain/community
```

import TogetherAI from "@examples/models/chat/integration_togetherai.ts";

<CodeBlock language="typescript">{TogetherAI}</CodeBlock>

Behind the scenes, TogetherAI uses the OpenAI SDK and OpenAI compatible API, with some caveats:

- Certain properties are not supported by the TogetherAI API, see [here](https://docs.together.ai/reference/chat-completions).
7 changes: 7 additions & 0 deletions examples/src/models/chat/integration_togetherai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import { ChatTogetherAI } from "@langchain/community/chat_models/togetherai";

const model = new ChatTogetherAI({
temperature: 0.9,
// In Node.js defaults to process.env.TOGETHER_AI_API_KEY
togetherAIApiKey: "YOUR-API-KEY",
});
3 changes: 3 additions & 0 deletions libs/langchain-community/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ chat_models/ollama.d.ts
chat_models/portkey.cjs
chat_models/portkey.js
chat_models/portkey.d.ts
chat_models/togetherai.cjs
chat_models/togetherai.js
chat_models/togetherai.d.ts
chat_models/yandex.cjs
chat_models/yandex.js
chat_models/yandex.d.ts
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-community/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ export const config = {
"chat_models/minimax": "chat_models/minimax",
"chat_models/ollama": "chat_models/ollama",
"chat_models/portkey": "chat_models/portkey",
"chat_models/togetherai": "chat_models/togetherai",
"chat_models/yandex": "chat_models/yandex",
// callbacks
"callbacks/handlers/llmonitor": "callbacks/handlers/llmonitor",
Expand Down
8 changes: 8 additions & 0 deletions libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,11 @@
"import": "./chat_models/portkey.js",
"require": "./chat_models/portkey.cjs"
},
"./chat_models/togetherai": {
"types": "./chat_models/togetherai.d.ts",
"import": "./chat_models/togetherai.js",
"require": "./chat_models/togetherai.cjs"
},
"./chat_models/yandex": {
"types": "./chat_models/yandex.d.ts",
"import": "./chat_models/yandex.js",
Expand Down Expand Up @@ -1702,6 +1707,9 @@
"chat_models/portkey.cjs",
"chat_models/portkey.js",
"chat_models/portkey.d.ts",
"chat_models/togetherai.cjs",
"chat_models/togetherai.js",
"chat_models/togetherai.d.ts",
"chat_models/yandex.cjs",
"chat_models/yandex.js",
"chat_models/yandex.d.ts",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import { describe, test } from "@jest/globals";
import { ChatMessage, HumanMessage } from "@langchain/core/messages";
import {
PromptTemplate,
ChatPromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
} from "@langchain/core/prompts";
import { ChatTogetherAI } from "../togetherai.js";

describe("ChatTogetherAI", () => {
test("invoke", async () => {
const chat = new ChatTogetherAI();
const message = new HumanMessage("Hello!");
const res = await chat.invoke([message]);
console.log({ res });
expect(res.content.length).toBeGreaterThan(10);
});

test("generate", async () => {
const chat = new ChatTogetherAI();
const message = new HumanMessage("Hello!");
const res = await chat.generate([[message]]);
console.log(JSON.stringify(res, null, 2));
expect(res.generations[0][0].text.length).toBeGreaterThan(10);
});

test("custom messages", async () => {
const chat = new ChatTogetherAI();
const res = await chat.invoke([new ChatMessage("Hello!", "user")]);
console.log({ res });
expect(res.content.length).toBeGreaterThan(10);
});

test("prompt templates", async () => {
const chat = new ChatTogetherAI();

// PaLM doesn't support translation yet
const systemPrompt = PromptTemplate.fromTemplate(
"You are a helpful assistant who must always respond like a {job}."
);

const chatPrompt = ChatPromptTemplate.fromMessages([
new SystemMessagePromptTemplate(systemPrompt),
HumanMessagePromptTemplate.fromTemplate("{text}"),
]);

const responseA = await chat.generatePrompt([
await chatPrompt.formatPromptValue({
job: "pirate",
text: "What would be a good company name a company that makes colorful socks?",
}),
]);

console.log(responseA.generations);
expect(responseA.generations[0][0].text.length).toBeGreaterThan(10);
});

test("longer chain of messages", async () => {
const chat = new ChatTogetherAI();

const chatPrompt = ChatPromptTemplate.fromMessages([
HumanMessagePromptTemplate.fromTemplate(`Hi, my name is Joe!`),
AIMessagePromptTemplate.fromTemplate(`Nice to meet you, Joe!`),
HumanMessagePromptTemplate.fromTemplate("{text}"),
]);

const responseA = await chat.generatePrompt([
await chatPrompt.formatPromptValue({
text: "What did I just say my name was?",
}),
]);

console.log(responseA.generations);
expect(responseA.generations[0][0].text.length).toBeGreaterThan(10);
});
});
141 changes: 141 additions & 0 deletions libs/langchain-community/src/chat_models/togetherai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import type { BaseChatModelParams } from "@langchain/core/language_models/chat_models";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I noticed that the recent code changes explicitly handle the environment variable for the TogetherAI API key. I've flagged this for your review to ensure it aligns with the intended functionality. Keep up the great work!

import {
type OpenAIClient,
type ChatOpenAICallOptions,
type OpenAIChatInput,
type OpenAICoreRequestOptions,
ChatOpenAI,
} from "@langchain/openai";

import { getEnvironmentVariable } from "@langchain/core/utils/env";

type TogetherAIUnsupportedArgs =
| "frequencyPenalty"
| "presencePenalty"
| "logitBias"
| "functions";

type TogetherAIUnsupportedCallOptions = "functions" | "function_call" | "tools";

export type ChatTogetherAICallOptions = Partial<
Omit<ChatOpenAICallOptions, TogetherAIUnsupportedCallOptions>
>;

/**
* Wrapper around TogetherAI API for large language models fine-tuned for chat
*
* TogetherAI API is compatible to the OpenAI API with some limitations. View the
* full API ref at:
* @link {https://docs.together.ai/reference/chat-completions}
*
* To use, you should have the `openai` package installed and
* the `TOGETHER_AI_API_KEY` environment variable set.
* @example
* ```typescript
* const model = new ChatTogetherAI({
* temperature: 0.9,
* togetherAIApiKey: "YOUR-API-KEY",
* });
*
* const response = await model.invoke("Hello, how are you?");
* console.log(response);
* ```
*/
export class ChatTogetherAI extends ChatOpenAI<ChatTogetherAICallOptions> {
static lc_name() {
return "ChatTogetherAI";
}

_llmType() {
return "togetherAI";
}

get lc_secrets(): { [key: string]: string } | undefined {
return {
togetherAIApiKey: "TOGETHER_AI_API_KEY",
};
}

lc_serializable = true;

togetherAIApiKey?: string;

constructor(
fields?: Partial<
Omit<OpenAIChatInput, "openAIApiKey" | TogetherAIUnsupportedArgs>
> &
BaseChatModelParams & { togetherAIApiKey?: string }
) {
const togetherAIApiKey =
fields?.togetherAIApiKey || getEnvironmentVariable("TOGETHER_AI_API_KEY");

if (!togetherAIApiKey) {
throw new Error(
`TogetherAI API key not found. Please set the TOGETHER_AI_API_KEY environment variable or provide the key into "togetherAIApiKey"`
);
}

super({
...fields,
modelName: fields?.modelName || "mistralai/Mixtral-8x7B-Instruct-v0.1",
openAIApiKey: togetherAIApiKey,
configuration: {
baseURL: "https://api.together.xyz/v1/",
},
});

this.togetherAIApiKey = togetherAIApiKey;
}

toJSON() {
const result = super.toJSON();

if (
"kwargs" in result &&
typeof result.kwargs === "object" &&
result.kwargs != null
) {
delete result.kwargs.openai_api_key;
delete result.kwargs.configuration;
}

return result;
}

async completionWithRetry(
request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming,
options?: OpenAICoreRequestOptions
): Promise<AsyncIterable<OpenAIClient.Chat.Completions.ChatCompletionChunk>>;

async completionWithRetry(
request: OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming,
options?: OpenAICoreRequestOptions
): Promise<OpenAIClient.Chat.Completions.ChatCompletion>;

/**
* Calls the TogetherAI API with retry logic in case of failures.
* @param request The request to send to the TogetherAI API.
* @param options Optional configuration for the API call.
* @returns The response from the TogetherAI API.
*/
async completionWithRetry(
request:
| OpenAIClient.Chat.ChatCompletionCreateParamsStreaming
| OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming,
options?: OpenAICoreRequestOptions
): Promise<
| AsyncIterable<OpenAIClient.Chat.Completions.ChatCompletionChunk>
| OpenAIClient.Chat.Completions.ChatCompletion
> {
delete request.frequency_penalty;
delete request.presence_penalty;
delete request.logit_bias;
delete request.functions;

if (request.stream === true) {
return super.completionWithRetry(request, options);
}

return super.completionWithRetry(request, options);
}
}
Loading