Skip to content

Commit

Permalink
community[patch]: #3369 Streaming support for Replicate models (#5365)
Browse files Browse the repository at this point in the history
* langchain-community[patch]: #3369 Streaming support for Replicate models

* lock

* Add streaming test

---------

Co-authored-by: jeasonnow <guyf@seeyon.com>
Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
  • Loading branch information
3 people authored May 25, 2024
1 parent 668b0bb commit 0f2be54
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 43 deletions.
4 changes: 2 additions & 2 deletions libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@
"puppeteer": "^19.7.2",
"redis": "^4.6.6",
"release-it": "^15.10.1",
"replicate": "^0.18.0",
"replicate": "^0.29.4",
"rollup": "^3.19.1",
"sonix-speech-recognition": "^2.1.1",
"srt-parser-2": "^1.2.3",
Expand Down Expand Up @@ -317,7 +317,7 @@
"portkey-ai": "^0.1.11",
"puppeteer": "^19.7.2",
"redis": "*",
"replicate": "^0.18.0",
"replicate": "^0.29.4",
"sonix-speech-recognition": "^2.1.1",
"srt-parser-2": "^1.2.3",
"typeorm": "^0.3.12",
Expand Down
117 changes: 82 additions & 35 deletions libs/langchain-community/src/llms/replicate.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { GenerationChunk } from "@langchain/core/outputs";

import type ReplicateInstance from "replicate";

/**
* Interface defining the structure of the input data for the Replicate
Expand Down Expand Up @@ -88,13 +92,85 @@ export class Replicate extends LLM implements ReplicateInput {
prompt: string,
options: this["ParsedCallOptions"]
): Promise<string> {
const replicate = await this._prepareReplicate();
const input = await this._getReplicateInput(replicate, prompt);

const output = await this.caller.callWithOptions(
{ signal: options.signal },
() =>
replicate.run(this.model, {
input,
})
);

if (typeof output === "string") {
return output;
} else if (Array.isArray(output)) {
return output.join("");
} else {
// Note this is a little odd, but the output format is not consistent
// across models, so it makes some amount of sense.
return String(output);
}
}

async *_streamResponseChunks(
prompt: string,
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
const replicate = await this._prepareReplicate();
const input = await this._getReplicateInput(replicate, prompt);

const stream = await this.caller.callWithOptions(
{ signal: options?.signal },
async () =>
replicate.stream(this.model, {
input,
})
);
for await (const chunk of stream) {
if (chunk.event === "output") {
yield new GenerationChunk({ text: chunk.data, generationInfo: chunk });
await runManager?.handleLLMNewToken(chunk.data ?? "");
}

// stream is done
if (chunk.event === "done")
yield new GenerationChunk({
text: "",
generationInfo: { finished: true },
});
}
}

/** @ignore */
static async imports(): Promise<{
Replicate: typeof ReplicateInstance;
}> {
try {
const { default: Replicate } = await import("replicate");
return { Replicate };
} catch (e) {
throw new Error(
"Please install replicate as a dependency with, e.g. `yarn add replicate`"
);
}
}

