Skip to content

Commit 1c98382

Browse files
aronmattt
andauthored
Add support for new deployment endpoints (#223)
* Add support for new deployment endpoints * Align definition for Deployment type to OpenAPI specification * Add test coverage for deployment endpoints --------- Co-authored-by: Mattt Zmuda <mattt@replicate.com>
1 parent f280b0b commit 1c98382

File tree

5 files changed

+234
-3
lines changed

5 files changed

+234
-3
lines changed

index.d.ts

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ declare module "replicate" {
3333
created_by: Account;
3434
configuration: {
3535
hardware: string;
36-
min_instances: number;
37-
max_instances: number;
36+
scaling: {
37+
min_instances: number;
38+
max_instances: number;
39+
};
3840
};
3941
};
4042
}
@@ -194,6 +196,30 @@ declare module "replicate" {
194196
deployment_owner: string,
195197
deployment_name: string
196198
): Promise<Deployment>;
199+
create(deployment_config: {
200+
name: string;
201+
model: string;
202+
version: string;
203+
hardware: string;
204+
min_instances: number;
205+
max_instances: number;
206+
}): Promise<Deployment>;
207+
update(
208+
deployment_owner: string,
209+
deployment_name: string,
210+
deployment_config: {
211+
version?: string;
212+
hardware?: string;
213+
min_instances?: number;
214+
max_instances?: number;
215+
} & (
216+
| { version: string }
217+
| { hardware: string }
218+
| { min_instances: number }
219+
| { max_instances: number }
220+
)
221+
): Promise<Deployment>;
222+
list(): Promise<Page<Deployment>>;
197223
};
198224

