Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions hack/values-dev.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
namespace: paperdebugger-dev

mongo:
in_cluster: false
2 changes: 1 addition & 1 deletion internal/api/gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func NewGinServer(cfg *cfg.Cfg, oauthHandler *auth.OAuthHandler) *GinServer {
ginServer := &GinServer{Engine: gin.New(), cfg: cfg}
ginServer.Use(ginServer.ginLogMiddleware(), gin.Recovery())
ginServer.Use(cors.New(cors.Config{
AllowOrigins: []string{"https://overleaf.com", "https://*.overleaf.com", "https://*.paperdebugger.com", "http://localhost:3000", "http://127.0.0.1:3000"},
AllowOrigins: []string{"*"},
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"*"},
ExposeHeaders: []string{"*"},
Expand Down
37 changes: 28 additions & 9 deletions internal/services/toolkit/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"paperdebugger/internal/services"
"paperdebugger/internal/services/toolkit/handler"
"paperdebugger/internal/services/toolkit/registry"
"paperdebugger/internal/services/toolkit/tools"
"paperdebugger/internal/services/toolkit/tools/xtramcp"

"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
Expand Down Expand Up @@ -42,18 +42,37 @@ func NewAIClient(
option.WithAPIKey(cfg.OpenAIAPIKey),
)
CheckOpenAIWorks(oaiClient, logger)

toolPaperScore := tools.NewPaperScoreTool(db, projectService)
toolPaperScoreComment := tools.NewPaperScoreCommentTool(db, projectService, reverseCommentService)
// toolPaperScore := tools.NewPaperScoreTool(db, projectService)
// toolPaperScoreComment := tools.NewPaperScoreCommentTool(db, projectService, reverseCommentService)

toolRegistry := registry.NewToolRegistry()
toolRegistry.Register("always_exception", tools.AlwaysExceptionToolDescription, tools.AlwaysExceptionTool)
toolRegistry.Register("greeting", tools.GreetingToolDescription, tools.GreetingTool)
toolRegistry.Register("paper_score", toolPaperScore.Description, toolPaperScore.Call)
toolRegistry.Register("paper_score_comment", toolPaperScoreComment.Description, toolPaperScoreComment.Call)

toolCallHandler := handler.NewToolCallHandler(toolRegistry)
// toolRegistry.Register("always_exception", tools.AlwaysExceptionToolDescription, tools.AlwaysExceptionTool)
// toolRegistry.Register("greeting", tools.GreetingToolDescription, tools.GreetingTool)
// toolRegistry.Register("paper_score", toolPaperScore.Description, toolPaperScore.Call)
// toolRegistry.Register("paper_score_comment", toolPaperScoreComment.Description, toolPaperScoreComment.Call)

// Load tools dynamically from backend (TODO: Make URL configurable / Xtramcp url)
xtraMCPLoader := xtramcp.NewXtraMCPLoader(db, projectService, "http://localhost:8080/mcp")

// initialize MCP session first and log session ID
sessionID, err := xtraMCPLoader.InitializeMCP()
if err != nil {
logger.Errorf("[AI Client] Failed to initialize XtraMCP session: %v", err)
// TODO: Fallback to static tools or exit?
} else {
logger.Info("[AI Client] XtraMCP session initialized", "sessionID", sessionID)

// dynamically load all tools from XtraMCP backend
err = xtraMCPLoader.LoadToolsFromBackend(toolRegistry)
if err != nil {
logger.Errorf("[AI Client] Failed to load XtraMCP tools: %v", err)
} else {
logger.Info("[AI Client] Successfully loaded XtraMCP tools")
}
}

toolCallHandler := handler.NewToolCallHandler(toolRegistry)
client := &AIClient{
openaiClient: &oaiClient,
toolCallHandler: toolCallHandler,
Expand Down
221 changes: 221 additions & 0 deletions internal/services/toolkit/tools/xtramcp/loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package xtramcp

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"paperdebugger/internal/libs/db"
"paperdebugger/internal/services"
"paperdebugger/internal/services/toolkit/registry"
)

// MCPListToolsResponse represents the JSON-RPC response from tools/list method
type MCPListToolsResponse struct {
JSONRPC string `json:"jsonrpc"`
ID int `json:"id"`
Result struct {
Tools []ToolSchema `json:"tools"`
} `json:"result"`
}

// loads tools dynamically from backend
type XtraMCPLoader struct {
db *db.DB
projectService *services.ProjectService
baseURL string
client *http.Client
sessionID string // Store the MCP session ID after initialization for re-use
}

// NewXtraMCPLoader creates a new dynamic XtraMCP loader
func NewXtraMCPLoader(db *db.DB, projectService *services.ProjectService, baseURL string) *XtraMCPLoader {
return &XtraMCPLoader{
db: db,
projectService: projectService,
baseURL: baseURL,
client: &http.Client{},
}
}

// LoadToolsFromBackend fetches tool schemas from backend and registers them
func (loader *XtraMCPLoader) LoadToolsFromBackend(toolRegistry *registry.ToolRegistry) error {
if loader.sessionID == "" {
return fmt.Errorf("MCP session not initialized - call InitializeMCP first")
}

// Fetch tools from backend using the established session
toolSchemas, err := loader.fetchAvailableTools()
if err != nil {
return fmt.Errorf("failed to fetch tools from backend: %w", err)
}

// Register each tool dynamically, passing the session ID
for _, toolSchema := range toolSchemas {
dynamicTool := NewDynamicTool(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID)

// Register the tool with the registry
toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call)

fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
}

return nil
}

// InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it
func (loader *XtraMCPLoader) InitializeMCP() (string, error) {
// Step 1: Initialize
sessionID, err := loader.performInitialize()
if err != nil {
return "", fmt.Errorf("step 1 - initialize failed: %w", err)
}

// Step 2: Send notifications/initialized
err = loader.sendInitializedNotification(sessionID)
if err != nil {
return "", fmt.Errorf("step 2 - notifications/initialized failed: %w", err)
}

// Store session ID for future use and return it
loader.sessionID = sessionID

return sessionID, nil
}

// performInitialize performs MCP initialization (1. establish connection)
func (loader *XtraMCPLoader) performInitialize() (string, error) {
initReq := map[string]interface{}{
"jsonrpc": "2.0",
"method": "initialize",
"id": 1,
"params": map[string]interface{}{
"protocolVersion": "2024-11-05",
"capabilities": map[string]interface{}{},
"clientInfo": map[string]interface{}{
"name": "paperdebugger-client",
"version": "1.0.0",
},
},
}

jsonData, err := json.Marshal(initReq)
if err != nil {
return "", fmt.Errorf("failed to marshal initialize request: %w", err)
}

req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("failed to create initialize request: %w", err)
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")

resp, err := loader.client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to make initialize request: %w", err)
}
defer resp.Body.Close()

// Extract session ID from response headers
sessionID := resp.Header.Get("mcp-session-id")
if sessionID == "" {
return "", fmt.Errorf("no session ID returned from initialize")
}

return sessionID, nil
}

// sendInitializedNotification completes MCP initialization (acknowledges initialization)
func (loader *XtraMCPLoader) sendInitializedNotification(sessionID string) error {
notifyReq := map[string]interface{}{
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": map[string]interface{}{},
}

jsonData, err := json.Marshal(notifyReq)
if err != nil {
return fmt.Errorf("failed to marshal notification: %w", err)
}

req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create notification request: %w", err)
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
req.Header.Set("mcp-session-id", sessionID)

resp, err := loader.client.Do(req)
if err != nil {
return fmt.Errorf("failed to send notification: %w", err)
}
defer resp.Body.Close()

return nil
}

// fetchAvailableTools makes a request to get available tools from backend
func (loader *XtraMCPLoader) fetchAvailableTools() ([]ToolSchema, error) {
// List all tools using the established session
requestBody := map[string]interface{}{
"jsonrpc": "2.0",
"method": "tools/list",
"params": map[string]interface{}{},
"id": 2,
}

jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

req, err := http.NewRequest("POST", loader.baseURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
req.Header.Set("mcp-session-id", loader.sessionID)

resp, err := loader.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()

// Read the raw response body (SSE format) for debugging
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}

// Parse SSE format - extract JSON from "data: " lines
lines := strings.Split(string(bodyBytes), "\n")
var extractedJSON string
for _, line := range lines {
if strings.HasPrefix(line, "data: ") {
extractedJSON = strings.TrimPrefix(line, "data: ")
break
}
}

if extractedJSON == "" {
return nil, fmt.Errorf("no data line found in SSE response")
}

// Parse the extracted JSON
var mcpResponse MCPListToolsResponse
err = json.Unmarshal([]byte(extractedJSON), &mcpResponse)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON from SSE data: %w. JSON data: %s", err, extractedJSON)
}

return mcpResponse.Result.Tools, nil
}
Loading
Loading