Skip to content

Commit 0d133ca

Browse files
committed
Return FileOutput from run() function
This is currently behind the `useFileOutput` flag provided to the Replicate constructor. This allows us to test the feature before rolling it out more widely. When enabled any URLs or data-uris will be converted into a FileOutput type. This is essentially a `ReadableStream` that has two additional methods `url()` to return the underlying URL and `blob()` which will return a `Blob()` object with the file data loaded into memory. The intention here is to make it easier to work with file outputs and allows us to optimize the delivery of file assets to the client in future iterations.
1 parent 16dae4d commit 0d133ca

File tree

3 files changed

+220
-3
lines changed

3 files changed

+220
-3
lines changed

index.d.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ declare module "replicate" {
88
response: Response;
99
}
1010

11+
export interface FileOutput extends ReadableStream {
12+
blob(): Promise<Blob>;
13+
url(): URL;
14+
toString(): string;
15+
}
16+
1117
export interface Account {
1218
type: "user" | "organization";
1319
username: string;
@@ -137,6 +143,7 @@ declare module "replicate" {
137143
init?: RequestInit
138144
) => Promise<Response>;
139145
fileEncodingStrategy?: FileEncodingStrategy;
146+
useFileOutput?: boolean;
140147
});
141148

142149
auth: string;

