Skip to content

feat: Add tool_calls support to JdbcChatMemoryRepository #3343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@

/**
* HSQLDB-specific SQL dialect for chat memory repository.
*
* @author DoHoon Kim
* @since 1.0.0
*/
public class HsqldbChatMemoryRepositoryDialect implements JdbcChatMemoryRepositoryDialect {

@Override
public String getSelectMessagesSql() {
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY timestamp ASC";
return "SELECT content, type, tool_calls FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY timestamp ASC";
}

@Override
public String getInsertMessageSql() {
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, timestamp) VALUES (?, ?, ?, ?)";
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, tool_calls, timestamp) VALUES (?, ?, ?, ?, ?)";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.sql.Timestamp;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;

import javax.sql.DataSource;
Expand All @@ -36,6 +37,7 @@
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
Expand All @@ -53,6 +55,7 @@
* @author Linar Abzaltdinov
* @author Mark Pollack
* @author Yanming Zhou
* @author DoHoon Kim
* @since 1.0.0
*/
public final class JdbcChatMemoryRepository implements ChatMemoryRepository {
Expand Down Expand Up @@ -124,7 +127,15 @@ public void setValues(PreparedStatement ps, int i) throws SQLException {
ps.setString(1, this.conversationId);
ps.setString(2, message.getText());
ps.setString(3, message.getMessageType().name());
ps.setTimestamp(4, new Timestamp(this.instantSeq.getAndIncrement()));

// Handle tool_calls column
String toolCallsJson = null;
if (message instanceof AssistantMessage assistantMessage && assistantMessage.hasToolCalls()) {
toolCallsJson = ModelOptionsUtils.toJsonString(assistantMessage.getToolCalls());
}
ps.setString(4, toolCallsJson);

ps.setTimestamp(5, new Timestamp(this.instantSeq.getAndIncrement()));
}

@Override
Expand All @@ -140,10 +151,24 @@ private static class MessageRowMapper implements RowMapper<Message> {
public Message mapRow(ResultSet rs, int i) throws SQLException {
var content = rs.getString(1);
var type = MessageType.valueOf(rs.getString(2));
var toolCallsJson = rs.getString(3);

return switch (type) {
case USER -> new UserMessage(content);
case ASSISTANT -> new AssistantMessage(content);
case ASSISTANT -> {
List<AssistantMessage.ToolCall> toolCalls = List.of();
if (toolCallsJson != null && !toolCallsJson.isBlank()) {
try {
toolCalls = ModelOptionsUtils.OBJECT_MAPPER.readValue(toolCallsJson,
ModelOptionsUtils.OBJECT_MAPPER.getTypeFactory()
.constructCollectionType(List.class, AssistantMessage.ToolCall.class));
}
catch (Exception e) {
logger.warn("Failed to deserialize tool calls JSON: {}", toolCallsJson, e);
}
}
yield new AssistantMessage(content, Map.of(), toolCalls);
}
case SYSTEM -> new SystemMessage(content);
// The content is always stored empty for ToolResponseMessages.
// If we want to capture the actual content, we need to extend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

package org.springframework.ai.chat.memory.repository.jdbc;

import javax.sql.DataSource;
import java.sql.Connection;

import javax.sql.DataSource;

/**
* Abstraction for database-specific SQL for chat memory repository.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@
* MySQL dialect for chat memory repository.
*
* @author Mark Pollack
* @author DoHoon Kim
* @since 1.0.0
*/
public class MysqlChatMemoryRepositoryDialect implements JdbcChatMemoryRepositoryDialect {

@Override
public String getSelectMessagesSql() {
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp`";
return "SELECT content, type, tool_calls FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp`";
}

@Override
public String getInsertMessageSql() {
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, `timestamp`) VALUES (?, ?, ?, ?)";
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, tool_calls, `timestamp`) VALUES (?, ?, ?, ?, ?)";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@
* Dialect for Postgres.
*
* @author Mark Pollack
* @author DoHoon Kim
* @since 1.0.0
*/
public class PostgresChatMemoryRepositoryDialect implements JdbcChatMemoryRepositoryDialect {

@Override
public String getSelectMessagesSql() {
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY \"timestamp\"";
return "SELECT content, type, tool_calls FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY \"timestamp\"";
}

@Override
public String getInsertMessageSql() {
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, \"timestamp\") VALUES (?, ?, ?, ?)";
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, tool_calls, \"timestamp\") VALUES (?, ?, ?, ?::jsonb, ?)";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@
* Dialect for SQL Server.
*
* @author Mark Pollack
* @author DoHoon Kim
* @since 1.0.0
*/
public class SqlServerChatMemoryRepositoryDialect implements JdbcChatMemoryRepositoryDialect {

@Override
public String getSelectMessagesSql() {
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY [timestamp]";
return "SELECT content, type, tool_calls FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY [timestamp]";
}

@Override
public String getInsertMessageSql() {
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, [timestamp]) VALUES (?, ?, ?, ?)";
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, tool_calls, [timestamp]) VALUES (?, ?, ?, ?, ?)";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ CREATE TABLE SPRING_AI_CHAT_MEMORY (
conversation_id VARCHAR(36) NOT NULL,
content LONGVARCHAR NOT NULL,
type VARCHAR(10) NOT NULL,
tool_calls LONGVARCHAR NULL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS SPRING_AI_CHAT_MEMORY (
conversation_id VARCHAR(36) NOT NULL,
content TEXT NOT NULL,
type VARCHAR(10) NOT NULL,
tool_calls LONGTEXT NULL,
`timestamp` TIMESTAMP NOT NULL,
CONSTRAINT TYPE_CHECK CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL'))
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS SPRING_AI_CHAT_MEMORY (
conversation_id VARCHAR(36) NOT NULL,
content TEXT NOT NULL,
type VARCHAR(10) NOT NULL CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL')),
tool_calls JSONB NULL,
"timestamp" TIMESTAMP NOT NULL
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ CREATE TABLE SPRING_AI_CHAT_MEMORY (
conversation_id VARCHAR(36) NOT NULL,
content NVARCHAR(MAX) NOT NULL,
type VARCHAR(10) NOT NULL,
tool_calls NVARCHAR(MAX) NULL,
[timestamp] DATETIME2 NOT NULL DEFAULT SYSDATETIME(),
CONSTRAINT TYPE_CHECK CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL'))
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.sql.Timestamp;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -77,10 +78,10 @@ void saveMessagesSingleMessage(String content, MessageType messageType) {
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect
.from(this.jdbcTemplate.getDataSource());
String selectSql = dialect.getSelectMessagesSql()
.replace("content, type", "conversation_id, content, type, timestamp");
.replace("content, type, tool_calls", "conversation_id, content, type, tool_calls, timestamp");
var result = this.jdbcTemplate.queryForMap(selectSql, conversationId);

assertThat(result.size()).isEqualTo(4);
assertThat(result.size()).isEqualTo(5);
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
assertThat(result.get("content")).isEqualTo(message.getText());
assertThat(result.get("type")).isEqualTo(messageType.name());
Expand All @@ -102,7 +103,7 @@ void saveMessagesMultipleMessages() {
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect
.from(this.jdbcTemplate.getDataSource());
String selectSql = dialect.getSelectMessagesSql()
.replace("content, type", "conversation_id, content, type, timestamp");
.replace("content, type, tool_calls", "conversation_id, content, type, tool_calls, timestamp");
var results = this.jdbcTemplate.queryForList(selectSql, conversationId);

assertThat(results).hasSize(messages.size());
Expand Down Expand Up @@ -186,6 +187,67 @@ void testMessageOrder() {
"4-Fourth message");
}

@Test
void saveAndRetrieveAssistantMessageWithToolCalls() {
String conversationId = UUID.randomUUID().toString();

// Create tool calls
List<AssistantMessage.ToolCall> toolCalls = List.of(
new AssistantMessage.ToolCall("call_1", "function", "get_weather", "{\"location\":\"Seoul\"}"),
new AssistantMessage.ToolCall("call_2", "function", "get_time", "{\"timezone\":\"Asia/Seoul\"}"));

var assistantMessage = new AssistantMessage("I'll help you with that.", Map.of(), toolCalls);

this.chatMemoryRepository.saveAll(conversationId, List.of(assistantMessage));

// Retrieve and verify
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);

Message retrievedMessage = retrievedMessages.get(0);
assertThat(retrievedMessage).isInstanceOf(AssistantMessage.class);

AssistantMessage retrievedAssistantMessage = (AssistantMessage) retrievedMessage;
assertThat(retrievedAssistantMessage.getText()).isEqualTo("I'll help you with that.");
assertThat(retrievedAssistantMessage.hasToolCalls()).isTrue();
assertThat(retrievedAssistantMessage.getToolCalls()).hasSize(2);

// Verify first tool call
AssistantMessage.ToolCall firstToolCall = retrievedAssistantMessage.getToolCalls().get(0);
assertThat(firstToolCall.id()).isEqualTo("call_1");
assertThat(firstToolCall.type()).isEqualTo("function");
assertThat(firstToolCall.name()).isEqualTo("get_weather");
assertThat(firstToolCall.arguments()).isEqualTo("{\"location\":\"Seoul\"}");

// Verify second tool call
AssistantMessage.ToolCall secondToolCall = retrievedAssistantMessage.getToolCalls().get(1);
assertThat(secondToolCall.id()).isEqualTo("call_2");
assertThat(secondToolCall.type()).isEqualTo("function");
assertThat(secondToolCall.name()).isEqualTo("get_time");
assertThat(secondToolCall.arguments()).isEqualTo("{\"timezone\":\"Asia/Seoul\"}");
}

@Test
void saveAndRetrieveAssistantMessageWithoutToolCalls() {
String conversationId = UUID.randomUUID().toString();

var assistantMessage = new AssistantMessage("Simple response without tool calls.");

this.chatMemoryRepository.saveAll(conversationId, List.of(assistantMessage));

// Retrieve and verify
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);

Message retrievedMessage = retrievedMessages.get(0);
assertThat(retrievedMessage).isInstanceOf(AssistantMessage.class);

AssistantMessage retrievedAssistantMessage = (AssistantMessage) retrievedMessage;
assertThat(retrievedAssistantMessage.getText()).isEqualTo("Simple response without tool calls.");
assertThat(retrievedAssistantMessage.hasToolCalls()).isFalse();
assertThat(retrievedAssistantMessage.getToolCalls()).isEmpty();
}

/**
* Base configuration for all integration tests.
*/
Expand Down
Loading