199225
hardware: {

index.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ class Replicate {
6767

6868
this.deployments = {
6969
get: deployments.get.bind(this),
70+
create: deployments.create.bind(this),
71+
update: deployments.update.bind(this),
72+
list: deployments.list.bind(this),
7073
predictions: {
7174
create: deployments.predictions.create.bind(this),
7275
},

index.test.ts

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,135 @@ describe("Replicate client", () => {
811811
// Add more tests for error handling, edge cases, etc.
812812
});
813813

814+
describe("deployments.create", () => {
815+
test("Calls the correct API route with the correct payload", async () => {
816+
nock(BASE_URL)
817+
.post("/deployments")
818+
.reply(200, {
819+
owner: "acme",
820+
name: "my-app-image-generator",
821+
current_release: {
822+
number: 1,
823+
model: "stability-ai/sdxl",
824+
version:
825+
"da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
826+
created_at: "2024-02-15T16:32:57.018467Z",
827+
created_by: {
828+
type: "organization",
829+
username: "acme",
830+
name: "Acme Corp, Inc.",
831+
github_url: "https://github.com/acme",
832+
},
833+
configuration: {
834+
hardware: "gpu-t4",
835+
scaling: {
836+
min_instances: 1,
837+
max_instances: 5,
838+
},
839+
},
840+
},
841+
});
842+
843+
const deployment = await client.deployments.create({
844+
name: "my-app-image-generator",
845+
model: "stability-ai/sdxl",
846+
version:
847+
"da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
848+
hardware: "gpu-t4",
849+
min_instances: 1,
850+
max_instances: 5,
851+
});
852+
853+
expect(deployment.owner).toBe("acme");
854+
expect(deployment.name).toBe("my-app-image-generator");
855+
expect(deployment.current_release.model).toBe("stability-ai/sdxl");
856+
});
857+
// Add more tests for error handling, edge cases, etc.
858+
});
859+
860+
describe("deployments.update", () => {
861+
test("Calls the correct API route with the correct payload", async () => {
862+
nock(BASE_URL)
863+
.patch("/deployments/acme/my-app-image-generator")
864+
.reply(200, {
865+
owner: "acme",
866+
name: "my-app-image-generator",
867+
current_release: {
868+
number: 2,
869+
model: "stability-ai/sdxl",
870+
version:
871+
"632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
872+
created_at: "2024-02-16T08:14:22.345678Z",
873+
created_by: {
874+
type: "organization",
875+
username: "acme",
876+
name: "Acme Corp, Inc.",
877+
github_url: "https://github.com/acme",
878+
},
879+
configuration: {
880+
hardware: "gpu-a40-large",
881+
scaling: {
882+
min_instances: 3,
883+
max_instances: 10,
884+
},
885+
},
886+
},
887+
});
888+
889+
const deployment = await client.deployments.update(
890+
"acme",
891+
"my-app-image-generator",
892+
{
893+
version:
894+
"632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
895+
hardware: "gpu-a40-large",
896+
min_instances: 3,
897+
max_instances: 10,
898+
}
899+
);
900+
901+
expect(deployment.current_release.number).toBe(2);
902+
expect(deployment.current_release.version).toBe(
903+
"632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532"
904+
);
905+
expect(deployment.current_release.configuration.hardware).toBe(
906+
"gpu-a40-large"
907+
);
908+
expect(
909+
deployment.current_release.configuration.scaling?.min_instances
910+
).toBe(3);
911+
expect(
912+
deployment.current_release.configuration.scaling?.max_instances
913+
).toBe(10);
914+
});
915+
// Add more tests for error handling, edge cases, etc.
916+
});
917+
918+
describe("deployments.list", () => {
919+
test("Calls the correct API route", async () => {
920+
nock(BASE_URL)
921+
.get("/deployments")
922+
.reply(200, {
923+
next: null,
924+
previous: null,
925+
results: [
926+
{
927+
owner: "acme",
928+
name: "my-app-image-generator",
929+
current_release: {
930+
// ...
931+
},
932+
},
933+
// ...
934+
],
935+
});
936+
937+
const deployments = await client.deployments.list();
938+
expect(deployments.results.length).toBe(1)
939+
});
940+
// Add more tests for pagination, error handling, edge cases, etc.
941+
});
942+
814943
describe("predictions.create with model", () => {
815944
test("Calls the correct API route with the correct payload", async () => {
816945
nock(BASE_URL)

lib/deployments.js

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,82 @@ async function getDeployment(deployment_owner, deployment_name) {
5757
return response.json();
5858
}
5959

60+
/**
61+
* @typedef {Object} DeploymentCreateRequest - Request body for `deployments.create`
62+
* @property {string} name - the name of the deployment
63+
* @property {string} model - the full name of the model that you want to deploy e.g. stability-ai/sdxl
64+
* @property {string} version - the 64-character string ID of the model version that you want to deploy
65+
* @property {string} hardware - the SKU for the hardware used to run the model, via `replicate.hardware.list()`
66+
* @property {number} min_instances - the minimum number of instances for scaling
67+
* @property {number} max_instances - the maximum number of instances for scaling
68+
*/
69+
70+
/**
71+
* Create a deployment
72+
*
73+
* @param {DeploymentCreateRequest} config - Required. The deployment config.
74+
* @returns {Promise<object>} Resolves with the deployment data
75+
*/
76+
async function createDeployment(deployment_config) {
77+
const response = await this.request("/deployments", {
78+
method: "POST",
79+
data: deployment_config,
80+
});
81+
82+
return response.json();
83+
}
84+
85+
/**
86+
* @typedef {Object} DeploymentUpdateRequest - Request body for `deployments.update`
87+
* @property {string} version - the 64-character string ID of the model version that you want to deploy
88+
* @property {string} hardware - the SKU for the hardware used to run the model, via `replicate.hardware.list()`
89+
* @property {number} min_instances - the minimum number of instances for scaling
90+
* @property {number} max_instances - the maximum number of instances for scaling
91+
*/
92+
93+
/**
94+
* Update an existing deployment
95+
*
96+
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
97+
* @param {string} deployment_name - Required. The name of the deployment
98+
* @param {DeploymentUpdateRequest} deployment_config - Required. The deployment changes.
99+
* @returns {Promise<object>} Resolves with the deployment data
100+
*/
101+
async function updateDeployment(
102+
deployment_owner,
103+
deployment_name,
104+
deployment_config
105+
) {
106+
const response = await this.request(
107+
`/deployments/${deployment_owner}/${deployment_name}`,
108+
{
109+
method: "PATCH",
110+
data: deployment_config,
111+
}
112+
);
113+
114+
return response.json();
115+
}
116+
117+
/**
118+
* List all deployments
119+
*
120+
* @returns {Promise<object>} - Resolves with a page of deployments
121+
*/
122+
async function listDeployments() {
123+
const response = await this.request("/deployments", {
124+
method: "GET",
125+
});
126+
127+
return response.json();
128+
}
129+
60130
module.exports = {
61131
predictions: {
62132
create: createPrediction,
63133
},
64134
get: getDeployment,
135+
create: createDeployment,
136+
update: updateDeployment,
137+
list: listDeployments,
65138
};

tsconfig.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
"strict": true,
66
"allowJs": true
77
},
8-
"exclude": ["integration/**", "**/node_modules"]
8+
"exclude": ["**/node_modules", "integration"]
99
}

0 commit comments

Comments
 (0)