index.js

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
const ApiError = require("./lib/error");
22
const ModelVersionIdentifier = require("./lib/identifier");
3-
const { createReadableStream } = require("./lib/stream");
3+
const { createReadableStream, createFileOutput } = require("./lib/stream");
44
const {
5+
transform,
56
withAutomaticRetries,
67
validateWebhook,
78
parseProgressFromLogs,
@@ -47,6 +48,7 @@ class Replicate {
4748
* @param {string} options.userAgent - Identifier of your app
4849
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
4950
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
51+
* @param {boolean} [options.useFileOutput] - Set to `true` to return `FileOutput` objects from `run` instead of URLs, defaults to false.
5052
* @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use
5153
*/
5254
constructor(options = {}) {
@@ -58,6 +60,7 @@ class Replicate {
5860
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
5961
this.fetch = options.fetch || globalThis.fetch;
6062
this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default";
63+
this.useFileOutput = options.useFileOutput ?? false;
6164

6265
this.accounts = {
6366
current: accounts.current.bind(this),
@@ -196,7 +199,17 @@ class Replicate {
196199
throw new Error(`Prediction failed: ${prediction.error}`);
197200
}
198201

199-
return prediction.output;
202+
return transform(prediction.output, (value) => {
203+
if (
204+
typeof value === "string" &&
205+
(value.startsWith("https:") || value.startsWith("data:"))
206+
) {
207+
return this.useFileOutput
208+
? createFileOutput({ url: value, fetch: this.fetch })
209+
: value;
210+
}
211+
return value;
212+
});
200213
}
201214

202215
/**

index.test.ts

Lines changed: 198 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import { expect, jest, test } from "@jest/globals";
22
import Replicate, {
33
ApiError,
4+
FileOutput,
45
Model,
56
Prediction,
67
validateWebhook,
78
parseProgressFromLogs,
89
} from "replicate";
910
import nock from "nock";
10-
import { Readable } from "node:stream";
1111
import { createReadableStream } from "./lib/stream";
1212

1313
let client: Replicate;
@@ -1562,6 +1562,203 @@ describe("Replicate client", () => {
15621562

15631563
scope.done();
15641564
});
1565+
1566+
test("returns FileOutput for URLs when useFileOutput is true", async () => {
1567+
client = new Replicate({ auth: "foo", useFileOutput: true });
1568+
1569+
nock(BASE_URL)
1570+
.post("/predictions")
1571+
.reply(201, {
1572+
id: "ufawqhfynnddngldkgtslldrkq",
1573+
status: "starting",
1574+
logs: null,
1575+
})
1576+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1577+
.reply(200, {
1578+
id: "ufawqhfynnddngldkgtslldrkq",
1579+
status: "processing",
1580+
logs: [].join("\n"),
1581+
})
1582+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1583+
.reply(200, {
1584+
id: "ufawqhfynnddngldkgtslldrkq",
1585+
status: "processing",
1586+
logs: [].join("\n"),
1587+
})
1588+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1589+
.reply(200, {
1590+
id: "ufawqhfynnddngldkgtslldrkq",
1591+
status: "succeeded",
1592+
output: "https://example.com",
1593+
logs: [].join("\n"),
1594+
});
1595+
1596+
nock("https://example.com")
1597+
.get("/")
1598+
.reply(200, "hello world", { "Content-Type": "text/plain" });
1599+
1600+
const output = (await client.run(
1601+
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1602+
{
1603+
input: { text: "Hello, world!" },
1604+
}
1605+
)) as FileOutput;
1606+
1607+
expect(output).toBeInstanceOf(ReadableStream);
1608+
expect(output.url()).toEqual(new URL("https://example.com"));
1609+
1610+
const blob = await output.blob();
1611+
expect(blob.type).toEqual("text/plain");
1612+
expect(blob.arrayBuffer()).toEqual(
1613+
new Blob(["Hello, world!"]).arrayBuffer()
1614+
);
1615+
});
1616+
1617+
test("returns FileOutput for URLs when useFileOutput is true - acts like string", async () => {
1618+
client = new Replicate({ auth: "foo", useFileOutput: true });
1619+
1620+
nock(BASE_URL)
1621+
.post("/predictions")
1622+
.reply(201, {
1623+
id: "ufawqhfynnddngldkgtslldrkq",
1624+
status: "starting",
1625+
logs: null,
1626+
})
1627+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1628+
.reply(200, {
1629+
id: "ufawqhfynnddngldkgtslldrkq",
1630+
status: "processing",
1631+
logs: [].join("\n"),
1632+
})
1633+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1634+
.reply(200, {
1635+
id: "ufawqhfynnddngldkgtslldrkq",
1636+
status: "processing",
1637+
logs: [].join("\n"),
1638+
})
1639+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1640+
.reply(200, {
1641+
id: "ufawqhfynnddngldkgtslldrkq",
1642+
status: "succeeded",
1643+
output: "https://example.com",
1644+
logs: [].join("\n"),
1645+
});
1646+
1647+
nock("https://example.com")
1648+
.get("/")
1649+
.reply(200, "hello world", { "Content-Type": "text/plain" });
1650+
1651+
const output = (await client.run(
1652+
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1653+
{
1654+
input: { text: "Hello, world!" },
1655+
}
1656+
)) as unknown as string;
1657+
1658+
expect(fetch(output).then((r) => r.text())).resolves.toEqual(
1659+
"hello world"
1660+
);
1661+
});
1662+
1663+
test("returns FileOutput for URLs when useFileOutput is true - array output", async () => {
1664+
client = new Replicate({ auth: "foo", useFileOutput: true });
1665+
1666+
nock(BASE_URL)
1667+
.post("/predictions")
1668+
.reply(201, {
1669+
id: "ufawqhfynnddngldkgtslldrkq",
1670+
status: "starting",
1671+
logs: null,
1672+
})
1673+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1674+
.reply(200, {
1675+
id: "ufawqhfynnddngldkgtslldrkq",
1676+
status: "processing",
1677+
logs: [].join("\n"),
1678+
})
1679+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1680+
.reply(200, {
1681+
id: "ufawqhfynnddngldkgtslldrkq",
1682+
status: "processing",
1683+
logs: [].join("\n"),
1684+
})
1685+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1686+
.reply(200, {
1687+
id: "ufawqhfynnddngldkgtslldrkq",
1688+
status: "succeeded",
1689+
output: ["https://example.com"],
1690+
logs: [].join("\n"),
1691+
});
1692+
1693+
nock("https://example.com")
1694+
.get("/")
1695+
.reply(200, "hello world", { "Content-Type": "text/plain" });
1696+
1697+
const [output] = (await client.run(
1698+
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1699+
{
1700+
input: { text: "Hello, world!" },
1701+
}
1702+
)) as FileOutput[];
1703+
1704+
expect(output).toBeInstanceOf(ReadableStream);
1705+
expect(output.url()).toEqual(new URL("https://example.com"));
1706+
1707+
const blob = await output.blob();
1708+
expect(blob.type).toEqual("text/plain");
1709+
expect(blob.arrayBuffer()).toEqual(
1710+
new Blob(["Hello, world!"]).arrayBuffer()
1711+
);
1712+
});
1713+
1714+
test("returns FileOutput for URLs when useFileOutput is true - data uri", async () => {
1715+
client = new Replicate({ auth: "foo", useFileOutput: true });
1716+
1717+
nock(BASE_URL)
1718+
.post("/predictions")
1719+
.reply(201, {
1720+
id: "ufawqhfynnddngldkgtslldrkq",
1721+
status: "starting",
1722+
logs: null,
1723+
})
1724+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1725+
.reply(200, {
1726+
id: "ufawqhfynnddngldkgtslldrkq",
1727+
status: "processing",
1728+
logs: [].join("\n"),
1729+
})
1730+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1731+
.reply(200, {
1732+
id: "ufawqhfynnddngldkgtslldrkq",
1733+
status: "processing",
1734+
logs: [].join("\n"),
1735+
})
1736+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1737+
.reply(200, {
1738+
id: "ufawqhfynnddngldkgtslldrkq",
1739+
status: "succeeded",
1740+
output: "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==",
1741+
logs: [].join("\n"),
1742+
});
1743+
1744+
const output = (await client.run(
1745+
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1746+
{
1747+
input: { text: "Hello, world!" },
1748+
}
1749+
)) as FileOutput;
1750+
1751+
expect(output).toBeInstanceOf(ReadableStream);
1752+
expect(output.url()).toEqual(
1753+
new URL("data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==")
1754+
);
1755+
1756+
const blob = await output.blob();
1757+
expect(blob.type).toEqual("text/plain");
1758+
expect(blob.arrayBuffer()).toEqual(
1759+
new Blob(["Hello, world!"]).arrayBuffer()
1760+
);
1761+
});
15651762
});
15661763

15671764
describe("webhooks.default.secret.get", () => {

0 commit comments

Comments
 (0)