From 16c04dd090e78fd470e89600949734f11d4ff35f Mon Sep 17 00:00:00 2001 From: Daniel Lee Date: Wed, 16 Oct 2024 13:09:40 -0700 Subject: [PATCH] Support streaming request/response for callable functions (v2 only). --- spec/common/providers/https.spec.ts | 97 +++++++++++++++++++++++------ spec/helper.ts | 10 ++- src/common/providers/https.ts | 64 +++++++++++++++---- src/v1/providers/https.ts | 8 +-- src/v2/providers/https.ts | 16 ++--- 5 files changed, 153 insertions(+), 42 deletions(-) diff --git a/spec/common/providers/https.spec.ts b/spec/common/providers/https.spec.ts index 9ea0f4f79..7156628d1 100644 --- a/spec/common/providers/https.spec.ts +++ b/spec/common/providers/https.spec.ts @@ -49,25 +49,33 @@ async function runCallableTest(test: CallTest): Promise { cors: { origin: true, methods: "POST" }, ...test.callableOption, }; - const callableFunctionV1 = https.onCallHandler(opts, (data, context) => { - expect(data).to.deep.equal(test.expectedData); - return test.callableFunction(data, context); - }); + const callableFunctionV1 = https.onCallHandler( + opts, + (data, context) => { + expect(data).to.deep.equal(test.expectedData); + return test.callableFunction(data, context); + }, + "v1" + ); const responseV1 = await runHandler(callableFunctionV1, test.httpRequest); - expect(responseV1.body).to.deep.equal(test.expectedHttpResponse.body); + expect(responseV1.body).to.deep.equal(JSON.stringify(test.expectedHttpResponse.body)); expect(responseV1.headers).to.deep.equal(test.expectedHttpResponse.headers); expect(responseV1.status).to.equal(test.expectedHttpResponse.status); - const callableFunctionV2 = https.onCallHandler(opts, (request) => { - expect(request.data).to.deep.equal(test.expectedData); - return test.callableFunction2(request); - }); + const callableFunctionV2 = https.onCallHandler( + opts, + (request) => { + expect(request.data).to.deep.equal(test.expectedData); + return test.callableFunction2(request); + }, + "v2" + ); const responseV2 = await runHandler(callableFunctionV2, test.httpRequest); - expect(responseV2.body).to.deep.equal(test.expectedHttpResponse.body); + expect(responseV2.body).to.deep.equal(JSON.stringify(test.expectedHttpResponse.body)); expect(responseV2.headers).to.deep.equal(test.expectedHttpResponse.headers); expect(responseV2.status).to.equal(test.expectedHttpResponse.status); } @@ -165,7 +173,7 @@ describe("onCallHandler", () => { status: 400, headers: expectedResponseHeaders, body: { - error: { status: "INVALID_ARGUMENT", message: "Bad Request" }, + error: { message: "Bad Request", status: "INVALID_ARGUMENT" }, }, }, }); @@ -203,7 +211,7 @@ describe("onCallHandler", () => { status: 400, headers: expectedResponseHeaders, body: { - error: { status: "INVALID_ARGUMENT", message: "Bad Request" }, + error: { message: "Bad Request", status: "INVALID_ARGUMENT" }, }, }, }); @@ -225,7 +233,7 @@ describe("onCallHandler", () => { status: 400, headers: expectedResponseHeaders, body: { - error: { status: "INVALID_ARGUMENT", message: "Bad Request" }, + error: { message: "Bad Request", status: "INVALID_ARGUMENT" }, }, }, }); @@ -244,7 +252,7 @@ describe("onCallHandler", () => { expectedHttpResponse: { status: 500, headers: expectedResponseHeaders, - body: { error: { status: "INTERNAL", message: "INTERNAL" } }, + body: { error: { message: "INTERNAL", status: "INTERNAL" } }, }, }); }); @@ -262,7 +270,7 @@ describe("onCallHandler", () => { expectedHttpResponse: { status: 500, headers: expectedResponseHeaders, - body: { error: { status: "INTERNAL", message: "INTERNAL" } }, + body: { error: { message: "INTERNAL", status: "INTERNAL" } }, }, }); }); @@ -280,7 +288,7 @@ describe("onCallHandler", () => { expectedHttpResponse: { status: 404, headers: expectedResponseHeaders, - body: { error: { status: "NOT_FOUND", message: "i am error" } }, + body: { error: { message: "i am error", status: "NOT_FOUND" } }, }, }); }); @@ -364,8 +372,8 @@ describe("onCallHandler", () => { headers: expectedResponseHeaders, body: { error: { - status: "UNAUTHENTICATED", message: "Unauthenticated", + status: "UNAUTHENTICATED", }, }, }, @@ -391,8 +399,8 @@ describe("onCallHandler", () => { headers: expectedResponseHeaders, body: { error: { - status: "UNAUTHENTICATED", message: "Unauthenticated", + status: "UNAUTHENTICATED", }, }, }, @@ -461,8 +469,8 @@ describe("onCallHandler", () => { headers: expectedResponseHeaders, body: { error: { - status: "UNAUTHENTICATED", message: "Unauthenticated", + status: "UNAUTHENTICATED", }, }, }, @@ -748,6 +756,57 @@ describe("onCallHandler", () => { }); }); }); + + describe("Streaming callables", () => { + it("returns data in SSE format for requests Accept: text/event-stream header", async () => { + const mockReq = mockRequest( + { message: "hello streaming" }, + "application/json", {}, + { accept: "text/event-stream" } + ) as any; + const fn = https.onCallHandler( + { + cors: { origin: true, methods: "POST" }, + }, + (req, resp) => { + resp.write("hello") + return 'world'; + }, + "v2" + ); + + const resp = await runHandler(fn, mockReq); + const data = [ + `data: {"message":"hello"}`, + `data: {"result":"world"}`, + ] + expect(resp.body).to.equal([...data, ""].join("\n")); + }); + + it("returns error in SSE format", async () => { + const mockReq = mockRequest( + { message: "hello streaming" }, + "application/json", + {}, + { accept: "text/event-stream" } + ) as any; + const fn = https.onCallHandler( + { + cors: { origin: true, methods: "POST" }, + }, + (req, resp) => { + throw new Error("BOOM") + }, + "v2" + ); + + const resp = await runHandler(fn, mockReq); + const data = [ + `data: {"error":{"message":"INTERNAL","status":"INTERNAL"}}`, + ] + expect(resp.body).to.equal([...data, ""].join("\n")); + }); + }); }); describe("encoding/decoding", () => { diff --git a/spec/helper.ts b/spec/helper.ts index 8dd78d82c..8529ffa9c 100644 --- a/spec/helper.ts +++ b/spec/helper.ts @@ -47,6 +47,7 @@ export function runHandler( // MockResponse mocks an express.Response. // This class lives here so it can reference resolve and reject. class MockResponse { + private sentBody = ""; private statusCode = 0; private headers: { [name: string]: string } = {}; private callback: () => void; @@ -65,7 +66,10 @@ export function runHandler( return this.headers[name]; } - public send(body: any) { + public send(sendBody: any) { + const toSend = typeof sendBody === "object" ? JSON.stringify(sendBody) : sendBody; + const body = this.sentBody ? this.sentBody + (toSend || "") : toSend; + resolve({ status: this.statusCode, headers: this.headers, @@ -76,6 +80,10 @@ export function runHandler( } } + public write(writeBody: any) { + this.sentBody += typeof writeBody === "object" ? JSON.stringify(writeBody) : writeBody; + } + public end() { this.send(undefined); } diff --git a/src/common/providers/https.ts b/src/common/providers/https.ts index 2f0e56538..b878b2624 100644 --- a/src/common/providers/https.ts +++ b/src/common/providers/https.ts @@ -141,6 +141,15 @@ export interface CallableRequest { rawRequest: Request; } +/** + * CallableProxyResponse exposes subset of express.Response object + * to allow writing partial, streaming responses back to the client. + */ +export interface CallableProxyResponse { + write: express.Response["write"]; + acceptsStreaming: boolean; +} + /** * The set of Firebase Functions status codes. The codes are the same at the * ones exposed by {@link https://github.com/grpc/grpc/blob/master/doc/statuscodes.md | gRPC}. @@ -673,7 +682,10 @@ async function checkAppCheckToken( } type v1CallableHandler = (data: any, context: CallableContext) => any | Promise; -type v2CallableHandler = (request: CallableRequest) => Res; +type v2CallableHandler = ( + request: CallableRequest, + response?: CallableProxyResponse +) => Res; /** @internal **/ export interface CallableOptions { @@ -685,9 +697,10 @@ export interface CallableOptions { /** @internal */ export function onCallHandler( options: CallableOptions, - handler: v1CallableHandler | v2CallableHandler + handler: v1CallableHandler | v2CallableHandler, + version: "v1" | "v2" ): (req: Request, res: express.Response) => Promise { - const wrapped = wrapOnCallHandler(options, handler); + const wrapped = wrapOnCallHandler(options, handler, version); return (req: Request, res: express.Response) => { return new Promise((resolve) => { res.on("finish", resolve); @@ -698,10 +711,15 @@ export function onCallHandler( }; } +function encodeSSE(data: unknown): string { + return `data: ${JSON.stringify(data)}\n`; +} + /** @internal */ function wrapOnCallHandler( options: CallableOptions, - handler: v1CallableHandler | v2CallableHandler + handler: v1CallableHandler | v2CallableHandler, + version: "v1" | "v2" ): (req: Request, res: express.Response) => Promise { return async (req: Request, res: express.Response): Promise => { try { @@ -719,7 +737,7 @@ function wrapOnCallHandler( // The original monkey-patched code lived in the functionsEmulatorRuntime // (link: https://github.com/firebase/firebase-tools/blob/accea7abda3cc9fa6bb91368e4895faf95281c60/src/emulator/functionsEmulatorRuntime.ts#L480) // and was not compatible with how monorepos separate out packages (see https://github.com/firebase/firebase-tools/issues/5210). - if (isDebugFeatureEnabled("skipTokenVerification") && handler.length === 2) { + if (isDebugFeatureEnabled("skipTokenVerification") && version === "v1") { const authContext = context.rawRequest.header(CALLABLE_AUTH_HEADER); if (authContext) { logger.debug("Callable functions auth override", { @@ -763,18 +781,34 @@ function wrapOnCallHandler( context.instanceIdToken = req.header("Firebase-Instance-ID-Token"); } + const acceptsStreaming = version === "v2" && req.header("accept") === "text/event-stream"; const data: Req = decode(req.body.data); let result: Res; - if (handler.length === 2) { - result = await handler(data, context); + if (version === "v1") { + result = await (handler as v1CallableHandler)(data, context); } else { const arg: CallableRequest = { ...context, data, }; + // TODO: set up optional heartbeat + const responseProxy: CallableProxyResponse = { + write(chunk): boolean { + if (acceptsStreaming) { + const formattedData = encodeSSE({ message: chunk }); + return res.write(formattedData); + } + // if client doesn't accept sse-protocol, response.write() is no-op. + }, + acceptsStreaming, + }; + if (acceptsStreaming) { + // SSE always responds with 200 + res.status(200) + } // For some reason the type system isn't picking up that the handler // is a one argument function. - result = await (handler as any)(arg); + result = await (handler as v2CallableHandler)(arg, responseProxy); } // Encode the result as JSON to preserve types like Dates. @@ -782,7 +816,12 @@ function wrapOnCallHandler( // If there was some result, encode it in the body. const responseBody: HttpResponseBody = { result }; - res.status(200).send(responseBody); + if (acceptsStreaming) { + res.write(encodeSSE(responseBody)) + res.end(); + } else { + res.status(200).send(responseBody); + } } catch (err) { let httpErr = err; if (!(err instanceof HttpsError)) { @@ -793,8 +832,11 @@ function wrapOnCallHandler( const { status } = httpErr.httpErrorCode; const body = { error: httpErr.toJSON() }; - - res.status(status).send(body); + if (req.header("accept") === "text/event-stream") { + res.send(encodeSSE(body)); + } else { + res.status(status).send(body); + } } }; } diff --git a/src/v1/providers/https.ts b/src/v1/providers/https.ts index e9cd5d132..3703f0502 100644 --- a/src/v1/providers/https.ts +++ b/src/v1/providers/https.ts @@ -102,9 +102,8 @@ export function _onCallWithOptions( handler: (data: any, context: CallableContext) => any | Promise, options: DeploymentOptions ): HttpsFunction & Runnable { - // onCallHandler sniffs the function length of the passed-in callback - // and the user could have only tried to listen to data. Wrap their handler - // in another handler to avoid accidentally triggering the v2 API + // fix the length of handler to make the call to handler consistent + // in the onCallHandler const fixedLen = (data: any, context: CallableContext) => { return withInit(handler)(data, context); }; @@ -115,7 +114,8 @@ export function _onCallWithOptions( consumeAppCheckToken: options.consumeAppCheckToken, cors: { origin: true, methods: "POST" }, }, - fixedLen + fixedLen, + "v1" ) ); diff --git a/src/v2/providers/https.ts b/src/v2/providers/https.ts index 16ad9038c..c28f0504c 100644 --- a/src/v2/providers/https.ts +++ b/src/v2/providers/https.ts @@ -33,6 +33,7 @@ import { isDebugFeatureEnabled } from "../../common/debug"; import { ResetValue } from "../../common/options"; import { CallableRequest, + CallableProxyResponse, FunctionsErrorCode, HttpsError, onCallHandler, @@ -347,7 +348,7 @@ export function onRequest( */ export function onCall>( opts: CallableOptions, - handler: (request: CallableRequest) => Return + handler: (request: CallableRequest, response?: CallableProxyResponse) => Return ): CallableFunction ? Return : Promise>; /** @@ -356,11 +357,11 @@ export function onCall>( * @returns A function that you can export and deploy. */ export function onCall>( - handler: (request: CallableRequest) => Return + handler: (request: CallableRequest, response?: CallableProxyResponse) => Return ): CallableFunction ? Return : Promise>; export function onCall>( optsOrHandler: CallableOptions | ((request: CallableRequest) => Return), - handler?: (request: CallableRequest) => Return + handler?: (request: CallableRequest, response?: CallableProxyResponse) => Return ): CallableFunction ? Return : Promise> { let opts: CallableOptions; if (arguments.length === 1) { @@ -378,16 +379,17 @@ export function onCall>( origin = origin[0]; } - // onCallHandler sniffs the function length to determine which API to present. - // fix the length to prevent api versions from being mismatched. - const fixedLen = (req: CallableRequest) => withInit(handler)(req); + // fix the length of handler to make the call to handler consistent + const fixedLen = (req: CallableRequest, resp?: CallableProxyResponse) => + withInit(handler)(req, resp); let func: any = onCallHandler( { cors: { origin, methods: "POST" }, enforceAppCheck: opts.enforceAppCheck ?? options.getGlobalOptions().enforceAppCheck, consumeAppCheckToken: opts.consumeAppCheckToken, }, - fixedLen + fixedLen, + "v2" ); func = wrapTraceContext(func);