Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 25 additions & 88 deletions cmd/mcp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,9 @@ package main

import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"

"github.com/algolia/algoliasearch-client-go/v3/algolia/search"
"github.com/algolia/mcp/pkg/abtesting"
Expand All @@ -22,12 +16,15 @@ import (
searchpkg "github.com/algolia/mcp/pkg/search"
"github.com/algolia/mcp/pkg/usage"

"github.com/mark3labs/mcp-go/server"
"github.com/modelcontextprotocol/go-sdk/mcp"
)

func main() {
// Create a new MCP server with name and version
mcps := server.NewMCPServer("Algolia MCP", "0.0.2")
mcpServer := mcp.NewServer(&mcp.Implementation{
Name: "Algolia MCP",
Version: "0.0.2",
}, nil)

// Parse MCP_ENABLED_TOOLS environment variable to determine which toolsets to enable
enabledToolsEnv := os.Getenv("MCP_ENABLED_TOOLS")
Expand Down Expand Up @@ -66,36 +63,36 @@ func main() {

// Register tools from enabled packages.
if enabled["abtesting"] {
abtesting.RegisterTools(mcps)
abtesting.RegisterTools(mcpServer)
}
if enabled["analytics"] {
analytics.RegisterTools(mcps)
analytics.RegisterTools(mcpServer)
}
if enabled["collections"] {
collections.RegisterTools(mcps)
collections.RegisterTools(mcpServer)
}
if enabled["monitoring"] {
monitoring.RegisterTools(mcps)
monitoring.RegisterTools(mcpServer)
}
if enabled["querysuggestions"] {
querysuggestions.RegisterAll(mcps)
querysuggestions.RegisterAll(mcpServer)
}
if enabled["recommend"] {
recommend.RegisterAll(mcps)
recommend.RegisterAll(mcpServer)
}
if enabled["search"] {
searchpkg.RegisterAll(mcps)
searchpkg.RegisterAll(mcpServer)
} else {
// Only register specific search tools if "search" is not enabled
if enabled["search_read"] {
searchpkg.RegisterReadAll(mcps, searchClient, searchIndex)
searchpkg.RegisterReadAll(mcpServer, searchClient, searchIndex)
}
if enabled["search_write"] {
searchpkg.RegisterWriteAll(mcps, searchClient, searchIndex)
searchpkg.RegisterWriteAll(mcpServer, searchClient, searchIndex)
}
}
if enabled["usage"] {
usage.RegisterAll(mcps)
usage.RegisterAll(mcpServer)
}

// Create a logger that writes to stderr instead of stdout
Expand All @@ -107,77 +104,17 @@ func main() {
// Check server type from environment variable (defaults to "stdio" if not set)
serverType := strings.ToLower(strings.TrimSpace(os.Getenv("MCP_SERVER_TYPE")))

// Start the appropriate server type
if serverType == "sse" {
// Get port from environment variable or use default
portStr := os.Getenv("MCP_SSE_PORT")
port := 8080 // Default port
if portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
} else {
logger.Printf("Warning: Invalid MCP_SSE_PORT value '%s', using default port 8080", portStr)
}
}

// Create the address string (e.g., ":8080")
addr := fmt.Sprintf(":%d", port)
logger.Printf("Starting SSE server on port %d...", port)

// Create the SSE server
sseServer := server.NewSSEServer(mcps)

// Set up signal handling for graceful shutdown
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)

// Start server in a goroutine
serverErrCh := make(chan error, 1)
go func() {
if err := sseServer.Start(addr); err != nil && err != http.ErrServerClosed {
serverErrCh <- fmt.Errorf("MCP server failed: %v", err)
return
}
serverErrCh <- nil
}()

// Wait for either a shutdown signal or a server error
select {
case sig := <-signalChan:
logger.Printf("Received signal %v, shutting down gracefully...", sig)
case err := <-serverErrCh:
if err != nil {
logger.Fatalf("Server error: %v", err)
}
}

// Use the server's shutdown method with a timeout context
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)

// Attempt to shut down the server
err := sseServer.Shutdown(shutdownCtx)

// Always cancel the context to prevent resource leaks
cancel()

// Check for shutdown errors after ensuring context is canceled
if err != nil {
logger.Fatalf("Server shutdown failed: %v", err)
}

