Skip to content

Commit b50fa6b

Browse files
authored
Merge pull request #1 from ovuruska/DeepInfra-embeddings-integration
lanchain-community [feature]: DeepInfra embeddings integration
2 parents d3f3f29 + c778987 commit b50fa6b

File tree

6 files changed

+276
-0
lines changed

6 files changed

+276
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import { DeepInfraEmbeddings } from "@langchain/community/embeddings/deepinfra";
2+
3+
4+
const model = new DeepInfraEmbeddings({
5+
apiToken: process.env.DEEPINFRA_API_TOKEN!,
6+
batchSize: 1024, // Default value
7+
modelName: "sentence-transformers/clip-ViT-B-32", // Default value
8+
});
9+
10+
const embeddings = await model.embedQuery(
11+
"Tell me a story about a dragon and a princess."
12+
);
13+
console.log(embeddings);
14+

libs/langchain-community/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ embeddings/cohere.cjs
138138
embeddings/cohere.js
139139
embeddings/cohere.d.ts
140140
embeddings/cohere.d.cts
141+
embeddings/deepinfra.cjs
142+
embeddings/deepinfra.js
143+
embeddings/deepinfra.d.ts
144+
embeddings/deepinfra.d.cts
141145
embeddings/fireworks.cjs
142146
embeddings/fireworks.js
143147
embeddings/fireworks.d.ts

