Skip to content

Commit 7112ddb

Browse files
authored
Merge pull request #69 from PaperDebugger/feat-verify-citations
feat: Verify citations
2 parents 5d71138 + 418274a commit 7112ddb

File tree

8 files changed

+245
-90
lines changed

8 files changed

+245
-90
lines changed

internal/services/toolkit/client/utils.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,17 @@ func initializeToolkit(
111111
// initialize MCP session first and log session ID
112112
sessionID, err := xtraMCPLoader.InitializeMCP()
113113
if err != nil {
114-
logger.Errorf("[AI Client] Failed to initialize XtraMCP session: %v", err)
114+
logger.Errorf("[XtraMCP Client] Failed to initialize XtraMCP session: %v", err)
115115
// TODO: Fallback to static tools or exit?
116116
} else {
117-
logger.Info("[AI Client] XtraMCP session initialized", "sessionID", sessionID)
117+
logger.Info("[XtraMCP Client] XtraMCP session initialized", "sessionID", sessionID)
118118

119119
// dynamically load all tools from XtraMCP backend
120120
err = xtraMCPLoader.LoadToolsFromBackend(toolRegistry)
121121
if err != nil {
122-
logger.Errorf("[AI Client] Failed to load XtraMCP tools: %v", err)
122+
logger.Errorf("[XtraMCP Client] Failed to load XtraMCP tools: %v", err)
123123
} else {
124-
logger.Info("[AI Client] Successfully loaded XtraMCP tools")
124+
logger.Info("[XtraMCP Client] Successfully loaded XtraMCP tools")
125125
}
126126
}
127127

internal/services/toolkit/client/utils_v2.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,23 +144,20 @@ func initializeToolkitV2(
144144

145145
logger.Info("[AI Client V2] Registered static LaTeX tools", "count", 0)
146146

147-
// Load tools dynamically from backend
147+
// // Load tools dynamically from backend
148148
// xtraMCPLoader := xtramcp.NewXtraMCPLoaderV2(db, projectService, cfg.XtraMCPURI)
149149

150-
// initialize MCP session first and log session ID
150+
// // initialize MCP session first and log session ID
151151
// sessionID, err := xtraMCPLoader.InitializeMCP()
152152
// if err != nil {
153-
// logger.Errorf("[AI Client V2] Failed to initialize XtraMCP session: %v", err)
154-
// // TODO: Fallback to static tools or exit?
153+
// logger.Errorf("[XtraMCP Client] Failed to initialize XtraMCP session: %v", err)
155154
// } else {
156-
// logger.Info("[AI Client V2] XtraMCP session initialized", "sessionID", sessionID)
155+
// logger.Info("[XtraMCP Client] XtraMCP session initialized", "sessionID", sessionID)
157156

158157
// // dynamically load all tools from XtraMCP backend
159158
// err = xtraMCPLoader.LoadToolsFromBackend(toolRegistry)
160159
// if err != nil {
161-
// logger.Errorf("[AI Client V2] Failed to load XtraMCP tools: %v", err)
162-
// } else {
163-
// logger.Info("[AI Client V2] Successfully loaded XtraMCP tools")
160+
// logger.Errorf("[XtraMCP Client] Failed to load XtraMCP tools: %v", err)
164161
// }
165162
// }
166163

internal/services/toolkit/tools/xtramcp/loader.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,45 @@ func (loader *XtraMCPLoader) LoadToolsFromBackend(toolRegistry *registry.ToolReg
5353

5454
// Register each tool dynamically, passing the session ID
5555
for _, toolSchema := range toolSchemas {
56-
dynamicTool := NewDynamicTool(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID)
56+
// some tools require secrutiy context injection e.g. user_id to authenticate
57+
requiresInjection := loader.requiresSecurityInjection(toolSchema)
58+
59+
dynamicTool := NewDynamicTool(
60+
loader.db,
61+
loader.projectService,
62+
toolSchema,
63+
loader.baseURL,
64+
loader.sessionID,
65+
requiresInjection,
66+
)
5767

5868
// Register the tool with the registry
5969
toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call)
6070

61-
fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
71+
if requiresInjection {
72+
fmt.Printf("Registered dynamic tool with security injection: %s\n", toolSchema.Name)
73+
} else {
74+
fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
75+
}
6276
}
6377

6478
return nil
6579
}
6680

81+
// checks if a tool schema contains parameters that should be inejected instead of LLM-generated
82+
func (loader *XtraMCPLoader) requiresSecurityInjection(schema ToolSchema) bool {
83+
properties, ok := schema.InputSchema["properties"].(map[string]interface{})
84+
if !ok {
85+
return false
86+
}
87+
88+
// injected parameters
89+
_, hasUserId := properties["user_id"]
90+
_, hasProjectId := properties["project_id"]
91+
92+
return hasUserId || hasProjectId
93+
}
94+
6795
// InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it
6896
func (loader *XtraMCPLoader) InitializeMCP() (string, error) {
6997
// Step 1: Initialize

internal/services/toolkit/tools/xtramcp/loader_v2.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,45 @@ func (loader *XtraMCPLoaderV2) LoadToolsFromBackend(toolRegistry *registry.ToolR
5353

5454
// Register each tool dynamically, passing the session ID
5555
for _, toolSchema := range toolSchemas {
56-
dynamicTool := NewDynamicToolV2(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID)
56+
// some tools require security context injection e.g. user_id to authenticate
57+
requiresInjection := loader.requiresSecurityInjection(toolSchema)
58+
59+
dynamicTool := NewDynamicToolV2(
60+
loader.db,
61+
loader.projectService,
62+
toolSchema,
63+
loader.baseURL,
64+
loader.sessionID,
65+
requiresInjection,
66+
)
5767

5868
// Register the tool with the registry
5969
toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call)
6070

61-
fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
71+
if requiresInjection {
72+
fmt.Printf("Registered dynamic tool with security injection: %s\n", toolSchema.Name)
73+
} else {
74+
fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
75+
}
6276
}
6377

6478
return nil
6579
}
6680

81+
// checks if a tool schema contains parameters that should be injected instead of LLM-generated
82+
func (loader *XtraMCPLoaderV2) requiresSecurityInjection(schema ToolSchemaV2) bool {
83+
properties, ok := schema.InputSchema["properties"].(map[string]interface{})
84+
if !ok {
85+
return false
86+
}
87+
88+
// injected parameters
89+
_, hasUserId := properties["user_id"]
90+
_, hasProjectId := properties["project_id"]
91+
92+
return hasUserId || hasProjectId
93+
}
94+
6795
// InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it
6896
func (loader *XtraMCPLoaderV2) InitializeMCP() (string, error) {
6997
// Step 1: Initialize
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package xtramcp
2+
3+
import "encoding/json"
4+
5+
// parameters that should be injected server-side
6+
var securityParameters = []string{"user_id", "project_id"}
7+
8+
// removes security parameters from schema shown to LLM so LLM does not need to generate / fill
9+
func filterSecurityParameters(schema map[string]interface{}) map[string]interface{} {
10+
filtered := deepCopySchema(schema)
11+
12+
// Remove from properties
13+
if properties, ok := filtered["properties"].(map[string]interface{}); ok {
14+
for _, param := range securityParameters {
15+
delete(properties, param)
16+
}
17+
}
18+
19+
// Remove from required array
20+
if required, ok := filtered["required"].([]interface{}); ok {
21+
filtered["required"] = filterRequiredArray(required, securityParameters)
22+
}
23+
24+
return filtered
25+
}
26+
27+
// creates a deep copy of the schema using JSON marshal/unmarshal
28+
// Uses JSON round-trip because map[string]interface{} contains nested structures
29+
// This ensures modifications to the copy don't affect the original schema.
30+
func deepCopySchema(schema map[string]interface{}) map[string]interface{} {
31+
// Use JSON marshal/unmarshal for deep copy
32+
jsonBytes, err := json.Marshal(schema)
33+
if err != nil {
34+
// Extremely unlikely with valid JSON schemas (MCP schemas are JSON-compatible)
35+
// // If marshaling fails, return original schema
36+
return schema
37+
}
38+
39+
var copy map[string]interface{}
40+
err = json.Unmarshal(jsonBytes, &copy)
41+
if err != nil {
42+
// Should never happen if marshal succeeded
43+
return schema
44+
}
45+
46+
return copy
47+
}
48+
49+
// removes security parameters from the required array
50+
func filterRequiredArray(required []interface{}, toRemove []string) []interface{} {
51+
filtered := []interface{}{}
52+
removeMap := make(map[string]bool)
53+
54+
for _, r := range toRemove {
55+
removeMap[r] = true
56+
}
57+
58+
// filter out security params
59+
for _, item := range required {
60+
if str, ok := item.(string); ok {
61+
if !removeMap[str] {
62+
filtered = append(filtered, item)
63+
}
64+
}
65+
}
66+
67+
return filtered
68+
}

internal/services/toolkit/tools/xtramcp/tool.go

Lines changed: 31 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -41,74 +41,54 @@ type MCPParams struct {
4141

4242
// DynamicTool represents a generic tool that can handle any schema
4343
type DynamicTool struct {
44-
Name string
45-
Description responses.ToolUnionParam
46-
toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB
47-
projectService *services.ProjectService
48-
coolDownTime time.Duration
49-
baseURL string
50-
client *http.Client
51-
schema map[string]interface{}
52-
sessionID string // Reuse the session ID from initialization
44+
Name string
45+
Description responses.ToolUnionParam
46+
toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB
47+
projectService *services.ProjectService
48+
coolDownTime time.Duration
49+
baseURL string
50+
client *http.Client
51+
schema map[string]interface{}
52+
sessionID string // Reuse the session ID from initialization
53+
requiresInjection bool // Indicates if this tool needs user/project injection
5354
}
5455

5556
// NewDynamicTool creates a new dynamic tool from a schema
56-
func NewDynamicTool(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchema, baseURL string, sessionID string) *DynamicTool {
57-
// Create tool description with the schema
57+
func NewDynamicTool(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchema, baseURL string, sessionID string, requiresInjection bool) *DynamicTool {
58+
// filter schema if injection is required (hide security context like user_id/project_id from LLM)
59+
schemaForLLM := toolSchema.InputSchema
60+
if requiresInjection {
61+
schemaForLLM = filterSecurityParameters(toolSchema.InputSchema)
62+
}
63+
5864
description := responses.ToolUnionParam{
5965
OfFunction: &responses.FunctionToolParam{
6066
Name: toolSchema.Name,
6167
Description: param.NewOpt(toolSchema.Description),
62-
Parameters: openai.FunctionParameters(toolSchema.InputSchema),
68+
Parameters: openai.FunctionParameters(schemaForLLM), // Use filtered schema
6369
},
6470
}
6571

6672
toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db)
73+
//TODO: consider letting llm client know of output schema too
6774
return &DynamicTool{
68-
Name: toolSchema.Name,
69-
Description: description,
70-
toolCallRecordDB: toolCallRecordDB,
71-
projectService: projectService,
72-
coolDownTime: 5 * time.Minute,
73-
baseURL: baseURL,
74-
client: &http.Client{},
75-
schema: toolSchema.InputSchema,
76-
sessionID: sessionID, // Store the session ID for reuse
75+
Name: toolSchema.Name,
76+
Description: description,
77+
toolCallRecordDB: toolCallRecordDB,
78+
projectService: projectService,
79+
coolDownTime: 5 * time.Minute,
80+
baseURL: baseURL,
81+
client: &http.Client{},
82+
schema: toolSchema.InputSchema, // Store original schema for validation
83+
sessionID: sessionID, // Store the session ID for reuse
84+
requiresInjection: requiresInjection,
7785
}
7886
}
7987

8088
// Call handles the tool execution (generic for any tool)
89+
// DEPRECATED: v1 API is no longer supported. This method should not be called.
8190
func (t *DynamicTool) Call(ctx context.Context, toolCallId string, args json.RawMessage) (string, string, error) {
82-
// Parse arguments as generic map since we don't know the structure
83-
var argsMap map[string]interface{}
84-
err := json.Unmarshal(args, &argsMap)
85-
if err != nil {
86-
return "", "", err
87-
}
88-
89-
// Create function call record
90-
record, err := t.toolCallRecordDB.Create(ctx, toolCallId, t.Name, argsMap)
91-
if err != nil {
92-
return "", "", err
93-
}
94-
95-
// Execute the tool via MCP
96-
respStr, err := t.executeTool(argsMap)
97-
if err != nil {
98-
err = fmt.Errorf("failed to execute tool %s: %v", t.Name, err)
99-
t.toolCallRecordDB.OnError(ctx, record, err)
100-
return "", "", err
101-
}
102-
103-
rawJson, err := json.Marshal(respStr)
104-
if err != nil {
105-
err = fmt.Errorf("failed to marshal tool result: %v", err)
106-
t.toolCallRecordDB.OnError(ctx, record, err)
107-
return "", "", err
108-
}
109-
t.toolCallRecordDB.OnSuccess(ctx, record, string(rawJson))
110-
111-
return respStr, "", nil
91+
return "", "", fmt.Errorf("v1 API is deprecated and no longer supported. Please use v2 API instead")
11292
}
11393

11494
// executeTool makes the MCP request (generic for any tool)

0 commit comments

Comments
 (0)