diff --git a/examples/template.ts b/examples/template.ts index a6960e3..0dd075f 100644 --- a/examples/template.ts +++ b/examples/template.ts @@ -63,16 +63,13 @@ const logger = new Logger({ name: "template" }); const modified = original.fork((oldConfig) => ({ ...oldConfig, template: `${oldConfig.template} Your answers must be concise.`, - schema: z.object({ - name: z.string().default("123"), - age: z.number(), - objective: z.string(), - }), + defaults: { + name: "Alex", + }, })); const output = modified.render({ name: undefined, - age: 12, objective: "fulfill the user needs", }); logger.info(output); diff --git a/src/template.ts b/src/template.ts index 4f2cecf..5d889f9 100644 --- a/src/template.ts +++ b/src/template.ts @@ -23,13 +23,14 @@ import { z, ZodType } from "zod"; import { createSchemaValidator, toJsonSchema } from "@/internals/helpers/schema.js"; import type { SchemaObject, ValidateFunction } from "ajv"; import { shallowCopy } from "@/serializer/utils.js"; +import { pickBy } from "remeda"; +import { getProp } from "@/internals/helpers/object.js"; -export type PromptTemplateRenderFn = ( - this: K extends ZodType ? A : never, -) => any; +export type InferValue = T extends ZodType ? A : never; +export type PromptTemplateRenderFn = (this: InferValue) => any; export type PromptTemplateRenderInput = z.input> = { - [K in keyof T2]: T2[K] | PromptTemplateRenderFn; + [K in keyof T2]: T2[K] | PromptTemplateRenderFn | undefined; }; export interface PromptTemplateInput { @@ -37,13 +38,15 @@ export interface PromptTemplateInput { customTags?: [string, string]; escape?: boolean; schema: SchemaObject; + defaults?: Partial>; functions?: Record>; } type PromptTemplateConstructor = N extends ZodType - ? Omit, "schema" | "functions"> & { + ? Omit, "schema" | "functions" | "defaults"> & { schema: N; functions?: Record>; + defaults?: Partial>; } : Omit, "schema"> & { schema: T | SchemaObject }; @@ -77,6 +80,7 @@ export class PromptTemplate extends Serializable { super(); this.config = { ...config, + defaults: (config.defaults ?? {}) as Partial>, schema: toJsonSchema(config.schema), escape: Boolean(config.escape), customTags: config.customTags ?? ["{{", "}}"], @@ -92,8 +96,7 @@ export class PromptTemplate extends Serializable { } protected validateInput(input: unknown): asserts input is T { - const schema = toJsonSchema(this.config.schema); - const validator = createSchemaValidator(schema, { + const validator = createSchemaValidator(this.config.schema, { coerceTypes: false, }) as ValidateFunction; @@ -119,11 +122,17 @@ export class PromptTemplate extends Serializable { return new PromptTemplate(newConfig); } - render(inputs: PromptTemplateRenderInput): string { - this.validateInput(inputs); + render(input: PromptTemplateRenderInput): string { + const updatedInput: typeof input = { ...input }; + Object.assign( + updatedInput, + pickBy(this.config.defaults, (_, k) => getProp(updatedInput, [k]) === undefined), + ); + + this.validateInput(updatedInput); const view: Record = { ...this.config.functions, - ...inputs, + ...updatedInput, }; const output = Mustache.render(