libs/langchain-community/langchain.config.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export const config = {
6464
"embeddings/bedrock": "embeddings/bedrock",
6565
"embeddings/cloudflare_workersai": "embeddings/cloudflare_workersai",
6666
"embeddings/cohere": "embeddings/cohere",
67+
"embeddings/deepinfra": "embeddings/deepinfra",
6768
"embeddings/fireworks": "embeddings/fireworks",
6869
"embeddings/googlepalm": "embeddings/googlepalm",
6970
"embeddings/googlevertexai": "embeddings/googlevertexai",

libs/langchain-community/package.json

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,15 @@
873873
"import": "./embeddings/cohere.js",
874874
"require": "./embeddings/cohere.cjs"
875875
},
876+
"./embeddings/deepinfra": {
877+
"types": {
878+
"import": "./embeddings/deepinfra.d.ts",
879+
"require": "./embeddings/deepinfra.d.cts",
880+
"default": "./embeddings/deepinfra.d.ts"
881+
},
882+
"import": "./embeddings/deepinfra.js",
883+
"require": "./embeddings/deepinfra.cjs"
884+
},
876885
"./embeddings/fireworks": {
877886
"types": {
878887
"import": "./embeddings/fireworks.d.ts",
@@ -2448,6 +2457,10 @@
24482457
"embeddings/cohere.js",
24492458
"embeddings/cohere.d.ts",
24502459
"embeddings/cohere.d.cts",
2460+
"embeddings/deepinfra.cjs",
2461+
"embeddings/deepinfra.js",
2462+
"embeddings/deepinfra.d.ts",
2463+
"embeddings/deepinfra.d.cts",
24512464
"embeddings/fireworks.cjs",
24522465
"embeddings/fireworks.js",
24532466
"embeddings/fireworks.d.ts",
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import axios, {AxiosInstance, AxiosResponse} from "axios";
2+
3+
import { getEnvironmentVariable } from "@langchain/core/utils/env";
4+
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings";
5+
import { chunkArray } from "@langchain/core/utils/chunk_array";
6+
7+
/**
8+
* The default model name to use for generating embeddings.
9+
*/
10+
const DEFAULT_MODEL_NAME = "sentence-transformers/clip-ViT-B-32";
11+
12+
/**
13+
* The default batch size to use for generating embeddings.
14+
* This is limited by the DeepInfra API to a maximum of 1024.
15+
*/
16+
const DEFAULT_BATCH_SIZE = 1024;
17+
18+
/**
19+
* Environment variable name for the DeepInfra API token.
20+
*/
21+
const API_TOKEN_ENV_VAR = "DEEPINFRA_API_TOKEN";
22+
23+
24+
export interface DeepInfraEmbeddingsRequest {
25+
inputs: string[];
26+
normalize?: boolean;
27+
image?: string;
28+
webhook?: string;
29+
}
30+
31+
32+
/**
33+
* Input parameters for the DeepInfra embeddings
34+
*/
35+
export interface DeepInfraEmbeddingsParams extends EmbeddingsParams {
36+
37+
/**
38+
* The API token to use for authentication.
39+
* If not provided, it will be read from the `DEEPINFRA_API_TOKEN` environment variable.
40+
*/
41+
apiToken?: string;
42+
43+
/**
44+
* The model ID to use for generating completions.
45+
* Default: `sentence-transformers/clip-ViT-B-32`
46+
*/
47+
modelName?: string;
48+
49+
/**
50+
* The maximum number of texts to embed in a single request. This is
51+
* limited by the DeepInfra API to a maximum of 1024.
52+
*/
53+
batchSize?: number;
54+
}
55+
56+
/**
57+
* Response from the DeepInfra embeddings API.
58+
*/
59+
export interface DeepInfraEmbeddingsResponse {
60+
/**
61+
* The embeddings generated for the input texts.
62+
*/
63+
embeddings: number[][];
64+
/**
65+
* The number of tokens in the input texts.
66+
*/
67+
input_tokens: number;
68+
/**
69+
* The status of the inference.
70+
*/
71+
request_id?: string;
72+
}
73+
74+
75+
/**
76+
* A class for generating embeddings using the Cohere API.
77+
* @example
78+
* ```typescript
79+
* // Embed a query using the CohereEmbeddings class
80+
* const model = new ChatOpenAI();
81+
* const res = await model.embedQuery(
82+
* "What would be a good company name for a company that makes colorful socks?",
83+
* );
84+
* console.log({ res });
85+
* ```
86+
*/
87+
export class DeepInfraEmbeddings
88+
extends Embeddings
89+
implements DeepInfraEmbeddingsParams
90+
{
91+
92+
private client: AxiosInstance;
93+
94+
apiToken: string;
95+
96+
batchSize: number;
97+
98+
modelName: string;
99+
100+
101+
/**
102+
* Constructor for the CohereEmbeddings class.
103+
* @param fields - An optional object with properties to configure the instance.
104+
*/
105+
constructor(
106+
fields?: Partial<DeepInfraEmbeddingsParams> & {
107+
verbose?: boolean;
108+
}
109+
) {
110+
const fieldsWithDefaults = {
111+
modelName: DEFAULT_MODEL_NAME,
112+
batchSize: DEFAULT_BATCH_SIZE,
113+
...fields };
114+
115+
super(fieldsWithDefaults);
116+
117+
const apiKey =
118+
fieldsWithDefaults?.apiToken || getEnvironmentVariable(API_TOKEN_ENV_VAR);
119+
120+
if (!apiKey) {
121+
throw new Error("DeepInfra API token not found");
122+
}
123+
124+
this.modelName = fieldsWithDefaults?.modelName ?? this.modelName;
125+
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
126+
this.apiToken = apiKey;
127+
128+
}
129+
130+
/**
131+
* Generates embeddings for an array of texts.
132+
* @param inputs - An array of strings to generate embeddings for.
133+
* @returns A Promise that resolves to an array of embeddings.
134+
*/
135+
async embedDocuments(inputs: string[]): Promise<number[][]> {
136+
await this.maybeInitClient();
137+
138+
const batches = chunkArray(inputs, this.batchSize) as string[][];
139+
140+
const batchRequests = batches.map((batch : string[]) =>
141+
this.embeddingWithRetry({
142+
inputs: batch,
143+
})
144+
);
145+
146+
const batchResponses = await Promise.all(batchRequests);
147+
148+
const out: number[][] = [];
149+
150+
for (let i = 0; i < batchResponses.length; i += 1) {
151+
const batch = batches[i];
152+
const { embeddings } = batchResponses[i];
153+
for (let j = 0; j < batch.length; j += 1) {
154+
out.push(embeddings[j]);
155+
}
156+
}
157+
158+
return out;
159+
}
160+
161+
/**
162+
* Generates an embedding for a single text.
163+
* @param text - A string to generate an embedding for.
164+
* @returns A Promise that resolves to an array of numbers representing the embedding.
165+
*/
166+
async embedQuery(text: string): Promise<number[]> {
167+
await this.maybeInitClient();
168+
169+
const {embeddings} = await this.embeddingWithRetry({
170+
inputs: [text],
171+
});
172+
return embeddings[0];
173+
}
174+
175+
/**
176+
* Generates embeddings with retry capabilities.
177+
* @param request - An object containing the request parameters for generating embeddings.
178+
* @returns A Promise that resolves to the API response.
179+
*/
180+
private async embeddingWithRetry(
181+
request: DeepInfraEmbeddingsRequest
182+
): Promise<DeepInfraEmbeddingsResponse> {
183+
this.maybeInitClient();
184+
const response = await this.caller.call(this.client.post.bind(this.client,""), request);
185+
return (response as AxiosResponse<DeepInfraEmbeddingsResponse>).data;
186+
}
187+
188+
/**
189+
* Initializes the DeepInfra client if it hasn't been initialized already.
190+
*/
191+
private maybeInitClient() {
192+
if (!this.client) {
193+
194+
this.client = axios.default.create({
195+
baseURL: `https://api.deepinfra.com/v1/inference/${this.modelName}`,
196+
headers: {
197+
Authorization: `Bearer ${this.apiToken}`,
198+
ContentType: "application/json",
199+
},
200+
});
201+
}
202+
}
203+
204+
/** @ignore */
205+
static async imports(): Promise<{}> {
206+
// Axios has already been defined as dependency in the package.json
207+
// so we can use it here without importing it.
208+
return {};
209+
}
210+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import { test, expect } from "@jest/globals";
2+
import { DeepInfraEmbeddings } from "../deepinfra.js";
3+
4+
test("Test DeepInfraEmbeddings.embedQuery", async () => {
5+
const embeddings = new DeepInfraEmbeddings();
6+
const res = await embeddings.embedQuery("Hello world");
7+
expect(typeof res[0]).toBe("number");
8+
});
9+
10+
test("Test DeepInfraEmbeddings.embedDocuments", async () => {
11+
const embeddings = new DeepInfraEmbeddings();
12+
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]);
13+
expect(res).toHaveLength(2);
14+
expect(typeof res[0][0]).toBe("number");
15+
expect(typeof res[1][0]).toBe("number");
16+
});
17+
18+
test("Test DeepInfraEmbeddings concurrency", async () => {
19+
const embeddings = new DeepInfraEmbeddings({
20+
batchSize: 1,
21+
});
22+
const res = await embeddings.embedDocuments([
23+
"Hello world",
24+
"Bye bye",
25+
"we need",
26+
"at least",
27+
"six documents",
28+
"to test concurrency"
29+
]);
30+
expect(res).toHaveLength(6);
31+
expect(res.find((embedding) => typeof embedding[0] !== "number")).toBe(
32+
undefined
33+
);
34+
});

0 commit comments

Comments
 (0)