Skip to content

Commit 48e79b6

Browse files
committed
Tests
1 parent 658b575 commit 48e79b6

File tree

2 files changed

+111
-119
lines changed

2 files changed

+111
-119
lines changed

spec/v2/providers/https.spec.ts

Lines changed: 105 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,27 @@ import { runHandler } from "../../helper";
3131
import { FULL_ENDPOINT, MINIMAL_V2_ENDPOINT, FULL_OPTIONS, FULL_TRIGGER } from "./fixtures";
3232
import { onInit } from "../../../src/v2/core";
3333
import { Handler } from "express";
34+
import { CallableFlow, z } from "genkit";
35+
36+
function request(args: { data?: any, auth?: Record<string, string>, headers?: Record<string, string>, method?: MockRequest["method"] }): any {
37+
let headers: Record<string, string> = {}
38+
if (args.method !== "POST") {
39+
headers["content-type"] = "application/json";
40+
}
41+
headers = {
42+
...headers,
43+
...args.headers,
44+
};
45+
if (args.auth) {
46+
headers["authorization"] = `bearer ignored.${Buffer.from(
47+
JSON.stringify(args.auth),
48+
"utf-8"
49+
).toString("base64")}.ignored`;
50+
}
51+
const ret = new MockRequest({ data: args.data || {} }, headers);
52+
ret.method = args.method || "POST";
53+
return ret;
54+
}
3455

