@@ -31,6 +31,27 @@ import { runHandler } from "../../helper";
3131import { FULL_ENDPOINT , MINIMAL_V2_ENDPOINT , FULL_OPTIONS , FULL_TRIGGER } from "./fixtures" ;
3232import { onInit } from "../../../src/v2/core" ;
3333import { Handler } from "express" ;
34+ import { CallableFlow , z } from "genkit" ;
35+
36+ function request ( args : { data ?: any , auth ?: Record < string , string > , headers ?: Record < string , string > , method ?: MockRequest [ "method" ] } ) : any {
37+ let headers : Record < string , string > = { }
38+ if ( args . method !== "POST" ) {
39+ headers [ "content-type" ] = "application/json" ;
40+ }
41+ headers = {
42+ ...headers ,
43+ ...args . headers ,
44+ } ;
45+ if ( args . auth ) {
46+ headers [ "authorization" ] = `bearer ignored.${ Buffer . from (
47+ JSON . stringify ( args . auth ) ,
48+ "utf-8"
49+ ) . toString ( "base64" ) } .ignored`;
50+ }
51+ const ret = new MockRequest ( { data : args . data || { } } , headers ) ;
52+ ret . method = args . method || "POST" ;
53+ return ret ;
54+ }
3455
3556describe ( "onRequest" , ( ) => {
3657 beforeEach ( ( ) => {
@@ -171,18 +192,8 @@ describe("onRequest", () => {
171192 res . send ( "Works" ) ;
172193 } ) ;
173194
174- const req = new MockRequest (
175- {
176- data : { } ,
177- } ,
178- {
179- "content-type" : "application/json" ,
180- origin : "example.com" ,
181- }
182- ) ;
183- req . method = "POST" ;
184-
185- const resp = await runHandler ( func , req as any ) ;
195+ const req = request ( { headers : { origin : "example.com" } } ) ;
196+ const resp = await runHandler ( func , req ) ;
186197 expect ( resp . body ) . to . equal ( "Works" ) ;
187198 } ) ;
188199
@@ -191,17 +202,14 @@ describe("onRequest", () => {
191202 throw new Error ( "Should not reach here for OPTIONS preflight" ) ;
192203 } ) ;
193204
194- const req = new MockRequest (
195- {
196- data : { } ,
197- } ,
198- {
205+ const req = request ( {
206+ headers : {
199207 "Access-Control-Request-Method" : "POST" ,
200208 "Access-Control-Request-Headers" : "origin" ,
201209 origin : "example.com" ,
202- }
203- ) ;
204- req . method = "OPTIONS" ;
210+ } ,
211+ method : "OPTIONS" ,
212+ } )
205213
206214 const resp = await runHandler ( func , req as any ) ;
207215 expect ( resp . status ) . to . equal ( 204 ) ;
@@ -221,17 +229,14 @@ describe("onRequest", () => {
221229 throw new Error ( "Should not reach here for OPTIONS preflight" ) ;
222230 } ) ;
223231
224- const req = new MockRequest (
225- {
226- data : { } ,
227- } ,
228- {
232+ const req = request ( {
233+ headers : {
229234 "Access-Control-Request-Method" : "POST" ,
230235 "Access-Control-Request-Headers" : "origin" ,
231236 origin : "localhost" ,
232- }
233- ) ;
234- req . method = "OPTIONS" ;
237+ } ,
238+ method : "OPTIONS" ,
239+ } ) ;
235240
236241 const resp = await runHandler ( func , req as any ) ;
237242 expect ( resp . status ) . to . equal ( 204 ) ;
@@ -253,17 +258,14 @@ describe("onRequest", () => {
253258 res . status ( 200 ) . send ( "Good" ) ;
254259 } ) ;
255260
256- const req = new MockRequest (
257- {
258- data : { } ,
259- } ,
260- {
261+ const req = request ( {
262+ headers : {
261263 "Access-Control-Request-Method" : "POST" ,
262264 "Access-Control-Request-Headers" : "origin" ,
263265 origin : "example.com" ,
264- }
265- ) ;
266- req . method = "OPTIONS" ;
266+ } ,
267+ method : "OPTIONS" ,
268+ } ) ;
267269
268270 const resp = await runHandler ( func , req as any ) ;
269271 expect ( resp . status ) . to . equal ( 200 ) ;
@@ -277,17 +279,14 @@ describe("onRequest", () => {
277279 const func = https . onRequest ( ( req , res ) => {
278280 res . status ( 200 ) . send ( "Good" ) ;
279281 } ) ;
280- const req = new MockRequest (
281- {
282- data : { } ,
283- } ,
284- {
282+ const req = request ( {
283+ headers : {
285284 "Access-Control-Request-Method" : "POST" ,
286285 "Access-Control-Request-Headers" : "origin" ,
287286 origin : "example.com" ,
288- }
289- ) ;
290- req . method = "OPTIONS" ;
287+ } ,
288+ method : "OPTIONS" ,
289+ } ) ;
291290 let hello ;
292291 onInit ( ( ) => ( hello = "world" ) ) ;
293292 expect ( hello ) . to . be . undefined ;
@@ -406,16 +405,7 @@ describe("onCall", () => {
406405 it ( "should be an express handler" , async ( ) => {
407406 const func = https . onCall ( ( ) => 42 ) ;
408407
409- const req = new MockRequest (
410- {
411- data : { } ,
412- } ,
413- {
414- "content-type" : "application/json" ,
415- origin : "example.com" ,
416- }
417- ) ;
418- req . method = "POST" ;
408+ const req = request ( { headers : { origin : "example.com" } } ) ;
419409
420410 const resp = await runHandler ( func , req as any ) ;
421411 expect ( resp . body ) . to . deep . equal ( JSON . stringify ( { result : 42 } ) ) ;
@@ -426,17 +416,14 @@ describe("onCall", () => {
426416 throw new Error ( "Should not reach here for OPTIONS preflight" ) ;
427417 } ) ;
428418
429- const req = new MockRequest (
430- {
431- data : { } ,
432- } ,
433- {
419+ const req = request ( {
420+ headers : {
434421 "Access-Control-Request-Method" : "POST" ,
435422 "Access-Control-Request-Headers" : "origin" ,
436423 origin : "example.com" ,
437- }
438- ) ;
439- req . method = "OPTIONS" ;
424+ } ,
425+ method : "OPTIONS" ,
426+ } ) ;
440427
441428 const resp = await runHandler ( func , req as any ) ;
442429 expect ( resp . status ) . to . equal ( 204 ) ;
@@ -455,17 +442,14 @@ describe("onCall", () => {
455442 const func = https . onCall ( { cors : "example.com" } , ( ) => {
456443 throw new Error ( "Should not reach here for OPTIONS preflight" ) ;
457444 } ) ;
458- const req = new MockRequest (
459- {
460- data : { } ,
461- } ,
462- {
445+ const req = request ( {
446+ headers : {
463447 "Access-Control-Request-Method" : "POST" ,
464448 "Access-Control-Request-Headers" : "origin" ,
465449 origin : "localhost" ,
466- }
467- ) ;
468- req . method = "OPTIONS" ;
450+ } ,
451+ method : "OPTIONS" ,
452+ } ) ;
469453
470454 const response = await runHandler ( func , req as any ) ;
471455
@@ -483,18 +467,8 @@ describe("onCall", () => {
483467
484468 it ( "adds CORS headers" , async ( ) => {
485469 const func = https . onCall ( ( ) => 42 ) ;
486- const req = new MockRequest (
487- {
488- data : { } ,
489- } ,
490- {
491- "content-type" : "application/json" ,
492- origin : "example.com" ,
493- }
494- ) ;
495- req . method = "POST" ;
496-
497- const response = await runHandler ( func , req as any ) ;
470+ const req = request ( { headers : { origin : "example.com" } } )
471+ const response = await runHandler ( func , req ) ;
498472
499473 expect ( response . status ) . to . equal ( 200 ) ;
500474 expect ( response . body ) . to . be . deep . equal ( JSON . stringify ( { result : 42 } ) ) ;
@@ -515,17 +489,7 @@ describe("onCall", () => {
515489 it ( "calls init function" , async ( ) => {
516490 const func = https . onCall ( ( ) => 42 ) ;
517491
518- const req = new MockRequest (
519- {
520- data : { } ,
521- } ,
522- {
523- "content-type" : "application/json" ,
524- origin : "example.com" ,
525- }
526- ) ;
527- req . method = "POST" ;
528-
492+ const req = request ( { headers : { origin : "example.com" } } ) ;
529493 let hello ;
530494 onInit ( ( ) => ( hello = "world" ) ) ;
531495 expect ( hello ) . to . be . undefined ;
@@ -534,20 +498,6 @@ describe("onCall", () => {
534498 } ) ;
535499
536500 describe ( "authPolicy" , ( ) => {
537- function req ( data : any , auth ?: Record < string , string > ) : any {
538- const headers = {
539- "content-type" : "application/json" ,
540- } ;
541- if ( auth ) {
542- headers [ "authorization" ] = `bearer ignored.${ Buffer . from (
543- JSON . stringify ( auth ) ,
544- "utf-8"
545- ) . toString ( "base64" ) } .ignored`;
546- }
547- const ret = new MockRequest ( { data } , headers ) ;
548- ret . method = "POST" ;
549- return ret ;
550- }
551501
552502 before ( ( ) => {
553503 sinon . stub ( debug , "isDebugFeatureEnabled" ) . withArgs ( "skipTokenVerification" ) . returns ( true ) ;
@@ -565,10 +515,10 @@ describe("onCall", () => {
565515 ( ) => 42
566516 ) ;
567517
568- const authResp = await runHandler ( func , req ( null , { sub : "inlined" } ) ) ;
518+ const authResp = await runHandler ( func , request ( { auth : { sub : "inlined" } } ) ) ;
569519 expect ( authResp . status ) . to . equal ( 200 ) ;
570520
571- const anonResp = await runHandler ( func , req ( null , null ) ) ;
521+ const anonResp = await runHandler ( func , request ( { } ) ) ;
572522 expect ( anonResp . status ) . to . equal ( 403 ) ;
573523 } ) ;
574524
@@ -586,18 +536,18 @@ describe("onCall", () => {
586536 ( ) => "HHGTG"
587537 ) ;
588538
589- const cases : Array < { fn : Handler ; auth : null | Record < string , string > ; status : number } > = [
539+ const cases : Array < { fn : Handler ; auth ?: Record < string , string > ; status : number } > = [
590540 { fn : anyValue , auth : { meaning : "42" } , status : 200 } ,
591541 { fn : anyValue , auth : { meaning : "43" } , status : 200 } ,
592542 { fn : anyValue , auth : { order : "66" } , status : 403 } ,
593- { fn : anyValue , auth : null , status : 403 } ,
543+ { fn : anyValue , status : 403 } ,
594544 { fn : specificValue , auth : { meaning : "42" } , status : 200 } ,
595545 { fn : specificValue , auth : { meaning : "43" } , status : 403 } ,
596546 { fn : specificValue , auth : { order : "66" } , status : 403 } ,
597- { fn : specificValue , auth : null , status : 403 } ,
547+ { fn : specificValue , status : 403 } ,
598548 ] ;
599549 for ( const test of cases ) {
600- const resp = await runHandler ( test . fn , req ( null , test . auth ) ) ;
550+ const resp = await runHandler ( test . fn , request ( { auth : test . auth } ) ) ;
601551 expect ( resp . status ) . to . equal ( test . status ) ;
602552 }
603553 } ) ;
@@ -610,10 +560,50 @@ describe("onCall", () => {
610560 ( req ) => req . data / 2
611561 ) ;
612562
613- const authorized = await runHandler ( divTwo , req ( 2 ) ) ;
563+ const authorized = await runHandler ( divTwo , request ( { data : 2 } ) ) ;
614564 expect ( authorized . status ) . to . equal ( 200 ) ;
615- const accessDenied = await runHandler ( divTwo , req ( 1 ) ) ;
565+ const accessDenied = await runHandler ( divTwo , request ( { data : 1 } ) ) ;
616566 expect ( accessDenied . status ) . to . equal ( 403 ) ;
617567 } ) ;
618568 } ) ;
619569} ) ;
570+
571+ describe ( "onCallGenkit" , ( ) => {
572+ interface Streamable {
573+ stream : sinon . SinonStub ;
574+ }
575+ it ( "calls with JSON requests" , async ( ) => {
576+ const flow = sinon . stub ( ) ;
577+ flow . withArgs ( "answer" ) . returns ( 42 ) ;
578+ ( flow as any as Streamable ) . stream = sinon . stub ( ) ;
579+ ( flow as any as Streamable ) . stream . onCall ( 0 ) . throws ( "Unexpected stream" ) ;
580+ ( flow as any ) . flow = { name : "flows/test" } ;
581+
582+ const f = https . onCallGenkit ( flow as any as CallableFlow < z . ZodString , z . ZodNumber > )
583+
584+ const req = request ( { data : "answer" } ) ;
585+ const res = await runHandler ( f , req ) ;
586+ expect ( JSON . parse ( res . body ) ) . to . deep . equal ( { "result" :42 } ) ;
587+ } ) ;
588+
589+
590+ it ( "Streams with SSE requests" , async ( ) => {
591+ const flow = sinon . stub ( ) ;
592+ flow . onFirstCall ( ) . throws ( ) ;
593+ ( flow as any as Streamable ) . stream = sinon . stub ( ) ;
594+ ( flow as any as Streamable ) . stream . withArgs ( "answer" ) . returns ( {
595+ stream : ( async function * ( ) {
596+ yield 1 ;
597+ yield 2 ;
598+ } ) ( ) ,
599+ output : Promise . resolve ( 42 ) ,
600+ } ) ;
601+ ( flow as any ) . flow = { name : "flows/test" } ;
602+
603+ const f = https . onCallGenkit ( flow as any as CallableFlow < z . ZodString , z . ZodNumber , z . ZodNumber > )
604+
605+ const req = request ( { data : "answer" , headers : { accept : "text/event-stream" } } ) ;
606+ const res = await runHandler ( f , req ) ;
607+ expect ( res . body ) . to . equal ( [ 'data: {"message":1}' , 'data: {"message":2}' , 'data: {"result":42}' , '' ] . join ( "\n" ) ) ;
608+ } ) ;
609+ } ) ;
0 commit comments