diff --git a/.changeset/strong-planets-brake.md b/.changeset/strong-planets-brake.md new file mode 100644 index 00000000000..66553ed7d36 --- /dev/null +++ b/.changeset/strong-planets-brake.md @@ -0,0 +1,5 @@ +--- +"@breadboard-ai/build": patch +--- + +Pass NodeHandlerContext to invoke functions diff --git a/packages/build/src/internal/define/define.ts b/packages/build/src/internal/define/define.ts index 2ebaebe564c..741ddda6ae9 100644 --- a/packages/build/src/internal/define/define.ts +++ b/packages/build/src/internal/define/define.ts @@ -7,6 +7,7 @@ import type { NodeDescriberContext, + NodeHandlerContext, NodeHandlerMetadata, } from "@google-labs/breadboard"; import type { CountUnion, Expand, MaybePromise } from "../common/type-util.js"; @@ -179,10 +180,7 @@ export function defineNodeType< params.outputs["*"], primary(params.inputs), primary(params.outputs), - params.invoke as Function as ( - staticParams: Record, - dynamicParams: Record - ) => { [K: string]: JsonSerializable }, + params.invoke as Function as VeryLooseInvokeFn, params.describe as LooseDescribeFn ); return Object.assign(impl.instantiate.bind(impl), { @@ -245,19 +243,24 @@ type LooseInvokeFn> = Expand< ( staticParams: Expand>, dynamicParams: Expand> - ) => - | { [K: string]: JsonSerializable } - | Promise<{ [K: string]: JsonSerializable }> + ) => MaybePromise<{ [K: string]: JsonSerializable }> >; +export type VeryLooseInvokeFn = ( + staticParams: Record, + dynamicParams: Record, + context: NodeHandlerContext +) => { [K: string]: JsonSerializable }; + type StrictInvokeFn< I extends Record, O extends Record, F extends LooseInvokeFn, > = ( staticInputs: Expand>, - dynamicInputs: Expand> -) => StrictInvokeFnReturn | Promise>; + dynamicInputs: Expand>, + context: NodeHandlerContext +) => MaybePromise>; type StrictInvokeFnReturn< I extends Record, diff --git a/packages/build/src/internal/define/definition.ts b/packages/build/src/internal/define/definition.ts index cca5f8b8c33..8e1ba281cd2 100644 --- a/packages/build/src/internal/define/definition.ts +++ b/packages/build/src/internal/define/definition.ts @@ -8,6 +8,7 @@ import type { InputValues, NodeDescriberContext, NodeDescriberResult, + NodeHandlerContext, OutputValues, Schema, } from "@google-labs/breadboard"; @@ -26,7 +27,11 @@ import type { StaticInputPortConfig, StaticOutputPortConfig, } from "./config.js"; -import type { DynamicInputPorts, LooseDescribeFn } from "./define.js"; +import type { + DynamicInputPorts, + LooseDescribeFn, + VeryLooseInvokeFn, +} from "./define.js"; import { Instance } from "./instance.js"; import { portConfigMapToJSONSchema } from "./json-schema.js"; @@ -71,11 +76,7 @@ export class DefinitionImpl< readonly #reflective: boolean; readonly #primaryInput: string | undefined; readonly #primaryOutput: string | undefined; - // TODO(aomarks) Support promises - readonly #invoke: ( - staticParams: Record, - dynamicParams: Record - ) => { [K: string]: JsonSerializable }; + readonly #invoke: VeryLooseInvokeFn; readonly #describe?: LooseDescribeFn; constructor( @@ -86,10 +87,7 @@ export class DefinitionImpl< dynamicOutputs: DynamicOutputPortConfig | undefined, primaryInput: string | undefined, primaryOutput: string | undefined, - invoke: ( - staticParams: Record, - dynamicParams: Record - ) => { [K: string]: JsonSerializable }, + invoke: VeryLooseInvokeFn, describe?: LooseDescribeFn ) { this.#name = name; @@ -130,10 +128,13 @@ export class DefinitionImpl< ); } - invoke(values: InputValues): Promise { + invoke( + values: InputValues, + context: NodeHandlerContext + ): Promise { const { staticValues, dynamicValues } = this.#applyDefaultsAndPartitionRuntimeInputValues(values); - return Promise.resolve(this.#invoke(staticValues, dynamicValues)); + return Promise.resolve(this.#invoke(staticValues, dynamicValues, context)); } /** diff --git a/packages/build/src/test/invoke_test.ts b/packages/build/src/test/invoke_test.ts new file mode 100644 index 00000000000..4c77a03a1cb --- /dev/null +++ b/packages/build/src/test/invoke_test.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2024 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { NodeHandlerContext } from "@google-labs/breadboard"; +import assert from "node:assert/strict"; +import { test } from "node:test"; +import { defineNodeType } from "../internal/define/define.js"; + +test("invoke receives context", async () => { + const expected: NodeHandlerContext = { + base: new URL("http://example.com/"), + outerGraph: { nodes: [], edges: [] }, + }; + let actual: NodeHandlerContext | undefined; + defineNodeType({ + name: "foo", + inputs: {}, + outputs: {}, + invoke: (_staticInputs, _dynamicInputs, context) => { + actual = context; + return {}; + }, + }).invoke({}, expected); + assert.deepEqual(actual, expected); +});