Skip to content

Commit 662d07d

Browse files
author
Salman Mohammed
committed
feat: Updated Graphql implementation to be compatible with UTCP 1.0v
1 parent 45793cf commit 662d07d

File tree

4 files changed

+317
-91
lines changed

4 files changed

+317
-91
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from utcp.plugins.discovery import register_communication_protocol, register_call_template
2+
3+
from .gql_communication_protocol import GraphQLCommunicationProtocol
4+
from .gql_call_template import GraphQLProvider, GraphQLProviderSerializer
5+
6+
7+
def register():
8+
register_communication_protocol("graphql", GraphQLCommunicationProtocol())
9+
register_call_template("graphql", GraphQLProviderSerializer())

plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
from utcp.data.call_template import CallTemplate
2-
from utcp.data.auth import Auth
1+
from utcp.data.call_template import CallTemplate, CallTemplateSerializer
2+
from utcp.data.auth import Auth, AuthSerializer
3+
from utcp.interfaces.serializer import Serializer
4+
from utcp.exceptions import UtcpSerializerValidationError
5+
import traceback
36
from typing import Dict, List, Optional, Literal
4-
from pydantic import Field
7+
from pydantic import Field, field_serializer, field_validator
58

69
class GraphQLProvider(CallTemplate):
710
"""Provider configuration for GraphQL-based tools.
@@ -27,3 +30,31 @@ class GraphQLProvider(CallTemplate):
2730
auth: Optional[Auth] = None
2831
headers: Optional[Dict[str, str]] = None
2932
header_fields: Optional[List[str]] = Field(default=None, description="List of input fields to be sent as request headers for the initial connection.")
33+
34+
@field_serializer("auth")
35+
def serialize_auth(self, auth: Optional[Auth]):
36+
if auth is None:
37+
return None
38+
return AuthSerializer().to_dict(auth)
39+
40+
@field_validator("auth", mode="before")
41+
@classmethod
42+
def validate_auth(cls, v: Optional[Auth | dict]):
43+
if v is None:
44+
return None
45+
if isinstance(v, Auth):
46+
return v
47+
return AuthSerializer().validate_dict(v)
48+
49+
50+
class GraphQLProviderSerializer(Serializer[GraphQLProvider]):
51+
def to_dict(self, obj: GraphQLProvider) -> dict:
52+
return obj.model_dump()
53+
54+
def validate_dict(self, data: dict) -> GraphQLProvider:
55+
try:
56+
return GraphQLProvider.model_validate(data)
57+
except Exception as e:
58+
raise UtcpSerializerValidationError(
59+
f"Invalid GraphQLProvider: {e}\n{traceback.format_exc()}"
60+
)
Lines changed: 163 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,54 @@
1-
import sys
2-
from typing import Dict, Any, List, Optional, Callable
1+
import logging
2+
from typing import Dict, Any, List, Optional, AsyncGenerator, TYPE_CHECKING
3+
34
import aiohttp
4-
import asyncio
5-
import ssl
65
from gql import Client as GqlClient, gql as gql_query
76
from gql.transport.aiohttp import AIOHTTPTransport
8-
from utcp.client.client_transport_interface import ClientTransportInterface
9-
from utcp.shared.provider import Provider, GraphQLProvider
10-
from utcp.shared.tool import Tool, ToolInputOutputSchema
11-
from utcp.shared.auth import ApiKeyAuth, BasicAuth, OAuth2Auth
12-
import logging
7+
8+
from utcp.interfaces.communication_protocol import CommunicationProtocol
9+
from utcp.data.call_template import CallTemplate
10+
from utcp.data.tool import Tool, JsonSchema
11+
from utcp.data.utcp_manual import UtcpManual
12+
from utcp.data.register_manual_response import RegisterManualResult
13+
from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth
14+
from utcp.data.auth_implementations.basic_auth import BasicAuth
15+
from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth
16+
17+
from utcp_gql.gql_call_template import GraphQLProvider
18+
19+
if TYPE_CHECKING:
20+
from utcp.utcp_client import UtcpClient
21+
1322

1423
logging.basicConfig(
1524
level=logging.INFO,
16-
format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s"
25+
format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s",
1726
)
1827

1928
logger = logging.getLogger(__name__)
2029

21-
class GraphQLClientTransport(ClientTransportInterface):
22-
"""
23-
Simple, robust, production-ready GraphQL transport using gql.
24-
Stateless, per-operation. Supports all GraphQL features.
30+
31+
class GraphQLCommunicationProtocol(CommunicationProtocol):
32+
"""GraphQL protocol implementation for UTCP 1.0.
33+
34+
- Discovers tools via GraphQL schema introspection.
35+
- Executes per-call sessions using `gql` over HTTP(S).
36+
- Supports `ApiKeyAuth`, `BasicAuth`, and `OAuth2Auth`.
37+
- Enforces HTTPS or localhost for security.
2538
"""
26-
def __init__(self):
39+
40+
def __init__(self) -> None:
2741
self._oauth_tokens: Dict[str, Dict[str, Any]] = {}
2842

