Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions plugins/examples/nemocheck/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import logging
import os
import requests
import json

# Initialize logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -113,7 +112,7 @@ async def tool_pre_invoke(
Returns:
The result of the plugin's analysis, including whether the tool can proceed.
"""
logger.info(
logger.debug(
f"[NemoCheck] Starting tool pre invoke hook with payload {payload}"
)

Expand Down Expand Up @@ -148,19 +147,29 @@ async def tool_pre_invoke(
data = response.json()
status = data.get("status", "blocked")
logger.debug(f"[NemoCheck] Rails reply: {data}")
metadata = data.get("rails_status")

if status == "success":
metadata = data.get("rails_status")
return ToolPreInvokeResult(
continue_processing=True, metadata=metadata
)
else:
metadata = data.get("rails_status")
logger.info(
f"[NemoCheck] Tool request blocked. Full NeMo response: {data}"
)
# Extract rail names from rails_status for more informative description
rails_run = list(metadata.keys()) if metadata else []
rails_info = (
f"Rails: {', '.join(rails_run)}"
if rails_run
else "No rails info"
)
violation = PluginViolation(
reason=f"Check tool rails:{status}.",
description=json.dumps(data),
reason=f"Tool request check failed: {status}",
description=f"{rails_info}",
code="NEMO_RAILS_BLOCKED",
details=metadata,
mcp_error_code=-32602, # Invalid params
)
return ToolPreInvokeResult(
continue_processing=False,
Expand All @@ -170,7 +179,7 @@ async def tool_pre_invoke(
else:
violation = PluginViolation(
reason="Tool Check Unavailable",
description=f"Tool arguments check server returned error. Status code: {response.status_code}, Response: {response.text}",
description=f"Tool request check server returned error. Status code: {response.status_code}, Response: {response.text}",
code="NEMO_SERVER_ERROR",
details={"status_code": response.status_code},
)
Expand All @@ -179,7 +188,7 @@ async def tool_pre_invoke(
)

except Exception as e:
logger.error(f"[NemoCheck] Error checking tool arguments: {e}")
logger.error(f"[NemoCheck] Error checking tool request: {e}")
violation = PluginViolation(
reason="Tool Check Error",
description=f"Failed to connect to check server: {str(e)}",
Expand All @@ -202,7 +211,7 @@ async def tool_post_invoke(
Returns:
The result of the plugin's analysis, including whether the tool result should proceed.
"""
logger.info(
logger.debug(
f"[NemoCheck] Starting tool post invoke hook with payload {payload}"
)

Expand Down Expand Up @@ -245,19 +254,29 @@ async def tool_post_invoke(
data = response.json()
status = data.get("status", "blocked")
logger.debug(f"[NemoCheck] Rails reply: {data}")
metadata = data.get("rails_status")

if status == "success":
metadata = data.get("rails_status")
result = ToolPostInvokeResult(
continue_processing=True, metadata=metadata
)
else: # blocked
metadata = data.get("rails_status")
logger.info(
f"[NemoCheck] Tool response blocked. Full NeMo response: {data}"
)
# Extract rail names from rails_status for more informative description
rails_run = list(metadata.keys()) if metadata else []
rails_info = (
f"Rails: {', '.join(rails_run)}"
if rails_run
else "No rails info"
)
violation = PluginViolation(
reason=f"Check tool rails:{status}.",
description=json.dumps(data),
reason=f"Tool response check failed: {status}",
description=f"{rails_info}",
code="NEMO_RAILS_BLOCKED",
details=metadata,
mcp_error_code=-32603, # Internal error for invalid tool response
)
result = ToolPostInvokeResult(
continue_processing=False,
Expand Down
82 changes: 76 additions & 6 deletions plugins/examples/nemocheck/tests/test_nemocheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def test_prompt_pre_fetch(plugin, context):

@pytest.mark.asyncio
@pytest.mark.parametrize(
"status_code,response_data,expected_continue,has_violation,expected_code",
"status_code,response_data,expected_continue,has_violation,expected_code,expected_mcp_code",
[
(
200,
Expand All @@ -72,6 +72,7 @@ async def test_prompt_pre_fetch(plugin, context):
True,
False,
None,
None,
),
(
200,
Expand All @@ -82,8 +83,9 @@ async def test_prompt_pre_fetch(plugin, context):
False,
True,
"NEMO_RAILS_BLOCKED",
-32602, # Invalid params for tool request
),
(503, None, False, True, "NEMO_SERVER_ERROR"),
(503, None, False, True, "NEMO_SERVER_ERROR", None),
],
)
async def test_tool_pre_invoke_scenarios(
Expand All @@ -94,8 +96,9 @@ async def test_tool_pre_invoke_scenarios(
expected_continue,
has_violation,
expected_code,
expected_mcp_code,
):
"""Test tool_pre_invoke with various scenarios including error codes."""
"""Test tool_pre_invoke with various scenarios including error codes and MCP error codes."""
payload = ToolPreInvokePayload(
name="test_tool",
args={"tool_args": '{"param": "value"}'},
Expand All @@ -111,11 +114,12 @@ async def test_tool_pre_invoke_scenarios(
assert (result.violation is not None) == has_violation
if has_violation:
assert result.violation.code == expected_code
assert result.violation.mcp_error_code == expected_mcp_code


@pytest.mark.asyncio
@pytest.mark.parametrize(
"status_code,response_data,expected_continue,has_violation,expected_code",
"status_code,response_data,expected_continue,has_violation,expected_code,expected_mcp_code",
[
(
200,
Expand All @@ -128,6 +132,7 @@ async def test_tool_pre_invoke_scenarios(
True,
False,
None,
None,
),
(
200,
Expand All @@ -138,8 +143,9 @@ async def test_tool_pre_invoke_scenarios(
False,
True,
"NEMO_RAILS_BLOCKED",
-32603, # Internal error for invalid tool response
),
(500, None, False, True, "NEMO_SERVER_ERROR"),
(500, None, False, True, "NEMO_SERVER_ERROR", None),
],
)
async def test_tool_post_invoke_http_scenarios(
Expand All @@ -150,8 +156,9 @@ async def test_tool_post_invoke_http_scenarios(
expected_continue,
has_violation,
expected_code,
expected_mcp_code,
):
"""Test tool_post_invoke with various HTTP response scenarios including error codes."""
"""Test tool_post_invoke with various HTTP response scenarios including error codes and MCP error codes."""
payload = ToolPostInvokePayload(
name="test_tool",
result={"content": [{"type": "text", "text": "Test content"}]},
Expand All @@ -167,6 +174,7 @@ async def test_tool_post_invoke_http_scenarios(
assert (result.violation is not None) == has_violation
if has_violation:
assert result.violation.code == expected_code
assert result.violation.mcp_error_code == expected_mcp_code


@pytest.mark.asyncio
Expand Down Expand Up @@ -272,3 +280,65 @@ async def test_connection_error_handling(
assert result.violation is not None
assert result.violation.code == "NEMO_CONNECTION_ERROR"
assert "Network error" in result.violation.description


@pytest.mark.asyncio
@pytest.mark.parametrize(
"hook_name,payload_factory,expected_reason_prefix",
[
(
"tool_pre_invoke",
lambda: ToolPreInvokePayload(
name="test_tool", args={"tool_args": '{"param": "value"}'}
),
"Tool request check failed",
),
(
"tool_post_invoke",
lambda: ToolPostInvokePayload(
name="test_tool",
result={"content": [{"type": "text", "text": "content"}]},
),
"Tool response check failed",
),
],
)
async def test_violation_includes_rail_names(
plugin, context, hook_name, payload_factory, expected_reason_prefix
):
"""Test that violation descriptions include the rail names from rails_status."""
payload = payload_factory()
hook = getattr(plugin, hook_name)

# Mock response with multiple rails
response_data = {
"status": "blocked",
"rails_status": {
"detect hap": {"status": "blocked"},
"detect sensitive data": {"status": "success"},
},
}

with patch(
"plugin.requests.post",
return_value=mock_http_response(200, response_data),
):
result = await hook(payload, context)

assert not result.continue_processing
assert result.violation is not None
assert result.violation.code == "NEMO_RAILS_BLOCKED"

# Verify reason includes the expected prefix
assert result.violation.reason.startswith(expected_reason_prefix)

# Verify description includes rail names
assert "Rails:" in result.violation.description
assert "detect hap" in result.violation.description
assert "detect sensitive data" in result.violation.description

# Verify MCP error code is set appropriately
if hook_name == "tool_pre_invoke":
assert result.violation.mcp_error_code == -32602
else:
assert result.violation.mcp_error_code == -32603
7 changes: 6 additions & 1 deletion src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,15 @@ def create_mcp_immediate_error_response(body, error_message, violation=None):
if violation is not None:
error_message = f"{violation.reason} -- {violation.description}"

# Use mcp_error_code from violation if present
error_code = -32000 # Otherwise default: generic server error
if violation is not None and violation.mcp_error_code is not None:
error_code = violation.mcp_error_code

error_body = {
"jsonrpc": body["jsonrpc"],
"id": body["id"],
"error": {"code": -32000, "message": error_message},
"error": {"code": error_code, "message": error_message},
}

return ep.ProcessingResponse(
Expand Down