2828
2929package schemacrawler .tools .command .aichat .mcp ;
3030
31- import java .util .ArrayList ;
32- import java .util .HashMap ;
33- import java .util .HashSet ;
34- import java .util .List ;
35- import java .util .Map ;
36- import java .util .Map .Entry ;
37- import java .util .Objects ;
38- import java .util .Set ;
39- import java .util .logging .Level ;
40- import java .util .logging .Logger ;
41- import org .springframework .ai .chat .model .ToolContext ;
42- import org .springframework .ai .tool .ToolCallback ;
43- import org .springframework .ai .tool .definition .ToolDefinition ;
44- import org .springframework .ai .util .json .JsonParser ;
45- import org .springframework .lang .Nullable ;
4631import com .fasterxml .jackson .databind .JsonNode ;
4732import com .fasterxml .jackson .databind .ObjectMapper ;
33+ import com .fasterxml .jackson .databind .node .ArrayNode ;
4834import com .fasterxml .jackson .databind .node .ObjectNode ;
4935import com .fasterxml .jackson .module .jsonSchema .JsonSchema ;
5036import com .fasterxml .jackson .module .jsonSchema .JsonSchemaGenerator ;
5137import com .github .victools .jsonschema .generator .SchemaVersion ;
38+ import org .springframework .ai .chat .model .ToolContext ;
39+ import org .springframework .ai .tool .ToolCallback ;
40+ import org .springframework .ai .tool .definition .ToolDefinition ;
41+ import org .springframework .ai .util .json .JsonParser ;
42+ import org .springframework .lang .Nullable ;
5243import schemacrawler .tools .command .aichat .FunctionDefinition ;
5344import schemacrawler .tools .command .aichat .FunctionDefinition .FunctionType ;
5445import schemacrawler .tools .command .aichat .functions .FunctionDefinitionRegistry ;
5546import us .fatehi .utility .UtilityMarker ;
5647
57- @ UtilityMarker
58- public class SpringAIUtility {
59-
60- public record SpringAIToolCallback (ToolDefinition toolDefinition ) implements ToolCallback {
61-
62- public SpringAIToolCallback {
63- Objects .requireNonNull (toolDefinition , "Tool definition must not be null" );
64- }
48+ import java .util .*;
49+ import java .util .Map .Entry ;
50+ import java .util .logging .Level ;
51+ import java .util .logging .Logger ;
6552
66- @ Override
67- public ToolDefinition getToolDefinition () {
68- return toolDefinition ;
69- }
53+ @ UtilityMarker
54+ public final class SpringAIUtility {
7055
71- @ Override
72- public String call (final String toolInput ) {
73- final String callMessage =
74- String .format (
75- "Call to <%s>%n%s%nTool was successfully executed with no return value." ,
76- toolDefinition .name (), toolInput );
77- System .out .println (callMessage );
78- return callMessage ;
79- }
56+ private static final Logger LOGGER = Logger .getLogger (SpringAIUtility .class .getCanonicalName ());
8057
81- @ Override
82- public String call (final String toolInput , @ Nullable final ToolContext tooContext ) {
83- return call (toolInput );
84- }
58+ private SpringAIUtility () {
59+ // Prevent instantiation
8560 }
8661
87- private static final Logger LOGGER = Logger .getLogger (SpringAIUtility .class .getCanonicalName ());
88-
8962 public static List <ToolCallback > toolCallbacks (final List <ToolDefinition > tools ) {
9063 Objects .requireNonNull (tools , "Tools must not be null" );
9164 final List <ToolCallback > toolCallbacks = new ArrayList <>();
@@ -99,22 +72,22 @@ public static List<ToolDefinition> tools() {
9972
10073 final List <ToolDefinition > toolDefinitions = new ArrayList <>();
10174 for (final FunctionDefinition <?> functionDefinition :
102- FunctionDefinitionRegistry .getFunctionDefinitionRegistry ().getFunctionDefinitions ()) {
75+ FunctionDefinitionRegistry .getFunctionDefinitionRegistry ().getFunctionDefinitions ()) {
10376 if (functionDefinition .getFunctionType () != FunctionType .USER ) {
10477 continue ;
10578 }
10679
10780 try {
10881 final ToolDefinition toolDefinition =
109- ToolDefinition .builder ()
110- .name (functionDefinition .getName ())
111- .description (functionDefinition .getDescription ())
112- .inputSchema (generateToolInput (functionDefinition .getParametersClass ()))
113- .build ();
82+ ToolDefinition .builder ()
83+ .name (functionDefinition .getName ())
84+ .description (functionDefinition .getDescription ())
85+ .inputSchema (generateToolInput (functionDefinition .getParametersClass ()))
86+ .build ();
11487 toolDefinitions .add (toolDefinition );
11588 } catch (final Exception e ) {
11689 LOGGER .log (
117- Level .WARNING , String .format ("Could not load <%s>" , functionDefinition .getName ()), e );
90+ Level .WARNING , String .format ("Could not load <%s>" , functionDefinition .getName ()), e );
11891 }
11992 }
12093
@@ -132,10 +105,21 @@ private static String generateToolInput(final Class<?> parametersClass) throws E
132105 schema .put ("$schema" , SchemaVersion .DRAFT_2020_12 .getIdentifier ());
133106 schema .put ("type" , "object" );
134107
108+ final List <String > required = new ArrayList <>();
135109 final ObjectNode properties = schema .putObject ("properties" );
136110 for (final Entry <String , JsonNode > parameter : parametersJsonSchema .entrySet ()) {
137- properties .set (parameter .getKey (), parameter .getValue ());
111+ final String parameterName = parameter .getKey ();
112+ final JsonNode parameterSchema = parameter .getValue ();
113+ if (parameterSchema .has ("required" ) && parameterSchema .get ("required" ).asBoolean ()) {
114+ ((ObjectNode ) parameterSchema ).remove ("required" );
115+ required .add (parameterName );
116+ }
117+ properties .set (parameterName , parameterSchema );
138118 }
119+ final ArrayNode requiredArray = schema .putArray ("required" );
120+ required .forEach (requiredArray ::add );
121+
122+ schema .put ("additionalProperties" , false );
139123
140124 return schema .toPrettyString ();
141125 }
@@ -159,7 +143,31 @@ private static Map<String, JsonNode> jsonSchema(final Class<?> parametersClass)
159143 return propertiesMap ;
160144 }
161145
162- private SpringAIUtility () {
163- // Prevent instantiation
146+ public record SpringAIToolCallback (
147+ ToolDefinition toolDefinition ) implements ToolCallback {
148+
149+ public SpringAIToolCallback {
150+ Objects .requireNonNull (toolDefinition , "Tool definition must not be null" );
151+ }
152+
153+ @ Override
154+ public ToolDefinition getToolDefinition () {
155+ return toolDefinition ;
156+ }
157+
158+ @ Override
159+ public String call (final String toolInput ) {
160+ final String callMessage =
161+ String .format (
162+ "Call to <%s>%n%s%nTool was successfully executed with no return value." ,
163+ toolDefinition .name (), toolInput );
164+ System .out .println (callMessage );
165+ return callMessage ;
166+ }
167+
168+ @ Override
169+ public String call (final String toolInput , @ Nullable final ToolContext tooContext ) {
170+ return call (toolInput );
171+ }
164172 }
165173}
0 commit comments