private async _prepareReplicate(): Promise<ReplicateInstance> {
const imports = await Replicate.imports();

const replicate = new imports.Replicate({
return new imports.Replicate({
userAgent: "langchain",
auth: this.apiKey,
});
}

private async _getReplicateInput(
replicate: ReplicateInstance,
prompt: string
) {
if (this.promptKey === undefined) {
const [modelString, versionString] = this.model.split(":");
const version = await replicate.models.versions.get(
Expand All @@ -119,40 +195,11 @@ export class Replicate extends LLM implements ReplicateInput {
this.promptKey = sortedInputProperties[0][0] ?? "prompt";
}
}
const output = await this.caller.callWithOptions(
{ signal: options.signal },
() =>
replicate.run(this.model, {
input: {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
[this.promptKey!]: prompt,
...this.input,
},
})
);

if (typeof output === "string") {
return output;
} else if (Array.isArray(output)) {
return output.join("");
} else {
// Note this is a little odd, but the output format is not consistent
// across models, so it makes some amount of sense.
return String(output);
}
}

/** @ignore */
static async imports(): Promise<{
Replicate: typeof import("replicate").default;
}> {
try {
const { default: Replicate } = await import("replicate");
return { Replicate };
} catch (e) {
throw new Error(
"Please install replicate as a dependency with, e.g. `yarn add replicate`"
);
}
return {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
[this.promptKey!]: prompt,
...this.input,
};
}
}
19 changes: 19 additions & 0 deletions libs/langchain-community/src/llms/tests/replicate.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ test.skip("Test Replicate", async () => {
expect(typeof res).toBe("string");
});

test.skip("Test Replicate streaming", async () => {
const model = new Replicate({
model:
"a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
input: {
max_length: 10,
},
});

const stream = await model.stream("Hello, my name is ");

const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
console.log(chunks);
expect(chunks.length).toBeGreaterThan(1);
});

test.skip("Serialise Replicate", () => {
const model = new Replicate({
model:
Expand Down
30 changes: 24 additions & 6 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -9167,7 +9167,7 @@ __metadata:
puppeteer: ^19.7.2
redis: ^4.6.6
release-it: ^15.10.1
replicate: ^0.18.0
replicate: ^0.29.4
rollup: ^3.19.1
sonix-speech-recognition: ^2.1.1
srt-parser-2: ^1.2.3
Expand Down Expand Up @@ -9290,7 +9290,7 @@ __metadata:
portkey-ai: ^0.1.11
puppeteer: ^19.7.2
redis: "*"
replicate: ^0.18.0
replicate: ^0.29.4
sonix-speech-recognition: ^2.1.1
srt-parser-2: ^1.2.3
typeorm: ^0.3.12
Expand Down Expand Up @@ -32640,6 +32640,19 @@ __metadata:
languageName: node
linkType: hard

"readable-stream@npm:>=4.0.0":
version: 4.5.2
resolution: "readable-stream@npm:4.5.2"
dependencies:
abort-controller: ^3.0.0
buffer: ^6.0.3
events: ^3.3.0
process: ^0.11.10
string_decoder: ^1.3.0
checksum: c4030ccff010b83e4f33289c535f7830190773e274b3fcb6e2541475070bdfd69c98001c3b0cb78763fc00c8b62f514d96c2b10a8bd35d5ce45203a25fa1d33a
languageName: node
linkType: hard

"readable-stream@npm:^2.0.0, readable-stream@npm:^2.0.1, readable-stream@npm:^2.3.0, readable-stream@npm:^2.3.5, readable-stream@npm:~2.3.6":
version: 2.3.8
resolution: "readable-stream@npm:2.3.8"
Expand Down Expand Up @@ -33060,10 +33073,15 @@ __metadata:
languageName: node
linkType: hard

"replicate@npm:^0.18.0":
version: 0.18.0
resolution: "replicate@npm:0.18.0"
checksum: 547a8b386418aedf6e5be2086a63090e5a5f6cda36202a0122c4036a2af8a80efea420393e5efa4810c9cff0616a7df5adbd40fd4a0560f4aa1b4eda60a34794
"replicate@npm:^0.29.4":
version: 0.29.4
resolution: "replicate@npm:0.29.4"
dependencies:
readable-stream: ">=4.0.0"
dependenciesMeta:
readable-stream:
optional: true
checksum: 9405e19f619134a312aa77b3c04156549e4c8ba5e0711a494b99358abd0378646c22cd9bf07e6f9c8ab4a2f80b69ba22ed0a5b8ec0610684e9fa5d413e3b5729
languageName: node
linkType: hard

Expand Down

0 comments on commit 0f2be54

Please sign in to comment.