Skip to content

Commit

Permalink
feat(custom-tool): add ssl support for code interpreter (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
matoushavlena authored Sep 4, 2024
1 parent 3f20727 commit 7de84d1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 18 deletions.
12 changes: 6 additions & 6 deletions src/tools/custom.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ describe("CustomTool", () => {
},
});

const customTool = await CustomTool.fromSourceCode("http://localhost", "source code");
const customTool = await CustomTool.fromSourceCode({ url: "http://localhost" }, "source code");

expect(customTool.name).toBe("test");
expect(customTool.description).toBe("A test tool");
Expand All @@ -76,9 +76,9 @@ describe("CustomTool", () => {
},
});

await expect(CustomTool.fromSourceCode("http://localhost", "source code")).rejects.toThrow(
"Error parsing tool",
);
await expect(
CustomTool.fromSourceCode({ url: "http://localhost" }, "source code"),
).rejects.toThrow("Error parsing tool");
});

it("should run the custom tool", async () => {
Expand All @@ -101,7 +101,7 @@ describe("CustomTool", () => {
});

const customTool = await CustomTool.fromSourceCode(
"http://localhost",
{ url: "http://localhost" },
"source code",
"executor-id",
);
Expand Down Expand Up @@ -148,7 +148,7 @@ describe("CustomTool", () => {
});

const customTool = await CustomTool.fromSourceCode(
"http://localhost",
{ url: "http://localhost" },
"source code",
"executor-id",
);
Expand Down
28 changes: 20 additions & 8 deletions src/tools/custom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@ import { FrameworkError } from "@/errors.js";
import { z } from "zod";
import { validate } from "@/internals/helpers/general.js";
import { CodeInterpreterService } from "bee-proto/code_interpreter/v1/code_interpreter_service_connect";
import { CodeInterpreterOptions } from "./python/python.js";

export class CustomToolCreateError extends FrameworkError {}
export class CustomToolExecuteError extends FrameworkError {}

const toolOptionsSchema = z
.object({
codeInterpreterUrl: z.string().url(),
codeInterpreter: z.object({
url: z.string().url(),
connectionOptions: z.any().optional(),
}),
sourceCode: z.string().min(1),
name: z.string().min(1),
description: z.string().min(1),
Expand All @@ -38,10 +42,14 @@ const toolOptionsSchema = z

export type CustomToolOptions = z.output<typeof toolOptionsSchema> & BaseToolOptions;

function createCodeInterpreterClient(url: string) {
function createCodeInterpreterClient(codeInterpreter: CodeInterpreterOptions) {
return createPromiseClient(
CodeInterpreterService,
createGrpcTransport({ baseUrl: url, httpVersion: "2" }),
createGrpcTransport({
baseUrl: codeInterpreter.url,
httpVersion: "2",
nodeOptions: codeInterpreter.connectionOptions,
}),
);
}

Expand All @@ -65,7 +73,7 @@ export class CustomTool extends Tool<StringToolOutput, CustomToolOptions> {
) {
validate(options, toolOptionsSchema);
super(options);
this.client = client || createCodeInterpreterClient(options.codeInterpreterUrl);
this.client = client || createCodeInterpreterClient(options.codeInterpreter);
this.name = options.name;
this.description = options.description;
}
Expand All @@ -89,11 +97,15 @@ export class CustomTool extends Tool<StringToolOutput, CustomToolOptions> {

loadSnapshot(snapshot: ReturnType<typeof this.createSnapshot>): void {
super.loadSnapshot(snapshot);
this.client = createCodeInterpreterClient(this.options.codeInterpreterUrl);
this.client = createCodeInterpreterClient(this.options.codeInterpreter);
}

static async fromSourceCode(codeInterpreterUrl: string, sourceCode: string, executorId?: string) {
const client = createCodeInterpreterClient(codeInterpreterUrl);
static async fromSourceCode(
codeInterpreter: CodeInterpreterOptions,
sourceCode: string,
executorId?: string,
) {
const client = createCodeInterpreterClient(codeInterpreter);
const response = await client.parseCustomTool({ toolSourceCode: sourceCode });

if (response.response.case === "error") {
Expand All @@ -104,7 +116,7 @@ export class CustomTool extends Tool<StringToolOutput, CustomToolOptions> {

return new CustomTool(
{
codeInterpreterUrl,
codeInterpreter,
sourceCode,
name: toolName,
description: toolDescription,
Expand Down
10 changes: 6 additions & 4 deletions src/tools/python/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ import { ValidationError } from "ajv";
import { ConnectionOptions } from "node:tls";
import { AnySchemaLike } from "@/internals/helpers/schema.js";

export interface CodeInterpreterOptions {
url: string;
connectionOptions?: ConnectionOptions;
}

export interface PythonToolOptions extends BaseToolOptions {
codeInterpreter: {
url: string;
connectionOptions?: ConnectionOptions;
};
codeInterpreter: CodeInterpreterOptions;
executorId?: string;
preprocess?: { llm: LLM<BaseLLMOutput>; promptTemplate: PromptTemplate<"input"> };
storage: PythonStorage;
Expand Down

0 comments on commit 7de84d1

Please sign in to comment.