Skip to content

Commit 9f500dd

Browse files
authored
Add support for models.predictions.create endpoint (#163)
1 parent eefbcd4 commit 9f500dd

File tree

4 files changed

+71
-0
lines changed

4 files changed

+71
-0
lines changed

index.d.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,18 @@ declare module 'replicate' {
164164
version_id: string
165165
): Promise<ModelVersion>;
166166
};
167+
predictions: {
168+
create(
169+
model_owner: string,
170+
model_name: string,
171+
options: {
172+
input: object;
173+
stream?: boolean;
174+
webhook?: string;
175+
webhook_events_filter?: WebhookEventType[];
176+
}
177+
): Promise<Prediction>;
178+
};
167179
};
168180

169181
predictions: {

index.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class Replicate {
6868
list: models.versions.list.bind(this),
6969
get: models.versions.get.bind(this),
7070
},
71+
predictions: {
72+
create: models.predictions.create.bind(this),
73+
},
7174
};
7275

7376
this.predictions = {

index.test.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,38 @@ describe('Replicate client', () => {
668668
// Add more tests for error handling, edge cases, etc.
669669
});
670670

671+
describe('models.predictions.create', () => {
672+
test('Calls the correct API route with the correct payload', async () => {
673+
nock(BASE_URL)
674+
.post('/models/meta/llama-2-70b-chat/predictions')
675+
.reply(200, {
676+
id: "heat2o3bzn3ahtr6bjfftvbaci",
677+
model: "replicate/lifeboat-70b",
678+
version: "d-c6559c5791b50af57b69f4a73f8e021c",
679+
input: {
680+
prompt: "Please write a haiku about llamas."
681+
},
682+
logs: "",
683+
error: null,
684+
status: "starting",
685+
created_at: "2023-11-27T13:35:45.99397566Z",
686+
urls: {
687+
cancel: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel",
688+
get: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"
689+
}
690+
});
691+
const prediction = await client.models.predictions.create("meta", "llama-2-70b-chat", {
692+
input: {
693+
prompt: "Please write a haiku about llamas."
694+
},
695+
webhook: 'http://test.host/webhook',
696+
webhook_events_filter: [ 'output', 'completed' ],
697+
});
698+
expect(prediction.id).toBe('heat2o3bzn3ahtr6bjfftvbaci');
699+
});
700+
// Add more tests for error handling, edge cases, etc.
701+
});
702+
671703
describe('hardware.list', () => {
672704
test('Calls the correct API route', async () => {
673705
nock(BASE_URL)

lib/models.js

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,33 @@ async function createModel(model_owner, model_name, options) {
8383
return response.json();
8484
}
8585

86+
/**
87+
* Create a new prediction
88+
*
89+
* @param {string} model_owner - Required. The name of the user or organization that owns the model
90+
* @param {string} model_name - Required. The name of the model
91+
* @param {object} options
92+
* @param {object} options.input - Required. An object with the model inputs
93+
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
94+
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
95+
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
96+
* @returns {Promise<object>} Resolves with the created prediction
97+
*/
98+
async function createPrediction(model_owner, model_name, options) {
99+
const { stream, ...data } = options;
100+
101+
const response = await this.request(`/models/${model_owner}/${model_name}/predictions`, {
102+
method: 'POST',
103+
data: { ...data, stream },
104+
});
105+
106+
return response.json();
107+
}
108+
86109
module.exports = {
87110
get: getModel,
88111
list: listModels,
89112
create: createModel,
90113
versions: { list: listModelVersions, get: getModelVersion },
114+
predictions: { create: createPrediction },
91115
};

0 commit comments

Comments
 (0)