Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/wild-aliens-kneel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"@voltagent/postgres": patch
"@voltagent/supabase": patch
"@voltagent/core": patch
---

fix: validate UI/response messages and keep streaming response message IDs consistent across UI streams - #1010
fix(postgres/supabase): upsert conversation messages by (conversation_id, message_id) to avoid duplicate insert failures
1 change: 1 addition & 0 deletions packages/core/src/agent/agent-semantic-search.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ vi.mock("ai", () => ({
streamObject: vi.fn(),
convertToModelMessages: vi.fn((messages) => messages),
stepCountIs: vi.fn(() => vi.fn(() => false)),
validateUIMessages: vi.fn(async ({ messages }) => messages),
}));

// Mock embedding adapter
Expand Down
62 changes: 62 additions & 0 deletions packages/core/src/agent/agent.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,68 @@ describe("Agent", () => {
expect(text).toBe("Streamed response");
});

it("pre-creates streaming messages and forwards the id to UI streams", async () => {
const agent = new Agent({
name: "TestAgent",
instructions: "You are a helpful assistant",
model: mockModel as any,
});

const mockStream = {
text: Promise.resolve("Streamed response"),
textStream: (async function* () {
yield "Streamed response";
})(),
fullStream: (async function* () {
yield {
type: "text-delta" as const,
id: "text-1",
delta: "Streamed response",
text: "Streamed response",
};
})(),
usage: Promise.resolve({
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
}),
finishReason: Promise.resolve("stop"),
warnings: [],
toUIMessageStream: vi.fn().mockReturnValue((async function* () {})()),
toUIMessageStreamResponse: vi.fn(),
pipeUIMessageStreamToResponse: vi.fn(),
pipeTextStreamToResponse: vi.fn(),
toTextStreamResponse: vi.fn(),
partialOutputStream: undefined,
};

const memoryManager = agent.getMemoryManager();
const saveMessageSpy = vi.spyOn(memoryManager, "saveMessage");

vi.mocked(ai.streamText).mockReturnValue(mockStream as any);

const result = await agent.streamText("Stream this", {
userId: "user-1",
conversationId: "conv-1",
});

const savedMessage = saveMessageSpy.mock.calls
.map((call) => call[1] as UIMessage)
.find((message) => message.role === "assistant" && message.parts.length === 0);

expect(savedMessage).toBeDefined();

result.toUIMessageStream();

const callArgs = mockStream.toUIMessageStream.mock.calls[0]?.[0];
expect(callArgs).toEqual(
expect.objectContaining({
generateMessageId: expect.any(Function),
}),
);
expect(callArgs?.generateMessageId()).toBe(savedMessage?.id);
});

