|
| 1 | +use async_graphql::{ |
| 2 | + http::{playground_source, GraphQLPlaygroundConfig, ALL_WEBSOCKET_PROTOCOLS}, |
| 3 | + Data, EmptyMutation, Schema, |
| 4 | +}; |
| 5 | +use async_graphql_axum::{GraphQLProtocol, GraphQLRequest, GraphQLResponse, GraphQLWebSocket}; |
| 6 | +use axum::{ |
| 7 | + extract::{ws::WebSocketUpgrade, Extension}, |
| 8 | + http::header::HeaderMap, |
| 9 | + response::{Html, IntoResponse, Response}, |
| 10 | + routing::get, |
| 11 | + Router, Server, |
| 12 | +}; |
| 13 | +use token::{on_connection_init, QueryRoot, SubscriptionRoot, Token, TokenSchema}; |
| 14 | + |
| 15 | +async fn graphql_playground() -> impl IntoResponse { |
| 16 | + Html(playground_source( |
| 17 | + GraphQLPlaygroundConfig::new("/").subscription_endpoint("/ws"), |
| 18 | + )) |
| 19 | +} |
| 20 | + |
| 21 | +fn get_token_from_headers(headers: &HeaderMap) -> Option<Token> { |
| 22 | + headers |
| 23 | + .get("Token") |
| 24 | + .and_then(|value| value.to_str().map(|s| Token(s.to_string())).ok()) |
| 25 | +} |
| 26 | + |
| 27 | +async fn graphql_handler( |
| 28 | + req: GraphQLRequest, |
| 29 | + Extension(schema): Extension<TokenSchema>, |
| 30 | + headers: HeaderMap, |
| 31 | +) -> GraphQLResponse { |
| 32 | + let mut req = req.into_inner(); |
| 33 | + if let Some(token) = get_token_from_headers(&headers) { |
| 34 | + req = req.data(token); |
| 35 | + } |
| 36 | + schema.execute(req).await.into() |
| 37 | +} |
| 38 | + |
| 39 | +async fn graphql_ws_handler( |
| 40 | + Extension(schema): Extension<TokenSchema>, |
| 41 | + protocol: GraphQLProtocol, |
| 42 | + websocket: WebSocketUpgrade, |
| 43 | + headers: HeaderMap, |
| 44 | +) -> Response { |
| 45 | + let mut data = Data::default(); |
| 46 | + if let Some(token) = get_token_from_headers(&headers) { |
| 47 | + data.insert(token); |
| 48 | + } |
| 49 | + |
| 50 | + websocket |
| 51 | + .protocols(ALL_WEBSOCKET_PROTOCOLS) |
| 52 | + .on_upgrade(move |stream| { |
| 53 | + GraphQLWebSocket::new(stream, schema.clone(), protocol) |
| 54 | + .with_data(data) |
| 55 | + .on_connection_init(on_connection_init) |
| 56 | + .serve() |
| 57 | + }) |
| 58 | +} |
| 59 | + |
| 60 | +#[tokio::main] |
| 61 | +async fn main() { |
| 62 | + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); |
| 63 | + |
| 64 | + let app = Router::new() |
| 65 | + .route("/", get(graphql_playground).post(graphql_handler)) |
| 66 | + .route("/ws", get(graphql_ws_handler)) |
| 67 | + .layer(Extension(schema)); |
| 68 | + |
| 69 | + println!("Playground: http://localhost:8000"); |
| 70 | + |
| 71 | + Server::bind(&"0.0.0.0:8000".parse().unwrap()) |
| 72 | + .serve(app.into_make_service()) |
| 73 | + .await |
| 74 | + .unwrap(); |
| 75 | +} |
0 commit comments