Skip to content

Commit

Permalink
feat(llm): make drivers serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D committed Sep 23, 2024
1 parent a91b8db commit 05d6dbb
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/llms/drivers/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ import { Retryable } from "@/internals/helpers/retryable.js";
import { PromptTemplate } from "@/template.js";
import { SchemaObject } from "ajv";
import { z } from "zod";
import { Serializable } from "@/internals/serializable.js";

export interface GenerateSchemaInput<T> {
maxRetries?: number;
options?: T;
}

export abstract class BaseDriver<TGenerateOptions extends GenerateOptions = GenerateOptions> {
export abstract class BaseDriver<
TGenerateOptions extends GenerateOptions = GenerateOptions,
> extends Serializable<any> {
protected abstract template: PromptTemplate.infer<{ schema: string }>;
protected errorTemplate = new PromptTemplate({
schema: z.object({
Expand All @@ -45,7 +48,9 @@ export abstract class BaseDriver<TGenerateOptions extends GenerateOptions = Gene
Validation Errors: "{{errors}}"`,
});

constructor(protected llm: ChatLLM<ChatLLMOutput, TGenerateOptions>) {}
constructor(protected readonly llm: ChatLLM<ChatLLMOutput, TGenerateOptions>) {
super();
}

protected abstract parseResponse(textResponse: string): unknown;
protected abstract schemaToString(schema: SchemaObject): Promise<string> | string;
Expand Down Expand Up @@ -123,4 +128,15 @@ Validation Errors: "{{errors}}"`,
},
}).get();
}

createSnapshot() {
return {
template: this.template,
errorTemplate: this.errorTemplate,
};
}

loadSnapshot(snapshot: ReturnType<typeof this.createSnapshot>) {
Object.assign(this, snapshot);
}
}
4 changes: 4 additions & 0 deletions src/llms/drivers/json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ IMPORTANT: Every message must be a parsable JSON string without additional outpu
`,
});

static {
this.register();
}

protected parseResponse(textResponse: string): unknown {
return parseBrokenJson(textResponse);
}
Expand Down
4 changes: 4 additions & 0 deletions src/llms/drivers/typescript.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ IMPORTANT: Every message must be a parsable JSON string without additional outpu
`,
});

static {
this.register();
}

protected parseResponse(textResponse: string): unknown {
return parseBrokenJson(textResponse);
}
Expand Down
4 changes: 4 additions & 0 deletions src/llms/drivers/yaml.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ IMPORTANT: Every message must be a parsable YAML string without additional outpu
`,
});

static {
this.register();
}

protected parseResponse(textResponse: string): unknown {
return yaml.load(textResponse);
}
Expand Down

0 comments on commit 05d6dbb

Please sign in to comment.