@@ -229,42 +229,73 @@ describe("Replicate client", () => {
229
229
} ) ;
230
230
231
231
describe ( "predictions.create" , ( ) => {
232
- test ( "Calls the correct API route with the correct payload" , async ( ) => {
233
- nock ( BASE_URL )
234
- . post ( "/predictions" )
235
- . reply ( 200 , {
236
- id : "ufawqhfynnddngldkgtslldrkq " ,
237
- model : "replicate/hello-world" ,
238
- version :
239
- "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
240
- urls : {
241
- get : "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq " ,
242
- cancel :
243
- "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" ,
244
- } ,
245
- created_at : "2022-04-26T22:13:06.224088Z" ,
246
- started_at : null ,
247
- completed_at : null ,
248
- status : "starting" ,
249
- input : {
250
- text : "Alice" ,
232
+ const predictionTestCases = [
233
+ {
234
+ description : "String input" ,
235
+ input : {
236
+ text : "Alice " ,
237
+ } ,
238
+ } ,
239
+ { } ,
240
+ {
241
+ description : "Array input " ,
242
+ input : {
243
+ text : [ "Alice" , "Bob" , "Charlie" ] ,
244
+ } ,
245
+ } ,
246
+ {
247
+ description : "Object input" ,
248
+ input : {
249
+ text : {
250
+ name : "Alice" ,
251
251
} ,
252
- output : null ,
253
- error : null ,
254
- logs : null ,
255
- metrics : { } ,
256
- } ) ;
257
- const prediction = await client . predictions . create ( {
252
+ } ,
253
+ } ,
254
+ ] . map ( ( testCase ) => ( {
255
+ ...testCase ,
256
+ expectedResponse : {
257
+ id : "ufawqhfynnddngldkgtslldrkq" ,
258
+ model : "replicate/hello-world" ,
258
259
version :
259
260
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
260
- input : {
261
- text : "Alice" ,
261
+ urls : {
262
+ get : "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq" ,
263
+ cancel :
264
+ "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" ,
262
265
} ,
263
- webhook : "http://test.host/webhook" ,
264
- webhook_events_filter : [ "output" , "completed" ] ,
265
- } ) ;
266
- expect ( prediction . id ) . toBe ( "ufawqhfynnddngldkgtslldrkq" ) ;
267
- } ) ;
266
+ input : testCase . input ,
267
+ created_at : "2022-04-26T22:13:06.224088Z" ,
268
+ started_at : null ,
269
+ completed_at : null ,
270
+ status : "starting" ,
271
+ } ,
272
+ } ) ) ;
273
+
274
+ test . each ( predictionTestCases ) (
275
+ "$description" ,
276
+ async ( { input, expectedResponse } ) => {
277
+ nock ( BASE_URL )
278
+ . post ( "/predictions" , {
279
+ version :
280
+ "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
281
+ input : input as Record < string , any > ,
282
+ webhook : "http://test.host/webhook" ,
283
+ webhook_events_filter : [ "output" , "completed" ] ,
284
+ } )
285
+ . reply ( 200 , expectedResponse ) ;
286
+
287
+ const response = await client . predictions . create ( {
288
+ version :
289
+ "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
290
+ input : input as Record < string , any > ,
291
+ webhook : "http://test.host/webhook" ,
292
+ webhook_events_filter : [ "output" , "completed" ] ,
293
+ } ) ;
294
+
295
+ expect ( response . input ) . toEqual ( input ) ;
296
+ expect ( response . status ) . toBe ( expectedResponse . status ) ;
297
+ }
298
+ ) ;
268
299
269
300
const fileTestCases = [
270
301
// Skip test case if File type is not available
0 commit comments