1- import sys
2- from typing import Dict , Any , List , Optional , Callable
1+ import logging
2+ from typing import Dict , Any , List , Optional , AsyncGenerator , TYPE_CHECKING
3+
34import aiohttp
4- import asyncio
5- import ssl
65from gql import Client as GqlClient , gql as gql_query
76from gql .transport .aiohttp import AIOHTTPTransport
8- from utcp .client .client_transport_interface import ClientTransportInterface
9- from utcp .shared .provider import Provider , GraphQLProvider
10- from utcp .shared .tool import Tool , ToolInputOutputSchema
11- from utcp .shared .auth import ApiKeyAuth , BasicAuth , OAuth2Auth
12- import logging
7+
8+ from utcp .interfaces .communication_protocol import CommunicationProtocol
9+ from utcp .data .call_template import CallTemplate
10+ from utcp .data .tool import Tool , JsonSchema
11+ from utcp .data .utcp_manual import UtcpManual
12+ from utcp .data .register_manual_response import RegisterManualResult
13+ from utcp .data .auth_implementations .api_key_auth import ApiKeyAuth
14+ from utcp .data .auth_implementations .basic_auth import BasicAuth
15+ from utcp .data .auth_implementations .oauth2_auth import OAuth2Auth
16+
17+ from utcp_gql .gql_call_template import GraphQLProvider
18+
19+ if TYPE_CHECKING :
20+ from utcp .utcp_client import UtcpClient
21+
1322
1423logging .basicConfig (
1524 level = logging .INFO ,
16- format = "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s"
25+ format = "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s" ,
1726)
1827
1928logger = logging .getLogger (__name__ )
2029
21- class GraphQLClientTransport (ClientTransportInterface ):
22- """
23- Simple, robust, production-ready GraphQL transport using gql.
24- Stateless, per-operation. Supports all GraphQL features.
30+
31+ class GraphQLCommunicationProtocol (CommunicationProtocol ):
32+ """GraphQL protocol implementation for UTCP 1.0.
33+
34+ - Discovers tools via GraphQL schema introspection.
35+ - Executes per-call sessions using `gql` over HTTP(S).
36+ - Supports `ApiKeyAuth`, `BasicAuth`, and `OAuth2Auth`.
37+ - Enforces HTTPS or localhost for security.
2538 """
26- def __init__ (self ):
39+
40+ def __init__ (self ) -> None :
2741 self ._oauth_tokens : Dict [str , Dict [str , Any ]] = {}
2842
29- def _enforce_https_or_localhost (self , url : str ):
30- if not (url .startswith ("https://" ) or url .startswith ("http://localhost" ) or url .startswith ("http://127.0.0.1" )):
43+ def _enforce_https_or_localhost (self , url : str ) -> None :
44+ if not (
45+ url .startswith ("https://" )
46+ or url .startswith ("http://localhost" )
47+ or url .startswith ("http://127.0.0.1" )
48+ ):
3149 raise ValueError (
32- f "Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. Got: { url } . "
33- "Non-secure URLs are vulnerable to man-in-the-middle attacks ."
50+ "Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. "
51+ f"Got: { url } ."
3452 )
3553
3654 async def _handle_oauth2 (self , auth : OAuth2Auth ) -> str :
@@ -39,98 +57,155 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str:
3957 return self ._oauth_tokens [client_id ]["access_token" ]
4058 async with aiohttp .ClientSession () as session :
4159 data = {
42- ' grant_type' : ' client_credentials' ,
43- ' client_id' : client_id ,
44- ' client_secret' : auth .client_secret ,
45- ' scope' : auth .scope
60+ " grant_type" : " client_credentials" ,
61+ " client_id" : client_id ,
62+ " client_secret" : auth .client_secret ,
63+ " scope" : auth .scope ,
4664 }
4765 async with session .post (auth .token_url , data = data ) as resp :
4866 resp .raise_for_status ()
4967 token_response = await resp .json ()
5068 self ._oauth_tokens [client_id ] = token_response
5169 return token_response ["access_token" ]
5270
53- async def _prepare_headers (self , provider : GraphQLProvider ) -> Dict [str , str ]:
54- headers = provider .headers .copy () if provider .headers else {}
71+ async def _prepare_headers (
72+ self , provider : GraphQLProvider , tool_args : Optional [Dict [str , Any ]] = None
73+ ) -> Dict [str , str ]:
74+ headers : Dict [str , str ] = provider .headers .copy () if provider .headers else {}
5575 if provider .auth :
5676 if isinstance (provider .auth , ApiKeyAuth ):
57- if provider .auth .api_key :
58- if provider .auth .location == "header" :
59- headers [provider .auth .var_name ] = provider .auth .api_key
60- # (query/cookie not supported for GraphQL by default)
77+ if provider .auth .api_key and provider .auth .location == "header" :
78+ headers [provider .auth .var_name ] = provider .auth .api_key
6179 elif isinstance (provider .auth , BasicAuth ):
6280 import base64
81+
6382 userpass = f"{ provider .auth .username } :{ provider .auth .password } "
6483 headers ["Authorization" ] = "Basic " + base64 .b64encode (userpass .encode ()).decode ()
6584 elif isinstance (provider .auth , OAuth2Auth ):
6685 token = await self ._handle_oauth2 (provider .auth )
6786 headers ["Authorization" ] = f"Bearer { token } "
87+
88+ # Map selected tool_args into headers if requested
89+ if tool_args and provider .header_fields :
90+ for field in provider .header_fields :
91+ if field in tool_args and isinstance (tool_args [field ], str ):
92+ headers [field ] = tool_args [field ]
93+
6894 return headers
6995
70- async def register_tool_provider (self , manual_provider : Provider ) -> List [Tool ]:
71- if not isinstance (manual_provider , GraphQLProvider ):
72- raise ValueError ("GraphQLClientTransport can only be used with GraphQLProvider" )
73- self ._enforce_https_or_localhost (manual_provider .url )
74- headers = await self ._prepare_headers (manual_provider )
75- transport = AIOHTTPTransport (url = manual_provider .url , headers = headers )
76- async with GqlClient (transport = transport , fetch_schema_from_transport = True ) as session :
77- schema = session .client .schema
78- tools = []
79- # Queries
80- if hasattr (schema , 'query_type' ) and schema .query_type :
81- for name , field in schema .query_type .fields .items ():
82- tools .append (Tool (
83- name = name ,
84- description = getattr (field , 'description' , '' ) or '' ,
85- inputs = ToolInputOutputSchema (required = None ),
86- tool_provider = manual_provider
87- ))
88- # Mutations
89- if hasattr (schema , 'mutation_type' ) and schema .mutation_type :
90- for name , field in schema .mutation_type .fields .items ():
91- tools .append (Tool (
92- name = name ,
93- description = getattr (field , 'description' , '' ) or '' ,
94- inputs = ToolInputOutputSchema (required = None ),
95- tool_provider = manual_provider
96- ))
97- # Subscriptions (listed, but not called here)
98- if hasattr (schema , 'subscription_type' ) and schema .subscription_type :
99- for name , field in schema .subscription_type .fields .items ():
100- tools .append (Tool (
101- name = name ,
102- description = getattr (field , 'description' , '' ) or '' ,
103- inputs = ToolInputOutputSchema (required = None ),
104- tool_provider = manual_provider
105- ))
106- return tools
107-
108- async def deregister_tool_provider (self , manual_provider : Provider ) -> None :
109- # Stateless: nothing to do
110- pass
111-
112- async def call_tool (self , tool_name : str , tool_args : Dict [str , Any ], tool_provider : Provider , query : Optional [str ] = None ) -> Any :
113- if not isinstance (tool_provider , GraphQLProvider ):
114- raise ValueError ("GraphQLClientTransport can only be used with GraphQLProvider" )
115- self ._enforce_https_or_localhost (tool_provider .url )
116- headers = await self ._prepare_headers (tool_provider )
117- transport = AIOHTTPTransport (url = tool_provider .url , headers = headers )
96+ async def register_manual (
97+ self , caller : "UtcpClient" , manual_call_template : CallTemplate
98+ ) -> RegisterManualResult :
99+ if not isinstance (manual_call_template , GraphQLProvider ):
100+ raise ValueError ("GraphQLCommunicationProtocol requires a GraphQLProvider call template" )
101+ self ._enforce_https_or_localhost (manual_call_template .url )
102+
103+ try :
104+ headers = await self ._prepare_headers (manual_call_template )
105+ transport = AIOHTTPTransport (url = manual_call_template .url , headers = headers )
106+ async with GqlClient (transport = transport , fetch_schema_from_transport = True ) as session :
107+ schema = session .client .schema
108+ tools : List [Tool ] = []
109+
110+ # Queries
111+ if hasattr (schema , "query_type" ) and schema .query_type :
112+ for name , field in schema .query_type .fields .items ():
113+ tools .append (
114+ Tool (
115+ name = name ,
116+ description = getattr (field , "description" , "" ) or "" ,
117+ inputs = JsonSchema (type = "object" ),
118+ outputs = JsonSchema (type = "object" ),
119+ tool_call_template = manual_call_template ,
120+ )
121+ )
122+
123+ # Mutations
124+ if hasattr (schema , "mutation_type" ) and schema .mutation_type :
125+ for name , field in schema .mutation_type .fields .items ():
126+ tools .append (
127+ Tool (
128+ name = name ,
129+ description = getattr (field , "description" , "" ) or "" ,
130+ inputs = JsonSchema (type = "object" ),
131+ outputs = JsonSchema (type = "object" ),
132+ tool_call_template = manual_call_template ,
133+ )
134+ )
135+
136+ # Subscriptions (listed for completeness)
137+ if hasattr (schema , "subscription_type" ) and schema .subscription_type :
138+ for name , field in schema .subscription_type .fields .items ():
139+ tools .append (
140+ Tool (
141+ name = name ,
142+ description = getattr (field , "description" , "" ) or "" ,
143+ inputs = JsonSchema (type = "object" ),
144+ outputs = JsonSchema (type = "object" ),
145+ tool_call_template = manual_call_template ,
146+ )
147+ )
148+
149+ manual = UtcpManual (tools = tools )
150+ return RegisterManualResult (
151+ manual_call_template = manual_call_template ,
152+ manual = manual ,
153+ success = True ,
154+ errors = [],
155+ )
156+ except Exception as e :
157+ logger .error (f"GraphQL manual registration failed for '{ manual_call_template .name } ': { e } " )
158+ return RegisterManualResult (
159+ manual_call_template = manual_call_template ,
160+ manual = UtcpManual (manual_version = "0.0.0" , tools = []),
161+ success = False ,
162+ errors = [str (e )],
163+ )
164+
165+ async def deregister_manual (
166+ self , caller : "UtcpClient" , manual_call_template : CallTemplate
167+ ) -> None :
168+ # Stateless: nothing to clean up
169+ return None
170+
171+ async def call_tool (
172+ self ,
173+ caller : "UtcpClient" ,
174+ tool_name : str ,
175+ tool_args : Dict [str , Any ],
176+ tool_call_template : CallTemplate ,
177+ ) -> Any :
178+ if not isinstance (tool_call_template , GraphQLProvider ):
179+ raise ValueError ("GraphQLCommunicationProtocol requires a GraphQLProvider call template" )
180+ self ._enforce_https_or_localhost (tool_call_template .url )
181+
182+ headers = await self ._prepare_headers (tool_call_template , tool_args )
183+ transport = AIOHTTPTransport (url = tool_call_template .url , headers = headers )
118184 async with GqlClient (transport = transport , fetch_schema_from_transport = True ) as session :
119- if query is not None :
120- document = gql_query (query )
121- result = await session .execute (document , variable_values = tool_args )
122- return result
123- # If no query provided, build a simple query
124- # Default to query operation
125- op_type = getattr (tool_provider , 'operation_type' , 'query' )
126- arg_str = ', ' .join (f"${ k } : String" for k in tool_args .keys ())
185+ op_type = getattr (tool_call_template , "operation_type" , "query" )
186+ # Strip manual prefix if present (client prefixes at save time)
187+ base_tool_name = tool_name .split ("." , 1 )[- 1 ] if "." in tool_name else tool_name
188+
189+ arg_str = ", " .join (f"${ k } : String" for k in tool_args .keys ())
127190 var_defs = f"({ arg_str } )" if arg_str else ""
128- arg_pass = ', ' .join (f"{ k } : ${ k } " for k in tool_args .keys ())
191+ arg_pass = ", " .join (f"{ k } : ${ k } " for k in tool_args .keys ())
129192 arg_pass = f"({ arg_pass } )" if arg_pass else ""
130- gql_str = f"{ op_type } { var_defs } {{ { tool_name } { arg_pass } }}"
193+
194+ gql_str = f"{ op_type } { var_defs } {{ { base_tool_name } { arg_pass } }}"
131195 document = gql_query (gql_str )
132196 result = await session .execute (document , variable_values = tool_args )
133197 return result
134198
199+ async def call_tool_streaming (
200+ self ,
201+ caller : "UtcpClient" ,
202+ tool_name : str ,
203+ tool_args : Dict [str , Any ],
204+ tool_call_template : CallTemplate ,
205+ ) -> AsyncGenerator [Any , None ]:
206+ # Basic implementation: execute non-streaming and yield once
207+ result = await self .call_tool (caller , tool_name , tool_args , tool_call_template )
208+ yield result
209+
135210 async def close (self ) -> None :
136- self ._oauth_tokens .clear ()
211+ self ._oauth_tokens .clear ()
0 commit comments