29-
def _enforce_https_or_localhost(self, url: str):
30-
if not (url.startswith("https://") or url.startswith("http://localhost") or url.startswith("http://127.0.0.1")):
43+
def _enforce_https_or_localhost(self, url: str) -> None:
44+
if not (
45+
url.startswith("https://")
46+
or url.startswith("http://localhost")
47+
or url.startswith("http://127.0.0.1")
48+
):
3149
raise ValueError(
32-
f"Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. Got: {url}. "
33-
"Non-secure URLs are vulnerable to man-in-the-middle attacks."
50+
"Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. "
51+
f"Got: {url}."
3452
)
3553

3654
async def _handle_oauth2(self, auth: OAuth2Auth) -> str:
@@ -39,98 +57,155 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str:
3957
return self._oauth_tokens[client_id]["access_token"]
4058
async with aiohttp.ClientSession() as session:
4159
data = {
42-
'grant_type': 'client_credentials',
43-
'client_id': client_id,
44-
'client_secret': auth.client_secret,
45-
'scope': auth.scope
60+
"grant_type": "client_credentials",
61+
"client_id": client_id,
62+
"client_secret": auth.client_secret,
63+
"scope": auth.scope,
4664
}
4765
async with session.post(auth.token_url, data=data) as resp:
4866
resp.raise_for_status()
4967
token_response = await resp.json()
5068
self._oauth_tokens[client_id] = token_response
5169
return token_response["access_token"]
5270

53-
async def _prepare_headers(self, provider: GraphQLProvider) -> Dict[str, str]:
54-
headers = provider.headers.copy() if provider.headers else {}
71+
async def _prepare_headers(
72+
self, provider: GraphQLProvider, tool_args: Optional[Dict[str, Any]] = None
73+
) -> Dict[str, str]:
74+
headers: Dict[str, str] = provider.headers.copy() if provider.headers else {}
5575
if provider.auth:
5676
if isinstance(provider.auth, ApiKeyAuth):
57-
if provider.auth.api_key:
58-
if provider.auth.location == "header":
59-
headers[provider.auth.var_name] = provider.auth.api_key
60-
# (query/cookie not supported for GraphQL by default)
77+
if provider.auth.api_key and provider.auth.location == "header":
78+
headers[provider.auth.var_name] = provider.auth.api_key
6179
elif isinstance(provider.auth, BasicAuth):
6280
import base64
81+
6382
userpass = f"{provider.auth.username}:{provider.auth.password}"
6483
headers["Authorization"] = "Basic " + base64.b64encode(userpass.encode()).decode()
6584
elif isinstance(provider.auth, OAuth2Auth):
6685
token = await self._handle_oauth2(provider.auth)
6786
headers["Authorization"] = f"Bearer {token}"
87+
88+
# Map selected tool_args into headers if requested
89+
if tool_args and provider.header_fields:
90+
for field in provider.header_fields:
91+
if field in tool_args and isinstance(tool_args[field], str):
92+
headers[field] = tool_args[field]
93+
6894
return headers
6995

