@@ -114,50 +114,6 @@ describe("Replicate client", () => {
114
114
const collections = await client . collections . list ( ) ;
115
115
expect ( collections . results . length ) . toBe ( 2 ) ;
116
116
} ) ;
117
-
118
- describe ( "predictions.create" , ( ) => {
119
- test ( "Handles array input correctly" , async ( ) => {
120
- const inputArray = [ "Alice" , "Bob" , "Charlie" ] ;
121
-
122
- nock ( BASE_URL )
123
- . post ( "/predictions" , {
124
- version :
125
- "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
126
- input : {
127
- text : inputArray ,
128
- } ,
129
- } )
130
- . reply ( 200 , {
131
- id : "ufawqhfynnddngldkgtslldrkq" ,
132
- model : "replicate/hello-world" ,
133
- version :
134
- "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
135
- urls : {
136
- get : "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq" ,
137
- cancel :
138
- "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" ,
139
- } ,
140
- created_at : "2022-04-26T22:13:06.224088Z" ,
141
- started_at : null ,
142
- completed_at : null ,
143
- status : "starting" ,
144
- input : {
145
- text : inputArray ,
146
- } ,
147
- } ) ;
148
-
149
- const response = await client . predictions . create ( {
150
- version :
151
- "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
152
- input : {
153
- text : inputArray ,
154
- } ,
155
- } ) ;
156
-
157
- expect ( response . input ) . toEqual ( { text : inputArray } ) ;
158
- expect ( response . status ) . toBe ( "starting" ) ;
159
- } ) ;
160
- } ) ;
161
117
// Add more tests for error handling, edge cases, etc.
162
118
} ) ;
163
119
@@ -229,42 +185,73 @@ describe("Replicate client", () => {
229
185
} ) ;
230
186
231
187
describe ( "predictions.create" , ( ) => {
232
- test ( "Calls the correct API route with the correct payload" , async ( ) => {
233
- nock ( BASE_URL )
234
- . post ( "/predictions" )
235
- . reply ( 200 , {
236
- id : "ufawqhfynnddngldkgtslldrkq " ,
237
- model : "replicate/hello-world" ,
238
- version :
239
- "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
240
- urls : {
241
- get : "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq " ,
242
- cancel :
243
- "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" ,
244
- } ,
245
- created_at : "2022-04-26T22:13:06.224088Z" ,
246
- started_at : null ,
247
- completed_at : null ,
248
- status : "starting" ,
249
- input : {
250
- text : "Alice" ,
188
+ const predictionTestCases = [
189
+ {
190
+ description : "String input" ,
191
+ input : {
192
+ text : "Alice " ,
193
+ } ,
194
+ } ,
195
+ { } ,
196
+ {
197
+ description : "Array input " ,
198
+ input : {
199
+ text : [ "Alice" , "Bob" , "Charlie" ] ,
200
+ } ,
201
+ } ,
202
+ {
203
+ description : "Object input" ,
204
+ input : {
205
+ text : {
206
+ name : "Alice" ,
251
207
} ,
252
- output : null ,
253
- error : null ,
254
- logs : null ,
255
- metrics : { } ,
256
- } ) ;
257
- const prediction = await client . predictions . create ( {
208
+ } ,
209
+ } ,
210
+ ] . map ( ( testCase ) => ( {
211
+ ...testCase ,
212
+ expectedResponse : {
213
+ id : "ufawqhfynnddngldkgtslldrkq" ,
214
+ model : "replicate/hello-world" ,
258
215
version :
259
216
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
260
- input : {
261
- text : "Alice" ,
217
+ urls : {
218
+ get : "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq" ,
219
+ cancel :
220
+ "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" ,
262
221
} ,
263
- webhook : "http://test.host/webhook" ,
264
- webhook_events_filter : [ "output" , "completed" ] ,
265
- } ) ;
266
- expect ( prediction . id ) . toBe ( "ufawqhfynnddngldkgtslldrkq" ) ;
267
- } ) ;
222
+ input : testCase . input ,
223
+ created_at : "2022-04-26T22:13:06.224088Z" ,
224
+ started_at : null ,
225
+ completed_at : null ,
226
+ status : "starting" ,
227
+ } ,
228
+ } ) ) ;
229
+
230
+ test . each ( predictionTestCases ) (
231
+ "$description" ,
232
+ async ( { input, expectedResponse } ) => {
233
+ nock ( BASE_URL )
234
+ . post ( "/predictions" , {
235
+ version :
236
+ "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
237
+ input : input as Record < string , any > ,
238
+ webhook : "http://test.host/webhook" ,
239
+ webhook_events_filter : [ "output" , "completed" ] ,
240
+ } )
241
+ . reply ( 200 , expectedResponse ) ;
242
+
243
+ const response = await client . predictions . create ( {
244
+ version :
245
+ "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
246
+ input : input as Record < string , any > ,
247
+ webhook : "http://test.host/webhook" ,
248
+ webhook_events_filter : [ "output" , "completed" ] ,
249
+ } ) ;
250
+
251
+ expect ( response . input ) . toEqual ( input ) ;
252
+ expect ( response . status ) . toBe ( expectedResponse . status ) ;
253
+ }
254
+ ) ;
268
255
269
256
const fileTestCases = [
270
257
// Skip test case if File type is not available
0 commit comments