This repository has been archived by the owner on Jan 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
/
HostedModel.ts
205 lines (193 loc) · 5.98 KB
/
HostedModel.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
/**
* A module for interfacing with Runway Hosted Models.
*/
import { AxiosRequestConfig, AxiosResponse } from 'axios';
import {
InvalidArgumentError,
InvlaidURLError,
ModelError,
NetworkError,
NotFoundError,
PermissionDeniedError,
UnexpectedError,
} from './HTTPErrors';
import { delay, requestWithRetry } from './utils';
/**
* The config object used to construct a new [[HostedModel]].
*
* ```typescript
* // For private models...
* {
* url: 'https://my-model.hosted-models.runwayml.cloud/v1',
* token: 'kN5PB/GGgL5IcXxViegzUA==', // this is a fake token, use your own
* }
*
* // For public models...
* {
* url: 'https://my-model.hosted-models.runwayml.cloud/v1',
* }
* ```
*/
export interface HostedModelConfig {
/**
* The full URL of your hosted model in the format `https://my-model.hosted-models.runwayml.cloud/v1`
*/
url: string;
/**
* The secret token associated with this model. Only required if this model is private.
*/
token?: string;
}
/**
* A class representing a Runway Hosted Model. This is the main interface provided by
* this package. Exposes two main methods for interfacing with a model.
*
* - `info()`
* - `query(input)`
*
* Exposes two helper methods for checking the "awake" status of a hosted model.
*
* - `isAwake()`
* - `waitUntilAwake()`
*/
export class HostedModel {
private url: string;
private token: string;
private headers: { [name: string]: string };
private responseCodesToRetry: number[];
/**
* ```typescript
* const model = new HostedModel({
* url: 'https://my-model.hosted-models.runwayml.cloud/v1',
* token: 'my-secret-token', # token is only required for private models
* })
* ```
*/
constructor(config: HostedModelConfig) {
if (typeof config !== 'object') throw new InvalidArgumentError('config');
this.url = config.url;
this.token = config.token || null;
this.headers = {
Accept: 'application/json',
'Content-Type': 'application/json',
};
if (this.token) this.headers['Authorization'] = `Bearer ${this.token}`;
this.responseCodesToRetry = [502, 429];
if (!this.isValidV1URL(this.url)) throw new InvlaidURLError();
// Wake up the model during construction because it will probably be used soon
this.root()
.then(result => result)
.catch(err => err);
}
/**
* Return info about the input/output spec provided by the model. Makes a GET request
* to the /v1/info route of a hosted model under the hood.
*/
async info() {
return this.requestHostedModel({
url: `${this.url}/info`,
method: 'GET',
headers: this.headers,
});
}
/**
* Run the model on your input and produce an output. This is how you "run" the model.
* @param input An object containing input parameters to be sent to the model.
* Use the [[info]] method to get the correct format for this object, as each model
* expects different inputs.
*/
async query(input: any) {
if (typeof input !== 'object') throw new InvalidArgumentError('input');
return this.requestHostedModel({
url: `${this.url}/query`,
method: 'POST',
headers: this.headers,
data: input,
});
}
/**
* Returns `true` if this model is awake, `false` if it is still waking up.
* See Awake, Awakening, and Awake in the
* [Hosted Models docs](https://learn.runwayml.com/#/how-to/hosted-models?id=asleep-awakening-and-awake-states).
*/
async isAwake() {
const root = await this.root();
return root.status === 'running';
}
/**
* Returns a promise that will resolve once the model is awake. This method is never
* required, as [[info]] and [[query]] will always return results eventually, but it can be
* useful for managing UI if you want to postpone making [[info]] and [[query]] requests
* until you know that they will resolve more quickly.
*
* ```typescript
* // This is pseudo code
* const model = new HostedModel({
* url: 'https://my-model.hosted-models.runwayml.cloud/v1',
* token: 'my-secret-token', # token is only required for private models
* })
* // Enter some loading state in the UI.
* loading(true)
* await model.waitUntilAwake() // This method is never required, but can sometimes be useful
* loading(false)
*
* while (true) {
* const input = getSomeInput()
* const output = await model.query(input)
* doSomething(output)
* }
* ```
*
* @param pollIntervalMillis [[waitUntilAwake]] The rate that this function will poll
* the hosted model endpoint to check if it is awake yet.
*/
async waitUntilAwake(pollIntervalMillis = 1000): Promise<void> {
return new Promise((resolve, reject) => {
(async () => {
try {
while (true) {
const awake = await this.isAwake();
if (awake) {
resolve();
return;
}
await delay(pollIntervalMillis);
}
} catch (err) {
reject(err);
}
})();
});
}
private async root() {
return this.requestHostedModel({
url: `${this.url}/`,
method: 'GET',
headers: this.headers,
});
}
private async requestHostedModel(config: AxiosRequestConfig) {
let result: AxiosResponse;
try {
result = await requestWithRetry(this.responseCodesToRetry, config);
} catch (err) {
throw new NetworkError(err.code);
}
if (this.isHostedModelResponseError(result)) {
if (result.status === 401) throw new PermissionDeniedError();
else if (result.status === 404) throw new NotFoundError();
else if (result.status === 500) throw new ModelError();
throw new UnexpectedError();
}
return result.data;
}
private isHostedModelResponseError(response: AxiosResponse) {
return (
!response.headers['content-type'].includes('application/json') ||
!(response.status >= 200 && response.status < 300)
);
}
private isValidV1URL(url: string) {
return /^https{0,1}:\/\/.+\.runwayml\.cloud\/v1/.test(url);
}
}