it("uses last-step usage for finish events when provider is anthropic", async () => {
const agent = new Agent({
name: "TestAgent",
Expand Down
120 changes: 96 additions & 24 deletions packages/core/src/agent/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ import {
createTextStreamResponse,
createUIMessageStream,
createUIMessageStreamResponse,
generateId,
generateObject,
generateText,
pipeTextStreamToResponse,
pipeUIMessageStreamToResponse,
stepCountIs,
streamObject,
streamText,
validateUIMessages,
} from "ai";
import { z } from "zod";
import { LogEvents, LoggerProxy } from "../logger";
Expand Down Expand Up @@ -183,6 +185,7 @@ const QUEUE_CONTEXT_KEY = Symbol("memoryPersistQueue");
const STEP_PERSIST_COUNT_KEY = Symbol("persistedStepCount");
const ABORT_LISTENER_ATTACHED_KEY = Symbol("abortListenerAttached");
const MIDDLEWARE_RETRY_FEEDBACK_KEY = Symbol("middlewareRetryFeedback");
const STREAM_RESPONSE_MESSAGE_ID_KEY = Symbol("streamResponseMessageId");
const DEFAULT_FEEDBACK_KEY = "satisfaction";
const DEFAULT_CONVERSATION_TITLE_PROMPT = [
"You generate concise titles for new conversations.",
Expand Down Expand Up @@ -1461,6 +1464,7 @@ export class Agent {
| undefined;
applyForcedToolChoice(aiSDKOptions, forcedToolChoice);

const responseMessageId = await this.ensureStreamingResponseMessageId(oc, buffer);
const guardrailStreamingEnabled = guardrailSet.output.length > 0;

let guardrailPipeline: GuardrailPipeline | null = null;
Expand Down Expand Up @@ -1837,6 +1841,17 @@ export class Agent {
: never;

const agent = this;
const applyResponseMessageId = (
streamOptions?: ToUIMessageStreamOptions,
): ToUIMessageStreamOptions | undefined => {
if (!responseMessageId) {
return streamOptions;
}
return {
...(streamOptions ?? {}),
generateMessageId: () => responseMessageId,
};
};

const createBaseFullStream = (): AsyncIterable<VoltAgentTextStreamPart> => {
// Wrap the base stream with abort handling
Expand Down Expand Up @@ -2013,10 +2028,11 @@ export class Agent {
const createMergedUIStream = (
streamOptions?: ToUIMessageStreamOptions,
): ToUIMessageStreamReturn => {
const resolvedStreamOptions = applyResponseMessageId(streamOptions);
const mergedStream = createUIMessageStream({
execute: async ({ writer }) => {
oc.systemContext.set("uiStreamWriter", writer);
writer.merge(getGuardrailAwareUIStream(streamOptions));
writer.merge(getGuardrailAwareUIStream(resolvedStreamOptions));
},
onError: (error) => String(error),
});
Expand Down Expand Up @@ -2080,9 +2096,10 @@ export class Agent {
const toUIMessageStreamSanitized = (
streamOptions?: ToUIMessageStreamOptions,
): ToUIMessageStreamReturn => {
const resolvedStreamOptions = applyResponseMessageId(streamOptions);
const baseStream = agent.subAgentManager.hasSubAgents()
? createMergedUIStream(streamOptions)
: getGuardrailAwareUIStream(streamOptions);
? createMergedUIStream(resolvedStreamOptions)
: getGuardrailAwareUIStream(resolvedStreamOptions);
return attachFeedbackMetadata(baseStream);
};

Expand Down Expand Up @@ -3246,6 +3263,7 @@ export class Agent {
oc.systemContext.delete(STEP_PERSIST_COUNT_KEY);
oc.systemContext.delete("conversationSteps");
oc.systemContext.delete("bailedResult");
oc.systemContext.delete(STREAM_RESPONSE_MESSAGE_ID_KEY);
oc.conversationSteps = [];
oc.output = undefined;
}
Expand All @@ -3268,6 +3286,37 @@ export class Agent {
return queue;
}

private async ensureStreamingResponseMessageId(
oc: OperationContext,
buffer: ConversationBuffer,
): Promise<string | null> {
const existing = oc.systemContext.get(STREAM_RESPONSE_MESSAGE_ID_KEY);
if (typeof existing === "string" && existing.trim().length > 0) {
return existing;
}

const messageId = generateId();
const placeholder: UIMessage = {
id: messageId,
role: "assistant",
parts: [],
};

buffer.ingestUIMessages([placeholder], false);
oc.systemContext.set(STREAM_RESPONSE_MESSAGE_ID_KEY, messageId);

if (!oc.userId || !oc.conversationId) {
return messageId;
}

if (!this.memoryManager.hasConversationMemory()) {
return messageId;
}

await this.memoryManager.saveMessage(oc, placeholder, oc.userId, oc.conversationId);
return messageId;
}

private async flushPendingMessagesOnError(oc: OperationContext): Promise<void> {
const buffer = this.getConversationBuffer(oc);
const queue = this.getMemoryPersistQueue(oc);
Expand Down Expand Up @@ -3786,10 +3835,11 @@ export class Agent {
options: BaseGenerationOptions | undefined,
buffer: ConversationBuffer,
): Promise<UIMessage[]> {
const resolvedInput = await this.validateIncomingUIMessages(input, oc);
const messages: UIMessage[] = [];

// Get system message with retriever context and working memory
const systemMessage = await this.getSystemMessage(input, oc, options);
const systemMessage = await this.getSystemMessage(resolvedInput, oc, options);
if (systemMessage) {
const systemMessagesAsUI: UIMessage[] = (() => {
if (typeof systemMessage === "string") {
Expand Down Expand Up @@ -3851,7 +3901,7 @@ export class Agent {
const useSemanticSearch = options?.semanticMemory?.enabled ?? this.hasSemanticSearchSupport();

// Extract user query for semantic search if enabled
const currentQuery = useSemanticSearch ? this.extractUserQuery(input) : undefined;
const currentQuery = useSemanticSearch ? this.extractUserQuery(resolvedInput) : undefined;

// Prepare memory read parameters
const semanticLimit = options?.semanticMemory?.semanticLimit ?? 5;
Expand All @@ -3865,7 +3915,7 @@ export class Agent {
// Create unified memory read span

const spanInput = {
query: isSemanticSearch ? currentQuery : input,
query: isSemanticSearch ? currentQuery : resolvedInput,
userId: options?.userId,
conversationId: options?.conversationId,
};
Expand Down Expand Up @@ -3908,11 +3958,11 @@ export class Agent {
// Regular memory context
// Convert model messages to UI for memory context if needed
const inputForMemory =
typeof input === "string"
? input
: Array.isArray(input) && (input as any[])[0]?.parts
? (input as UIMessage[])
: convertModelMessagesToUIMessages(input as BaseMessage[]);
typeof resolvedInput === "string"
? resolvedInput
: Array.isArray(resolvedInput) && (resolvedInput as any[])[0]?.parts
? (resolvedInput as UIMessage[])
: convertModelMessagesToUIMessages(resolvedInput as BaseMessage[]);

const result = await this.memoryManager.prepareConversationContext(
oc,
Expand Down Expand Up @@ -3952,11 +4002,11 @@ export class Agent {
if (isSemanticSearch && oc.userId && oc.conversationId) {
try {
const inputForMemory =
typeof input === "string"
? input
: Array.isArray(input) && (input as any[])[0]?.parts
? (input as UIMessage[])
: convertModelMessagesToUIMessages(input as BaseMessage[]);
typeof resolvedInput === "string"
? resolvedInput
: Array.isArray(resolvedInput) && (resolvedInput as any[])[0]?.parts
? (resolvedInput as UIMessage[])
: convertModelMessagesToUIMessages(resolvedInput as BaseMessage[]);
this.memoryManager.queueSaveInput(oc, inputForMemory, oc.userId, oc.conversationId);
} catch (_e) {
// Non-fatal: background persistence should not block message preparation
Expand All @@ -3972,16 +4022,16 @@ export class Agent {
}

// Add current input
if (typeof input === "string") {
if (typeof resolvedInput === "string") {
messages.push({
id: randomUUID(),
role: "user",
parts: [{ type: "text", text: input }],
parts: [{ type: "text", text: resolvedInput }],
});
} else if (Array.isArray(input)) {
const first = (input as any[])[0];
} else if (Array.isArray(resolvedInput)) {
const first = (resolvedInput as any[])[0];
if (first && Array.isArray(first.parts)) {
const inputMessages = input as UIMessage[];
const inputMessages = resolvedInput as UIMessage[];
const idsToReplace = new Set(
inputMessages
.map((message) => message.id)
Expand All @@ -3998,7 +4048,7 @@ export class Agent {

messages.push(...inputMessages);
} else {
messages.push(...convertModelMessagesToUIMessages(input as BaseMessage[]));
messages.push(...convertModelMessagesToUIMessages(resolvedInput as BaseMessage[]));
}
}

Expand All @@ -4022,10 +4072,32 @@ export class Agent {
agent: this,
context: oc,
});
return result?.messages || summarizedMessages;
const preparedMessages = result?.messages || summarizedMessages;
return await validateUIMessages({ messages: preparedMessages });
}

return summarizedMessages;
return await validateUIMessages({ messages: summarizedMessages });
}

private async validateIncomingUIMessages(
input: string | UIMessage[] | BaseMessage[],
oc: OperationContext,
): Promise<string | UIMessage[] | BaseMessage[]> {
if (!Array.isArray(input) || input.length === 0) {
return input;
}

const first = (input as any[])[0];
if (!first || !Array.isArray((first as { parts?: unknown }).parts)) {
return input;
}

try {
return await validateUIMessages({ messages: input as UIMessage[] });
} catch (error) {
oc.logger?.error?.("Invalid UI messages", { error });
throw error;
}
}

/**
Expand Down
22 changes: 18 additions & 4 deletions packages/postgres/src/memory-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ export class PostgreSQLMemoryAdapter implements StorageAdapter {
await client.query("BEGIN");

const messagesTable = this.getTableName(`${this.tablePrefix}_messages`);
const messageId = message.id || this.generateId();

// Ensure conversation exists
const conversation = await this.getConversation(conversationId);
Expand All @@ -336,10 +337,16 @@ export class PostgreSQLMemoryAdapter implements StorageAdapter {
await client.query(
`INSERT INTO ${messagesTable}
(conversation_id, message_id, user_id, role, parts, metadata, format_version, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (conversation_id, message_id) DO UPDATE SET
user_id = EXCLUDED.user_id,
role = EXCLUDED.role,
parts = EXCLUDED.parts,
metadata = EXCLUDED.metadata,
format_version = EXCLUDED.format_version`,
[
conversationId,
message.id || this.generateId(),
messageId,
userId,
message.role,
safeStringify(message.parts),
Expand Down Expand Up @@ -381,13 +388,20 @@ export class PostgreSQLMemoryAdapter implements StorageAdapter {

// Insert all messages
for (const message of messages) {
const messageId = message.id || this.generateId();
await client.query(
`INSERT INTO ${messagesTable}
(conversation_id, message_id, user_id, role, parts, metadata, format_version, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (conversation_id, message_id) DO UPDATE SET
user_id = EXCLUDED.user_id,
role = EXCLUDED.role,
parts = EXCLUDED.parts,
metadata = EXCLUDED.metadata,
format_version = EXCLUDED.format_version`,
[
conversationId,
message.id || this.generateId(),
messageId,
userId,
message.role,
safeStringify(message.parts),
Expand Down
Loading
Loading