@@ -31,6 +31,27 @@ import { runHandler } from "../../helper";
31
31
import { FULL_ENDPOINT , MINIMAL_V2_ENDPOINT , FULL_OPTIONS , FULL_TRIGGER } from "./fixtures" ;
32
32
import { onInit } from "../../../src/v2/core" ;
33
33
import { Handler } from "express" ;
34
+ import { CallableFlow , z } from "genkit" ;
35
+
36
+ function request ( args : { data ?: any , auth ?: Record < string , string > , headers ?: Record < string , string > , method ?: MockRequest [ "method" ] } ) : any {
37
+ let headers : Record < string , string > = { }
38
+ if ( args . method !== "POST" ) {
39
+ headers [ "content-type" ] = "application/json" ;
40
+ }
41
+ headers = {
42
+ ...headers ,
43
+ ...args . headers ,
44
+ } ;
45
+ if ( args . auth ) {
46
+ headers [ "authorization" ] = `bearer ignored.${ Buffer . from (
47
+ JSON . stringify ( args . auth ) ,
48
+ "utf-8"
49
+ ) . toString ( "base64" ) } .ignored`;
50
+ }
51
+ const ret = new MockRequest ( { data : args . data || { } } , headers ) ;
52
+ ret . method = args . method || "POST" ;
53
+ return ret ;
54
+ }
34
55
35
56
describe ( "onRequest" , ( ) => {
36
57
beforeEach ( ( ) => {
@@ -171,18 +192,8 @@ describe("onRequest", () => {
171
192
res . send ( "Works" ) ;
172
193
} ) ;
173
194
174
- const req = new MockRequest (
175
- {
176
- data : { } ,
177
- } ,
178
- {
179
- "content-type" : "application/json" ,
180
- origin : "example.com" ,
181
- }
182
- ) ;
183
- req . method = "POST" ;
184
-
185
- const resp = await runHandler ( func , req as any ) ;
195
+ const req = request ( { headers : { origin : "example.com" } } ) ;
196
+ const resp = await runHandler ( func , req ) ;
186
197
expect ( resp . body ) . to . equal ( "Works" ) ;
187
198
} ) ;
188
199
@@ -191,17 +202,14 @@ describe("onRequest", () => {
191
202
throw new Error ( "Should not reach here for OPTIONS preflight" ) ;
192
203
} ) ;
193
204
194
- const req = new MockRequest (
195
- {
196
- data : { } ,
197
- } ,
198
- {
205
+ const req = request ( {
206
+ headers : {
199
207
"Access-Control-Request-Method" : "POST" ,
200
208
"Access-Control-Request-Headers" : "origin" ,
201
209
origin : "example.com" ,
202
- }
203
- ) ;
204
- req . method = "OPTIONS" ;
210
+ } ,
211
+ method : "OPTIONS" ,
212
+ } )
205
213
206
214
const resp = await runHandler ( func , req as any ) ;
207
215
expect ( resp . status ) . to . equal ( 204 ) ;
@@ -221,17 +229,14 @@ describe("onRequest", () => {
221
229
throw new Error ( "Should not reach here for OPTIONS preflight" ) ;
222
230
} ) ;
223
231
224
- const req = new MockRequest (
225
- {
226
- data : { } ,
227
- } ,
228
- {
232
+ const req = request ( {
233
+ headers : {
229
234
"Access-Control-Request-Method" : "POST" ,
230
235
"Access-Control-Request-Headers" : "origin" ,
231
236
origin : "localhost" ,
232
- }
233
- ) ;
234
- req . method = "OPTIONS" ;
237
+ } ,
238
+ method : "OPTIONS" ,
239
+ } ) ;
235
240
236
241
const resp = await runHandler ( func , req as any ) ;
237
242
expect ( resp . status ) . to . equal ( 204 ) ;
@@ -253,17 +258,14 @@ describe("onRequest", () => {
253
258
res . status ( 200 ) . send ( "Good" ) ;
254
259
} ) ;
255
260
256
- const req = new MockRequest (
257
- {
258
- data : { } ,
259
- } ,
260
- {
261
+ const req = request ( {
262
+ headers : {
261
263
"Access-Control-Request-Method" : "POST" ,
262
264
"Access-Control-Request-Headers" : "origin" ,
263
265
origin : "example.com" ,
264
- }
265
- ) ;
266
- req . method = "OPTIONS" ;
266
+ } ,
267
+ method : "OPTIONS" ,
268
+ } ) ;
267
269
268
270
const resp = await runHandler ( func , req as any ) ;
269
271
expect ( resp . status ) . to . equal ( 200 ) ;
@@ -277,17 +279,14 @@ describe("onRequest", () => {
277
279
const func = https . onRequest ( ( req , res ) => {
278
280
res . status ( 200 ) . send ( "Good" ) ;
279
281
} ) ;
280
- const req = new MockRequest (
281
- {
282
- data : { } ,
283
- } ,
284
- {
282
+ const req = request ( {
283
+ headers : {
285
284
"Access-Control-Request-Method" : "POST" ,
286
285
"Access-Control-Request-Headers" : "origin" ,
287
286
origin : "example.com" ,
288
- }
289
- ) ;
290
- req . method = "OPTIONS" ;
287
+ } ,
288
+ method : "OPTIONS" ,
289
+ } ) ;
291
290
let hello ;
292
291
onInit ( ( ) => ( hello = "world" ) ) ;
293
292
expect ( hello ) . to . be . undefined ;
@@ -406,16 +405,7 @@ describe("onCall", () => {
406
405
it ( "should be an express handler" , async ( ) => {
407
406
const func = https . onCall ( ( ) => 42 ) ;
408
407
409
- const req = new MockRequest (
410
- {
411
- data : { } ,
412
- } ,
413
- {
414
- "content-type" : "application/json" ,
415
- origin : "example.com" ,
416
- }
417
- ) ;
418
- req . method = "POST" ;
408
+ const req = request ( { headers : { origin : "example.com" } } ) ;
419
409
420
410
const resp = await runHandler ( func , req as any ) ;
421
411
expect ( resp . body ) . to . deep . equal ( JSON . stringify ( { result : 42 } ) ) ;
@@ -426,17 +416,14 @@ describe("onCall", () => {
426
416
throw new Error ( "Should not reach here for OPTIONS preflight" ) ;
427
417
} ) ;
428
418
429
- const req = new MockRequest (
430
- {
431
- data : { } ,
432
- } ,
433
- {
419
+ const req = request ( {
420
+ headers : {
434
421
"Access-Control-Request-Method" : "POST" ,
435
422
"Access-Control-Request-Headers" : "origin" ,
436
423
origin : "example.com" ,
437
- }
438
- ) ;
439
- req . method = "OPTIONS" ;
424
+ } ,
425
+ method : "OPTIONS" ,
426
+ } ) ;
440
427
441
428
const resp = await runHandler ( func , req as any ) ;
442
429
expect ( resp . status ) . to . equal ( 204 ) ;
@@ -455,17 +442,14 @@ describe("onCall", () => {
455
442
const func = https . onCall ( { cors : "example.com" } , ( ) => {
456
443
throw new Error ( "Should not reach here for OPTIONS preflight" ) ;
457
444
} ) ;
458
- const req = new MockRequest (
459
- {
460
- data : { } ,
461
- } ,
462
- {
445
+ const req = request ( {
446
+ headers : {
463
447
"Access-Control-Request-Method" : "POST" ,
464
448
"Access-Control-Request-Headers" : "origin" ,
465
449
origin : "localhost" ,
466
- }
467
- ) ;
468
- req . method = "OPTIONS" ;
450
+ } ,
451
+ method : "OPTIONS" ,
452
+ } ) ;
469
453
470
454
const response = await runHandler ( func , req as any ) ;
471
455
@@ -483,18 +467,8 @@ describe("onCall", () => {
483
467
484
468
it ( "adds CORS headers" , async ( ) => {
485
469
const func = https . onCall ( ( ) => 42 ) ;
486
- const req = new MockRequest (
487
- {
488
- data : { } ,
489
- } ,
490
- {
491
- "content-type" : "application/json" ,
492
- origin : "example.com" ,
493
- }
494
- ) ;
495
- req . method = "POST" ;
496
-
497
- const response = await runHandler ( func , req as any ) ;
470
+ const req = request ( { headers : { origin : "example.com" } } )
471
+ const response = await runHandler ( func , req ) ;
498
472
499
473
expect ( response . status ) . to . equal ( 200 ) ;
500
474
expect ( response . body ) . to . be . deep . equal ( JSON . stringify ( { result : 42 } ) ) ;
@@ -515,17 +489,7 @@ describe("onCall", () => {
515
489
it ( "calls init function" , async ( ) => {
516
490
const func = https . onCall ( ( ) => 42 ) ;
517
491
518
- const req = new MockRequest (
519
- {
520
- data : { } ,
521
- } ,
522
- {
523
- "content-type" : "application/json" ,
524
- origin : "example.com" ,
525
- }
526
- ) ;
527
- req . method = "POST" ;
528
-
492
+ const req = request ( { headers : { origin : "example.com" } } ) ;
529
493
let hello ;
530
494
onInit ( ( ) => ( hello = "world" ) ) ;
531
495
expect ( hello ) . to . be . undefined ;
@@ -534,20 +498,6 @@ describe("onCall", () => {
534
498
} ) ;
535
499
536
500
describe ( "authPolicy" , ( ) => {
537
- function req ( data : any , auth ?: Record < string , string > ) : any {
538
- const headers = {
539
- "content-type" : "application/json" ,
540
- } ;
541
- if ( auth ) {
542
- headers [ "authorization" ] = `bearer ignored.${ Buffer . from (
543
- JSON . stringify ( auth ) ,
544
- "utf-8"
545
- ) . toString ( "base64" ) } .ignored`;
546
- }
547
- const ret = new MockRequest ( { data } , headers ) ;
548
- ret . method = "POST" ;
549
- return ret ;
550
- }
551
501
552
502
before ( ( ) => {
553
503
sinon . stub ( debug , "isDebugFeatureEnabled" ) . withArgs ( "skipTokenVerification" ) . returns ( true ) ;
@@ -565,10 +515,10 @@ describe("onCall", () => {
565
515
( ) => 42
566
516
) ;
567
517
568
- const authResp = await runHandler ( func , req ( null , { sub : "inlined" } ) ) ;
518
+ const authResp = await runHandler ( func , request ( { auth : { sub : "inlined" } } ) ) ;
569
519
expect ( authResp . status ) . to . equal ( 200 ) ;
570
520
571
- const anonResp = await runHandler ( func , req ( null , null ) ) ;
521
+ const anonResp = await runHandler ( func , request ( { } ) ) ;
572
522
expect ( anonResp . status ) . to . equal ( 403 ) ;
573
523
} ) ;
574
524
@@ -586,18 +536,18 @@ describe("onCall", () => {
586
536
( ) => "HHGTG"
587
537
) ;
588
538
589
- const cases : Array < { fn : Handler ; auth : null | Record < string , string > ; status : number } > = [
539
+ const cases : Array < { fn : Handler ; auth ?: Record < string , string > ; status : number } > = [
590
540
{ fn : anyValue , auth : { meaning : "42" } , status : 200 } ,
591
541
{ fn : anyValue , auth : { meaning : "43" } , status : 200 } ,
592
542
{ fn : anyValue , auth : { order : "66" } , status : 403 } ,
593
- { fn : anyValue , auth : null , status : 403 } ,
543
+ { fn : anyValue , status : 403 } ,
594
544
{ fn : specificValue , auth : { meaning : "42" } , status : 200 } ,
595
545
{ fn : specificValue , auth : { meaning : "43" } , status : 403 } ,
596
546
{ fn : specificValue , auth : { order : "66" } , status : 403 } ,
597
- { fn : specificValue , auth : null , status : 403 } ,
547
+ { fn : specificValue , status : 403 } ,
598
548
] ;
599
549
for ( const test of cases ) {
600
- const resp = await runHandler ( test . fn , req ( null , test . auth ) ) ;
550
+ const resp = await runHandler ( test . fn , request ( { auth : test . auth } ) ) ;
601
551
expect ( resp . status ) . to . equal ( test . status ) ;
602
552
}
603
553
} ) ;
@@ -610,10 +560,50 @@ describe("onCall", () => {
610
560
( req ) => req . data / 2
611
561
) ;
612
562
613
- const authorized = await runHandler ( divTwo , req ( 2 ) ) ;
563
+ const authorized = await runHandler ( divTwo , request ( { data : 2 } ) ) ;
614
564
expect ( authorized . status ) . to . equal ( 200 ) ;
615
- const accessDenied = await runHandler ( divTwo , req ( 1 ) ) ;
565
+ const accessDenied = await runHandler ( divTwo , request ( { data : 1 } ) ) ;
616
566
expect ( accessDenied . status ) . to . equal ( 403 ) ;
617
567
} ) ;
618
568
} ) ;
619
569
} ) ;
570
+
571
+ describe ( "onCallGenkit" , ( ) => {
572
+ interface Streamable {
573
+ stream : sinon . SinonStub ;
574
+ }
575
+ it ( "calls with JSON requests" , async ( ) => {
576
+ const flow = sinon . stub ( ) ;
577
+ flow . withArgs ( "answer" ) . returns ( 42 ) ;
578
+ ( flow as any as Streamable ) . stream = sinon . stub ( ) ;
579
+ ( flow as any as Streamable ) . stream . onCall ( 0 ) . throws ( "Unexpected stream" ) ;
580
+ ( flow as any ) . flow = { name : "flows/test" } ;
581
+
582
+ const f = https . onCallGenkit ( flow as any as CallableFlow < z . ZodString , z . ZodNumber > )
583
+
584
+ const req = request ( { data : "answer" } ) ;
585
+ const res = await runHandler ( f , req ) ;
586
+ expect ( JSON . parse ( res . body ) ) . to . deep . equal ( { "result" :42 } ) ;
587
+ } ) ;
588
+
589
+
590
+ it ( "Streams with SSE requests" , async ( ) => {
591
+ const flow = sinon . stub ( ) ;
592
+ flow . onFirstCall ( ) . throws ( ) ;
593
+ ( flow as any as Streamable ) . stream = sinon . stub ( ) ;
594
+ ( flow as any as Streamable ) . stream . withArgs ( "answer" ) . returns ( {
595
+ stream : ( async function * ( ) {
596
+ yield 1 ;
597
+ yield 2 ;
598
+ } ) ( ) ,
599
+ output : Promise . resolve ( 42 ) ,
600
+ } ) ;
601
+ ( flow as any ) . flow = { name : "flows/test" } ;
602
+
603
+ const f = https . onCallGenkit ( flow as any as CallableFlow < z . ZodString , z . ZodNumber , z . ZodNumber > )
604
+
605
+ const req = request ( { data : "answer" , headers : { accept : "text/event-stream" } } ) ;
606
+ const res = await runHandler ( f , req ) ;
607
+ expect ( res . body ) . to . equal ( [ 'data: {"message":1}' , 'data: {"message":2}' , 'data: {"result":42}' , '' ] . join ( "\n" ) ) ;
608
+ } ) ;
609
+ } ) ;
0 commit comments