@@ -11,17 +11,20 @@ use std::{
1111} ;
1212
1313use async_graphql:: {
14- EmptySubscription , Schema , SchemaBuilder ,
14+ Data , Schema , SchemaBuilder ,
1515 extensions:: { ApolloTracing , ExtensionFactory , Tracing } ,
16+ http:: ALL_WEBSOCKET_PROTOCOLS ,
1617} ;
17- use async_graphql_axum:: { GraphQLRequest , GraphQLResponse } ;
18+ use async_graphql_axum:: { GraphQLProtocol , GraphQLRequest , GraphQLResponse , GraphQLWebSocket } ;
1819use axum:: {
1920 Extension , Router ,
2021 body:: Body ,
21- extract:: { ConnectInfo , FromRef , Query as AxumQuery , State } ,
22+ extract:: {
23+ ConnectInfo , FromRef , FromRequestParts , Query as AxumQuery , State , ws:: WebSocketUpgrade ,
24+ } ,
2225 http:: { HeaderMap , StatusCode } ,
2326 middleware:: { self } ,
24- response:: IntoResponse ,
27+ response:: { IntoResponse , Response as AxumResponse } ,
2528 routing:: { MethodRouter , Route , get, post} ,
2629} ;
2730use chrono:: Utc ;
@@ -71,6 +74,7 @@ use crate::{
7174 object:: IObject ,
7275 owner:: IOwner ,
7376 query:: { IotaGraphQLSchema , Query } ,
77+ subscription:: { GraphQLStream , Subscription } ,
7478 } ,
7579} ;
7680
@@ -166,7 +170,7 @@ impl Server {
166170
167171pub ( crate ) struct ServerBuilder {
168172 state : AppState ,
169- schema : SchemaBuilder < Query , Mutation , EmptySubscription > ,
173+ schema : SchemaBuilder < Query , Mutation , Subscription > ,
170174 router : Option < Router > ,
171175 db_reader : Option < Db > ,
172176 resolver : Option < PackageResolver > ,
@@ -239,7 +243,7 @@ impl ServerBuilder {
239243 self
240244 }
241245
242- fn build_schema ( self ) -> Schema < Query , Mutation , EmptySubscription > {
246+ fn build_schema ( self ) -> Schema < Query , Mutation , Subscription > {
243247 self . schema . finish ( )
244248 }
245249
@@ -249,7 +253,7 @@ impl ServerBuilder {
249253 self ,
250254 ) -> (
251255 String ,
252- Schema < Query , Mutation , EmptySubscription > ,
256+ Schema < Query , Mutation , Subscription > ,
253257 Db ,
254258 PackageResolver ,
255259 Router ,
@@ -275,9 +279,16 @@ impl ServerBuilder {
275279 if self . router . is_none ( ) {
276280 let router: Router = Router :: new ( )
277281 . route ( "/" , post ( graphql_handler) )
282+ . route ( "/subscriptions" , get ( subscription_handler) )
278283 . route ( "/{version}" , post ( graphql_handler) )
284+ . route ( "/{version}/subscriptions" , get ( subscription_handler) )
279285 . route ( "/graphql" , post ( graphql_handler) )
286+ . route ( "/graphql/subscriptions" , get ( subscription_handler) )
280287 . route ( "/graphql/{version}" , post ( graphql_handler) )
288+ . route (
289+ "/graphql/{version}/subscriptions" ,
290+ get ( subscription_handler) ,
291+ )
281292 . route ( "/health" , get ( health_check) )
282293 . route ( "/graphql/health" , get ( health_check) )
283294 . route ( "/graphql/{version}/health" , get ( health_check) )
@@ -327,8 +338,8 @@ impl ServerBuilder {
327338 info ! ( "Access control allow origin set to: {acl:?}" ) ;
328339
329340 let cors = CorsLayer :: new ( )
330- // Allow `POST` when accessing the resource
331- . allow_methods ( [ Method :: POST ] )
341+ // Allow `POST` & `GET` when accessing the resource
342+ . allow_methods ( [ Method :: POST , Method :: GET ] )
332343 // Allow requests from any origin
333344 . allow_origin ( acl)
334345 . allow_headers ( [ hyper:: header:: CONTENT_TYPE , LIMITS_HEADER . clone ( ) ] ) ;
@@ -479,6 +490,8 @@ impl ServerBuilder {
479490 None
480491 } ;
481492
493+ let graphql_streams = GraphQLStream :: new ( & config. connection . db_url , reader) . await ?;
494+
482495 builder = builder
483496 . context_data ( config. service . clone ( ) )
484497 . context_data ( loader)
@@ -489,7 +502,8 @@ impl ServerBuilder {
489502 . context_data ( iota_names_config)
490503 . context_data ( zklogin_config)
491504 . context_data ( metrics. clone ( ) )
492- . context_data ( config. clone ( ) ) ;
505+ . context_data ( config. clone ( ) )
506+ . context_data ( graphql_streams) ;
493507
494508 if config. internal_features . feature_gate {
495509 builder = builder. extension ( FeatureGate ) ;
@@ -526,8 +540,8 @@ impl ServerBuilder {
526540 }
527541}
528542
529- fn schema_builder ( ) -> SchemaBuilder < Query , Mutation , EmptySubscription > {
530- async_graphql:: Schema :: build ( Query , Mutation , EmptySubscription )
543+ fn schema_builder ( ) -> SchemaBuilder < Query , Mutation , Subscription > {
544+ async_graphql:: Schema :: build ( Query , Mutation , Subscription )
531545 . register_output_type :: < IMoveObject > ( )
532546 . register_output_type :: < IObject > ( )
533547 . register_output_type :: < IOwner > ( )
@@ -572,6 +586,47 @@ async fn graphql_handler(
572586 ( extensions, result. into ( ) )
573587}
574588
589+ /// Entry point for graphql streaming requests. Each request is stamped with a
590+ /// unique ID, a `ShowUsage` flag if set in the request headers and tracks the
591+ /// connection information produced by the client.
592+ async fn subscription_handler (
593+ ConnectInfo ( addr) : ConnectInfo < SocketAddr > ,
594+ Extension ( schema) : Extension < IotaGraphQLSchema > ,
595+ req : http:: Request < Body > ,
596+ ) -> AxumResponse {
597+ let headers_contains_show_usage = req. headers ( ) . contains_key ( ShowUsage :: name ( ) ) ;
598+ let ( mut parts, _body) = req. into_parts ( ) ;
599+
600+ // extract GraphQL protocol
601+ let protocol = match GraphQLProtocol :: from_request_parts ( & mut parts, & ( ) ) . await {
602+ Ok ( protocol) => protocol,
603+ Err ( err) => return err. into_response ( ) ,
604+ } ;
605+
606+ // extract WebSocket upgrade from request
607+ let upgrade = match WebSocketUpgrade :: from_request_parts ( & mut parts, & ( ) ) . await {
608+ Ok ( upgrade) => upgrade,
609+ Err ( err) => return err. into_response ( ) ,
610+ } ;
611+
612+ let resp = upgrade
613+ . protocols ( ALL_WEBSOCKET_PROTOCOLS )
614+ . on_upgrade ( move |stream| async move {
615+ // create connection data with per-connection values
616+ let mut connection_data = Data :: default ( ) ;
617+ connection_data. insert ( Uuid :: new_v4 ( ) ) ;
618+ if headers_contains_show_usage {
619+ connection_data. insert ( ShowUsage )
620+ }
621+ connection_data. insert ( addr) ;
622+
623+ let connection =
624+ GraphQLWebSocket :: new ( stream, schema, protocol) . with_data ( connection_data) ;
625+ connection. serve ( ) . await ;
626+ } ) ;
627+ resp
628+ }
629+
575630#[ derive( Clone ) ]
576631struct MetricsMakeCallbackHandler {
577632 metrics : Metrics ,
0 commit comments