Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,20 @@ func NewToolWithInputSchema[Out any](name, description string, inputSchema map[s
return &tool{Action: toolAction}
}

// ToolSchema is a struct that contains the input and output schemas for a tool.
type ToolSchema struct {
Input map[string]any
Output map[string]any
}

// NewToolWithOutputSchema creates a new [Tool] with a custom output schema. It can be passed directly to [Generate].
func NewToolWithSchema[In, Out any](name, description string, schema ToolSchema, fn ToolFunc[In, Out]) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
metadata["dynamic"] = true
toolAction := core.NewStructuredAction(name, api.ActionTypeTool, metadata, schema.Input, schema.Output, wrappedFn)
return &tool{Action: toolAction}
}

// implementTool creates the metadata and wrapped function common to both DefineTool and NewTool.
func implementTool[In, Out any](name, description string, fn ToolFunc[In, Out]) (map[string]any, func(context.Context, In) (Out, error)) {
metadata := map[string]any{
Expand Down
33 changes: 26 additions & 7 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,24 @@ func NewAction[In, Out any](
inputSchema map[string]any,
fn Func[In, Out],
) *ActionDef[In, Out, struct{}] {
return newAction(name, atype, metadata, inputSchema,
return newAction(name, atype, metadata, inputSchema, nil,
func(ctx context.Context, in In, cb noStream) (Out, error) {
return fn(ctx, in)
})
}

// NewStructuredAction creates a new non-streaming [Action] without registering it.
// It can be used to create a tool with a custom input and output schema.
// If either inputSchema or outputSchema are nil, they are inferred from the function's input or output api.
func NewStructuredAction[In, Out any](
name string,
atype api.ActionType,
metadata map[string]any,
inputSchema map[string]any,
outputSchema map[string]any,
fn Func[In, Out],
) *ActionDef[In, Out, struct{}] {
return newAction(name, atype, metadata, inputSchema, outputSchema,
func(ctx context.Context, in In, cb noStream) (Out, error) {
return fn(ctx, in)
})
Expand All @@ -77,7 +94,7 @@ func NewStreamingAction[In, Out, Stream any](
inputSchema map[string]any,
fn StreamingFunc[In, Out, Stream],
) *ActionDef[In, Out, Stream] {
return newAction(name, atype, metadata, inputSchema, fn)
return newAction(name, atype, metadata, inputSchema, nil, fn)
}

// DefineAction creates a new non-streaming Action and registers it.
Expand Down Expand Up @@ -118,7 +135,7 @@ func defineAction[In, Out, Stream any](
inputSchema map[string]any,
fn StreamingFunc[In, Out, Stream],
) *ActionDef[In, Out, Stream] {
a := newAction(name, atype, metadata, inputSchema, fn)
a := newAction(name, atype, metadata, inputSchema, nil, fn)
provider, id := api.ParseName(name)
key := api.NewKey(atype, provider, id)
r.RegisterAction(key, a)
Expand All @@ -133,6 +150,7 @@ func newAction[In, Out, Stream any](
atype api.ActionType,
metadata map[string]any,
inputSchema map[string]any,
outputSchema map[string]any,
fn StreamingFunc[In, Out, Stream],
) *ActionDef[In, Out, Stream] {
if inputSchema == nil {
Expand All @@ -142,10 +160,11 @@ func newAction[In, Out, Stream any](
}
}

var o Out
var outputSchema map[string]any
if reflect.ValueOf(o).Kind() != reflect.Invalid {
outputSchema = InferSchemaMap(o)
if outputSchema == nil {
var o Out
if reflect.ValueOf(o).Kind() != reflect.Invalid {
outputSchema = InferSchemaMap(o)
}
}

var description string
Expand Down
2 changes: 1 addition & 1 deletion go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ require (
github.com/jackc/pgx/v5 v5.7.5
github.com/jba/slog v0.2.0
github.com/lib/pq v1.10.9
github.com/mark3labs/mcp-go v0.29.0
github.com/mark3labs/mcp-go v0.42.0
github.com/pgvector/pgvector-go v0.3.0
github.com/stretchr/testify v1.10.0
github.com/weaviate/weaviate v1.30.0
Expand Down
2 changes: 2 additions & 0 deletions go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mark3labs/mcp-go v0.29.0 h1:sH1NBcumKskhxqYzhXfGc201D7P76TVXiT0fGVhabeI=
github.com/mark3labs/mcp-go v0.29.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
github.com/mark3labs/mcp-go v0.42.0 h1:gk/8nYJh8t3yroCAOBhNbYsM9TCKvkM13I5t5Hfu6Ls=
github.com/mark3labs/mcp-go v0.42.0/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE=
github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0=
github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A=
Expand Down
4 changes: 4 additions & 0 deletions go/internal/base/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ import (
"fmt"
"strings"

"github.com/mark3labs/mcp-go/mcp"
"github.com/xeipuuv/gojsonschema"
)

// ValidateValue will validate any value against the expected schema.
// It will return an error if it doesn't match the schema, otherwise it will return nil.
func ValidateValue(data any, schema map[string]any) error {
if callToolResult, ok := data.(*mcp.CallToolResult); ok {
data = callToolResult.StructuredContent
}
if schema == nil {
return nil
}
Expand Down
33 changes: 32 additions & 1 deletion go/plugins/mcp/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,23 @@ func (c *GenkitMCPClient) getInputSchema(mcpTool mcp.Tool) (map[string]any, erro
return out, nil
}

// getOutputSchema returns the MCP output schema as a generic map for Genkit
func (c *GenkitMCPClient) getOutputSchema(mcpTool mcp.Tool) (map[string]any, error) {
var out map[string]any
schemaBytes, err := json.Marshal(mcpTool.OutputSchema)
if err != nil {
return nil, fmt.Errorf("failed to marshal MCP output schema for tool %s: %w", mcpTool.Name, err)
}
if err := json.Unmarshal(schemaBytes, &out); err != nil {
// Fall back to empty map if unmarshalling fails
out = map[string]any{}
}
if out == nil {
out = map[string]any{}
}
return out, nil
}

// createTool converts a single MCP tool to a Genkit tool
func (c *GenkitMCPClient) createTool(mcpTool mcp.Tool) (ai.Tool, error) {
// Use namespaced tool name
Expand All @@ -84,8 +101,22 @@ func (c *GenkitMCPClient) createTool(mcpTool mcp.Tool) (ai.Tool, error) {
if err != nil {
return nil, fmt.Errorf("failed to get input schema for tool %s: %w", mcpTool.Name, err)
}
outputSchema, err := c.getOutputSchema(mcpTool)
if err != nil {
return nil, fmt.Errorf("failed to get output schema for tool %s: %w", mcpTool.Name, err)
}
var tool ai.Tool
if len(inputSchema) > 0 {
if len(inputSchema) > 0 && len(outputSchema) > 0 {
tool = ai.NewToolWithSchema(
namespacedToolName,
mcpTool.Description,
ai.ToolSchema{
Input: inputSchema,
Output: outputSchema,
},
toolFunc,
)
} else if len(inputSchema) > 0 {
tool = ai.NewToolWithInputSchema(
namespacedToolName,
mcpTool.Description,
Expand Down
110 changes: 110 additions & 0 deletions go/plugins/mcp/tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
package mcp

import (
"context"
"encoding/json"
"testing"

"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)

func asMap(t *testing.T, v any, label string) map[string]any {
Expand Down Expand Up @@ -161,3 +163,111 @@ func TestPrepareToolArguments(t *testing.T) {
t.Fatalf("expected error for nil args with required field")
}
}

// TestToolOutputSchema tests that both input and output schemas are correctly retrieved
// from the MCP server.
func TestToolOutputSchema(t *testing.T) {
// Start a test MCP server with a tool that has an input and output schema.
type InputSchema struct {
City string
}
type OutputSchema struct {
Weather string
Temperature int
}
mcpServer := server.NewMCPServer("test", "1.0.0",
server.WithToolCapabilities(true),
)
mcpServer.AddTool(
mcp.NewTool("getWeather",
mcp.WithInputSchema[InputSchema](),
mcp.WithOutputSchema[OutputSchema](),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultStructured(
OutputSchema{Weather: "Sunny, 25°C", Temperature: 25},
"{\"weather\": \"Sunny, 25°C\", \"temperature\": 25}",
), nil
},
)
// Start the stdio server
sseServer := server.NewTestServer(mcpServer)
defer sseServer.Close()
client, err := NewGenkitMCPClient(MCPClientOptions{
Name: "test",
SSE: &SSEConfig{
BaseURL: sseServer.URL + "/sse",
},
})
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
defer client.Disconnect()
// Retrieve tools from the MCP server
tools, err := client.GetActiveTools(context.Background(), nil)
if err != nil {
t.Fatalf("GetActiveTools error: %v", err)
}
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
for _, tool := range tools {
if tool.Name() != "test_getWeather" {
t.Fatalf("unexpected tool: %s", tool.Name())
}
inputSchema := tool.Definition().InputSchema
assertSchemaProperty(t, inputSchema, "City", "string")

outputSchema := tool.Definition().OutputSchema
assertSchemaProperty(t, outputSchema, "Weather", "string")
assertSchemaProperty(t, outputSchema, "Temperature", "integer")

result, err := tool.RunRaw(t.Context(), InputSchema{
City: "Paris",
})
if err != nil {
t.Fatalf("RunRaw error: %v", err)
}
if result == nil {
t.Fatalf("RunRaw result is nil")
}
toolResult := ParseMapToStruct[mcp.CallToolResult](t, result)
toolResultOutput := ParseMapToStruct[OutputSchema](t, toolResult.StructuredContent)
if toolResultOutput.Weather != "Sunny, 25°C" {
t.Fatalf("unexpected weather: %s", toolResultOutput.Weather)
}
if toolResultOutput.Temperature != 25 {
t.Fatalf("unexpected temperature: %d", toolResultOutput.Temperature)
}
}
}

func ParseMapToStruct[T any](t *testing.T, v any) T {
t.Helper()
var result T
jsonBytes, err := json.Marshal(v)
if err != nil {
t.Fatalf("failed to marshal map to JSON: %v", err)
}
err = json.Unmarshal(jsonBytes, &result)
if err != nil {
t.Fatalf("failed to unmarshal JSON to struct: %v", err)
}
return result
}

// assertSchemaProperty asserts that a property in a schema is present and of the expected type.
func assertSchemaProperty(t *testing.T, schema map[string]any, propName string, propType string) {
t.Helper()
if schema == nil {
t.Fatalf("schema is nil")
}
if props, ok := schema["properties"].(map[string]any); !ok {
t.Fatalf("schema properties is nil")
} else if propValue, ok := props[propName].(map[string]any); !ok {
t.Fatalf("schema property %s is nil. schema: %v", propName, schema)
} else if propValue["type"] != propType {
t.Fatalf("schema property %s type is %s, expected %s",
propName, propValue["type"], propType)
}
}