Skip to content

Commit

Permalink
fix: allow passing KeepWith to tool calls (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
connor4312 authored Feb 2, 2025
1 parent fa634ca commit c600ee3
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 12 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ class MyPromptElement extends PromptElement {

Unlike `<Chunk />`, which prevents pruning of any children and simply removes them as a block, `<KeepWith />` in this case will allow the `ToolCallResponse` to be pruned, and if it's fully pruned it will also remove the `ToolCallRequest`.

You can also pass the `KeepWith` instance to `toolCalls` in `AssistantMessage`s.

#### Debugging Budgeting

You can set a `tracer` property on the `PromptElement` to debug how your elements are rendered and how this library allocates your budget. We include a basic `HTMLTracer` you can use, which can be served on an address:
Expand Down
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@vscode/prompt-tsx",
"version": "0.3.0-alpha.17",
"version": "0.3.0-alpha.18",
"description": "Declare LLM prompts with TSX",
"main": "./dist/base/index.js",
"types": "./dist/base/index.d.ts",
Expand Down
57 changes: 49 additions & 8 deletions src/base/materialized.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
*--------------------------------------------------------------------------------------------*/

import { once } from './once';
import { ChatCompletionContentPart, ChatMessage, ChatMessageToolCall, ChatRole } from './openai';
import {
AssistantChatMessage,
ChatCompletionContentPart,
ChatMessage,
ChatMessageToolCall,
ChatRole,
} from './openai';
import { ToolCall } from './promptElements';
import { PromptMetadata } from './results';
import { ITokenizer } from './tokenizer/tokenizer';

Expand Down Expand Up @@ -208,7 +215,7 @@ export class MaterializedChatMessage implements IMaterializedNode {
public readonly id: number,
public readonly role: ChatRole,
public readonly name: string | undefined,
public readonly toolCalls: ChatMessageToolCall[] | undefined,
public toolCalls: readonly ToolCall[] | undefined,
public readonly toolCallId: string | undefined,
public readonly priority: number,
public readonly metadata: PromptMetadata[],
Expand Down Expand Up @@ -353,12 +360,18 @@ export class MaterializedChatMessage implements IMaterializedNode {
...(this.name ? { name: this.name } : {}),
};
} else if (this.role === ChatRole.Assistant) {
return {
role: this.role,
content,
...(this.toolCalls ? { tool_calls: this.toolCalls } : {}),
...(this.name ? { name: this.name } : {}),
};
const msg: AssistantChatMessage = { role: this.role, content };
if (this.name) {
msg.name = this.name;
}
if (this.toolCalls?.length) {
msg.tool_calls = this.toolCalls.map(tc => ({
function: tc.function,
id: tc.id,
type: tc.type,
}));
}
return msg;
} else if (this.role === ChatRole.User) {
return {
role: this.role,
Expand Down Expand Up @@ -639,6 +652,15 @@ function removeOtherKeepWiths(nodeThatWasRemoved: MaterializedNode) {
for (const node of forEachNode(root)) {
if (isKeepWith(node) && removeKeepWithIds.has(node.keepWithId)) {
removeNode(node);
} else if (node instanceof MaterializedChatMessage && node.toolCalls) {
node.toolCalls = filterIfDifferent(
node.toolCalls,
c => !(c.keepWith && removeKeepWithIds.has(c.keepWith.id))
);

if (node.isEmpty) { // may have become empty if it only contained tool calls
removeNode(node);
}
}
}
} finally {
Expand Down Expand Up @@ -700,3 +722,22 @@ function getEncodedBase64(base64String: string): string {

return base64String;
}

/** Like Array.filter(), but only clones the array if a change is made */
function filterIfDifferent<T>(arr: readonly T[], predicate: (item: T) => boolean): readonly T[] {
for (let i = 0; i < arr.length; i++) {
if (predicate(arr[i])) {
continue;
}

const newArr = arr.slice(0, i);
for (let k = i + 1; k < arr.length; k++) {
if (predicate(arr[k])) {
newArr.push(arr[k]);
}
}
return newArr;
}

return arr;
}
12 changes: 11 additions & 1 deletion src/base/promptElements.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ export interface ToolCall {
id: string;
function: ToolFunction;
type: 'function';
/**
* A `<KeepWith />` element, created from {@link useKeepWith}, that wraps
* the tool result. This will ensure that if the tool result is pruned,
* the tool call is also pruned to avoid errors.
*/
keepWith?: KeepWithCtor;
}

export interface ToolFunction {
Expand Down Expand Up @@ -388,6 +394,8 @@ export abstract class AbstractKeepWith extends PromptElement {

let keepWidthId = 0;

export type KeepWithCtor = typeof AbstractKeepWith & { id: number };

/**
* Returns a PromptElement that ensures each wrapped element is retained only
* so long as each other wrapped is not empty.
Expand All @@ -412,9 +420,11 @@ let keepWidthId = 0;
* `ToolCallResponse` to be pruned, and if it's fully pruned it will also
* remove the `ToolCallRequest`.
*/
export function useKeepWith(): PromptElementCtor<BasePromptElementProps, void> {
export function useKeepWith(): KeepWithCtor {
const id = keepWidthId++;
return class KeepWith extends AbstractKeepWith {
public static readonly id = id;

public readonly id = id;

render(): PromptPiece {
Expand Down
65 changes: 65 additions & 0 deletions src/base/test/renderer.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,71 @@ suite('PromptRenderer', () => {

assert.deepStrictEqual(messages, ['a\nb\nc#\nc\nd\ne\nf', 'b\nc#\nc\nd\ne\nf', 'f', '']);
});

test('trims tool calls', async () => {
const KeepWith1 = useKeepWith();
const KeepWith2 = useKeepWith();
const KeepWith3 = useKeepWith();
const it = pruneDown(
<>
<AssistantMessage
toolCalls={[
{
id: '1',
type: 'function',
function: { name: 'tool1', arguments: '"a"' },
keepWith: KeepWith1,
},
{
id: '2',
type: 'function',
function: { name: 'tool2', arguments: '"b"' },
keepWith: KeepWith2,
},
]}
/>
<AssistantMessage
toolCalls={[
{
id: '3',
type: 'function',
function: { name: 'tool3', arguments: '"c"' },
keepWith: KeepWith3,
},
]}
/>
<UserMessage>
<KeepWith1 priority={1}>
<TextChunk priority={1}>a</TextChunk>
</KeepWith1>
<KeepWith2 priority={2}>
<TextChunk priority={2}>b</TextChunk>
</KeepWith2>
<KeepWith3 priority={3}>
<TextChunk priority={3}>c</TextChunk>
</KeepWith3>
</UserMessage>
</>
);

let messages: { content: string[]; tcIds: string[] }[] = [];
for await (const m of it) {
messages.push({
content: m.map(m => `${m.role}: ${m.content}`),
tcIds: m
.filter(m => m.role === ChatRole.Assistant)
.flatMap(m => m.tool_calls ?? [])
.map(tc => tc.id),
});
}

assert.deepStrictEqual(messages, [
{ content: ['assistant: ', 'assistant: ', 'user: a\nb\nc'], tcIds: ['1', '2', '3'] },
{ content: ['assistant: ', 'assistant: ', 'user: b\nc'], tcIds: ['2', '3'] },
{ content: ['assistant: ', 'user: c'], tcIds: ['3'] },
{ content: [], tcIds: [] },
]);
});
});

suite('prunes in priority order', () => {
Expand Down

0 comments on commit c600ee3

Please sign in to comment.