diff --git a/.gitignore b/.gitignore index 6527148..1d9b32b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target tunnelto_server +!tunnelto_server/ .env \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index d7f4e3d..622aae7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,6 +76,7 @@ dependencies = [ "humansize", "num-traits", "serde", + "serde_json", "toml", ] @@ -2211,6 +2212,7 @@ dependencies = [ "pretty_env_logger", "serde", "serde_json", + "serde_urlencoded", "structopt", "thiserror", "tokio", diff --git a/tunnelto/Cargo.toml b/tunnelto/Cargo.toml index 3124bef..eaefffc 100644 --- a/tunnelto/Cargo.toml +++ b/tunnelto/Cargo.toml @@ -32,8 +32,9 @@ indicatif = "0.15.0" httparse = "1.3.4" warp = "0.2.2" bytes = "0.5.5" -askama = "0.9.0" +askama = { version = "0.9.0", features = ["serde-json"] } chrono = "0.4.11" uuid = {version = "0.8.1", features = ["serde", "v4"] } hyper = "0.13.6" -http-body = "0.3.1" \ No newline at end of file +http-body = "0.3.1" +serde_urlencoded = "0.6.1" \ No newline at end of file diff --git a/tunnelto/src/introspect/mod.rs b/tunnelto/src/introspect/mod.rs index d3fe83d..8e3da41 100644 --- a/tunnelto/src/introspect/mod.rs +++ b/tunnelto/src/introspect/mod.rs @@ -11,23 +11,32 @@ use bytes::Buf; use uuid::Uuid; use http_body::Body; +type HttpClient = hyper::Client; + #[derive(Debug, Clone)] pub struct Request { id: String, status: u16, + is_replay: bool, path: String, method: Method, headers: HashMap>, body_data: Vec, response_headers: HashMap>, response_data: Vec, - timestamp: chrono::NaiveDateTime, + started: chrono::NaiveDateTime, + completed: chrono::NaiveDateTime, } -#[derive(Debug, Clone, askama::Template)] -#[template(path="base.html")] -struct Inspector { - requests: Vec +impl Request { + pub fn elapsed(&self) -> String { + let duration = self.completed - self.started; + if duration.num_seconds() == 0 { + format!("{}ms", duration.num_milliseconds()) + } else { + format!("{}s", duration.num_seconds()) + } + } } lazy_static::lazy_static! { @@ -52,12 +61,20 @@ impl warp::reject::Reject for ForwardError {} pub fn start_introspection_server(config: Config) -> IntrospectionAddrs { let local_addr = format!("localhost:{}", &config.local_port); + let http_client = HttpClient::new(); + + let get_client = move || { + let client = http_client.clone(); + warp::any().map(move || client.clone()).boxed() + }; + let intercept = warp::any() .and(warp::any().map(move || local_addr.clone())) .and(warp::method()) .and(warp::path::full()) .and(warp::header::headers_cloned()) .and(warp::body::stream()) + .and(get_client()) .and_then(forward); let (forward_address, intercept_server) = warp::serve(intercept).bind_ephemeral(SocketAddr::from(([0,0,0,0], 0))); @@ -81,8 +98,22 @@ pub fn start_introspection_server(config: Config) -> IntrospectionAddrs { ); res })); - - let web_explorer = warp::get().and(warp::path::end()).and_then(inspector).or(css).or(logo); + let forward_clone = forward_address.clone(); + + let web_explorer = warp::get().and(warp::path::end()).and_then(inspector) + .or(warp::get() + .and(warp::path("detail")) + .and(warp::path::param()) + .and_then(request_detail)) + .or(warp::post() + .and(warp::path("replay")) + .and(warp::path::param()) + .and(get_client()) + .and_then(move |id, client| { + replay_request(id, client, forward_clone.clone()) + })) + .or(css) + .or(logo); let (web_explorer_address, explorer_server) = warp::serve(web_explorer).bind_ephemeral(SocketAddr::from(([0,0,0,0], 0))); tokio::spawn(explorer_server); @@ -94,10 +125,11 @@ async fn forward(local_addr: String, method: Method, path: FullPath, mut headers: HeaderMap, - mut body: impl Stream> + Send + Sync + Unpin + 'static) - -> Result, warp::reject::Rejection> + mut body: impl Stream> + Send + Sync + Unpin + 'static, + client: HttpClient) -> Result, warp::reject::Rejection> { - let now = chrono::Utc::now().naive_utc(); + let started = chrono::Utc::now().naive_utc(); + let mut request_headers = HashMap::new(); headers.keys().for_each(|k| { let values = headers.get_all(k).iter().filter_map(|v| v.to_str().ok()).map(|s| s.to_owned()).collect(); @@ -115,7 +147,6 @@ async fn forward(local_addr: String, collected.extend_from_slice(chunk.as_ref()) } - let client = hyper::client::Client::new(); let url = format!("http://{}{}", local_addr, path.as_str()); let mut request = hyper::Request::builder() @@ -138,7 +169,7 @@ async fn forward(local_addr: String, let mut response_headers = HashMap::new(); response.headers().keys().for_each(|k| { - let values = headers.get_all(k).iter().filter_map(|v| v.to_str().ok()).map(|s| s.to_owned()).collect(); + let values = response.headers().get_all(k).iter().filter_map(|v| v.to_str().ok()).map(|s| s.to_owned()).collect(); response_headers.insert(k.as_str().to_owned(), values); }); @@ -163,7 +194,9 @@ async fn forward(local_addr: String, body_data: collected, response_headers, response_data: response_data.clone(), - timestamp: now, + started, + completed: chrono::Utc::now().naive_utc(), + is_replay: false, }; REQUESTS.write().unwrap().insert(stored_request.id.clone(), stored_request); @@ -171,13 +204,121 @@ async fn forward(local_addr: String, Ok(Box::new(warp::http::Response::from_parts(parts, response_data))) } +#[derive(Debug, Clone, askama::Template)] +#[template(path="index.html")] +struct Inspector { + requests: Vec +} + +#[derive(Debug, Clone, askama::Template)] +#[template(path="detail.html")] +struct InspectorDetail { + request: Request, + incoming: BodyData, + response: BodyData, +} + +#[derive(Debug, Clone)] +struct BodyData { + data_type: DataType, + content: Option, + raw: String, +} + +impl AsRef for BodyData { + fn as_ref(&self) -> &BodyData { + &self + } +} + +#[derive(Debug, Clone)] +enum DataType { + Json, + Unknown +} + async fn inspector() -> Result, warp::reject::Rejection> { let mut requests:Vec = REQUESTS.read().unwrap().values().map(|r| r.clone()).collect(); - requests.sort_by(|a,b| a.timestamp.cmp(&b.timestamp)); + requests.sort_by(|a,b| b.completed.cmp(&a.completed)); let inspect = Inspector { requests }; Ok(Page(inspect)) } +async fn request_detail(rid: String) -> Result, warp::reject::Rejection> { + let request:Request = match REQUESTS.read().unwrap().get(&rid) { + Some(r) => r.clone(), + None => return Err(warp::reject::not_found()) + }; + + let detail = InspectorDetail{ + incoming: get_body_data(&request.body_data), + response: get_body_data(&request.response_data), + request, + }; + + Ok(Page(detail)) +} + +fn get_body_data(input: &[u8]) -> BodyData { + let mut body = BodyData { + data_type: DataType::Unknown, + content: None, + raw: std::str::from_utf8(input).map(|s| s.to_string()).unwrap_or("No UTF-8 Data".to_string()) + }; + + match serde_json::from_slice::(input) { + Ok(serde_json::Value::Object(map)) => { + body.data_type = DataType::Json; + body.content = serde_json::to_string_pretty(&map).ok(); + }, + Ok(serde_json::Value::Array(arr)) => { + body.data_type = DataType::Json; + body.content = serde_json::to_string_pretty(&arr).ok(); + }, + _ => {} + } + + body +} + +async fn replay_request(rid: String, client: HttpClient, addr: SocketAddr) -> Result, warp::reject::Rejection> { + let request:Request = match REQUESTS.read().unwrap().get(&rid) { + Some(r) => r.clone(), + None => return Err(warp::reject::not_found()) + }; + + let url = format!("http://localhost:{}{}", addr.port(), &request.path); + + let mut new_request = hyper::Request::builder() + .method(request.method) + .uri(url.parse::().map_err(|e| { + log::error!("invalid incoming url: {}, error: {:?}", url, e); + warp::reject::custom(ForwardError::InvalidURL) + })?); + + for (header, values) in &request.headers { + for v in values { + new_request = new_request.header(header, v) + } + } + + let new_request = new_request.body(hyper::Body::from(request.body_data)).map_err(|e| { + log::error!("failed to build request: {:?}", e); + warp::reject::custom(ForwardError::InvalidRequest) + })?; + + let _ = client.request(new_request).await.map_err(|e| { + log::error!("local server error: {:?}", e); + warp::reject::custom(ForwardError::LocalServerError) + })?; + + let response = warp::http::Response::builder() + .status(warp::http::StatusCode::SEE_OTHER) + .header(warp::http::header::LOCATION, "/") + .body(b"".to_vec()); + + Ok(Box::new(response)) +} struct Page(T); diff --git a/tunnelto/templates/base.html b/tunnelto/templates/base.html index 7c2b183..35dc814 100644 --- a/tunnelto/templates/base.html +++ b/tunnelto/templates/base.html @@ -31,50 +31,7 @@

- {% if requests.is_empty() %} -

No requests yet

- {% else %} -
- - - - - - - - - - - {% for r in requests %} - - - - - - - - - {% endfor %} - -
- Timestamp - StatusMethodPathReq.Resp.
- {{r.timestamp}} - - {{r.status}} - - {{r.method}} - - {{r.path}} - - {{r.body_data.len()/1024}} KB - - {{r.response_data.len() / 1024}} KB -
-
- {% endif %} - - + {% block content %}{% endblock %}
\ No newline at end of file diff --git a/tunnelto/templates/body_detail.html b/tunnelto/templates/body_detail.html new file mode 100644 index 0000000..2804784 --- /dev/null +++ b/tunnelto/templates/body_detail.html @@ -0,0 +1,79 @@ + + +
+
    +
  • + + Text + +
  • + {% match body.data_type %} + {% when DataType::Json %} +
  • + + JSON + +
  • + {% when DataType::Unknown %} + {% endmatch %} +
+
+
+
+

{{ body.raw }}

+
+
+ {% match body.content %} + {% when Some with (contents) %} + {{ contents|linebreaks|safe }} + {% when None %} + {% endmatch %} +
+
+ + diff --git a/tunnelto/templates/detail.html b/tunnelto/templates/detail.html new file mode 100644 index 0000000..315c4cf --- /dev/null +++ b/tunnelto/templates/detail.html @@ -0,0 +1,94 @@ +{% extends "base.html" %} + +{% block content %} + + + + + Go Back + + +
+
+ + + + + + + + + + + + + + + + + + + + + + + + +
Time StartDurationStatusMethodPathINOUT
+ {{request.completed.format("%H:%M:%S")}} + + {{request.elapsed() }} + + {% if request.status >= 200 && request.status < 300 %} + {{request.status}} + {% elseif request.status >= 300 && request.status < 400 %} + {{request.status}} + {% elseif request.status >= 400 && request.status < 500 %} + {{request.status}} + {% elseif request.status >= 500 %} + {{request.status}} + {% else %} + {{request.status}} + {% endif %} + + {{request.method}} + + {{request.path}} + + {{request.body_data.len()/1024}} KB + + {{request.response_data.len() / 1024}} KB + +
+ +
+
+
+
+ + +
+

Request

+ {# hacky to get local vars #} + {% if 1 == 1 %} + {% let prefix = "req" %} + {% let body = incoming.as_ref() %} + {% let headers = request.headers.clone() %} + {% include "headers_detail.html" %} + {% include "body_detail.html" %} + {% endif %} +
+ +
+

Response

+ {# hacky to get local vars #} + {% if 1 == 1 %} + {% let prefix = "resp" %} + {% let body = response.as_ref() %} + {% let headers = request.response_headers.clone() %} + {% include "headers_detail.html" %} + {% include "body_detail.html" %} + {% endif %} +
+ +{% endblock %} diff --git a/tunnelto/templates/headers_detail.html b/tunnelto/templates/headers_detail.html new file mode 100644 index 0000000..d72e22f --- /dev/null +++ b/tunnelto/templates/headers_detail.html @@ -0,0 +1,24 @@ +
+ + + + + + + {% for (name, values) in headers %} + + + {% for value in values %} + + {% endfor %} + + + {% endfor %} + + +
Header NameHeader Values
+ {{name}} + + {{ value }} +
+
diff --git a/tunnelto/templates/index.html b/tunnelto/templates/index.html new file mode 100644 index 0000000..1de811a --- /dev/null +++ b/tunnelto/templates/index.html @@ -0,0 +1,74 @@ +{% extends "base.html" %} + +{% block content %} + + + + + Load new data + + {% if requests.is_empty() %} +

No requests yet

+ {% else %} +
+ + + + + + + + + + + + + {% for r in requests %} + + + + + + + + + + + {% endfor %} + +
Time StartDurationStatusMethodPathINOUT
+ + {{r.completed.format("%H:%M:%S")}} + + + {{r.elapsed() }} + + {% if r.status >= 200 && r.status < 300 %} + {{r.status}} + {% elseif r.status >= 300 && r.status < 400 %} + {{r.status}} + {% elseif r.status >= 400 && r.status < 500 %} + {{r.status}} + {% elseif r.status >= 500 %} + {{r.status}} + {% else %} + {{r.status}} + {% endif %} + + {{r.method}} + + {{r.path}} + + {{r.body_data.len()/1024}} KB + + {{r.response_data.len() / 1024}} KB + + + + + + +
+
+ {% endif %} +{% endblock %} diff --git a/tunnelto_server/Cargo.toml b/tunnelto_server/Cargo.toml new file mode 100644 index 0000000..4de40be --- /dev/null +++ b/tunnelto_server/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "tunnelto_server" +description = "expose your local web server to the internet with a public url" +version = "0.1.10" +authors = ["Alex Grinman "] +edition = "2018" +license = "MIT" +repository = "https://tunnelto.dev" +readme = "../README.md" + +[[bin]] +name = "tunnelto_server" +path = "src/main.rs" + +[dependencies] +tunnelto_lib = { path = "../tunnelto_lib" } +warp = "0.2.2" +tokio = { version = "0.2", features = ["full"] } +base64 = "0.11.0" +futures = "0.3.4" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +lazy_static = "1.4.0" +chrono = "0.4.11" +pretty_env_logger = "0.4.0" +log = "0.4.8" +httparse = "1.3.4" +url = "2.1.1" +thiserror = "1.0" +uuid = {version = "0.8.1", features = ["serde", "v4"] } +sha2 = "0.9.0" + +# auth handler +rusoto_core = "0.44.0" +rusoto_dynamodb = "0.44.0" +rusoto_credential = "0.44.0" \ No newline at end of file diff --git a/tunnelto_server/src/active_stream.rs b/tunnelto_server/src/active_stream.rs new file mode 100644 index 0000000..ee63367 --- /dev/null +++ b/tunnelto_server/src/active_stream.rs @@ -0,0 +1,28 @@ +#[derive(Debug, Clone)] +pub struct ActiveStream { + pub id: StreamId, + pub client: ConnectedClient, + pub tx: UnboundedSender +} + +impl ActiveStream { + pub fn new(client: ConnectedClient) -> (Self, UnboundedReceiver) { + let (tx, rx) = unbounded(); + (ActiveStream { + id: StreamId::generate(), + client, + tx + }, + rx) + } +} + +pub type ActiveStreams = Arc>>; + +use super::*; +#[derive(Debug, Clone)] +pub enum StreamMessage { + Data(Vec), + TunnelRefused, + NoClientTunnel +} diff --git a/tunnelto_server/src/auth_db.rs b/tunnelto_server/src/auth_db.rs new file mode 100644 index 0000000..67c2c3c --- /dev/null +++ b/tunnelto_server/src/auth_db.rs @@ -0,0 +1,129 @@ +use rusoto_dynamodb::{DynamoDbClient, DynamoDb, AttributeValue, GetItemInput, GetItemError}; +use rusoto_core::{HttpClient, Client, Region}; + +use std::collections::HashMap; +use uuid::Uuid; +use thiserror::Error; +use sha2::Digest; +use rusoto_credential::EnvironmentProvider; +use std::str::FromStr; + +pub struct AuthDbService { + client: DynamoDbClient, +} + +impl AuthDbService { + pub fn new() -> Result> { + let provider = EnvironmentProvider::default(); + let http_client = HttpClient::new()?; + let client = Client::new_with(provider, http_client); + + Ok(Self { client: DynamoDbClient::new_with_client(client, Region::UsEast1) }) + } +} + +mod domain_db { + pub const TABLE_NAME:&'static str = "tunnelto_domains"; + pub const PRIMARY_KEY:&'static str = "subdomain"; + pub const ACCOUNT_ID:&'static str = "account_id"; +} + +mod key_db { + pub const TABLE_NAME:&'static str = "tunnelto_auth"; + pub const PRIMARY_KEY:&'static str = "auth_key_hash"; + pub const ACCOUNT_ID:&'static str = "account_id"; +} + +fn key_id(auth_key: &str) -> String { + let hash = sha2::Sha256::digest(auth_key.as_bytes()).to_vec(); + base64::encode_config(&hash, base64::URL_SAFE_NO_PAD) +} + +#[derive(Error, Debug)] +pub enum Error { + #[error("failed to get domain item")] + AuthDbGetItem(#[from] rusoto_core::RusotoError), + + #[error("The authentication key is invalid")] + AccountNotFound, + + #[error("The authentication key is invalid")] + InvalidAccountId(#[from] uuid::Error), + + #[error("The subdomain is not authorized")] + SubdomainNotAuthorized, +} + +pub enum AuthResult { + ReservedByYou, + ReservedByOther, + Available, +} +impl AuthDbService { + pub async fn auth_sub_domain(&self, auth_key: &str, subdomain: &str) -> Result { + let authenticated_account_id = self.get_account_id_for_auth_key(auth_key).await?; + match self.get_account_id_for_subdomain(subdomain).await? { + Some(account_id) => { + if authenticated_account_id == account_id { + return Ok(AuthResult::ReservedByYou) + } + + Ok(AuthResult::ReservedByOther) + }, + None => Ok(AuthResult::Available) + } + } + + async fn get_account_id_for_auth_key(&self, auth_key: &str) -> Result { + let auth_key_hash = key_id(auth_key); + + let mut input = GetItemInput { table_name: key_db::TABLE_NAME.to_string(), ..Default::default() }; + input.key = { + let mut item = HashMap::new(); + item.insert(key_db::PRIMARY_KEY.to_string(), AttributeValue { + s: Some(auth_key_hash), + ..Default::default() + }); + item + }; + + let result = self.client.get_item(input).await?; + let account_str = result.item + .unwrap_or(HashMap::new()) + .get(key_db::ACCOUNT_ID) + .cloned() + .unwrap_or(AttributeValue::default()) + .s + .ok_or(Error::AccountNotFound)?; + + let uuid = Uuid::from_str(&account_str)?; + Ok(uuid) + } + + async fn get_account_id_for_subdomain(&self, subdomain: &str) -> Result, Error> { + let mut input = GetItemInput { table_name: domain_db::TABLE_NAME.to_string(), ..Default::default() }; + input.key = { + let mut item = HashMap::new(); + item.insert(domain_db::PRIMARY_KEY.to_string(), AttributeValue { + s: Some(subdomain.to_string()), + ..Default::default() + }); + item + }; + + let result = self.client.get_item(input).await?; + let account_str = result.item + .unwrap_or(HashMap::new()) + .get(domain_db::ACCOUNT_ID) + .cloned() + .unwrap_or(AttributeValue::default()) + .s; + + if let Some(account_str) = account_str { + let uuid = Uuid::from_str(&account_str)?; + Ok(Some(uuid)) + } else { + Ok(None) + } + } +} \ No newline at end of file diff --git a/tunnelto_server/src/client_auth.rs b/tunnelto_server/src/client_auth.rs new file mode 100644 index 0000000..032c018 --- /dev/null +++ b/tunnelto_server/src/client_auth.rs @@ -0,0 +1,140 @@ +use tunnelto_lib::{ClientHelloV1, ClientHello, ClientId, ServerHello, ClientType}; +use warp::filters::ws::{WebSocket, Message}; +use futures::{SinkExt, StreamExt}; +use crate::connected_clients::Connections; +use crate::auth_db::AuthResult; +use log::error; +use crate::BLOCKED_SUB_DOMAINS; + +pub struct ClientHandshake { + pub id: ClientId, + pub sub_domain: String, + pub is_anonymous: bool, +} + +pub async fn auth_client_handshake(mut websocket: WebSocket) -> Option<(WebSocket, ClientHandshake)> { + let client_hello_data = match websocket.next().await { + Some(Ok(msg)) => msg, + _ => { + error!("no client init message"); + return None + }, + }; + + if let Ok(client_hello_v1) = serde_json::from_slice::(client_hello_data.as_bytes()) { + auth_client_v1(client_hello_v1, websocket).await + } else { + auth_client(client_hello_data.as_bytes(), websocket).await + } +} + +async fn auth_client_v1(client_hello: ClientHelloV1, mut websocket:WebSocket) -> Option<(WebSocket, ClientHandshake)> { + let sub_domain = match client_hello.sub_domain { + None => ServerHello::random_domain(), + + // otherwise, try to assign the sub domain + Some(sub_domain) => { + let (ws, sub_domain) = match sanitize_sub_domain_and_pre_validate(websocket, sub_domain, &client_hello.id).await { + Some(s) => s, + None => return None, + }; + websocket = ws; + + + // don't allow specified domains for anonymous v1 clients + ServerHello::prefixed_random_domain(&sub_domain) + } + }; + + Some((websocket, ClientHandshake {id: client_hello.id, sub_domain, is_anonymous: true})) +} + +async fn auth_client(client_hello_data: &[u8], mut websocket: WebSocket) -> Option<(WebSocket, ClientHandshake)> { + // parse the client hello + let client_hello:ClientHello = match serde_json::from_slice(client_hello_data) { + Ok(ch) => ch, + Err(e) => { + error!("invalid client hello: {}", e); + let data = serde_json::to_vec(&ServerHello::AuthFailed).unwrap_or_default(); + let _ = websocket.send(Message::binary(data)).await; + return None + } + }; + + let (auth_key, requested_sub_domain) = match client_hello.client_type { + ClientType::Anonymous => { + let sub_domain = match client_hello.sub_domain { + Some(sd) => ServerHello::prefixed_random_domain(&sd), + None => ServerHello::random_domain(), + }; + return Some((websocket, ClientHandshake { id: client_hello.id, sub_domain, is_anonymous: true })); + }, + ClientType::Auth { key } => { + match client_hello.sub_domain { + Some(requested_sub_domain) => { + let (ws, sub_domain) = match sanitize_sub_domain_and_pre_validate(websocket, requested_sub_domain, &client_hello.id).await { + Some(s) => s, + None => return None, + }; + websocket = ws; + + (key, sub_domain) + }, + None => { + let sub_domain = ServerHello::random_domain(); + return Some((websocket, ClientHandshake { id: client_hello.id, sub_domain, is_anonymous: false })); + } + } + } + }; + + + // next authenticate the sub-domain + let sub_domain = match crate::AUTH_DB_SERVICE.auth_sub_domain(&auth_key.0, &requested_sub_domain).await { + Ok(AuthResult::Available) | Ok(AuthResult::ReservedByYou) => requested_sub_domain, + Ok(AuthResult::ReservedByOther) => { + let data = serde_json::to_vec(&ServerHello::SubDomainInUse).unwrap_or_default(); + let _ = websocket.send(Message::binary(data)).await; + return None + } + Err(e) => { + error!("error auth-ing user {:?}!", e); + let data = serde_json::to_vec(&ServerHello::AuthFailed).unwrap_or_default(); + let _ = websocket.send(Message::binary(data)).await; + return None + } + }; + + Some((websocket, ClientHandshake { id: client_hello.id, sub_domain, is_anonymous: false })) +} + +async fn sanitize_sub_domain_and_pre_validate(mut websocket: WebSocket, requested_sub_domain: String, client_id: &ClientId) -> Option<(WebSocket, String)>{ + // ignore uppercase + let sub_domain = requested_sub_domain.to_lowercase(); + + if sub_domain.chars().filter(|c| !c.is_alphanumeric()).count() > 0 { + error!("invalid client hello: only alphanumeric chars allowed!"); + let data = serde_json::to_vec(&ServerHello::InvalidSubDomain).unwrap_or_default(); + let _ = websocket.send(Message::binary(data)).await; + return None + } + + // ensure this sub-domain isn't taken + let existing_client = Connections::client_for_host(&sub_domain); + if existing_client.is_some() && Some(client_id) != existing_client.as_ref() { + error!("invalid client hello: requested sub domain in use already!"); + let data = serde_json::to_vec(&ServerHello::SubDomainInUse).unwrap_or_default(); + let _ = websocket.send(Message::binary(data)).await; + return None + } + + // ensure it's not a restricted one + if BLOCKED_SUB_DOMAINS.contains(&sub_domain) { + error!("invalid client hello: sub-domain restrict!"); + let data = serde_json::to_vec(&ServerHello::SubDomainInUse).unwrap_or_default(); + let _ = websocket.send(Message::binary(data)).await; + return None + } + + Some((websocket, sub_domain)) +} \ No newline at end of file diff --git a/tunnelto_server/src/connected_clients.rs b/tunnelto_server/src/connected_clients.rs new file mode 100644 index 0000000..449af80 --- /dev/null +++ b/tunnelto_server/src/connected_clients.rs @@ -0,0 +1,66 @@ +use super::*; + +#[derive(Debug, Clone)] +pub struct ConnectedClient { + pub id: ClientId, + pub host: String, + pub tx: UnboundedSender, +} + +pub struct Connections { + clients: Arc>>, + hosts: Arc>> +} + +impl Connections { + pub fn new() -> Self { + Self { + clients: Arc::new(RwLock::new(HashMap::new())), + hosts: Arc::new(RwLock::new(HashMap::new())) + } + } + + pub fn remove(client: &ConnectedClient) { + client.tx.close_channel(); + let mut connected = CONNECTIONS.clients.write().unwrap(); + let mut hosts = CONNECTIONS.hosts.write().unwrap(); + + // ensure another client isn't using this host + match hosts.get(&client.host) { + Some(client_for_host) if client_for_host.id == client.id => { + log::debug!("dropping sub-domain: {}", &client.host); + hosts.remove(&client.host); + }, + _ => {} + }; + + connected.remove(&client.id); + log::debug!("rm client: {}", &client.id); + + // drop all the streams + // if there are no more tunnel clients + if connected.is_empty() { + let mut streams = ACTIVE_STREAMS.write().unwrap(); + for (_, stream) in streams.drain() { + stream.tx.close_channel(); + } + } + } + + pub fn client_for_host(host: &String) -> Option { + CONNECTIONS.hosts.read().unwrap().get(host).map(|c| c.id.clone()) + } + + pub fn get(client_id: &ClientId) -> Option { + CONNECTIONS.clients.read().unwrap().get(&client_id).cloned() + } + + pub fn find_by_host(host: &String) -> Option { + CONNECTIONS.hosts.read().unwrap().get(host).cloned() + } + + pub fn add(client: ConnectedClient) { + CONNECTIONS.clients.write().unwrap().insert(client.id.clone(), client.clone()); + CONNECTIONS.hosts.write().unwrap().insert(client.host.clone(), client); + } +} diff --git a/tunnelto_server/src/control_server.rs b/tunnelto_server/src/control_server.rs new file mode 100644 index 0000000..1dc055e --- /dev/null +++ b/tunnelto_server/src/control_server.rs @@ -0,0 +1,161 @@ +pub use super::*; +use std::net::SocketAddr; +use std::time::Duration; + +pub fn spawn>(addr: A) { + let health_check = warp::get().and(warp::path("health_check")).map(|| { + log::info!("Health Check #2 triggered"); + "ok" + }); + let client_conn = warp::path("wormhole").and(warp::ws()).map(move |ws: Ws| { + ws.on_upgrade(handle_new_connection) + }); + + // spawn our websocket control server + tokio::spawn(warp::serve(client_conn.or(health_check)).run(addr.into())); +} + +async fn handle_new_connection(websocket: WebSocket) { + let (websocket, client_id, sub_domain) = match try_client_handshake(websocket).await { + Some(ws) => ws, + None => return, + }; + + log::debug!("open tunnel: {}.", &sub_domain); + + let (tx, rx) = unbounded::(); + let mut client = ConnectedClient { id: client_id, host: sub_domain, tx }; + Connections::add(client.clone()); + + let (sink, stream) = websocket.split(); + + let client_clone = client.clone(); + + tokio::spawn(async move { + tunnel_client(client_clone, sink, rx).await; + }); + + let client_clone = client.clone(); + + tokio::spawn(async move { + process_client_messages(client_clone, stream).await; + }); + + // play ping pong + tokio::spawn(async move { + loop { + log::trace!("sending ping"); + match client.tx.send(ControlPacket::Ping).await { + Ok(_) => {}, + Err(e) => { + log::debug!("Failed to send ping: {:?}, removing client", e); + Connections::remove(&client); + return + } + }; + + tokio::time::delay_for(Duration::new(PING_INTERVAL, 0)).await; + } + }); +} + +async fn try_client_handshake(websocket: WebSocket) -> Option<(WebSocket, ClientId, String)> { + // Authenticate client handshake + let (mut websocket, client_handshake) = client_auth::auth_client_handshake(websocket).await?; + + // Send server hello success + let data = serde_json::to_vec(&ServerHello::Success { sub_domain: client_handshake.sub_domain.clone() }).unwrap_or_default(); + let send_result = websocket.send(Message::binary(data)).await; + if let Err(e) = send_result { + error!("aborting...failed to write server hello: {:?}", e); + return None + } + + info!("new client connected: {:?}{}", &client_handshake.id, if client_handshake.is_anonymous { " (anonymous)"} else { "" }); + Some((websocket, client_handshake.id, client_handshake.sub_domain)) +} + +/// Send the client a "stream init" message +pub async fn send_client_stream_init(mut stream: ActiveStream) { + match stream.client.tx.send(ControlPacket::Init(stream.id.clone())).await { + Ok(_) => { + info!("sent control to client: {}", &stream.client.id); + }, + Err(_) => { + info!("removing disconnected client: {}", &stream.client.id); + Connections::remove(&stream.client); + } + } + +} + +/// Process client control messages +async fn process_client_messages(client: ConnectedClient, mut client_conn: SplitStream) { + loop { + let result = client_conn.next().await; + + let message = match result { + Some(Ok(msg)) if !msg.as_bytes().is_empty() => msg, + _ => { + info!("goodbye client: {:?}", &client.id); + Connections::remove(&client); + return + }, + }; + + let packet = match ControlPacket::deserialize(message.as_bytes()) { + Ok(packet) => packet, + Err(e) => { + eprintln!("invalid data packet: {:?}", e); + continue + } + }; + + let (stream_id, message) = match packet { + ControlPacket::Data(stream_id, data) => { + info!("forwarding to stream[id={}]: {} bytes", &stream_id.to_string(), data.len()); + (stream_id, StreamMessage::Data(data)) + }, + ControlPacket::Refused(stream_id) => { + log::info!("tunnel says: refused"); + (stream_id, StreamMessage::TunnelRefused) + } + ControlPacket::Init(_) | ControlPacket::End(_) => { + error!("invalid protocol control::init message"); + continue + }, + ControlPacket::Ping => { + log::trace!("pong"); + continue + }, + }; + + let stream = ACTIVE_STREAMS.read().unwrap().get(&stream_id).cloned(); + + if let Some(mut stream) = stream { + let _ = stream.tx.send(message).await.map_err(|e| { + log::error!("Failed to send to stream tx: {:?}", e); + }); + } + } +} + +async fn tunnel_client(client: ConnectedClient, mut sink: SplitSink, mut queue: UnboundedReceiver) { + loop { + match queue.next().await { + Some(packet) => { + let result = sink.send(Message::binary(packet.serialize())).await; + if result.is_err() { + eprintln!("client disconnected: aborting."); + Connections::remove(&client); + return + } + }, + None => { + info!("ending client tunnel"); + return + }, + }; + + } +} \ No newline at end of file diff --git a/tunnelto_server/src/main.rs b/tunnelto_server/src/main.rs new file mode 100644 index 0000000..1ec1732 --- /dev/null +++ b/tunnelto_server/src/main.rs @@ -0,0 +1,80 @@ +use futures::{StreamExt, SinkExt}; +use warp::{Filter}; +use warp::ws::{Ws, Message, WebSocket}; + +pub use tunnelto_lib::*; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +use tokio::net::{TcpListener}; + +use futures::stream::{SplitSink, SplitStream}; +use futures::channel::mpsc::{unbounded, UnboundedSender, UnboundedReceiver}; +use lazy_static::lazy_static; +use log::{info, error}; + +mod connected_clients; +use self::connected_clients::*; +mod active_stream; +use self::active_stream::*; + +mod client_auth; +mod auth_db; +pub use self::auth_db::AuthDbService; + +mod remote; +mod control_server; + +lazy_static! { + pub static ref CONNECTIONS:Connections = Connections::new(); + pub static ref ACTIVE_STREAMS:ActiveStreams = Arc::new(RwLock::new(HashMap::new())); + pub static ref ALLOWED_HOSTS:Vec = allowed_host_suffixes(); + pub static ref BLOCKED_SUB_DOMAINS:Vec = blocked_sub_domains_suffixes(); + pub static ref AUTH_DB_SERVICE:AuthDbService = AuthDbService::new().expect("failed to init auth-service"); +} + +/// What hosts do we allow tunnels on: +/// i.e: baz.com => *.baz.com +/// foo.bar => *.foo.bar +pub fn allowed_host_suffixes() -> Vec { + std::env::var("ALLOWED_HOSTS") + .map(|s| s.split(",").map(String::from).collect()) + .unwrap_or(vec![]) +} + + +/// What sub-domains do we always block: +/// i.e: dashboard.tunnelto.dev +pub fn blocked_sub_domains_suffixes() -> Vec { + std::env::var("BLOCKED_SUB_DOMAINS") + .map(|s| s.split(",").map(String::from).collect()) + .unwrap_or(vec![]) +} + +#[tokio::main] +async fn main() { + pretty_env_logger::init(); + + info!("starting wormhole server"); + control_server::spawn(([0,0,0,0], 5000)); + + let listen_addr = format!("0.0.0.0:{}", std::env::var("PORT").unwrap_or("8080".to_string())); + info!("listening on: {}", &listen_addr); + + // create our accept any server + let mut listener = TcpListener::bind(listen_addr).await.expect("failed to bind"); + + loop { + let socket = match listener.accept().await { + Ok((socket, _)) => socket, + _ => { + error!("failed to accept socket"); + continue; + } + }; + + tokio::spawn(async move { + remote::accept_connection(socket).await; + }); + } +} \ No newline at end of file diff --git a/tunnelto_server/src/remote.rs b/tunnelto_server/src/remote.rs new file mode 100644 index 0000000..63ef46a --- /dev/null +++ b/tunnelto_server/src/remote.rs @@ -0,0 +1,272 @@ +use super::*; +use tokio::net::TcpStream; +use tokio::io::{ReadHalf, WriteHalf}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +async fn direct_to_control(mut incoming: TcpStream) { + let mut control_socket = match TcpStream::connect("localhost:5000").await { + Ok(s) => s, + Err(e) => { + log::warn!("failed to connect to local control server {:?}", e); + return + } + }; + + let (mut control_r, mut control_w) = control_socket.split(); + let (mut incoming_r, mut incoming_w) = incoming.split(); + + let join_1 = tokio::io::copy(&mut control_r, &mut incoming_w); + let join_2 = tokio::io::copy(&mut incoming_r, &mut control_w); + + match futures::future::join(join_1, join_2).await { + (Ok(_), Ok(_)) => {}, + (Err(e), _) | (_, Err(e)) => { + log::error!("directing stream to control failed: {:?}", e); + }, + } +} + +pub async fn accept_connection(socket: TcpStream) { + // peek the host of the http request + // if health check, then handle it and return + let (mut socket, host) = match peek_http_request_host(socket).await { + Some(s) => s, + None => return, + }; + + // parse the host string and find our client + if ALLOWED_HOSTS.contains(&host) { + error!("redirect to homepage"); + let _ = socket.write_all(HTTP_REDIRECT_RESPONSE).await; + return + + } + let host = match validate_host_prefix(&host) { + Some(sub_domain) => sub_domain, + None => { + error!("invalid host specified"); + let _ = socket.write_all(HTTP_INVALID_HOST_RESPONSE).await; + return + } + }; + + // Special case -- we redirect this tcp connection to the control server + if host.as_str() == "wormhole" { + direct_to_control(socket).await; + return + } + + // find the client listening for this host + let client = match Connections::find_by_host(&host) { + Some(client) => client.clone(), + None => { + error!("No tunnel found for host: {}.<>", host); + let _ = socket.write_all(HTTP_NOT_FOUND_RESPONSE).await; + return + } + }; + + // allocate a new stream for this request + let (active_stream, queue_rx) = ActiveStream::new(client.clone()); + let stream_id = active_stream.id.clone(); + + info!("new stream connected: {}", active_stream.id.to_string()); + let (stream, sink) = tokio::io::split(socket); + + // add our stream + ACTIVE_STREAMS.write().unwrap().insert(stream_id.clone(), active_stream.clone()); + + // read from socket, write to client + tokio::spawn(async move { + process_tcp_stream(active_stream, stream).await; + }); + + // read from client, write to socket + tokio::spawn( async move { + tunnel_to_stream(stream_id, sink, queue_rx).await; + }); +} + +fn validate_host_prefix(host: &str) -> Option { + let url = format!("http://{}", host); + + let host = match url::Url::parse(&url) + .map(|u| u.host().map(|h| h.to_owned())) + .unwrap_or(None) + { + Some(domain) => { + domain.to_string() + }, + None => { + error!("invalid host header"); + return None + } + }; + + let domain_segments = host.split(".").collect::>(); + let prefix = &domain_segments[0]; + let remaining = &domain_segments[1..].join("."); + + if ALLOWED_HOSTS.contains(remaining) { + Some(prefix.to_string()) + } else { + None + } +} + +/// Response Constants +const HTTP_REDIRECT_RESPONSE:&'static [u8] = b"HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnelto.dev/\r\nContent-Length: 20\r\n\r\nhttps://tunnelto.dev"; +const HTTP_INVALID_HOST_RESPONSE:&'static [u8] = b"HTTP/1.1 400\r\nContent-Length: 23\r\n\r\nError: Invalid Hostname"; +const HTTP_NOT_FOUND_RESPONSE:&'static [u8] = b"HTTP/1.1 400\r\nContent-Length: 23\r\n\r\nError: Tunnel Not Found"; +const HTTP_TUNNEL_REFUSED_RESPONSE:&'static [u8] = b"HTTP/1.1 500\r\nContent-Length: 32\r\n\r\nTunnel says: connection refused."; +const HTTP_OK_RESPONSE:&'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"; +const HEALTH_CHECK_PATH:&'static [u8] = b"/0xDEADBEEF_HEALTH_CHECK"; + +/// Filter incoming remote streams +async fn peek_http_request_host(mut socket: TcpStream) -> Option<(TcpStream, String)> { + /// Note we return out if the host header is not found + /// within the first 4kb of the request. + const MAX_HEADER_PEAK:usize = 4096; + let mut buf = vec![0; MAX_HEADER_PEAK]; //1kb + + log::debug!("checking stream headers"); + + let n = match socket.peek(&mut buf).await { + Ok(n) => n, + Err(e) => { + error!("failed to read from tcp socket to determine host: {:?}", e); + return None + }, + }; + + // make sure we're not peeking the same header bytes + if n == 0 { + log::debug!("unable to peek header bytes"); + return None; + } + + log::debug!("peeked {} stream bytes ", n); + + let mut headers = [httparse::EMPTY_HEADER; 64]; // 30 seems like a generous # of headers + let mut req = httparse::Request::new(&mut headers); + + if let Err(e) = req.parse(&buf[..n]) { + error!("failed to parse incoming http bytes: {:?}", e); + return None + } + + // Handle the health check route + if req.path.map(|s| s.as_bytes()) == Some(HEALTH_CHECK_PATH) { + info!("Health Check Triggered"); + + let _ = socket.write_all(HTTP_OK_RESPONSE).await.map_err(|e| { + error!("failed to write health_check: {:?}", e); + }); + + return None + } + + // look for a host header + if let Some(Ok(host)) = req.headers.iter() + .filter(|h| h.name.to_lowercase() == "host".to_string()) + .map(|h| std::str::from_utf8(h.value)) + .next() + { + return Some((socket, host.to_string())) + } + + log::debug!("Found no host header, dropping connection."); + None +} + +/// Process Messages from the control path in & out of the remote stream +async fn process_tcp_stream(mut tunnel_stream: ActiveStream, mut tcp_stream: ReadHalf) { + // send initial control stream init to client + control_server::send_client_stream_init(tunnel_stream.clone()).await; + + // now read from stream and forward to clients + let mut buf = [0; 1024]; + + loop { + // client is no longer connected + if Connections::get(&tunnel_stream.client.id).is_none() { + info!("client disconnected, closing stream"); + let _ = tunnel_stream.tx.send(StreamMessage::NoClientTunnel).await; + tunnel_stream.tx.close_channel(); + return + } + + // read from stream + let n = match tcp_stream.read(&mut buf).await { + Ok(n) => n, + Err(e) => { + eprintln!("failed to read from tcp socket: {:?}", e); + return + } + }; + + if n == 0 { + info!("stream ended"); + let _ = tunnel_stream.client.tx.send(ControlPacket::End(tunnel_stream.id.clone())).await + .map_err(|e| { + error!("failed to send end signal: {:?}", e); + }); + return; + } + + info!("read {} bytes", n); + + let data = &buf[..n]; + let packet = ControlPacket::Data(tunnel_stream.id.clone(), data.to_vec()); + + match tunnel_stream.client.tx.send(packet.clone()).await { + Ok(_) => info!("sent data packet to client: {}", &tunnel_stream.client.id), + Err(_) => { + error!("failed to forward tcp packets to disconnected client. dropping client."); + Connections::remove(&tunnel_stream.client); + } + } + } +} + +async fn tunnel_to_stream(stream_id: StreamId, mut sink: WriteHalf, mut queue: UnboundedReceiver) { + loop { + let result = queue.next().await; + + let result = if let Some(message) = result { + match message { + StreamMessage::Data(data) => Some(data), + StreamMessage::TunnelRefused => { + info!("tunnel refused"); + let _ = sink.write_all(HTTP_TUNNEL_REFUSED_RESPONSE).await; + None + } + StreamMessage::NoClientTunnel => { + info!("client tunnel not found"); + let _ = sink.write_all(HTTP_NOT_FOUND_RESPONSE).await; + None + } + } + } else { None }; + + let data = match result { + Some(data) => data, + None => { + info!("done tunneling to sink"); + let _ = sink.shutdown().await.map_err(|_e| { + error!("error shutting down tcp stream"); + }); + + ACTIVE_STREAMS.write().unwrap().remove(&stream_id); + return + } + }; + + let result = sink.write_all(&data).await; + + if result.is_err() { + info!("stream closed, disconnecting"); + return + } + } +}