Skip to content

Commit

Permalink
review pass adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
justschen committed Jan 13, 2025
1 parent 1e24294 commit 742aec3
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 110 deletions.
15 changes: 3 additions & 12 deletions src/base/htmlTracer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
TraceMaterializedNodeType,
} from './htmlTracerTypes';
import {
MaterializedChatMesageImage,
MaterializedChatMessageImage,
MaterializedChatMessage,
MaterializedChatMessageTextChunk,
MaterializedContainer,
Expand Down Expand Up @@ -225,23 +225,14 @@ async function serializeMaterialized(
value: materialized.text,
tokens: await materialized.upperBoundTokenCount(tokenizer),
};
} else if (materialized instanceof MaterializedChatMesageImage) {
} else if (materialized instanceof MaterializedChatMessageImage) {
return {
...common,
name: materialized.id.toString(),
id: materialized.id,
type: TraceMaterializedNodeType.Image,
value: materialized.imageUrl,
value: materialized.src,
tokens: await materialized.upperBoundTokenCount(tokenizer),
children: await Promise.all(
materialized.children.map(c =>
serializeMaterialized(
tokenizer,
c,
inChatMessage || materialized instanceof MaterializedChatMesageImage
)
)
),
}
} else {
const containerCommon = {
Expand Down
3 changes: 1 addition & 2 deletions src/base/htmlTracerTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ export interface ITraceMaterializedChatMessageImage extends ITraceMaterializedCo
name: string
value: string;
priority: number;
tokens: number;
children: ITraceMaterializedNode[];
tokens: number,
}

4 changes: 4 additions & 0 deletions src/base/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ export async function renderPrompt<P extends BasePromptElementProps>(
'countTokens' in tokenizerMetadata
? new AnyTokenizer((text, token) => tokenizerMetadata.countTokens(text, token))
: tokenizerMetadata;

if (tokenizer instanceof AnyTokenizer && mode !== 'vscode') {
throw new Error('Tokenizer must be an instance of AnyTokenizer when not in vscode mode.');
}
const renderer = new PromptRenderer(endpoint, ctor, props, tokenizer);
const renderResult = await renderer.render(progress, token);
const { tokenCount, references, metadata } = renderResult;
Expand Down
2 changes: 1 addition & 1 deletion src/base/jsonTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export interface ImageChatMessagePieceJSON {
children: PromptNodeJSON[];
references: PromptReferenceJSON[] | undefined;
props: {
imageUrl: string;
src: string;
detail?: "low" | "high";
};
}
Expand Down
84 changes: 29 additions & 55 deletions src/base/materialized.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export type MaterializedNode =
| MaterializedContainer
| MaterializedChatMessage
| MaterializedChatMessageTextChunk
| MaterializedChatMesageImage;
| MaterializedChatMessageImage;

export const enum ContainerFlags {
/** It's a {@link LegacyPrioritization} instance */
Expand Down Expand Up @@ -93,7 +93,7 @@ export class MaterializedContainer implements IMaterializedNode {
/**
* Finds a node in the tree by ID.
*/
findById(nodeId: number): MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage | undefined {
findById(nodeId: number): MaterializedContainer | MaterializedChatMessage | undefined {
return findNodeById(nodeId, this);
}

Expand Down Expand Up @@ -173,17 +173,13 @@ export class MaterializedChatMessage implements IMaterializedNode {
}

/** Gets the text this message contains */
public get text(): (string | MaterializedChatMesageImage)[] {
public get text(): (string | MaterializedChatMessageImage)[] {
return this._text();
}

/** Gets whether the message is empty */
public get isEmpty() {
const content = this.text
.filter(element => typeof element === 'string')
.join('').trimEnd();

return !this.toolCalls?.length && !this.text.some(element => element instanceof MaterializedChatMesageImage || /\S/.test(content));
return !this.toolCalls?.length && !this.text.some(element => element instanceof MaterializedChatMessageImage || /\S/.test(element));
}

/**
Expand Down Expand Up @@ -212,7 +208,7 @@ export class MaterializedChatMessage implements IMaterializedNode {
/**
* Finds a node in the tree by ID.
*/
findById(nodeId: number): MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage | undefined {
findById(nodeId: number): MaterializedContainer | MaterializedChatMessage | MaterializedChatMessageImage | undefined {
return findNodeById(nodeId, this);
}

Expand All @@ -235,13 +231,10 @@ export class MaterializedChatMessage implements IMaterializedNode {
return tokenizer.countMessageTokens({ ...this.toChatMessage(), content: '' });
});

private readonly _text = once((): (string | MaterializedChatMesageImage)[] => {
let result: (string | MaterializedChatMesageImage)[] = [];
private readonly _text = once((): (string | MaterializedChatMessageImage)[] => {
let result: (string | MaterializedChatMessageImage)[] = [];
for (const { text, isTextSibling } of textChunks(this)) {
if (text instanceof MaterializedChatMesageImage) {
if (text.children.length > 0) {
throw new Error('Images cannot have children');
}
if (text instanceof MaterializedChatMessageImage) {
result.push(text);
continue;
}
Expand Down Expand Up @@ -270,18 +263,18 @@ export class MaterializedChatMessage implements IMaterializedNode {
.filter(element => typeof element === 'string')
.join('').trim();

if (this.text.some(element => element instanceof MaterializedChatMesageImage)) {
if (this.text.some(element => element instanceof MaterializedChatMessageImage)) {
if (this.role !== ChatRole.User) {
throw new Error('Only User messages can have images');
}

let prompts: ChatCompletionContentPart[] = this.text.map(element => {
if (typeof element === 'string') {
return { type: 'text', text: element };
} else if (element instanceof MaterializedChatMesageImage) {
} else if (element instanceof MaterializedChatMessageImage) {
return {
type: 'image_url',
image_url: { url: getEncodedBase64(element.imageUrl), detail: element.detail },
image_url: { url: getEncodedBase64(element.src), detail: element.detail },
};
} else {
throw new Error('Unexpected element type');
Expand Down Expand Up @@ -329,15 +322,14 @@ export class MaterializedChatMessage implements IMaterializedNode {
}
}

export class MaterializedChatMesageImage implements IMaterializedNode {
export class MaterializedChatMessageImage implements IMaterializedNode {
constructor(
public readonly id: number,
public readonly role: ChatRole,
public readonly imageUrl: string,
// public readonly role: ChatRole,
public readonly src: string,
public readonly priority: number,
public readonly metadata: PromptMetadata[] = [],
public readonly lineBreakBefore: LineBreakBefore,
public readonly children: MaterializedNode[],
public readonly detail?: 'low' | 'high',
) { }
upperBoundTokenCount(tokenizer: ITokenizer): Promise<number> {
Expand All @@ -351,50 +343,32 @@ export class MaterializedChatMesageImage implements IMaterializedNode {
return 0;
});

removeLowestPriorityChild(): void {
removeLowestPriorityChild(this);
}

/**
* Replaces a node in the tree with the given one, by its ID.
*/
replaceNode(nodeId: number, withNode: MaterializedNode): MaterializedNode | undefined {
return replaceNode(nodeId, this.children, withNode);
}

/**
* Finds a node in the tree by ID.
*/
findById(nodeId: number): MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage | undefined {
return findNodeById(nodeId, this);
}

isEmpty: boolean = false;
}

function isContainerType(
node: MaterializedNode
): node is MaterializedContainer | MaterializedChatMessage {
return !(node instanceof MaterializedChatMessageTextChunk);
return !(node instanceof MaterializedChatMessageTextChunk || node instanceof MaterializedChatMessageImage);
}

function assertContainerOrChatMessage(
v: MaterializedNode
): asserts v is MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage {
if (!(v instanceof MaterializedContainer) && !(v instanceof MaterializedChatMessage) && !(v instanceof MaterializedChatMesageImage)) {
): asserts v is MaterializedContainer | MaterializedChatMessage | MaterializedChatMessageImage {
if (!(v instanceof MaterializedContainer) && !(v instanceof MaterializedChatMessage) && !(v instanceof MaterializedChatMessageImage)) {
throw new Error(`Cannot have a text node outside a ChatMessage. Text: "${v.text}"`);
}
}

function* textChunks(
node: MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage,
node: MaterializedContainer | MaterializedChatMessage,
isTextSibling = false
): Generator<{ text: MaterializedChatMessageTextChunk | MaterializedChatMesageImage; isTextSibling: boolean }> {
): Generator<{ text: MaterializedChatMessageTextChunk | MaterializedChatMessageImage; isTextSibling: boolean }> {
for (const child of node.children) {
if (child instanceof MaterializedChatMessageTextChunk) {
yield { text: child, isTextSibling };
isTextSibling = true;
} else if (child instanceof MaterializedChatMesageImage) {
} else if (child instanceof MaterializedChatMessageImage) {
yield { text: child, isTextSibling: false };
} else {
if (child)
Expand All @@ -408,15 +382,15 @@ function removeLowestPriorityLegacy(root: MaterializedNode) {
let lowest:
| undefined
| {
chain: (MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage)[];
node: MaterializedChatMessageTextChunk;
chain: (MaterializedContainer | MaterializedChatMessage)[];
node: MaterializedChatMessageTextChunk | MaterializedChatMessageImage;
};

function findLowestInTree(
node: MaterializedNode,
chain: (MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage)[]
chain: (MaterializedContainer | MaterializedChatMessage)[]
) {
if (node instanceof MaterializedChatMessageTextChunk) {
if (node instanceof MaterializedChatMessageTextChunk || node instanceof MaterializedChatMessageImage) {
if (!lowest || node.priority < lowest.node.priority) {
lowest = { chain: chain.slice(), node };
}
Expand Down Expand Up @@ -458,11 +432,11 @@ function removeLowestPriorityLegacy(root: MaterializedNode) {
}
}

function removeLowestPriorityChild(node: MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage) {
function removeLowestPriorityChild(node: MaterializedContainer | MaterializedChatMessage) {
let lowest:
| undefined
| {
chain: (MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage)[];
chain: (MaterializedContainer | MaterializedChatMessage)[];
index: number;
value: MaterializedNode;
lowestNested?: number;
Expand Down Expand Up @@ -498,7 +472,7 @@ function removeLowestPriorityChild(node: MaterializedContainer | MaterializedCha

const containingList = lowest.chain[lowest.chain.length - 1].children;
if (
lowest.value instanceof MaterializedChatMessageTextChunk ||
lowest.value instanceof MaterializedChatMessageTextChunk || lowest.value instanceof MaterializedChatMessageImage ||
(lowest.value instanceof MaterializedContainer && lowest.value.has(ContainerFlags.IsChunk)) ||
(isContainerType(lowest.value) && !lowest.value.children.length)
) {
Expand Down Expand Up @@ -567,8 +541,8 @@ function replaceNode(

function findNodeById(
nodeId: number,
container: MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage
): MaterializedContainer | MaterializedChatMessage | MaterializedChatMesageImage | undefined {
container: MaterializedContainer | MaterializedChatMessage
): MaterializedContainer | MaterializedChatMessage | undefined {
if (container.id === nodeId) {
return container;
}
Expand Down
4 changes: 1 addition & 3 deletions src/base/promptElements.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,8 @@ export interface TextChunkProps extends BasePromptElementProps {
}

export interface ImageProps extends BasePromptElementProps {
imageUrl: string;
src: string;
detail?: 'low' | 'high';
role?: ChatRole.User;
}

/**
Expand Down Expand Up @@ -229,7 +228,6 @@ async function getTextContentBelowBudget(

export class BaseImageMessage extends BaseChatMessage<ImageProps> {
constructor(props: ImageProps) {
props.role = ChatRole.User;
super(props);
}
}
Expand Down
Loading

0 comments on commit 742aec3

Please sign in to comment.