Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
lib/
tree.txt
**/gpl*.*
.idea/

# Created by https://www.toptal.com/developers/gitignore/api/java,eclipse,intellij+all,visualstudiocode,maven,gradle
# Edit at https://www.toptal.com/developers/gitignore?templates=java,eclipse,intellij+all,visualstudiocode,maven,gradle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@

package schemacrawler.tools.command.aichat;

import java.sql.Connection;
import java.util.Scanner;
import java.util.logging.Level;
import java.util.logging.Logger;
import static us.fatehi.utility.Utility.isBlank;
import schemacrawler.schema.Catalog;
import schemacrawler.schemacrawler.exceptions.SchemaCrawlerException;
import schemacrawler.tools.command.aichat.options.AiChatCommandOptions;
import schemacrawler.tools.executable.BaseSchemaCrawlerCommand;
Expand Down Expand Up @@ -66,7 +64,8 @@ public void execute() {

// Load ChatAssistant implementation using ChatAssistantRegistry
final ChatAssistantRegistry registry = ChatAssistantRegistry.getChatAssistantRegistry();
final ChatAssistant chatAssistant = registry.newChatAssistant(commandOptions, catalog, connection);
final ChatAssistant chatAssistant =
registry.newChatAssistant(commandOptions, catalog, connection);

try (final ChatAssistant assistant = chatAssistant;
final Scanner scanner = new Scanner(System.in)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import schemacrawler.tools.command.aichat.options.AiChatCommandOptions;
import schemacrawler.tools.registry.BasePluginRegistry;
import us.fatehi.utility.property.PropertyName;
import us.fatehi.utility.string.StringFormat;

/** Chat assistant registry for loading chat assistant implementations. */
public final class ChatAssistantRegistry extends BasePluginRegistry {
Expand Down Expand Up @@ -74,7 +73,8 @@ public String getName() {
public Collection<PropertyName> getRegisteredPlugins() {
final List<PropertyName> assistants = new ArrayList<>();
for (final Class<? extends ChatAssistant> chatAssistantClass : chatAssistantClasses) {
assistants.add(new PropertyName(chatAssistantClass.getSimpleName(), chatAssistantClass.getName()));
assistants.add(
new PropertyName(chatAssistantClass.getSimpleName(), chatAssistantClass.getName()));
}
Collections.sort(assistants);
return assistants;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package schemacrawler.tools.command.aichat.mcp;

import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.stereotype.Service;
import schemacrawler.Version;

@Service
public class CommonService {

@Tool(name = "get-schemacrawler-version", description = "Gets the version of SchemaCrawler", returnDirect = true)
public String getSchemaCrawlerVersion(
@ToolParam(
description =
"""
Current date, as an ISO 8601 local date.
""",
required = false) final String date) {
System.out.printf("get-schemacrawler-version called with %s", date);
return Version.about();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package schemacrawler.tools.command.aichat.mcp;

import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;

import java.util.List;

/**
* Spring Boot application for the SchemaCrawler AI MCP server. This class enables the Spring AI MCP
* server capabilities.
*/
@SpringBootApplication
public class SchemaCrawlerMCPServer {

public static void main(final String[] args) {
SpringApplication.run(SchemaCrawlerMCPServer.class, args);
}

@Bean
public ToolCallbackProvider schemaCrawlerTools() {
final List<ToolCallback> tools = SpringAIUtility.toolCallbacks(SpringAIUtility.tools());
final ToolCallbackProvider toolCallbackProvider = ToolCallbackProvider.from(tools);
printTools(toolCallbackProvider);
return toolCallbackProvider;
}

@Bean
public ToolCallbackProvider weatherTools(final CommonService weatherService) {
final MethodToolCallbackProvider toolCallbackProvider =
MethodToolCallbackProvider.builder().toolObjects(weatherService).build();
printTools(toolCallbackProvider);
return toolCallbackProvider;
}

private void printTools(final ToolCallbackProvider toolCallbackProvider) {
List.of(toolCallbackProvider.getToolCallbacks())
.forEach(
toolCallback -> {
System.out.println(toolCallback.getToolDefinition().name());
System.out.println(toolCallback.getToolDefinition().description());
System.out.println(toolCallback.getToolDefinition().inputSchema());
System.out.println("----------------------------------------");
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
========================================================================
SchemaCrawler
http://www.schemacrawler.com
Copyright (c) 2000-2025, Sualeh Fatehi <sualeh@hotmail.com>.
All rights reserved.
------------------------------------------------------------------------

SchemaCrawler is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

SchemaCrawler and the accompanying materials are made available under
the terms of the Eclipse Public License v1.0, GNU General Public License
v3 or GNU Lesser General Public License v3.

You may elect to redistribute this code under any of these licenses.

The Eclipse Public License is available at:
http://www.eclipse.org/legal/epl-v10.html

The GNU General Public License v3 and the GNU Lesser General Public
License v3 are available at:
http://www.gnu.org/licenses/

========================================================================
*/

package schemacrawler.tools.command.aichat.mcp;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.module.jsonSchema.JsonSchema;
import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator;
import com.github.victools.jsonschema.generator.SchemaVersion;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.lang.Nullable;
import schemacrawler.tools.command.aichat.FunctionDefinition;
import schemacrawler.tools.command.aichat.FunctionDefinition.FunctionType;
import schemacrawler.tools.command.aichat.functions.FunctionDefinitionRegistry;
import us.fatehi.utility.UtilityMarker;

import java.util.*;
import java.util.Map.Entry;
import java.util.logging.Level;
import java.util.logging.Logger;

@UtilityMarker
public final class SpringAIUtility {

private static final Logger LOGGER = Logger.getLogger(SpringAIUtility.class.getCanonicalName());

private SpringAIUtility() {
// Prevent instantiation
}

public static List<ToolCallback> toolCallbacks(final List<ToolDefinition> tools) {
Objects.requireNonNull(tools, "Tools must not be null");
final List<ToolCallback> toolCallbacks = new ArrayList<>();
for (final ToolDefinition toolDefinition : tools) {
toolCallbacks.add(new SpringAIToolCallback(toolDefinition));
}
return toolCallbacks;
}

public static List<ToolDefinition> tools() {

final List<ToolDefinition> toolDefinitions = new ArrayList<>();
for (final FunctionDefinition<?> functionDefinition :
FunctionDefinitionRegistry.getFunctionDefinitionRegistry().getFunctionDefinitions()) {
if (functionDefinition.getFunctionType() != FunctionType.USER) {
continue;
}

try {
final ToolDefinition toolDefinition =
ToolDefinition.builder()
.name(functionDefinition.getName())
.description(functionDefinition.getDescription())
.inputSchema(generateToolInput(functionDefinition.getParametersClass()))
.build();
toolDefinitions.add(toolDefinition);
} catch (final Exception e) {
LOGGER.log(
Level.WARNING, String.format("Could not load <%s>", functionDefinition.getName()), e);
}
}

return toolDefinitions;
}

/**
* @see org.springframework.ai.util.json.schema.JsonSchemaGenerator
*/
private static String generateToolInput(final Class<?> parametersClass) throws Exception {
Objects.requireNonNull(parametersClass, "Parameters must not be null");

final Map<String, JsonNode> parametersJsonSchema = jsonSchema(parametersClass);
final ObjectNode schema = JsonParser.getObjectMapper().createObjectNode();
schema.put("$schema", SchemaVersion.DRAFT_2020_12.getIdentifier());
schema.put("type", "object");

final List<String> required = new ArrayList<>();
final ObjectNode properties = schema.putObject("properties");
for (final Entry<String, JsonNode> parameter : parametersJsonSchema.entrySet()) {
final String parameterName = parameter.getKey();
final JsonNode parameterSchema = parameter.getValue();
if (parameterSchema.has("required") && parameterSchema.get("required").asBoolean()) {
((ObjectNode) parameterSchema).remove("required");
required.add(parameterName);
}
properties.set(parameterName, parameterSchema);
}
final ArrayNode requiredArray = schema.putArray("required");
required.forEach(requiredArray::add);

schema.put("additionalProperties", false);

return schema.toPrettyString();
}

private static Map<String, JsonNode> jsonSchema(final Class<?> parametersClass) throws Exception {
final ObjectMapper mapper = new ObjectMapper();
final JsonSchemaGenerator schemaGen = new JsonSchemaGenerator(mapper);
final JsonSchema schema = schemaGen.generateSchema(parametersClass);
final JsonNode schemaNode = mapper.valueToTree(schema);
final JsonNode properties = schemaNode.get("properties");
final Set<Entry<String, JsonNode>> namedProperties;
if (properties == null) {
namedProperties = new HashSet<>();
} else {
namedProperties = properties.properties();
}
final Map<String, JsonNode> propertiesMap = new HashMap<>();
for (final Entry<String, JsonNode> entry : namedProperties) {
propertiesMap.put(entry.getKey(), entry.getValue());
}
return propertiesMap;
}

public record SpringAIToolCallback(
ToolDefinition toolDefinition) implements ToolCallback {

public SpringAIToolCallback {
Objects.requireNonNull(toolDefinition, "Tool definition must not be null");
}

@Override
public ToolDefinition getToolDefinition() {
return toolDefinition;
}

@Override
public String call(final String toolInput) {
final String callMessage =
String.format(
"Call to <%s>%n%s%nTool was successfully executed with no return value.",
toolDefinition.name(), toolInput);
System.out.println(callMessage);
return callMessage;
}

@Override
public String call(final String toolInput, @Nullable final ToolContext tooContext) {
return call(toolInput);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package schemacrawler.tools.command.aichat.mcp.controller;

import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

/** Simple controller to check if the server is running. */
@RestController
public class HealthController {

@GetMapping("/health")
public Map<String, Object> healthCheck() {
final Map<String, Object> response = new HashMap<>();
response.put("status", "UP");
response.put("service", "SchemaCrawler MCP Server");
response.put("timestamp", LocalDateTime.now().toString());
return response;
}

@GetMapping("/")
public Map<String, Object> root() {
final Map<String, Object> response = new HashMap<>();
response.put("message", "SchemaCrawler AI MCP Server is running");
response.put("health_endpoint", "/health");
response.put("timestamp", LocalDateTime.now().toString());
return response;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package schemacrawler.tools.command.aichat.mcp;

import static org.assertj.core.api.Assertions.assertThat;
import java.util.Map;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class SchemaCrawlerMCPServerTest {

@LocalServerPort private int port;

@Autowired private TestRestTemplate restTemplate;

@Test
@DisplayName("Application context loads successfully")
public void contextLoads() {
// This test will fail if the application context cannot start
}

@Test
@DisplayName("Health endpoint returns status UP in integration test")
public void healthEndpoint() {
final ResponseEntity<Map> response =
restTemplate.getForEntity("http://localhost:" + port + "/health", Map.class);

assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
assertThat(response.getBody()).containsKey("status");
assertThat(response.getBody().get("status")).isEqualTo("UP");
assertThat(response.getBody()).containsKey("service");
assertThat(response.getBody().get("service")).isEqualTo("SchemaCrawler MCP Server");
}

@Test
@DisplayName("Root endpoint returns welcome message in integration test")
public void rootEndpoint() {
final ResponseEntity<Map> response =
restTemplate.getForEntity("http://localhost:" + port + "/", Map.class);

assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
assertThat(response.getBody()).containsKey("message");
assertThat(response.getBody().get("message").toString()).contains("running");
assertThat(response.getBody()).containsKey("health_endpoint");
assertThat(response.getBody().get("health_endpoint")).isEqualTo("/health");
}
}
Loading