70-
async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]:
71-
if not isinstance(manual_provider, GraphQLProvider):
72-
raise ValueError("GraphQLClientTransport can only be used with GraphQLProvider")
73-
self._enforce_https_or_localhost(manual_provider.url)
74-
headers = await self._prepare_headers(manual_provider)
75-
transport = AIOHTTPTransport(url=manual_provider.url, headers=headers)
76-
async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session:
77-
schema = session.client.schema
78-
tools = []
79-
# Queries
80-
if hasattr(schema, 'query_type') and schema.query_type:
81-
for name, field in schema.query_type.fields.items():
82-
tools.append(Tool(
83-
name=name,
84-
description=getattr(field, 'description', '') or '',
85-
inputs=ToolInputOutputSchema(required=None),
86-
tool_provider=manual_provider
87-
))
88-
# Mutations
89-
if hasattr(schema, 'mutation_type') and schema.mutation_type:
90-
for name, field in schema.mutation_type.fields.items():
91-
tools.append(Tool(
92-
name=name,
93-
description=getattr(field, 'description', '') or '',
94-
inputs=ToolInputOutputSchema(required=None),
95-
tool_provider=manual_provider
96-
))
97-
# Subscriptions (listed, but not called here)
98-
if hasattr(schema, 'subscription_type') and schema.subscription_type:
99-
for name, field in schema.subscription_type.fields.items():
100-
tools.append(Tool(
101-
name=name,
102-
description=getattr(field, 'description', '') or '',
103-
inputs=ToolInputOutputSchema(required=None),
104-
tool_provider=manual_provider
105-
))
106-
return tools
107-
108-
async def deregister_tool_provider(self, manual_provider: Provider) -> None:
109-
# Stateless: nothing to do
110-
pass
111-
112-
async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider, query: Optional[str] = None) -> Any:
113-
if not isinstance(tool_provider, GraphQLProvider):
114-
raise ValueError("GraphQLClientTransport can only be used with GraphQLProvider")
115-
self._enforce_https_or_localhost(tool_provider.url)
116-
headers = await self._prepare_headers(tool_provider)
117-
transport = AIOHTTPTransport(url=tool_provider.url, headers=headers)
96+
async def register_manual(
97+
self, caller: "UtcpClient", manual_call_template: CallTemplate
98+
) -> RegisterManualResult:
99+
if not isinstance(manual_call_template, GraphQLProvider):
100+
raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template")
101+
self._enforce_https_or_localhost(manual_call_template.url)
102+
103+
try:
104+
headers = await self._prepare_headers(manual_call_template)
105+
transport = AIOHTTPTransport(url=manual_call_template.url, headers=headers)
106+
async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session:
107+
schema = session.client.schema
108+
tools: List[Tool] = []
109+
110+
# Queries
111+
if hasattr(schema, "query_type") and schema.query_type:
112+
for name, field in schema.query_type.fields.items():
113+
tools.append(
114+
Tool(
115+
name=name,
116+
description=getattr(field, "description", "") or "",
117+
inputs=JsonSchema(type="object"),
118+
outputs=JsonSchema(type="object"),
119+
tool_call_template=manual_call_template,
120+
)
121+
)
122+
123+
# Mutations
124+
if hasattr(schema, "mutation_type") and schema.mutation_type:
125+
for name, field in schema.mutation_type.fields.items():
126+
tools.append(
127+
Tool(
128+
name=name,
129+
description=getattr(field, "description", "") or "",
130+
inputs=JsonSchema(type="object"),
131+
outputs=JsonSchema(type="object"),
132+
tool_call_template=manual_call_template,
133+
)
134+
)
135+
136+
# Subscriptions (listed for completeness)
137+
if hasattr(schema, "subscription_type") and schema.subscription_type:
138+
for name, field in schema.subscription_type.fields.items():
139+
tools.append(
140+
Tool(
141+
name=name,
142+
description=getattr(field, "description", "") or "",
143+
inputs=JsonSchema(type="object"),
144+
outputs=JsonSchema(type="object"),
145+
tool_call_template=manual_call_template,
146+
)
147+
)
148+
149+
manual = UtcpManual(tools=tools)
150+
return RegisterManualResult(
151+
manual_call_template=manual_call_template,
152+
manual=manual,
153+
success=True,
154+
errors=[],
155+
)
156+
except Exception as e:
157+
logger.error(f"GraphQL manual registration failed for '{manual_call_template.name}': {e}")
158+
return RegisterManualResult(
159+
manual_call_template=manual_call_template,
160+
manual=UtcpManual(manual_version="0.0.0", tools=[]),
161+
success=False,
162+
errors=[str(e)],
163+
)
164+
165+
async def deregister_manual(
166+
self, caller: "UtcpClient", manual_call_template: CallTemplate
167+
) -> None:
168+
# Stateless: nothing to clean up
169+
return None
170+
171+
async def call_tool(
172+
self,
173+
caller: "UtcpClient",
174+
tool_name: str,
175+
tool_args: Dict[str, Any],
176+
tool_call_template: CallTemplate,
177+
) -> Any:
178+
if not isinstance(tool_call_template, GraphQLProvider):
179+
raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template")
180+
self._enforce_https_or_localhost(tool_call_template.url)
181+
182+
headers = await self._prepare_headers(tool_call_template, tool_args)
183+
transport = AIOHTTPTransport(url=tool_call_template.url, headers=headers)
118184
async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session:
119-
if query is not None:
120-
document = gql_query(query)
121-
result = await session.execute(document, variable_values=tool_args)
122-
return result
123-
# If no query provided, build a simple query
124-
# Default to query operation
125-
op_type = getattr(tool_provider, 'operation_type', 'query')
126-
arg_str = ', '.join(f"${k}: String" for k in tool_args.keys())
185+
op_type = getattr(tool_call_template, "operation_type", "query")
186+
# Strip manual prefix if present (client prefixes at save time)
187+
base_tool_name = tool_name.split(".", 1)[-1] if "." in tool_name else tool_name
188+
189+
arg_str = ", ".join(f"${k}: String" for k in tool_args.keys())
127190
var_defs = f"({arg_str})" if arg_str else ""
128-
arg_pass = ', '.join(f"{k}: ${k}" for k in tool_args.keys())
191+
arg_pass = ", ".join(f"{k}: ${k}" for k in tool_args.keys())
129192
arg_pass = f"({arg_pass})" if arg_pass else ""
130-
gql_str = f"{op_type} {var_defs} {{ {tool_name}{arg_pass} }}"
193+
194+
gql_str = f"{op_type} {var_defs} {{ {base_tool_name}{arg_pass} }}"
131195
document = gql_query(gql_str)
132196
result = await session.execute(document, variable_values=tool_args)
133197
return result
134198

199+
async def call_tool_streaming(
200+
self,
201+
caller: "UtcpClient",
202+
tool_name: str,
203+
tool_args: Dict[str, Any],
204+
tool_call_template: CallTemplate,
205+
) -> AsyncGenerator[Any, None]:
206+
# Basic implementation: execute non-streaming and yield once
207+
result = await self.call_tool(caller, tool_name, tool_args, tool_call_template)
208+
yield result
209+
135210
async def close(self) -> None:
136-
self._oauth_tokens.clear()
211+
self._oauth_tokens.clear()

0 commit comments

Comments
 (0)