logger.Println("Server gracefully stopped")
} else {
// Default to stdio server
if serverType != "" && serverType != "stdio" {
logger.Printf("Warning: Unknown server type '%s', defaulting to stdio", serverType)
}
// The official SDK primarily supports stdio transport
// For now, we'll use stdio transport as the main method
if serverType != "" && serverType != "stdio" {
logger.Printf("Warning: Server type '%s' not fully supported with official SDK, defaulting to stdio", serverType)
}

// Log to stderr to avoid interfering with JSON-RPC communication
logger.Println("Starting stdio server...")
// Log to stderr to avoid interfering with JSON-RPC communication
logger.Println("Starting stdio server...")

// Use the same logger for error logging in the stdio server
if err := server.ServeStdio(mcps, server.WithErrorLogger(logger)); err != nil {
logger.Fatalf("MCP server failed: %v", err)
}
// Use stdio transport with the official SDK
if err := mcpServer.Run(context.Background(), mcp.NewStdioTransport()); err != nil {
logger.Fatalf("MCP server failed: %v", err)
}
}
5 changes: 2 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ go 1.24.1

require (
github.com/algolia/algoliasearch-client-go/v3 v3.31.4
github.com/mark3labs/mcp-go v0.24.1
github.com/modelcontextprotocol/go-sdk v0.2.0
)

require (
github.com/google/uuid v1.6.0 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/stretchr/testify v1.9.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
)
22 changes: 6 additions & 16 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,20 @@ github.com/algolia/algoliasearch-client-go/v3 v3.31.4/go.mod h1:i7tLoP7TYDmHX3Q7
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mark3labs/mcp-go v0.24.1 h1:YV+5X/+W4oBdERLWgiA1uR7AIvenlKJaa5V4hqufI7E=
github.com/mark3labs/mcp-go v0.24.1/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/modelcontextprotocol/go-sdk v0.2.0 h1:PESNYOmyM1c369tRkzXLY5hHrazj8x9CY1Xu0fLCryM=
github.com/modelcontextprotocol/go-sdk v0.2.0/go.mod h1:0sL9zUKKs2FTTkeCCVnKqbLJTw5TScefPAzojjU459E=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
Expand Down
4 changes: 2 additions & 2 deletions pkg/abtesting/abtesting.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package abtesting

import "github.com/mark3labs/mcp-go/server"
import "github.com/modelcontextprotocol/go-sdk/mcp"

