Skip to content

Tool call + content - UI change #382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Added
- Reasoning component.

## Changed
- Allow for tool call + content in the UI.

## [v0.6.2] - 23.06.2025

### Added
Expand Down
3 changes: 1 addition & 2 deletions backend/src/neuroagent/app/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def format_messages_vercel(
else:
status = "pending"

parts.append(TextPartVercel(text=text_content or ""))
parts.append(
ToolCallPartVercel(
toolInvocation=ToolCallVercel(
Expand All @@ -392,8 +393,6 @@ def format_messages_vercel(
)
)

parts.append(TextPartVercel(text=text_content or ""))

# Merge the actual tool result back into the buffered part
elif msg.entity == Entity.TOOL:
tool_call_id = json.loads(msg.content).get("tool_call_id")
Expand Down
3 changes: 2 additions & 1 deletion backend/src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
StrainGetOneTool,
SubjectGetAllTool,
SubjectGetOneTool,
WeatherTool,
WebSearchTool,
)
from neuroagent.tools.base_tool import BaseTool
Expand Down Expand Up @@ -341,7 +342,7 @@ def get_tool_list(
SubjectGetAllTool,
SubjectGetOneTool,
# NowTool,
# WeatherTool,
WeatherTool,
# RandomPlotGeneratorTool,
]

Expand Down
18 changes: 9 additions & 9 deletions backend/tests/app/routers/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,22 +575,22 @@ async def test_get_thread_messages_vercel_format(
assert len(parts) == 3

first_part = parts[0]
assert first_part.get("type") == "tool-invocation"
tool_inv = first_part.get("toolInvocation")
assert first_part.get("type") == "text"
assert first_part.get("text") == ""

second_part = parts[1]
assert second_part.get("type") == "tool-invocation"
tool_inv = second_part.get("toolInvocation")
assert isinstance(tool_inv, dict)
assert tool_inv.get("toolCallId") == "mock_id_tc"
assert tool_inv.get("toolName") == "get_weather"
assert tool_inv.get("args") == {"location": "Geneva"}
assert tool_inv.get("state") == "call"
assert tool_inv.get("results") is None

second_part = parts[1]
assert second_part.get("type") == "text"
assert second_part.get("text") == ""

second_part = parts[2]
assert second_part.get("type") == "text"
assert second_part.get("text") == "sample response content."
third_part = parts[2]
assert third_part.get("type") == "text"
assert third_part.get("text") == "sample response content."

annotations = item.get("annotations")
assert isinstance(annotations, list)
Expand Down
2 changes: 1 addition & 1 deletion backend/tests/app/test_app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def test_format_messages_vercel():
createdAt=datetime(2025, 6, 4, 14, 4, 41),
content="DUMMY_AI_CONTENT",
parts=[
TextPartVercel(type="text", text="DUMMY_AI_TOOL_CONTENT"),
ToolCallPartVercel(
type="tool-invocation",
toolInvocation=ToolCallVercel(
Expand All @@ -404,7 +405,6 @@ def test_format_messages_vercel():
results=None,
),
),
TextPartVercel(type="text", text="DUMMY_AI_TOOL_CONTENT"),
TextPartVercel(type="text", text="DUMMY_AI_CONTENT"),
],
annotations=[
Expand Down
47 changes: 47 additions & 0 deletions frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"@radix-ui/react-dropdown-menu": "^2.1.5",
"@radix-ui/react-label": "^2.1.1",
"@radix-ui/react-popover": "^1.1.6",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slot": "^1.2.3",
"@radix-ui/react-tooltip": "^1.1.8",
"@t3-oss/env-nextjs": "^0.12.0",
Expand Down
112 changes: 58 additions & 54 deletions frontend/src/__tests__/lib/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import {
convert_tools_to_set,
isLastMessageComplete,
getToolInvocations,
getStorageIDs,
getStorageID,
getValidationStatus,
getStoppedStatus,
} from "@/lib/utils";
Expand Down Expand Up @@ -95,71 +95,75 @@ describe("getToolInvocations", () => {
expect(getToolInvocations(msgNoParts)).toEqual([]);
});
});
describe("getStorageID", () => {
test("extracts a single storage_id from a JSON string", () => {
const toolCall = {
type: "tool-invocation",
toolInvocation: {
state: "result",
result: JSON.stringify({ storage_id: "abc-123" }),
},
} as ToolInvocationUIPart;

describe("getStorageIDs", () => {
test("parses JSON strings and extracts single storage_id", () => {
const jsonResult = JSON.stringify({ storage_id: "abc-123" });
const message = {
parts: [
{
type: "tool-invocation",
toolInvocation: {
state: "result",
result: jsonResult,
},
} as ToolInvocationUIPart,
],
} as unknown as MessageStrict;

const ids = getStorageIDs(message);
const ids = getStorageID(toolCall);
expect(ids).toEqual(["abc-123"]);
});

test("parses object results and extracts array of storage_id values", () => {
const message = {
parts: [
{
type: "tool-invocation",
toolInvocation: {
state: "result",
result: { storage_id: ["id1", "id2"] },
},
} as ToolInvocationUIPart,
],
} as unknown as MessageStrict;
test("extracts multiple storage_ids from an object", () => {
const toolCall = {
type: "tool-invocation",
toolInvocation: {
state: "result",
result: { storage_id: ["id1", "id2"] },
},
} as ToolInvocationUIPart;

const ids = getStorageIDs(message);
const ids = getStorageID(toolCall);
expect(ids).toEqual(["id1", "id2"]);
});

test("ignores parts whose result is invalid JSON or has no storage_id field", () => {
const message = {
parts: [
{
type: "tool-invocation",
toolInvocation: {
state: "result",
result: "not a json",
},
} as ToolInvocationUIPart,
{
type: "tool-invocation",
toolInvocation: {
state: "result",
result: { some_other_field: "value" },
},
} as ToolInvocationUIPart,
],
} as unknown as MessageStrict;
test("returns an empty array if result is invalid JSON string", () => {
const toolCall = {
type: "tool-invocation",
toolInvocation: {
state: "result",
result: "not json",
},
} as ToolInvocationUIPart;

const ids = getStorageIDs(message);
const ids = getStorageID(toolCall);
expect(ids).toEqual([]);
});

test("returns empty array if message or parts is undefined", () => {
expect(getStorageIDs(undefined)).toEqual([]);
const msgNoParts = { parts: undefined } as unknown as MessageStrict;
expect(getStorageIDs(msgNoParts)).toEqual([]);
test("returns an empty array if storage_id is missing", () => {
const toolCall = {
type: "tool-invocation",
toolInvocation: {
state: "result",
result: { some_other_field: "value" },
},
} as ToolInvocationUIPart;

const ids = getStorageID(toolCall);
expect(ids).toEqual([]);
});

test("returns empty array if toolCall is undefined", () => {
const ids = getStorageID(undefined);
expect(ids).toEqual([]);
});

test("returns empty array if type is not 'tool-invocation'", () => {
const toolCall = {
type: "not-a-tool-invocation",
toolInvocation: {
state: "result",
result: { storage_id: "abc-123" },
},
} as unknown as ToolInvocationUIPart;

const ids = getStorageID(toolCall);
expect(ids).toEqual([]);
});
});

Expand Down
46 changes: 4 additions & 42 deletions frontend/src/components/chat/chat-message-ai.tsx
Original file line number Diff line number Diff line change
@@ -1,58 +1,20 @@
import { Card, CardContent } from "@/components/ui/card";
import { Button } from "@/components/ui/button";
import { LoaderPinwheel, ChevronDown, Wrench } from "lucide-react";
import { MemoizedMarkdown } from "@/components/memoized-markdown";

type ChatMessageAIProps = {
content?: string;
hasTools: boolean;
isToolsCollapsed: boolean;
toggleCollapse: () => void;
messageId: string;
isLoading: boolean;
isLastMessage: boolean;
};

export const ChatMessageAI = function ChatMessageAI({
content,
hasTools,
isToolsCollapsed,
toggleCollapse,
messageId,
isLoading,
isLastMessage,
}: ChatMessageAIProps) {
const lastMessageLoading = isLastMessage && isLoading;
return (
<div className="mt-4 flex justify-start">
{!lastMessageLoading ? (
hasTools ? (
<Button
className="ml-8 mt-1 rounded-full bg-blue-500 p-2.5 hover:scale-105 active:scale-[1.10]"
onClick={toggleCollapse}
>
{isToolsCollapsed ? (
<Wrench className="text-black dark:text-white" />
) : (
<ChevronDown className="text-black dark:text-white" />
)}
</Button>
) : (
<Button className="ml-8 mt-1 rounded-full bg-blue-500 p-2.5 hover:bg-blue-500">
<LoaderPinwheel className="text-black dark:text-white" />
</Button>
)
) : (
<Button
className={`ml-8 mt-1 rounded-full bg-blue-500 p-2.5 hover:bg-blue-500 ${!content && "animate-pulse"}`}
>
<LoaderPinwheel className="text-black dark:text-white" />
</Button>
)}

<Card className="mt-1 max-w-[70%] break-all border-none bg-transparent shadow-none">
<CardContent>
<div className="prose max-w-none pt-1 text-left text-lg dark:prose-invert">
<div className="flex items-center justify-start">
<Card className="ml-6 max-w-[70%] break-all border-none bg-transparent shadow-none">
<CardContent className="flex items-center py-2">
<div className="prose max-w-none text-left text-lg dark:prose-invert">
<MemoizedMarkdown content={content || ""} id={messageId} />
</div>
</CardContent>
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/components/chat/chat-message-human.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export const ChatMessageHuman = function ChatMessageHuman({
content,
}: ChatMessageHumanProps) {
return (
<div className="flex justify-end break-all border-solid p-8">
<div className="flex justify-end break-all border-solid p-4 pt-6">
<Card className="max-w-[70%]">
<CardContent>
<h1 className="whitespace-pre-wrap pt-8 text-lg">{content}</h1>
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/components/chat/chat-message-loading.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { LoaderPinwheel } from "lucide-react";
export const ChatMessageLoading = function ChatMessageLoading() {
return (
<div className="mt-4 flex justify-start">
<Button className="ml-8 mt-1 animate-pulse rounded-full bg-blue-500 p-2.5 hover:bg-blue-500">
<Button className="ml-12 mt-1 animate-pulse rounded-full bg-blue-500 p-2.5 hover:bg-blue-500">
<LoaderPinwheel className="text-black dark:text-white" />
</Button>
</div>
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/components/chat/chat-message-tool.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export const ChatMessageTool = function ChatMessageTool({
?.label ?? tool.toolName;

return (
<div className="border-white-300 ml-5 border-solid p-3.5">
<div className="border-white-300 ml-5 border-solid p-0.5">
<HumanValidationDialog
key={tool.toolCallId}
threadId={threadId}
Expand All @@ -91,7 +91,7 @@ export const ChatMessageTool = function ChatMessageTool({
setMessage={setMessage}
mutate={mutate}
/>
<div className="flex justify-start">
<div className="ml-5 flex justify-start">
<ToolCallCollapsible
tool={tool}
stopped={stopped}
Expand Down
Loading