diff --git a/README.md b/README.md index 716e88297..5db8cf49d 100644 --- a/README.md +++ b/README.md @@ -103,11 +103,14 @@ Restricts which repositories a guard allows and at what integrity level: **`trusted-users`** *(optional)* — Array of GitHub usernames whose content is unconditionally elevated to `approved` integrity. Useful for granting specific external contributors (e.g., trusted open-source maintainers) the same treatment as repository members, without lowering `min-integrity` globally. Uses `max(base, approved)` so it never lowers integrity. Does not override `blocked-users`. +**`tool-call-limits`** *(optional)* — Map of tool names to per-session call limits enforced by the gateway before the backend is invoked. Positive values hard-limit that tool for the session, while `0` or an omitted entry leaves the tool unlimited. + ```json "guard-policies": { "allow-only": { "repos": ["myorg/*"], "min-integrity": "approved", + "tool-call-limits": {"issue_read": 1}, "blocked-users": ["spam-bot", "compromised-user"], "approval-labels": ["human-reviewed", "safe-for-agent"], "trusted-users": ["alice", "trusted-contributor"] diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index de8edb10a..834525c74 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -251,6 +251,8 @@ min-integrity = "unapproved" - **`trusted-users`** *(optional)*: Array of GitHub usernames whose content is unconditionally elevated to `approved` integrity. Useful for granting specific external contributors the same treatment as repository members without lowering `min-integrity` globally. Uses `max(base, approved)` so it never lowers integrity. Does not override `blocked-users`. +- **`tool-call-limits`** *(optional)*: Map of tool names to per-session call limits enforced by the gateway. Positive values cap how many times that tool may be called in one session; `0` or an omitted entry leaves the tool unlimited. + - **Meaning**: Restricts the GitHub MCP server to only access specified repositories. Tools like `get_file_contents`, `search_code`, etc. will only work on allowed repositories. Attempts to access other repositories will be denied by the guard policy. ### write-sink (output servers) diff --git a/docs/ENVIRONMENT_VARIABLES.md b/docs/ENVIRONMENT_VARIABLES.md index 446b7d9dd..c8df35c05 100644 --- a/docs/ENVIRONMENT_VARIABLES.md +++ b/docs/ENVIRONMENT_VARIABLES.md @@ -22,7 +22,7 @@ When running locally (`run.sh`), these variables are optional (warnings shown if | `MCP_GATEWAY_DOMAIN` | Gateway domain | `localhost` | | `MCP_GATEWAY_API_KEY` | Informational only — not read directly by the binary; must be referenced in your config via `"${MCP_GATEWAY_API_KEY}"` to enable authentication | (disabled) | | `MCP_GATEWAY_LOG_DIR` | Log file directory (sets default for `--log-dir` flag) | `/tmp/gh-aw/mcp-logs` | -| `MCP_GATEWAY_WASM_CACHE_DIR` | Disk-backed wazero compilation cache directory (sets default for `--wasm-cache-dir`; defaults to `/wazero-cache`) | `/tmp/gh-aw/mcp-logs/wazero-cache` | +| `MCP_GATEWAY_WASM_CACHE_DIR` | Disk-backed wazero compilation cache directory (sets default for `--wasm-cache-dir`; defaults to `/wazero-cache`, a sibling of the log directory) | `/tmp/gh-aw/wazero-cache` | | `MCP_GATEWAY_PAYLOAD_DIR` | Large payload storage directory (sets default for `--payload-dir` flag). Must be an absolute path. | `/tmp/jq-payloads` | | `MCP_GATEWAY_PAYLOAD_PATH_PREFIX` | Path prefix for remapping payloadPath returned to clients (sets default for `--payload-path-prefix` flag) | (empty - use actual filesystem path) | | `MCP_GATEWAY_PAYLOAD_SIZE_THRESHOLD` | Size threshold in bytes for payload storage (sets default for `--payload-size-threshold` flag) | `524288` | diff --git a/docs/PROXY_MODE.md b/docs/PROXY_MODE.md index 6dd004369..2e058bacf 100644 --- a/docs/PROXY_MODE.md +++ b/docs/PROXY_MODE.md @@ -55,7 +55,7 @@ gh CLI → awmg proxy (localhost:8443, TLS) → api.github.com 2. It maps the URL/query to a guard tool name (e.g., `/repos/:owner/:repo/issues` → `list_issues`) 3. The guard WASM module evaluates access based on the configured policy 4. If allowed, the request is forwarded to `api.github.com` -5. The response is filtered per-item based on secrecy/integrity labels +5. The response is filtered per-item based on secrecy/integrity labels; top-level REST arrays that have no per-item guard schema (for example, issue comment list endpoints) are passed through unchanged after the coarse policy check 6. The filtered response is returned to the client Write operations (PUT, POST, DELETE, PATCH) pass through unmodified. @@ -102,8 +102,10 @@ Supported path families include: |-------------|-----------| | `/repos/:owner/:repo/issues` | `list_issues` | | `/repos/:owner/:repo/issues/:number` | `issue_read` | +| `/repos/:owner/:repo/issues/:number/comments` | `issue_read` (`method=get_comments`) | | `/repos/:owner/:repo/pulls` | `list_pull_requests` | | `/repos/:owner/:repo/pulls/:number` | `pull_request_read` | +| `/repos/:owner/:repo/pulls/:number/comments` | `pull_request_read` (`method=get_review_comments`) | | `/repos/:owner/:repo/commits` | `list_commits` | | `/repos/:owner/:repo/commits/:sha` | `get_commit` | | `/repos/:owner/:repo/contents/:path` | `get_file_contents` | diff --git a/guards/github-guard/docs/OVERVIEW.md b/guards/github-guard/docs/OVERVIEW.md index d7cde1e95..dfefc2006 100644 --- a/guards/github-guard/docs/OVERVIEW.md +++ b/guards/github-guard/docs/OVERVIEW.md @@ -81,7 +81,7 @@ Rule evaluation and fallback semantics: - If DIFC evaluation fails for `write`/`read-write`, backend and `label_response` are both skipped. - `read` calls execute and are forwarded to `label_response`. - `label_response` returning `0` means skip fine-grained response labeling. -- When path-based labeling is unavailable, the guard falls back to legacy item labeling (including singleton fallback labeling when needed). +- When path-based labeling is unavailable, the guard falls back to legacy item labeling; singleton fallback applies to object/singleton responses, while top-level arrays with no fine-grained labels return `0` (coarse-only passthrough). ## Operating Modes diff --git a/guards/github-guard/rust-guard/src/lib.rs b/guards/github-guard/rust-guard/src/lib.rs index c3c3a795c..ef4a58960 100644 --- a/guards/github-guard/rust-guard/src/lib.rs +++ b/guards/github-guard/rust-guard/src/lib.rs @@ -47,6 +47,10 @@ fn safe_preview(s: &str, max_bytes: usize) -> &str { &s[..end] } +fn should_fallback_to_single_item_label(response: &Value) -> bool { + !response.is_array() +} + /// Global policy context for WASM runtime entry points. /// /// `label_agent` stores the parsed policy here; `label_resource` and @@ -301,6 +305,101 @@ struct LabelResponseOutput { items: Vec, } +enum FallbackAction { + ContinueProcessing, + SkipLabeling, +} + +/// Applies metadata/singleton fallback labeling when no fine-grained items exist. +/// Returns [`FallbackAction::SkipLabeling`] when the caller should return `0` +/// (top-level array passthrough), or [`FallbackAction::ContinueProcessing`] +/// when normal output generation should continue. +fn apply_singleton_fallback_if_needed( + input: &LabelResponseInput, + ctx: &PolicyContext, + labeled_items: &mut Vec, +) -> FallbackAction { + if !labeled_items.is_empty() { + return FallbackAction::ContinueProcessing; + } + + // Extract repo info from tool args (same logic as label_resource) + let (_, _, repo_id) = extract_repo_info(&input.tool_args); + let baseline_scope = infer_scope_for_baseline(&input.tool_name, &input.tool_args, &repo_id); + + // Server-generated metadata (pagination errors, empty search results) contains + // no repository data — pass through with approved integrity so the agent can + // see instructional messages and empty-result confirmations. + let actual_response = labels::extract_mcp_response(&input.tool_result); + let is_server_metadata = labels::is_mcp_text_wrapper(&actual_response) + || (labels::is_search_result_wrapper(&actual_response) + && labels::search_result_total_count(&actual_response) == Some(0)); + + if is_server_metadata { + let scope = if baseline_scope.is_empty() { + scope_names::GITHUB + } else { + &baseline_scope + }; + // Use writer_integrity which goes through normalize_scope to match + // the policy scope token (e.g., "github" for owner-scoped policies). + let integrity = labels::writer_integrity(scope, ctx); + let desc = format!("metadata:{}", input.tool_name); + + log_info(&format!( + " server metadata (text message or empty search), integrity={:?}", + integrity + )); + + labeled_items.push(LabeledItem { + data: input.tool_result.clone(), + labels: ResourceLabels { + description: desc, + secrecy: vec![].into(), + integrity: integrity.into(), + }, + }); + return FallbackAction::ContinueProcessing; + } + + if !should_fallback_to_single_item_label(&actual_response) { + log_info(" no fine-grained items for top-level array response, skipping fallback label"); + return FallbackAction::SkipLabeling; + } + + log_info(" no fine-grained items, creating fallback single-item label"); + + // Use apply_tool_labels to get proper labels for this tool + let desc = format!("resource:{}", input.tool_name); + let (secrecy, integrity, final_desc) = labels::apply_tool_labels( + &input.tool_name, + &input.tool_args, + &repo_id, + vec![], // default secrecy + vec![], // default integrity + desc, + ctx, + ); + + let integrity = labels::ensure_integrity_baseline(&baseline_scope, integrity, ctx); + + log_info(&format!( + " fallback labels: secrecy={:?}, integrity={:?}", + secrecy, integrity + )); + + labeled_items.push(LabeledItem { + data: input.tool_result.clone(), + labels: ResourceLabels { + description: final_desc, + secrecy: secrecy.into(), + integrity: integrity.into(), + }, + }); + + FallbackAction::ContinueProcessing +} + fn infer_scope_for_baseline<'a>( tool_name: &str, tool_args: &Value, @@ -913,75 +1012,14 @@ pub extern "C" fn label_response( labels::label_response_items(&input.tool_name, &input.tool_args, &input.tool_result, &ctx); // If no items were generated, wrap entire response as single item with computed labels - // This ensures single-item responses (like get_file_contents) are properly labeled - if labeled_items.is_empty() { - // Extract repo info from tool args (same logic as label_resource) - let (_, _, repo_id) = extract_repo_info(&input.tool_args); - let baseline_scope = infer_scope_for_baseline(&input.tool_name, &input.tool_args, &repo_id); - - // Server-generated metadata (pagination errors, empty search results) contains - // no repository data — pass through with approved integrity so the agent can - // see instructional messages and empty-result confirmations. - let actual_response = labels::extract_mcp_response(&input.tool_result); - let is_server_metadata = labels::is_mcp_text_wrapper(&actual_response) - || (labels::is_search_result_wrapper(&actual_response) - && labels::search_result_total_count(&actual_response) == Some(0)); - - if is_server_metadata { - let scope = if baseline_scope.is_empty() { - scope_names::GITHUB - } else { - &baseline_scope - }; - // Use writer_integrity which goes through normalize_scope to match - // the policy scope token (e.g., "github" for owner-scoped policies). - let integrity = labels::writer_integrity(scope, &ctx); - let desc = format!("metadata:{}", input.tool_name); - - log_info(&format!( - " server metadata (text message or empty search), integrity={:?}", - integrity - )); - - labeled_items.push(LabeledItem { - data: input.tool_result.clone(), - labels: ResourceLabels { - description: desc, - secrecy: vec![].into(), - integrity: integrity.into(), - }, - }); - } else { - log_info(" no fine-grained items, creating fallback single-item label"); - - // Use apply_tool_labels to get proper labels for this tool - let desc = format!("resource:{}", input.tool_name); - let (secrecy, integrity, final_desc) = labels::apply_tool_labels( - &input.tool_name, - &input.tool_args, - &repo_id, - vec![], // default secrecy - vec![], // default integrity - desc, - &ctx, - ); - - let integrity = labels::ensure_integrity_baseline(&baseline_scope, integrity, &ctx); - - log_info(&format!( - " fallback labels: secrecy={:?}, integrity={:?}", - secrecy, integrity - )); - - labeled_items.push(LabeledItem { - data: input.tool_result.clone(), - labels: ResourceLabels { - description: final_desc, - secrecy: secrecy.into(), - integrity: integrity.into(), - }, - }); - } + // when appropriate. This ensures single-item responses (like get_file_contents) + // are properly labeled while preserving unlabeled top-level array passthrough. + if matches!( + apply_singleton_fallback_if_needed(&input, &ctx, &mut labeled_items), + FallbackAction::SkipLabeling + ) { + log_info("<<< label_response returning 0 (top-level array passthrough)"); + return 0; } log_info(&format!( @@ -1082,6 +1120,43 @@ mod tests { use super::*; use serde_json::json; + #[test] + fn top_level_array_responses_skip_single_item_fallback() { + assert!(!should_fallback_to_single_item_label(&json!([{"id": 1}]))); + } + + #[test] + fn singleton_object_responses_use_single_item_fallback() { + assert!(should_fallback_to_single_item_label(&json!({"id": 1}))); + } + + #[test] + fn label_response_control_flow_skips_fallback_for_unlabeled_top_level_array() { + let input = LabelResponseInput { + tool_name: "issue_read".to_string(), + tool_args: json!({ + "owner": "org", + "repo": "repo", + "issue_number": "7", + "method": "get_comments" + }), + tool_result: json!([ + {"id": 1, "body": "first"}, + {"id": 2, "body": "second"} + ]), + }; + + let mut labeled_items = Vec::new(); + let action = apply_singleton_fallback_if_needed( + &input, + &PolicyContext::default(), + &mut labeled_items, + ); + + assert!(matches!(action, FallbackAction::SkipLabeling)); + assert!(labeled_items.is_empty()); + } + #[test] fn parse_scope_accepts_owner_wildcard_array_entry() { let parsed = parse_scope(ReposValue::ScopedList(vec!["octocat/*".to_string()])) diff --git a/internal/cmd/flags_difc.go b/internal/cmd/flags_difc.go index f21cfcfa9..27a9f1612 100644 --- a/internal/cmd/flags_difc.go +++ b/internal/cmd/flags_difc.go @@ -25,6 +25,9 @@ var ( allowOnlyMinInt string ) +// containerGuardWasmPath is the baked-in guard path in the container image. +const containerGuardWasmPath = "/guards/github/00-github-guard.wasm" + func init() { RegisterFlag(func(cmd *cobra.Command) { cmd.Flags().StringVar(&difcMode, "guards-mode", getDefaultDIFCMode(), "Guards enforcement mode: strict (deny violations), filter (remove denied tools), or propagate (auto-adjust agent labels on reads)") @@ -37,6 +40,18 @@ func init() { }) } +// detectGuardWasm returns the baked-in container guard path if it exists, +// or empty string if not found (requiring the user to specify --guard-wasm). +func detectGuardWasm() string { + debugLog.Printf("Checking for baked-in guard at %s", containerGuardWasmPath) + if _, err := os.Stat(containerGuardWasmPath); err == nil { + debugLog.Printf("Auto-detected baked-in guard: %s", containerGuardWasmPath) + return containerGuardWasmPath + } + debugLog.Print("Baked-in guard not found, --guard-wasm flag required") + return "" +} + func resolveGuardPolicyOverride(cmd *cobra.Command) (*config.GuardPolicy, string, error) { cliGuardPolicyChanged := cmd.Flags().Changed("guard-policy-json") cliChanged := cliGuardPolicyChanged || diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 362cd7387..1e34cd469 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -226,6 +226,7 @@ func run(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("invalid guard policy configuration: %w", err) } + debugLog.Printf("Guard policy resolved: hasOverride=%v, source=%s", policyOverride != nil, policySource) if policyOverride != nil { cfg.GuardPolicy = policyOverride cfg.GuardPolicySource = policySource @@ -342,10 +343,12 @@ func run(cmd *cobra.Command, args []string) error { } // Create unified MCP server (backend for both modes) + debugLog.Printf("Creating unified MCP server: mode=%s, servers=%d", mode, len(cfg.Servers)) unifiedServer, err := server.NewUnified(ctx, cfg) if err != nil { return fmt.Errorf("failed to create unified server: %w", err) } + debugLog.Printf("Unified MCP server created successfully") defer unifiedServer.Close() // Handle graceful shutdown via context cancellation @@ -399,6 +402,7 @@ func run(cmd *cobra.Command, args []string) error { hasCert := tlsCertPath != "" hasKey := tlsKeyPath != "" hasCA := tlsCAPath != "" + debugLog.Printf("TLS configuration: hasCert=%v, hasKey=%v, hasCA=%v", hasCert, hasKey, hasCA) if hasCert != hasKey { return fmt.Errorf("--tls-cert and --tls-key must both be provided together") } @@ -410,6 +414,7 @@ func run(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("failed to listen on %s: %w", listenAddr, err) } + debugLog.Printf("TCP listener created on %s", listenAddr) tlsEnabled := hasCert && hasKey var tlsCfg *tls.Config if tlsEnabled { @@ -446,6 +451,7 @@ func run(cmd *cobra.Command, args []string) error { // Wait for shutdown signal <-ctx.Done() + debugLog.Print("Shutdown signal received, initiating graceful shutdown") // Gracefully shutdown HTTP server with timeout shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) diff --git a/internal/cmd/wasm_cache.go b/internal/cmd/wasm_cache.go index eedad4641..be89b4dd5 100644 --- a/internal/cmd/wasm_cache.go +++ b/internal/cmd/wasm_cache.go @@ -12,21 +12,6 @@ import ( "github.com/github/gh-aw-mcpg/internal/guard" ) -// containerGuardWasmPath is the baked-in guard path in the container image. -const containerGuardWasmPath = "/guards/github/00-github-guard.wasm" - -// detectGuardWasm returns the baked-in container guard path if it exists, -// or empty string if not found (requiring the user to specify --guard-wasm). -func detectGuardWasm() string { - debugLog.Printf("Checking for baked-in guard at %s", containerGuardWasmPath) - if _, err := os.Stat(containerGuardWasmPath); err == nil { - debugLog.Printf("Auto-detected baked-in guard: %s", containerGuardWasmPath) - return containerGuardWasmPath - } - debugLog.Print("Baked-in guard not found, --guard-wasm flag required") - return "" -} - func defaultWasmCacheDir(logDir string) string { if logDir == "" { return config.DefaultWasmCacheDirName diff --git a/internal/config/config_tracing.go b/internal/config/config_tracing.go index 3e5fd6cbb..9285de708 100644 --- a/internal/config/config_tracing.go +++ b/internal/config/config_tracing.go @@ -102,43 +102,3 @@ func init() { } }) } - -// expandTracingVariables expands ${VAR} expressions in TracingConfig fields. -// This is called for TOML-loaded configs before validation, mirroring the -// stdin JSON path where ExpandRawJSONVariables handles expansion. -func expandTracingVariables(cfg *TracingConfig) error { - if cfg == nil { - return nil - } - - logValidation.Printf("Expanding tracing config variables: hasEndpoint=%v, hasTraceID=%v, hasSpanID=%v, hasHeaders=%v", - cfg.Endpoint != "", cfg.TraceID != "", cfg.SpanID != "", cfg.Headers != "") - - fields := []struct { - name string - jsonPath string - value *string - }{ - {name: "endpoint", jsonPath: "gateway.opentelemetry.endpoint", value: &cfg.Endpoint}, - {name: "traceId", jsonPath: "gateway.opentelemetry.traceId", value: &cfg.TraceID}, - {name: "spanId", jsonPath: "gateway.opentelemetry.spanId", value: &cfg.SpanID}, - {name: "headers", jsonPath: "gateway.opentelemetry.headers", value: &cfg.Headers}, - } - - for _, field := range fields { - if *field.value == "" { - continue - } - - expanded, err := expandVariables(*field.value, field.jsonPath) - if err != nil { - return err - } - - logValidation.Printf("Expanded tracing %s variable", field.name) - *field.value = expanded - } - - logValidation.Print("Tracing config variable expansion completed") - return nil -} diff --git a/internal/config/expand.go b/internal/config/expand.go index 3323845ad..53801c672 100644 --- a/internal/config/expand.go +++ b/internal/config/expand.go @@ -88,3 +88,43 @@ func expandEnvVariables(env map[string]string, serverName string) (map[string]st logValidation.Printf("Env variable expansion completed for server: %s", serverName) return result, nil } + +// expandTracingVariables expands ${VAR} expressions in TracingConfig fields. +// This is called for TOML-loaded configs before validation, mirroring the +// stdin JSON path where ExpandRawJSONVariables handles expansion. +func expandTracingVariables(cfg *TracingConfig) error { + if cfg == nil { + return nil + } + + logValidation.Printf("Expanding tracing config variables: hasEndpoint=%v, hasTraceID=%v, hasSpanID=%v, hasHeaders=%v", + cfg.Endpoint != "", cfg.TraceID != "", cfg.SpanID != "", cfg.Headers != "") + + fields := []struct { + name string + jsonPath string + value *string + }{ + {name: "endpoint", jsonPath: "gateway.opentelemetry.endpoint", value: &cfg.Endpoint}, + {name: "traceId", jsonPath: "gateway.opentelemetry.traceId", value: &cfg.TraceID}, + {name: "spanId", jsonPath: "gateway.opentelemetry.spanId", value: &cfg.SpanID}, + {name: "headers", jsonPath: "gateway.opentelemetry.headers", value: &cfg.Headers}, + } + + for _, field := range fields { + if *field.value == "" { + continue + } + + expanded, err := expandVariables(*field.value, field.jsonPath) + if err != nil { + return err + } + + logValidation.Printf("Expanded tracing %s variable", field.name) + *field.value = expanded + } + + logValidation.Print("Tracing config variable expansion completed") + return nil +} diff --git a/internal/config/guard_policy.go b/internal/config/guard_policy.go index 1d94cb877..9968c97c7 100644 --- a/internal/config/guard_policy.go +++ b/internal/config/guard_policy.go @@ -38,33 +38,35 @@ type WriteSinkPolicy struct { // AllowOnlyPolicy configures scope and minimum required integrity. type AllowOnlyPolicy struct { - Repos interface{} `toml:"repos" json:"repos"` - MinIntegrity string `toml:"min-integrity" json:"min-integrity"` - BlockedUsers []string `toml:"blocked-users" json:"blocked-users,omitempty"` - ApprovalLabels []string `toml:"approval-labels" json:"approval-labels,omitempty"` - TrustedUsers []string `toml:"trusted-users" json:"trusted-users,omitempty"` - EndorsementReactions []string `toml:"endorsement-reactions" json:"endorsement-reactions,omitempty"` - DisapprovalReactions []string `toml:"disapproval-reactions" json:"disapproval-reactions,omitempty"` - DisapprovalIntegrity string `toml:"disapproval-integrity" json:"disapproval-integrity,omitempty"` - EndorserMinIntegrity string `toml:"endorser-min-integrity" json:"endorser-min-integrity,omitempty"` - PromotionLabel string `toml:"promotion-label" json:"promotion-label,omitempty"` - DemotionLabel string `toml:"demotion-label" json:"demotion-label,omitempty"` + Repos interface{} `toml:"repos" json:"repos"` + MinIntegrity string `toml:"min-integrity" json:"min-integrity"` + ToolCallLimits map[string]int `toml:"tool-call-limits" json:"tool-call-limits,omitempty"` + BlockedUsers []string `toml:"blocked-users" json:"blocked-users,omitempty"` + ApprovalLabels []string `toml:"approval-labels" json:"approval-labels,omitempty"` + TrustedUsers []string `toml:"trusted-users" json:"trusted-users,omitempty"` + EndorsementReactions []string `toml:"endorsement-reactions" json:"endorsement-reactions,omitempty"` + DisapprovalReactions []string `toml:"disapproval-reactions" json:"disapproval-reactions,omitempty"` + DisapprovalIntegrity string `toml:"disapproval-integrity" json:"disapproval-integrity,omitempty"` + EndorserMinIntegrity string `toml:"endorser-min-integrity" json:"endorser-min-integrity,omitempty"` + PromotionLabel string `toml:"promotion-label" json:"promotion-label,omitempty"` + DemotionLabel string `toml:"demotion-label" json:"demotion-label,omitempty"` } // NormalizedGuardPolicy is a canonical policy representation for caching and observability. type NormalizedGuardPolicy struct { - ScopeKind string `json:"scope_kind"` - ScopeValues []string `json:"scope_values,omitempty"` - MinIntegrity string `json:"min-integrity"` - BlockedUsers []string `json:"blocked-users,omitempty"` - ApprovalLabels []string `json:"approval-labels,omitempty"` - TrustedUsers []string `json:"trusted-users,omitempty"` - EndorsementReactions []string `json:"endorsement-reactions,omitempty"` - DisapprovalReactions []string `json:"disapproval-reactions,omitempty"` - DisapprovalIntegrity string `json:"disapproval-integrity,omitempty"` - EndorserMinIntegrity string `json:"endorser-min-integrity,omitempty"` - PromotionLabel string `json:"promotion-label,omitempty"` - DemotionLabel string `json:"demotion-label,omitempty"` + ScopeKind string `json:"scope_kind"` + ScopeValues []string `json:"scope_values,omitempty"` + MinIntegrity string `json:"min-integrity"` + ToolCallLimits map[string]int `json:"tool-call-limits,omitempty"` + BlockedUsers []string `json:"blocked-users,omitempty"` + ApprovalLabels []string `json:"approval-labels,omitempty"` + TrustedUsers []string `json:"trusted-users,omitempty"` + EndorsementReactions []string `json:"endorsement-reactions,omitempty"` + DisapprovalReactions []string `json:"disapproval-reactions,omitempty"` + DisapprovalIntegrity string `json:"disapproval-integrity,omitempty"` + EndorserMinIntegrity string `json:"endorser-min-integrity,omitempty"` + PromotionLabel string `json:"promotion-label,omitempty"` + DemotionLabel string `json:"demotion-label,omitempty"` } func (p *GuardPolicy) UnmarshalJSON(data []byte) error { @@ -144,6 +146,10 @@ func (p *AllowOnlyPolicy) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(value, &p.MinIntegrity); err != nil { return fmt.Errorf("invalid allow-only.min-integrity: %w", err) } + case "tool-call-limits": + if err := json.Unmarshal(value, &p.ToolCallLimits); err != nil { + return fmt.Errorf("invalid allow-only.tool-call-limits: %w", err) + } case "blocked-users": if err := json.Unmarshal(value, &p.BlockedUsers); err != nil { return fmt.Errorf("invalid allow-only.blocked-users: %w", err) @@ -198,17 +204,18 @@ func (p *AllowOnlyPolicy) UnmarshalJSON(data []byte) error { func (p AllowOnlyPolicy) MarshalJSON() ([]byte, error) { type serializedAllowOnly struct { - Repos interface{} `json:"repos"` - MinIntegrity string `json:"min-integrity"` - BlockedUsers []string `json:"blocked-users,omitempty"` - ApprovalLabels []string `json:"approval-labels,omitempty"` - TrustedUsers []string `json:"trusted-users,omitempty"` - EndorsementReactions []string `json:"endorsement-reactions,omitempty"` - DisapprovalReactions []string `json:"disapproval-reactions,omitempty"` - DisapprovalIntegrity string `json:"disapproval-integrity,omitempty"` - EndorserMinIntegrity string `json:"endorser-min-integrity,omitempty"` - PromotionLabel string `json:"promotion-label,omitempty"` - DemotionLabel string `json:"demotion-label,omitempty"` + Repos interface{} `json:"repos"` + MinIntegrity string `json:"min-integrity"` + ToolCallLimits map[string]int `json:"tool-call-limits,omitempty"` + BlockedUsers []string `json:"blocked-users,omitempty"` + ApprovalLabels []string `json:"approval-labels,omitempty"` + TrustedUsers []string `json:"trusted-users,omitempty"` + EndorsementReactions []string `json:"endorsement-reactions,omitempty"` + DisapprovalReactions []string `json:"disapproval-reactions,omitempty"` + DisapprovalIntegrity string `json:"disapproval-integrity,omitempty"` + EndorserMinIntegrity string `json:"endorser-min-integrity,omitempty"` + PromotionLabel string `json:"promotion-label,omitempty"` + DemotionLabel string `json:"demotion-label,omitempty"` } return json.Marshal(serializedAllowOnly(p)) diff --git a/internal/config/guard_policy_test.go b/internal/config/guard_policy_test.go index 648ffd1af..db84a3e45 100644 --- a/internal/config/guard_policy_test.go +++ b/internal/config/guard_policy_test.go @@ -680,6 +680,13 @@ func TestAllowOnlyPolicyUnmarshalJSON(t *testing.T) { assert.Equal(t, []string{"evil-bot", "bad-actor"}, p.BlockedUsers) }, }, + { + name: "tool-call-limits parsed correctly", + json: `{"repos":"public","min-integrity":"none","tool-call-limits":{"issue_read":1,"list_issues":2}}`, + check: func(t *testing.T, p *AllowOnlyPolicy) { + assert.Equal(t, map[string]int{"issue_read": 1, "list_issues": 2}, p.ToolCallLimits) + }, + }, { name: "approval-labels parsed correctly", json: `{"repos":"public","min-integrity":"none","approval-labels":["approved","human-reviewed"]}`, @@ -829,6 +836,21 @@ func TestAllowOnlyPolicyMarshalJSON(t *testing.T) { assert.Contains(t, jsonStr, `"human-reviewed"`) }) + t.Run("tool-call-limits is included when set", func(t *testing.T) { + policy := AllowOnlyPolicy{ + Repos: "public", + MinIntegrity: "none", + ToolCallLimits: map[string]int{"issue_read": 1}, + } + + data, err := json.Marshal(policy) + require.NoError(t, err) + + jsonStr := string(data) + assert.Contains(t, jsonStr, `"tool-call-limits"`) + assert.Contains(t, jsonStr, `"issue_read"`) + }) + t.Run("nil blocked-users and approval-labels are omitted", func(t *testing.T) { policy := AllowOnlyPolicy{ Repos: "public", @@ -966,6 +988,16 @@ func TestValidateGuardPolicy(t *testing.T) { require.NoError(t, err) }) + t.Run("zero tool-call-limit is treated as unlimited", func(t *testing.T) { + policy := &GuardPolicy{AllowOnly: &AllowOnlyPolicy{ + Repos: "all", + MinIntegrity: "none", + ToolCallLimits: map[string]int{"issue_read": 0}, + }} + err := ValidateGuardPolicy(policy) + require.NoError(t, err) + }) + t.Run("invalid policy returns error", func(t *testing.T) { policy := &GuardPolicy{AllowOnly: &AllowOnlyPolicy{ Repos: "all", @@ -974,6 +1006,17 @@ func TestValidateGuardPolicy(t *testing.T) { err := ValidateGuardPolicy(policy) require.Error(t, err) }) + + t.Run("negative tool-call-limit returns error", func(t *testing.T) { + policy := &GuardPolicy{AllowOnly: &AllowOnlyPolicy{ + Repos: "all", + MinIntegrity: "none", + ToolCallLimits: map[string]int{"issue_read": -1}, + }} + err := ValidateGuardPolicy(policy) + require.Error(t, err) + assert.ErrorContains(t, err, `allow-only.tool-call-limits["issue_read"] must be >= 0`) + }) } // TestIsScopeTokenChar tests valid and invalid characters for scope tokens. @@ -1004,6 +1047,17 @@ func TestNormalizeGuardPolicyReactionEndorsement(t *testing.T) { assert.Equal(t, []string{"THUMBS_UP", "HEART"}, got.EndorsementReactions) }) + t.Run("tool-call-limits propagated to normalized policy", func(t *testing.T) { + policy := &GuardPolicy{AllowOnly: &AllowOnlyPolicy{ + Repos: "public", + MinIntegrity: "approved", + ToolCallLimits: map[string]int{"issue_read": 1, "list_issues": 0}, + }} + got, err := NormalizeGuardPolicy(policy) + require.NoError(t, err) + assert.Equal(t, map[string]int{"issue_read": 1, "list_issues": 0}, got.ToolCallLimits) + }) + t.Run("disapproval-reactions propagated and normalized to uppercase", func(t *testing.T) { policy := &GuardPolicy{AllowOnly: &AllowOnlyPolicy{ Repos: "public", diff --git a/internal/config/guard_policy_unmarshal_coverage_test.go b/internal/config/guard_policy_unmarshal_coverage_test.go index 593ca2554..e3dbb9817 100644 --- a/internal/config/guard_policy_unmarshal_coverage_test.go +++ b/internal/config/guard_policy_unmarshal_coverage_test.go @@ -157,6 +157,11 @@ func TestAllowOnlyPolicyUnmarshalJSON_FieldErrorPaths(t *testing.T) { json: `{"repos": "all", "min-integrity": "none", "blocked-users": "notanarray"}`, wantErr: "invalid allow-only.blocked-users", }, + { + name: "tool-call-limits field invalid JSON type", + json: `{"repos": "all", "min-integrity": "none", "tool-call-limits": "notamap"}`, + wantErr: "invalid allow-only.tool-call-limits", + }, { name: "approval-labels field invalid JSON type", json: `{"repos": "all", "min-integrity": "none", "approval-labels": 42}`, @@ -416,6 +421,7 @@ func TestAllowOnlyPolicyUnmarshalJSON_FullRoundTrip(t *testing.T) { BlockedUsers: []string{"bad-actor"}, ApprovalLabels: []string{"approved"}, TrustedUsers: []string{"contractor"}, + ToolCallLimits: map[string]int{"issue_read": 1}, EndorsementReactions: []string{"THUMBS_UP"}, DisapprovalReactions: []string{"THUMBS_DOWN"}, DisapprovalIntegrity: "none", @@ -434,6 +440,7 @@ func TestAllowOnlyPolicyUnmarshalJSON_FullRoundTrip(t *testing.T) { assert.Equal(t, original.BlockedUsers, parsed.BlockedUsers) assert.Equal(t, original.ApprovalLabels, parsed.ApprovalLabels) assert.Equal(t, original.TrustedUsers, parsed.TrustedUsers) + assert.Equal(t, original.ToolCallLimits, parsed.ToolCallLimits) assert.Equal(t, original.EndorsementReactions, parsed.EndorsementReactions) assert.Equal(t, original.DisapprovalReactions, parsed.DisapprovalReactions) assert.Equal(t, original.DisapprovalIntegrity, parsed.DisapprovalIntegrity) diff --git a/internal/config/guard_policy_validation.go b/internal/config/guard_policy_validation.go index 7c4f916c8..525cef236 100644 --- a/internal/config/guard_policy_validation.go +++ b/internal/config/guard_policy_validation.go @@ -108,6 +108,11 @@ func NormalizeGuardPolicy(policy *GuardPolicy) (*NormalizedGuardPolicy, error) { var err error + normalized.ToolCallLimits, err = normalizeToolCallLimits(policy.AllowOnly.ToolCallLimits) + if err != nil { + return nil, err + } + // Validate and normalize blocked-users, approval-labels, trusted-users. // Dedup uses lowercased keys; original trimmed values are stored. normalized.BlockedUsers, err = normalizeStringSlice("blocked-users", policy.AllowOnly.BlockedUsers, strings.ToLower, false) @@ -332,3 +337,22 @@ func normalizeStringSlice(field string, input []string, caseNorm func(string) st } return out, nil } + +func normalizeToolCallLimits(input map[string]int) (map[string]int, error) { + if len(input) == 0 { + return nil, nil + } + + out := make(map[string]int, len(input)) + for toolName, limit := range input { + toolName = strings.TrimSpace(toolName) + if toolName == "" { + return nil, fmt.Errorf("allow-only.tool-call-limits keys must not be empty") + } + if limit < 0 { + return nil, fmt.Errorf("allow-only.tool-call-limits[%q] must be >= 0", toolName) + } + out[toolName] = limit + } + return out, nil +} diff --git a/internal/logger/rpc_logger_test.go b/internal/logger/rpc_logger_test.go index 71b6528fc..27da8a9e4 100644 --- a/internal/logger/rpc_logger_test.go +++ b/internal/logger/rpc_logger_test.go @@ -657,6 +657,44 @@ func TestLogRPCMessage(t *testing.T) { assert.Contains(t, string(mdContent), "**custom-server**→`custom/method`") } +// TestRPCMessageType_JSONLEvent verifies that JSONLEvent returns the correct +// event name for each RPCMessageType, including the default/unknown case. +func TestRPCMessageType_JSONLEvent(t *testing.T) { + tests := []struct { + name string + msgType RPCMessageType + wantEvent string + }{ + { + name: "request type returns rpc_request", + msgType: RPCMessageRequest, + wantEvent: "rpc_request", + }, + { + name: "response type returns rpc_response", + msgType: RPCMessageResponse, + wantEvent: "rpc_response", + }, + { + name: "unknown type returns rpc_unknown", + msgType: RPCMessageType("UNKNOWN"), + wantEvent: "rpc_unknown", + }, + { + name: "empty type returns rpc_unknown", + msgType: RPCMessageType(""), + wantEvent: "rpc_unknown", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.msgType.JSONLEvent() + assert.Equal(t, tt.wantEvent, got) + }) + } +} + func TestLogRPCResponse_NoError(t *testing.T) { tmpDir := t.TempDir() logDir := filepath.Join(tmpDir, "logs") diff --git a/internal/mcp/http_transport_test.go b/internal/mcp/http_transport_test.go index eac6ffb87..9a1563dd0 100644 --- a/internal/mcp/http_transport_test.go +++ b/internal/mcp/http_transport_test.go @@ -1545,4 +1545,55 @@ func TestParseHTTPResult(t *testing.T) { assert.Equal(t, -32601, resp.Error.Code) assert.Equal(t, "Method not found", resp.Error.Message) }) + + t.Run("non-200 status with JSON-RPC body synthesises HTTP error", func(t *testing.T) { + // parseJSONRPCResponseWithSSE synthesises a synthetic HTTP error for non-200 responses. + // In the current implementation this means the JSON-RPC error already present in the body + // is overridden by the synthetic error, so the -32603 code is what callers observe. + // This test documents that current behaviour; if parseJSONRPCResponseWithSSE is changed + // to pass through the body-level error, the assertions below will need updating. + result := &httpRequestResult{ + StatusCode: http.StatusInternalServerError, + ResponseBody: []byte(`{"jsonrpc":"2.0","id":4,"error":{"code":-32000,"message":"Server overloaded"}}`), + } + resp, err := parseHTTPResult(result) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Error, "non-200 response should have an error set") + // Synthetic HTTP error is produced by parseJSONRPCResponseWithSSE for non-200 statuses. + assert.Equal(t, -32603, resp.Error.Code, "synthetic HTTP error code should be -32603") + assert.Contains(t, resp.Error.Message, "500", "synthetic error should include HTTP status") + }) +} + +// TestBuildHTTPClientWithHeaders_NilTransport verifies that when the base client has +// a nil Transport, buildHTTPClientWithHeaders falls back to http.DefaultTransport as +// the inner transport for the injecting round-tripper. +func TestBuildHTTPClientWithHeaders_NilTransport(t *testing.T) { + // Use a buffered channel to safely pass the observed header value from the + // handler goroutine to the test goroutine without a data race. + receivedHeader := make(chan string, 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeader <- r.Header.Get("X-Test-Header") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + // base client has no Transport set (nil) — the wrapper must fall back to + // http.DefaultTransport so real network requests still work. + base := &http.Client{} + injected := buildHTTPClientWithHeaders(base, map[string]string{ + "X-Test-Header": "nil-transport-value", + }) + assert.NotSame(t, base, injected, "non-empty headers should return a new client") + + req, err := http.NewRequestWithContext(context.Background(), "GET", srv.URL, nil) + require.NoError(t, err) + + resp, err := injected.Do(req) + require.NoError(t, err) + resp.Body.Close() + + assert.Equal(t, "nil-transport-value", <-receivedHeader, + "header should be injected even when base client Transport is nil") } diff --git a/internal/proxy/handler_test.go b/internal/proxy/handler_test.go index 0487bf6a5..3a63d596b 100644 --- a/internal/proxy/handler_test.go +++ b/internal/proxy/handler_test.go @@ -388,6 +388,65 @@ func TestHandleWithDIFC_JSONArrayResponse(t *testing.T) { // so the original responseData is written } +func TestHandleWithDIFC_IssueCommentsArrayResponse(t *testing.T) { + upstreamBody := []interface{}{ + map[string]interface{}{"id": float64(1), "body": "first"}, + map[string]interface{}{"id": float64(2), "body": "second"}, + } + upstream := mockUpstream(t, http.StatusOK, upstreamBody) + defer upstream.Close() + + // Simulate the legacy singleton fallback behavior from the guard: the entire + // top-level array is emitted as one labeled collection item. + g := &stubGuard{ + labelResourceResult: publicResource(), + labelResourceOp: difc.OperationRead, + labelResponseData: &difc.CollectionLabeledData{ + Items: []difc.LabeledItem{ + { + Data: upstreamBody, + Labels: publicResource(), + }, + }, + }, + } + s := newTestServerWithStub(t, upstream.URL, g, difc.EnforcementFilter) + h := &proxyHandler{server: s} + + req := httptest.NewRequest(http.MethodGet, "/repos/org/repo/issues/7/comments", nil) + w := httptest.NewRecorder() + h.handleWithDIFC(w, req, "/repos/org/repo/issues/7/comments", "issue_read", + map[string]interface{}{"owner": "org", "repo": "repo", "issue_number": "7", "method": "get_comments"}, nil) + + assert.Equal(t, http.StatusOK, w.Code) + assert.JSONEq(t, `[{"id":1,"body":"first"},{"id":2,"body":"second"}]`, w.Body.String()) +} + +func TestHandleWithDIFC_IssueCommentsArrayResponse_NoFineGrainedLabels(t *testing.T) { + upstreamBody := []interface{}{ + map[string]interface{}{"id": float64(1), "body": "first"}, + map[string]interface{}{"id": float64(2), "body": "second"}, + } + upstream := mockUpstream(t, http.StatusOK, upstreamBody) + defer upstream.Close() + + g := &stubGuard{ + labelResourceResult: publicResource(), + labelResourceOp: difc.OperationRead, + labelResponseData: nil, // simulate label_response returning 0 (no fine-grained labels) + } + s := newTestServerWithStub(t, upstream.URL, g, difc.EnforcementFilter) + h := &proxyHandler{server: s} + + req := httptest.NewRequest(http.MethodGet, "/repos/org/repo/issues/7/comments", nil) + w := httptest.NewRecorder() + h.handleWithDIFC(w, req, "/repos/org/repo/issues/7/comments", "issue_read", + map[string]interface{}{"owner": "org", "repo": "repo", "issue_number": "7", "method": "get_comments"}, nil) + + assert.Equal(t, http.StatusOK, w.Code) + assert.JSONEq(t, `[{"id":1,"body":"first"},{"id":2,"body":"second"}]`, w.Body.String()) +} + // ─── handleWithDIFC: GraphQL query passes through DIFC ─────────────────────── func TestHandleWithDIFC_GraphQLBody(t *testing.T) { diff --git a/internal/proxy/response_transform.go b/internal/proxy/response_transform.go index 0c65e6afa..db81368d8 100644 --- a/internal/proxy/response_transform.go +++ b/internal/proxy/response_transform.go @@ -1,6 +1,8 @@ package proxy import ( + "reflect" + "github.com/github/gh-aw-mcpg/internal/difc" "github.com/github/gh-aw-mcpg/internal/logger" ) @@ -48,6 +50,24 @@ func rewrapSearchResponse(originalData interface{}, filteredItems interface{}) i // This unwraps it back to obj when the original response was a single object // (e.g., get_file_contents, get_commit, issue_read). func unwrapSingleObject(originalData interface{}, filteredData interface{}) interface{} { + // Guard compatibility: older singleton fallback could wrap a top-level array + // as a single collection item, producing [[...]]. If the wrapped value is + // exactly the original array payload, restore the original top-level shape. + if originalArray, isArray := originalData.([]interface{}); isArray { + if arr, ok := filteredData.([]interface{}); ok && len(arr) == 1 { + if wrapped, ok := arr[0].([]interface{}); ok && + len(wrapped) == len(originalArray) { + if len(wrapped) == 0 { + return wrapped + } + if reflect.DeepEqual(wrapped, originalArray) { + return wrapped + } + } + } + return filteredData + } + original, isMap := originalData.(map[string]interface{}) if !isMap { return filteredData diff --git a/internal/proxy/response_transform_coverage_test.go b/internal/proxy/response_transform_coverage_test.go index 16eb8eecf..de6c445f8 100644 --- a/internal/proxy/response_transform_coverage_test.go +++ b/internal/proxy/response_transform_coverage_test.go @@ -96,6 +96,31 @@ func TestUnwrapSingleObject_NonMapOriginal(t *testing.T) { } } +// TestUnwrapSingleObject_LegacyWrappedTopLevelArray verifies the compatibility +// unwrap path for legacy singleton fallback output: [[original-array]] → [original-array]. +func TestUnwrapSingleObject_LegacyWrappedTopLevelArray(t *testing.T) { + original := []interface{}{ + map[string]interface{}{"id": float64(1), "body": "first"}, + map[string]interface{}{"id": float64(2), "body": "second"}, + } + filtered := []interface{}{original} + + result := unwrapSingleObject(original, filtered) + + assert.Equal(t, original, result, "legacy wrapped array should be restored to top-level array") +} + +// TestUnwrapSingleObject_ArrayNotLegacyWrapped verifies that arbitrary array +// responses are not unwrapped unless they exactly match the legacy wrapper shape. +func TestUnwrapSingleObject_ArrayNotLegacyWrapped(t *testing.T) { + original := []interface{}{[]interface{}{float64(1), float64(2)}} + filtered := []interface{}{[]interface{}{float64(1), float64(2)}} + + result := unwrapSingleObject(original, filtered) + + assert.Equal(t, filtered, result, "non-legacy array responses must be left unchanged") +} + // TestUnwrapSingleObject_SearchEnvelope verifies that a map containing // "total_count" (search envelope) is NOT unwrapped — filteredData is returned as-is. func TestUnwrapSingleObject_SearchEnvelope(t *testing.T) { diff --git a/internal/server/call_backend_tool_difc_test.go b/internal/server/call_backend_tool_difc_test.go index 1e88c38c3..791bf815b 100644 --- a/internal/server/call_backend_tool_difc_test.go +++ b/internal/server/call_backend_tool_difc_test.go @@ -820,3 +820,156 @@ func TestCallBackendTool_GuardInitError(t *testing.T) { require.Error(err) assert.ErrorContains(err, "guard session initialization failed") } + +func TestCallBackendTool_ToolCallLimitEnforcedPerSession(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + backendCalls := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + method, _ := req["method"].(string) + w.Header().Set("Content-Type", "application/json") + switch method { + case "initialize": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{"name": "test-backend", "version": "1.0"}, + }, + }) + case "tools/list": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "issue_read", + "description": "test tool", + "inputSchema": map[string]interface{}{"type": "object", "properties": map[string]interface{}{}}, + }, + }, + }, + }) + case "tools/call": + backendCalls++ + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "content": []map[string]interface{}{{"type": "text", "text": "tool result"}}, + "isError": false, + }, + }) + } + })) + defer backend.Close() + + g := &difcTestGuard{name: "difc-tool-call-limit-guard"} + us := makeUnifiedWithGuard(t, "difc-tool-call-limit-type", g, backend, "strict") + us.cfg.GuardPolicy.AllowOnly.ToolCallLimits = map[string]int{"issue_read": 2} + defer us.Close() + + result, _, err := us.callBackendTool(callCtx("session-limit-a"), "test-server", "issue_read", nil) + require.NotNil(result) + require.NoError(err) + assert.False(result.IsError) + + result, _, err = us.callBackendTool(callCtx("session-limit-a"), "test-server", "issue_read", nil) + require.NotNil(result) + require.NoError(err) + assert.False(result.IsError) + + result, _, err = us.callBackendTool(callCtx("session-limit-a"), "test-server", "issue_read", nil) + require.NotNil(result) + require.Error(err) + assert.True(result.IsError) + assert.Contains(result.Content[0].(*sdk.TextContent).Text, `tool call limit reached for "issue_read" (max: 2)`) + assert.Equal(2, backendCalls, "over-limit call must not reach the backend") + + result, _, err = us.callBackendTool(callCtx("session-limit-b"), "test-server", "issue_read", nil) + require.NotNil(result) + require.NoError(err) + assert.False(result.IsError) + assert.Equal(3, backendCalls, "a new session must get a fresh per-tool budget") +} + +func TestCallBackendTool_ToolCallLimitZeroOrAbsentIsUnlimited(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + backendCalls := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + method, _ := req["method"].(string) + w.Header().Set("Content-Type", "application/json") + switch method { + case "initialize": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{"name": "test-backend", "version": "1.0"}, + }, + }) + case "tools/list": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "zero_limit_tool", + "description": "zero limit tool", + "inputSchema": map[string]interface{}{"type": "object", "properties": map[string]interface{}{}}, + }, + { + "name": "unlisted_tool", + "description": "unlisted tool", + "inputSchema": map[string]interface{}{"type": "object", "properties": map[string]interface{}{}}, + }, + }, + }, + }) + case "tools/call": + backendCalls++ + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", "id": req["id"], + "result": map[string]interface{}{ + "content": []map[string]interface{}{{"type": "text", "text": "tool result"}}, + "isError": false, + }, + }) + } + })) + defer backend.Close() + + g := &difcTestGuard{name: "difc-zero-limit-guard"} + us := makeUnifiedWithGuard(t, "difc-zero-limit-type", g, backend, "strict") + us.cfg.GuardPolicy.AllowOnly.ToolCallLimits = map[string]int{"zero_limit_tool": 0} + defer us.Close() + + for i := 0; i < 3; i++ { + result, _, err := us.callBackendTool(callCtx("session-unlimited"), "test-server", "zero_limit_tool", nil) + require.NotNil(result) + require.NoError(err) + assert.False(result.IsError) + } + for i := 0; i < 2; i++ { + result, _, err := us.callBackendTool(callCtx("session-unlimited"), "test-server", "unlisted_tool", nil) + require.NotNil(result) + require.NoError(err) + assert.False(result.IsError) + } + + assert.Equal(5, backendCalls, "zero or absent limits must not block tool calls") +} diff --git a/internal/server/guard_init.go b/internal/server/guard_init.go index 996470c10..9b0afe907 100644 --- a/internal/server/guard_init.go +++ b/internal/server/guard_init.go @@ -382,12 +382,17 @@ func (us *UnifiedServer) ensureGuardInitialized( if session.GuardInit == nil { session.GuardInit = make(map[string]*GuardSessionState) } + var toolCallLimits map[string]int + if policy.AllowOnly != nil { + toolCallLimits = copyToolCallLimits(policy.AllowOnly.ToolCallLimits) + } session.GuardInit[serverID] = &GuardSessionState{ Initialized: true, PolicyHash: policyHash, PolicySource: source, DIFCMode: mode, NormalizedPolicy: normalizedPolicy, + ToolCallLimits: toolCallLimits, } us.sessionMu.Unlock() @@ -397,6 +402,20 @@ func (us *UnifiedServer) ensureGuardInitialized( return mode, nil } +// copyToolCallLimits returns a defensive copy of tool-call-limits so per-session +// counters cannot be affected by later config mutations. Keys are trimmed of +// surrounding whitespace to match the normalization applied during validation. +func copyToolCallLimits(input map[string]int) map[string]int { + if len(input) == 0 { + return nil + } + out := make(map[string]int, len(input)) + for toolName, limit := range input { + out[strings.TrimSpace(toolName)] = limit + } + return out +} + // getTrustedBots returns the configured list of additional trusted bot usernames, // or nil if none are configured. func (us *UnifiedServer) getTrustedBots() []string { diff --git a/internal/server/unified.go b/internal/server/unified.go index 5ff9b6278..5d6fe2242 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -50,6 +50,9 @@ type GuardSessionState struct { PolicySource string DIFCMode difc.EnforcementMode NormalizedPolicy map[string]interface{} + ToolCallLimits map[string]int + ToolCallCounts map[string]int + CallCountMu sync.Mutex } // ServerStatus represents the health status of a backend server @@ -377,7 +380,8 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName oteltrace.WithSpanKind(oteltrace.SpanKindInternal), ) // httpStatusCode tracks the conceptual HTTP status of the proxied response (spec §4.1.3.6). - // It starts at 200 and is updated to 500 (error) or 403 (access denied) before each exit. + // It starts at 200 and is updated to 500 (error), 403 (access denied), or 429 (budget + // exhaustion) before each exit. httpStatusCode := 200 defer func() { toolSpan.SetAttributes(semconv.HTTPResponseStatusCodeKey.Int(httpStatusCode)) @@ -414,6 +418,12 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName httpStatusCode = 500 return mcp.NewErrorCallToolResult(fmt.Errorf("guard session initialization failed: %w", err)) } + if err := us.enforceToolCallLimit(sessionID, serverID, toolName); err != nil { + httpStatusCode = 429 + toolSpan.RecordError(err) + toolSpan.SetStatus(codes.Error, "tool call limit reached") + return mcp.NewErrorCallToolResult(err) + } requestEvaluator := difc.NewEvaluatorWithMode(enforcementMode) @@ -635,6 +645,40 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName return callResult, finalResult, nil } +// enforceToolCallLimit applies the configured per-session budget for toolName on +// the given server, incrementing the call counter for in-budget attempts and +// returning an error without incrementing when the session has exhausted its limit. +func (us *UnifiedServer) enforceToolCallLimit(sessionID, serverID, toolName string) error { + us.sessionMu.RLock() + session := us.sessions[sessionID] + var state *GuardSessionState + if session != nil { + state = session.GuardInit[serverID] + } + us.sessionMu.RUnlock() + + if state == nil || len(state.ToolCallLimits) == 0 { + return nil + } + + state.CallCountMu.Lock() + defer state.CallCountMu.Unlock() + + limit, ok := state.ToolCallLimits[toolName] + if !ok || limit == 0 { + return nil + } + if state.ToolCallCounts == nil { + state.ToolCallCounts = make(map[string]int) + } + + if state.ToolCallCounts[toolName] >= limit { + return fmt.Errorf("tool call limit reached for %q (max: %d)", toolName, limit) + } + state.ToolCallCounts[toolName]++ + return nil +} + // Run starts the unified MCP server on the specified transport func (us *UnifiedServer) Run(transport sdk.Transport) error { logger.LogInfo("startup", "Starting unified MCP server...") diff --git a/internal/tracing/config_resolver.go b/internal/tracing/config_resolver.go new file mode 100644 index 000000000..d8cc59c40 --- /dev/null +++ b/internal/tracing/config_resolver.go @@ -0,0 +1,196 @@ +package tracing + +import ( + "context" + "encoding/hex" + "net/url" + "os" + "strings" + + "go.opentelemetry.io/otel/trace" + + "github.com/github/gh-aw-mcpg/internal/config" +) + +// defaultSignalPath is the OTLP traces signal path per the OpenTelemetry spec. +const defaultSignalPath = "/v1/traces" + +// resolveEndpoint returns the OTLP endpoint from config. +// CLI flags set the config value using env vars as defaults, so config already +// reflects the correct precedence: CLI flag > env var > config file. +// +// Per the OpenTelemetry specification, OTEL_EXPORTER_OTLP_ENDPOINT is a base URL +// and SDKs must append the signal path (/v1/traces for traces). Since we use +// WithEndpointURL (which takes the URL as-is), we append the signal path here +// when it is not already present. The path defaults to /v1/traces but can be +// overridden via TracingConfig.SignalPath. +func resolveEndpoint(cfg *config.TracingConfig) string { + if cfg == nil || cfg.Endpoint == "" { + return "" + } + endpoint := cfg.Endpoint + signalPath := cfg.SignalPath + if signalPath == "" { + signalPath = defaultSignalPath + } + + u, err := url.Parse(endpoint) + if err != nil { + // If unparseable, fall back to string append for best-effort. + // Normalize trailing slashes before the suffix check to avoid + // duplicating the signal path when input already ends with it. + normalized := strings.TrimRight(endpoint, "/") + if !strings.HasSuffix(normalized, signalPath) { + normalized += signalPath + } + return normalized + } + + // Normalize path and check whether signal path is already the suffix + normalizedPath := strings.TrimRight(u.Path, "/") + if !strings.HasSuffix(normalizedPath, signalPath) { + u.Path = normalizedPath + signalPath + } else { + u.Path = normalizedPath + } + return u.String() +} + +// resolveServiceName returns the service name from config. +func resolveServiceName(cfg *config.TracingConfig) string { + if cfg != nil && cfg.ServiceName != "" { + return cfg.ServiceName + } + return config.DefaultTracingServiceName +} + +// resolveSampleRate returns the sample rate from config (defaults to 1.0). +// Valid configured values are in the range [0.0, 1.0], where 0.0 disables sampling. +func resolveSampleRate(cfg *config.TracingConfig) float64 { + rate := cfg.GetSampleRate() + + if rate >= 0.0 && rate <= 1.0 { + return rate + } + + logTracing.Printf("Warning: invalid tracing sample rate %.4f; using default %.2f", rate, config.DefaultTracingSampleRate) + return config.DefaultTracingSampleRate +} + +// parseOTLPHeaders parses a comma-separated "key=value" string into a map. +// Empty pairs, pairs without "=", and pairs with an empty key are logged as +// warnings and skipped to avoid invalid HTTP header field names. +// Leading/trailing whitespace around keys and values is trimmed. +func parseOTLPHeaders(raw string) map[string]string { + return parseOTLPHeadersWithDecoder(raw, false) +} + +func parseOTLPHeadersWithDecoder(raw string, decodeValues bool) map[string]string { + headers := make(map[string]string) + for _, pair := range strings.Split(raw, ",") { + trimmed := strings.TrimSpace(pair) + if trimmed == "" { + continue + } + k, v, ok := strings.Cut(trimmed, "=") + if !ok { + logTracing.Printf("Warning: skipping malformed OTLP header pair (missing '=')") + continue + } + key := strings.TrimSpace(k) + if key == "" { + logTracing.Printf("Warning: skipping OTLP header pair with empty key") + continue + } + value := strings.TrimSpace(v) + if decodeValues { + decoded, err := url.PathUnescape(value) + if err != nil { + logTracing.Printf("Warning: invalid percent-encoding in OTLP header value for key %q; using raw value", key) + } else { + value = decoded + } + } + headers[key] = value + } + return headers +} + +// resolveHeaders parses the configured OTLP export headers string (or returns nil). +// When no headers are configured via config, it falls back to the standard +// OTEL_EXPORTER_OTLP_HEADERS environment variable (W3C Baggage format: +// "key1=value1,key2=value2") per the OTel OTLP Exporter specification. +func resolveHeaders(cfg *config.TracingConfig) map[string]string { + raw := "" + if cfg != nil { + raw = cfg.Headers + } + if raw == "" { + raw = os.Getenv("OTEL_EXPORTER_OTLP_HEADERS") + if raw != "" { + logTracing.Printf("Using OTEL_EXPORTER_OTLP_HEADERS env var for OTLP export headers") + } + } + if raw == "" { + return nil + } + if cfg == nil || cfg.Headers == "" { + return parseOTLPHeadersWithDecoder(raw, true) + } + return parseOTLPHeaders(raw) +} + +// resolveParentContext builds a context carrying the W3C remote parent span context +// from the configured traceId and spanId (spec §4.1.3.6). +// If traceId is absent, or either ID is malformed, the original context is returned unchanged. +// A missing spanId is replaced with a random span ID so the traceparent is still valid. +func resolveParentContext(ctx context.Context, cfg *config.TracingConfig) context.Context { + if cfg == nil || cfg.TraceID == "" { + return ctx + } + + traceIDBytes, err := hex.DecodeString(cfg.TraceID) + if err != nil || len(traceIDBytes) != 16 { + logTracing.Printf("Warning: invalid traceId '%s'; skipping W3C parent context", cfg.TraceID) + return ctx + } + var traceID trace.TraceID + copy(traceID[:], traceIDBytes) + + var spanID trace.SpanID + if cfg.SpanID != "" { + spanIDBytes, err := hex.DecodeString(cfg.SpanID) + if err != nil || len(spanIDBytes) != 8 { + logTracing.Printf("Warning: invalid spanId '%s'; generating a random span ID", cfg.SpanID) + // Fall through to generate a random span ID below + } else { + copy(spanID[:], spanIDBytes) + } + } + + // When spanId is all-zeros (absent or invalid), generate a random span ID. + // A valid SpanContext requires a non-zero SpanID (W3C Trace Context spec). + // T-OTEL-008: when only traceId is provided, a random spanId is generated. + if spanID == (trace.SpanID{}) { + generatedID, genErr := generateRandomSpanID() + if genErr != nil { + logTracing.Printf("Warning: failed to generate random span ID: %v; skipping W3C parent context", genErr) + return ctx + } + spanID = generatedID + logTracing.Printf("Generated random spanId for W3C parent context") + } + + sc := trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: traceID, + SpanID: spanID, + TraceFlags: trace.FlagsSampled, + Remote: true, + }) + if !sc.IsValid() { + logTracing.Printf("Warning: constructed parent SpanContext is not valid; skipping W3C parent context") + return ctx + } + logTracing.Printf("W3C parent context resolved: traceId=%s, spanId=%s", traceID, spanID) + return trace.ContextWithRemoteSpanContext(ctx, sc) +} diff --git a/internal/tracing/provider.go b/internal/tracing/provider.go index 5a0dd978d..91f92de70 100644 --- a/internal/tracing/provider.go +++ b/internal/tracing/provider.go @@ -20,11 +20,7 @@ package tracing import ( "context" "crypto/rand" - "encoding/hex" "fmt" - "net/url" - "os" - "strings" "time" "go.opentelemetry.io/otel" @@ -66,189 +62,6 @@ func (p *Provider) Shutdown(ctx context.Context) error { return nil } -// defaultSignalPath is the OTLP traces signal path per the OpenTelemetry spec. -const defaultSignalPath = "/v1/traces" - -// resolveEndpoint returns the OTLP endpoint from config. -// CLI flags set the config value using env vars as defaults, so config already -// reflects the correct precedence: CLI flag > env var > config file. -// -// Per the OpenTelemetry specification, OTEL_EXPORTER_OTLP_ENDPOINT is a base URL -// and SDKs must append the signal path (/v1/traces for traces). Since we use -// WithEndpointURL (which takes the URL as-is), we append the signal path here -// when it is not already present. The path defaults to /v1/traces but can be -// overridden via TracingConfig.SignalPath. -func resolveEndpoint(cfg *config.TracingConfig) string { - if cfg == nil || cfg.Endpoint == "" { - return "" - } - endpoint := cfg.Endpoint - signalPath := cfg.SignalPath - if signalPath == "" { - signalPath = defaultSignalPath - } - - u, err := url.Parse(endpoint) - if err != nil { - // If unparseable, fall back to string append for best-effort. - // Normalize trailing slashes before the suffix check to avoid - // duplicating the signal path when input already ends with it. - normalized := strings.TrimRight(endpoint, "/") - if !strings.HasSuffix(normalized, signalPath) { - normalized += signalPath - } - return normalized - } - - // Normalize path and check whether signal path is already the suffix - normalizedPath := strings.TrimRight(u.Path, "/") - if !strings.HasSuffix(normalizedPath, signalPath) { - u.Path = normalizedPath + signalPath - } else { - u.Path = normalizedPath - } - return u.String() -} - -// resolveServiceName returns the service name from config. -func resolveServiceName(cfg *config.TracingConfig) string { - if cfg != nil && cfg.ServiceName != "" { - return cfg.ServiceName - } - return config.DefaultTracingServiceName -} - -// resolveSampleRate returns the sample rate from config (defaults to 1.0). -// Valid configured values are in the range [0.0, 1.0], where 0.0 disables sampling. -func resolveSampleRate(cfg *config.TracingConfig) float64 { - rate := cfg.GetSampleRate() - - if rate >= 0.0 && rate <= 1.0 { - return rate - } - - logTracing.Printf("Warning: invalid tracing sample rate %.4f; using default %.2f", rate, config.DefaultTracingSampleRate) - return config.DefaultTracingSampleRate -} - -// parseOTLPHeaders parses a comma-separated "key=value" string into a map. -// Empty pairs, pairs without "=", and pairs with an empty key are logged as -// warnings and skipped to avoid invalid HTTP header field names. -// Leading/trailing whitespace around keys and values is trimmed. -func parseOTLPHeaders(raw string) map[string]string { - return parseOTLPHeadersWithDecoder(raw, false) -} - -func parseOTLPHeadersWithDecoder(raw string, decodeValues bool) map[string]string { - headers := make(map[string]string) - for _, pair := range strings.Split(raw, ",") { - trimmed := strings.TrimSpace(pair) - if trimmed == "" { - continue - } - k, v, ok := strings.Cut(trimmed, "=") - if !ok { - logTracing.Printf("Warning: skipping malformed OTLP header pair (missing '=')") - continue - } - key := strings.TrimSpace(k) - if key == "" { - logTracing.Printf("Warning: skipping OTLP header pair with empty key") - continue - } - value := strings.TrimSpace(v) - if decodeValues { - decoded, err := url.PathUnescape(value) - if err != nil { - logTracing.Printf("Warning: invalid percent-encoding in OTLP header value for key %q; using raw value", key) - } else { - value = decoded - } - } - headers[key] = value - } - return headers -} - -// resolveHeaders parses the configured OTLP export headers string (or returns nil). -// When no headers are configured via config, it falls back to the standard -// OTEL_EXPORTER_OTLP_HEADERS environment variable (W3C Baggage format: -// "key1=value1,key2=value2") per the OTel OTLP Exporter specification. -func resolveHeaders(cfg *config.TracingConfig) map[string]string { - raw := "" - if cfg != nil { - raw = cfg.Headers - } - if raw == "" { - raw = os.Getenv("OTEL_EXPORTER_OTLP_HEADERS") - if raw != "" { - logTracing.Printf("Using OTEL_EXPORTER_OTLP_HEADERS env var for OTLP export headers") - } - } - if raw == "" { - return nil - } - if cfg == nil || cfg.Headers == "" { - return parseOTLPHeadersWithDecoder(raw, true) - } - return parseOTLPHeaders(raw) -} - -// resolveParentContext builds a context carrying the W3C remote parent span context -// from the configured traceId and spanId (spec §4.1.3.6). -// If traceId is absent, or either ID is malformed, the original context is returned unchanged. -// A missing spanId is replaced with a random span ID so the traceparent is still valid. -func resolveParentContext(ctx context.Context, cfg *config.TracingConfig) context.Context { - if cfg == nil || cfg.TraceID == "" { - return ctx - } - - traceIDBytes, err := hex.DecodeString(cfg.TraceID) - if err != nil || len(traceIDBytes) != 16 { - logTracing.Printf("Warning: invalid traceId '%s'; skipping W3C parent context", cfg.TraceID) - return ctx - } - var traceID trace.TraceID - copy(traceID[:], traceIDBytes) - - var spanID trace.SpanID - if cfg.SpanID != "" { - spanIDBytes, err := hex.DecodeString(cfg.SpanID) - if err != nil || len(spanIDBytes) != 8 { - logTracing.Printf("Warning: invalid spanId '%s'; generating a random span ID", cfg.SpanID) - // Fall through to generate a random span ID below - } else { - copy(spanID[:], spanIDBytes) - } - } - - // When spanId is all-zeros (absent or invalid), generate a random span ID. - // A valid SpanContext requires a non-zero SpanID (W3C Trace Context spec). - // T-OTEL-008: when only traceId is provided, a random spanId is generated. - if spanID == (trace.SpanID{}) { - generatedID, genErr := generateRandomSpanID() - if genErr != nil { - logTracing.Printf("Warning: failed to generate random span ID: %v; skipping W3C parent context", genErr) - return ctx - } - spanID = generatedID - logTracing.Printf("Generated random spanId for W3C parent context") - } - - sc := trace.NewSpanContext(trace.SpanContextConfig{ - TraceID: traceID, - SpanID: spanID, - TraceFlags: trace.FlagsSampled, - Remote: true, - }) - if !sc.IsValid() { - logTracing.Printf("Warning: constructed parent SpanContext is not valid; skipping W3C parent context") - return ctx - } - logTracing.Printf("W3C parent context resolved: traceId=%s, spanId=%s", traceID, spanID) - return trace.ContextWithRemoteSpanContext(ctx, sc) -} - // generateRandomSpanID creates a cryptographically random 8-byte span ID. func generateRandomSpanID() (trace.SpanID, error) { var id trace.SpanID