3556
describe("onRequest", () => {
3657
beforeEach(() => {
@@ -171,18 +192,8 @@ describe("onRequest", () => {
171192
res.send("Works");
172193
});
173194

174-
const req = new MockRequest(
175-
{
176-
data: {},
177-
},
178-
{
179-
"content-type": "application/json",
180-
origin: "example.com",
181-
}
182-
);
183-
req.method = "POST";
184-
185-
const resp = await runHandler(func, req as any);
195+
const req = request({ headers: { origin: "example.com" }});
196+
const resp = await runHandler(func, req);
186197
expect(resp.body).to.equal("Works");
187198
});
188199

@@ -191,17 +202,14 @@ describe("onRequest", () => {
191202
throw new Error("Should not reach here for OPTIONS preflight");
192203
});
193204

194-
const req = new MockRequest(
195-
{
196-
data: {},
197-
},
198-
{
205+
const req = request({
206+
headers: {
199207
"Access-Control-Request-Method": "POST",
200208
"Access-Control-Request-Headers": "origin",
201209
origin: "example.com",
202-
}
203-
);
204-
req.method = "OPTIONS";
210+
},
211+
method: "OPTIONS",
212+
})
205213

206214
const resp = await runHandler(func, req as any);
207215
expect(resp.status).to.equal(204);
@@ -221,17 +229,14 @@ describe("onRequest", () => {
221229
throw new Error("Should not reach here for OPTIONS preflight");
222230
});
223231

224-
const req = new MockRequest(
225-
{
226-
data: {},
227-
},
228-
{
232+
const req = request({
233+
headers: {
229234
"Access-Control-Request-Method": "POST",
230235
"Access-Control-Request-Headers": "origin",
231236
origin: "localhost",
232-
}
233-
);
234-
req.method = "OPTIONS";
237+
},
238+
method: "OPTIONS",
239+
});
235240

236241
const resp = await runHandler(func, req as any);
237242
expect(resp.status).to.equal(204);
@@ -253,17 +258,14 @@ describe("onRequest", () => {
253258
res.status(200).send("Good");
254259
});
255260

256-
const req = new MockRequest(
257-
{
258-
data: {},
259-
},
260-
{
261+
const req = request({
262+
headers: {
261263
"Access-Control-Request-Method": "POST",
262264
"Access-Control-Request-Headers": "origin",
263265
origin: "example.com",
264-
}
265-
);
266-
req.method = "OPTIONS";
266+
},
267+
method: "OPTIONS",
268+
});
267269

268270
const resp = await runHandler(func, req as any);
269271
expect(resp.status).to.equal(200);
@@ -277,17 +279,14 @@ describe("onRequest", () => {
277279
const func = https.onRequest((req, res) => {
278280
res.status(200).send("Good");
279281
});
280-
const req = new MockRequest(
281-
{
282-
data: {},
283-
},
284-
{
282+
const req = request({
283+
headers: {
285284
"Access-Control-Request-Method": "POST",
286285
"Access-Control-Request-Headers": "origin",
287286
origin: "example.com",
288-
}
289-
);
290-
req.method = "OPTIONS";
287+
},
288+
method: "OPTIONS",
289+
});
291290
let hello;
292291
onInit(() => (hello = "world"));
293292
expect(hello).to.be.undefined;
@@ -406,16 +405,7 @@ describe("onCall", () => {
406405
it("should be an express handler", async () => {
407406
const func = https.onCall(() => 42);
408407

409-
const req = new MockRequest(
410-
{
411-
data: {},
412-
},
413-
{
414-
"content-type": "application/json",
415-
origin: "example.com",
416-
}
417-
);
418-
req.method = "POST";
408+
const req = request({ headers: { origin: "example.com" }});
419409

420410
const resp = await runHandler(func, req as any);
421411
expect(resp.body).to.deep.equal(JSON.stringify({ result: 42 }));
@@ -426,17 +416,14 @@ describe("onCall", () => {
426416
throw new Error("Should not reach here for OPTIONS preflight");
427417
});
428418

429-
const req = new MockRequest(
430-
{
431-
data: {},
432-
},
433-
{
419+
const req = request({
420+
headers: {
434421
"Access-Control-Request-Method": "POST",
435422
"Access-Control-Request-Headers": "origin",
436423
origin: "example.com",
437-
}
438-
);
439-
req.method = "OPTIONS";
424+
},
425+
method: "OPTIONS",
426+
});
440427

441428
const resp = await runHandler(func, req as any);
442429
expect(resp.status).to.equal(204);
@@ -455,17 +442,14 @@ describe("onCall", () => {
455442
const func = https.onCall({ cors: "example.com" }, () => {
456443
throw new Error("Should not reach here for OPTIONS preflight");
457444
});
458-
const req = new MockRequest(
459-
{
460-
data: {},
461-
},
462-
{
445+
const req = request({
446+
headers: {
463447
"Access-Control-Request-Method": "POST",
464448
"Access-Control-Request-Headers": "origin",
465449
origin: "localhost",
466-
}
467-
);
468-
req.method = "OPTIONS";
450+
},
451+
method: "OPTIONS",
452+
});
469453

470454
const response = await runHandler(func, req as any);
471455

@@ -483,18 +467,8 @@ describe("onCall", () => {
483467

484468
it("adds CORS headers", async () => {
485469
const func = https.onCall(() => 42);
486-
const req = new MockRequest(
487-
{
488-
data: {},
489-
},
490-
{
491-
"content-type": "application/json",
492-
origin: "example.com",
493-
}
494-
);
495-
req.method = "POST";
496-
497-
const response = await runHandler(func, req as any);
470+
const req = request({ headers: { origin: "example.com" }})
471+
const response = await runHandler(func, req);
498472

499473
expect(response.status).to.equal(200);
500474
expect(response.body).to.be.deep.equal(JSON.stringify({ result: 42 }));
@@ -515,17 +489,7 @@ describe("onCall", () => {
515489
it("calls init function", async () => {
516490
const func = https.onCall(() => 42);
517491

518-
const req = new MockRequest(
519-
{
520-
data: {},
521-
},
522-
{
523-
"content-type": "application/json",
524-
origin: "example.com",
525-
}
526-
);
527-
req.method = "POST";
528-
492+
const req = request({ headers: { origin: "example.com" }});
529493
let hello;
530494
onInit(() => (hello = "world"));
531495
expect(hello).to.be.undefined;
@@ -534,20 +498,6 @@ describe("onCall", () => {
534498
});
535499

536500
describe("authPolicy", () => {
537-
function req(data: any, auth?: Record<string, string>): any {
538-
const headers = {
539-
"content-type": "application/json",
540-
};
541-
if (auth) {
542-
headers["authorization"] = `bearer ignored.${Buffer.from(
543-
JSON.stringify(auth),
544-
"utf-8"
545-
).toString("base64")}.ignored`;
546-
}
547-
const ret = new MockRequest({ data }, headers);
548-
ret.method = "POST";
549-
return ret;
550-
}
551501

552502
before(() => {
553503
sinon.stub(debug, "isDebugFeatureEnabled").withArgs("skipTokenVerification").returns(true);
@@ -565,10 +515,10 @@ describe("onCall", () => {
565515
() => 42
566516
);
567517

568-
const authResp = await runHandler(func, req(null, { sub: "inlined" }));
518+
const authResp = await runHandler(func, request({ auth: { sub: "inlined" }}));
569519
expect(authResp.status).to.equal(200);
570520

571-
const anonResp = await runHandler(func, req(null, null));
521+
const anonResp = await runHandler(func, request({}));
572522
expect(anonResp.status).to.equal(403);
573523
});
574524

@@ -586,18 +536,18 @@ describe("onCall", () => {
586536
() => "HHGTG"
587537
);
588538

589-
const cases: Array<{ fn: Handler; auth: null | Record<string, string>; status: number }> = [
539+
const cases: Array<{ fn: Handler; auth?: Record<string, string>; status: number }> = [
590540
{ fn: anyValue, auth: { meaning: "42" }, status: 200 },
591541
{ fn: anyValue, auth: { meaning: "43" }, status: 200 },
592542
{ fn: anyValue, auth: { order: "66" }, status: 403 },
593-
{ fn: anyValue, auth: null, status: 403 },
543+
{ fn: anyValue, status: 403 },
594544
{ fn: specificValue, auth: { meaning: "42" }, status: 200 },
595545
{ fn: specificValue, auth: { meaning: "43" }, status: 403 },
596546
{ fn: specificValue, auth: { order: "66" }, status: 403 },
597-
{ fn: specificValue, auth: null, status: 403 },
547+
{ fn: specificValue, status: 403 },
598548
];
599549
for (const test of cases) {
600-
const resp = await runHandler(test.fn, req(null, test.auth));
550+
const resp = await runHandler(test.fn, request({ auth: test.auth }));
601551
expect(resp.status).to.equal(test.status);
602552
}
603553
});
@@ -610,10 +560,50 @@ describe("onCall", () => {
610560
(req) => req.data / 2
611561
);
612562

613-
const authorized = await runHandler(divTwo, req(2));
563+
const authorized = await runHandler(divTwo, request({ data: 2 }));
614564
expect(authorized.status).to.equal(200);
615-
const accessDenied = await runHandler(divTwo, req(1));
565+
const accessDenied = await runHandler(divTwo, request({ data: 1 }));
616566
expect(accessDenied.status).to.equal(403);
617567
});
618568
});
619569
});
570+
571+
describe("onCallGenkit", () => {
572+
interface Streamable {
573+
stream: sinon.SinonStub;
574+
}
575+
it("calls with JSON requests", async () => {
576+
const flow = sinon.stub();
577+
flow.withArgs("answer").returns(42);
578+
(flow as any as Streamable).stream = sinon.stub();
579+
(flow as any as Streamable).stream.onCall(0).throws("Unexpected stream");
580+
(flow as any).flow = { name: "flows/test" };
581+
582+
const f = https.onCallGenkit(flow as any as CallableFlow<z.ZodString, z.ZodNumber>)
583+
584+
const req = request({ data: "answer" });
585+
const res = await runHandler(f, req);
586+
expect(JSON.parse(res.body)).to.deep.equal({"result":42});
587+
});
588+
589+
590+
it("Streams with SSE requests", async () => {
591+
const flow = sinon.stub();
592+
flow.onFirstCall().throws();
593+
(flow as any as Streamable).stream = sinon.stub();
594+
(flow as any as Streamable).stream.withArgs("answer").returns({
595+
stream: (async function*() {
596+
yield 1;
597+
yield 2;
598+
})(),
599+
output: Promise.resolve(42),
600+
});
601+
(flow as any).flow = { name: "flows/test" };
602+
603+
const f = https.onCallGenkit(flow as any as CallableFlow<z.ZodString, z.ZodNumber, z.ZodNumber>)
604+
605+
const req = request({ data: "answer", headers: { accept: "text/event-stream" }});
606+
const res = await runHandler(f, req);
607+
expect(res.body).to.equal(['data: {"message":1}', 'data: {"message":2}', 'data: {"result":42}',''].join("\n"));
608+
});
609+
});

src/v2/providers/https.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
import * as cors from "cors";
2929
import * as express from "express";
3030
import { type CallableFlow } from "genkit";
31-
import * as g from "genkit";
32-
import { convertIfPresent, convertInvoker } from "../../common/encoding";
31+
import { convertIfPresent, convertInvoker, copyIfPresent } from "../../common/encoding";
3332
import { wrapTraceContext } from "../trace";
3433
import { isDebugFeatureEnabled } from "../../common/debug";
3534
import { ResetValue } from "../../common/options";
@@ -515,14 +514,17 @@ export function onCallGenkit<F extends CallableFlow<any, any, any>>(optsOrFlow:
515514
logger.debug(`Genkit function for ${flow.flow.name} is not bound to any secret. This may mean that you are not storing API keys as a secret or that you are not binding your secret to this function. See https://firebase.google.com/docs/functions/config-env?gen=2nd#secret_parameters for more information.`);
516515
}
517516
const cloudFunction = onCall<FlowInput<F>, Promise<FlowOutput<F>>, FlowStream<F>>(opts, async (req, res) => {
517+
let context: Omit<CallableRequest, "data" | "rawRequest" | "acceptsStreaming"> = {};
518+
copyIfPresent(context, req, "auth", "app", "instanceIdToken");
519+
518520
if (req.acceptsStreaming) {
519-
const { stream, output } = flow.stream(req.data);
521+
const { stream, output } = flow.stream(req.data, { context });
520522
for await (const chunk of stream) {
521523
await res.sendChunk(chunk);
522524
}
523525
return output;
524526
}
525-
return flow(req.data);
527+
return flow(req.data, { context });
526528
});
527529

528530
cloudFunction.__endpoint.callableTrigger.genkitAction = flow.flow.name;

0 commit comments

Comments
 (0)