// RegisterTools aggregates all abtesting tool registrations.
func RegisterTools(mcps *server.MCPServer) {
func RegisterTools(mcps *mcp.Server) {
RegisterListABTests(mcps)
RegisterGetABTest(mcps)
RegisterCreateABTest(mcps)
Expand Down
57 changes: 28 additions & 29 deletions pkg/abtesting/create_abtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,37 @@ import (
"net/http"
"os"

"github.com/algolia/mcp/pkg/mcputil"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/modelcontextprotocol/go-sdk/jsonschema"
"github.com/modelcontextprotocol/go-sdk/mcp"
)

// CreateABTestParams defines the parameters for creating an A/B test.
type CreateABTestParams struct {
Name string `json:"name" jsonschema:"A/B test name"`
EndAt string `json:"endAt" jsonschema:"End date and time of the A/B test in RFC 3339 format (e.g. 2023-06-17T00:00:00Z)"`
Variants string `json:"variants" jsonschema:"A/B test variants as JSON array (exactly 2 variants required). Each variant must have 'index' and 'trafficPercentage' fields and may optionally have 'description' and 'customSearchParameters' fields."`
}

// RegisterCreateABTest registers the create_abtest tool with the MCP server.
func RegisterCreateABTest(mcps *server.MCPServer) {
createABTestTool := mcp.NewTool(
"abtesting_create_abtest",
mcp.WithDescription("Create a new A/B test"),
mcp.WithString(
"name",
mcp.Description("A/B test name"),
mcp.Required(),
),
mcp.WithString(
"endAt",
mcp.Description("End date and time of the A/B test, in RFC 3339 format (e.g., 2023-06-17T00:00:00Z)"),
mcp.Required(),
),
mcp.WithString(
"variants",
mcp.Description("A/B test variants as JSON array (exactly 2 variants required). Each variant must have 'index' and 'trafficPercentage' fields, and may optionally have 'description' and 'customSearchParameters' fields."),
mcp.Required(),
),
)

mcps.AddTool(createABTestTool, func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
func RegisterCreateABTest(mcps *mcp.Server) {
schema, _ := jsonschema.For[CreateABTestParams]()
createABTestTool := &mcp.Tool{
Name: "abtesting_create_abtest",
Description: "Create a new A/B test",
InputSchema: schema,
}

mcp.AddTool(mcps, createABTestTool, func(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[CreateABTestParams]) (*mcp.CallToolResultFor[any], error) {
appID := os.Getenv("ALGOLIA_APP_ID")
apiKey := os.Getenv("ALGOLIA_WRITE_API_KEY") // Note: Using write API key for creating AB tests
if appID == "" || apiKey == "" {
return nil, fmt.Errorf("ALGOLIA_APP_ID and ALGOLIA_WRITE_API_KEY environment variables are required")
}

// Extract parameters
name, _ := req.Params.Arguments["name"].(string)
endAt, _ := req.Params.Arguments["endAt"].(string)
variantsJSON, _ := req.Params.Arguments["variants"].(string)
name := params.Arguments.Name
endAt := params.Arguments.EndAt
variantsJSON := params.Arguments.Variants

// Parse variants JSON
var variants []any
Expand Down Expand Up @@ -105,6 +98,12 @@ func RegisterCreateABTest(mcps *server.MCPServer) {
return nil, fmt.Errorf("failed to parse response: %w", err)
}

return mcputil.JSONToolResult("AB Test Created", result)
return &mcp.CallToolResultFor[any]{
Content: []mcp.Content{
&mcp.TextContent{
Text: "AB Test Created: " + fmt.Sprintf("%v", result),
},
},
}, nil
})
}
45 changes: 24 additions & 21 deletions pkg/abtesting/delete_abtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,33 @@ import (
"os"

"github.com/algolia/algoliasearch-client-go/v3/algolia/analytics"
"github.com/algolia/mcp/pkg/mcputil"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/modelcontextprotocol/go-sdk/jsonschema"
"github.com/modelcontextprotocol/go-sdk/mcp"
)

// DeleteABTestParams defines the parameters for deleting an A/B test.
type DeleteABTestParams struct {
ID float64 `json:"id" jsonschema:"Unique A/B test identifier"`
}

// RegisterDeleteABTest registers the delete_abtest tool with the MCP server.
func RegisterDeleteABTest(mcps *server.MCPServer) {
deleteABTestTool := mcp.NewTool(
"abtesting_delete_abtest",
mcp.WithDescription("Delete an A/B test by its ID"),
mcp.WithNumber(
"id",
mcp.Description("Unique A/B test identifier"),
mcp.Required(),
),
)

mcps.AddTool(deleteABTestTool, func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
func RegisterDeleteABTest(mcps *mcp.Server) {
schema, _ := jsonschema.For[DeleteABTestParams]()
deleteABTestTool := &mcp.Tool{
Name: "abtesting_delete_abtest",
Description: "Delete an A/B test by its ID",
InputSchema: schema,
}

mcp.AddTool(mcps, deleteABTestTool, func(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteABTestParams]) (*mcp.CallToolResultFor[any], error) {
appID := os.Getenv("ALGOLIA_APP_ID")
apiKey := os.Getenv("ALGOLIA_WRITE_API_KEY") // Note: Using write API key for deleting AB tests
if appID == "" || apiKey == "" {
return nil, fmt.Errorf("ALGOLIA_APP_ID and ALGOLIA_WRITE_API_KEY environment variables are required")
}

// Get the AB Test ID from the request
idFloat, ok := req.Params.Arguments["id"].(float64)
if !ok {
return nil, fmt.Errorf("invalid AB test ID")
}
id := int(idFloat)
id := int(params.Arguments.ID)

// Create Algolia Analytics client
client := analytics.NewClient(appID, apiKey)
Expand All @@ -52,6 +49,12 @@ func RegisterDeleteABTest(mcps *server.MCPServer) {
"index": res.Index,
}

return mcputil.JSONToolResult(fmt.Sprintf("AB Test %d Deleted", id), result)
return &mcp.CallToolResultFor[any]{
Content: []mcp.Content{
&mcp.TextContent{
Text: fmt.Sprintf("AB Test %d Deleted: %v", id, result),
},
},
}, nil
})
}
Loading
Loading