Skip to content

Commit

Permalink
some refactoring, add conn metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-oracle committed May 21, 2024
1 parent 82e9e24 commit f26caca
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 44 deletions.
14 changes: 13 additions & 1 deletion src/core.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use anyhow::{anyhow, Context, Error};
use itertools::Itertools;
use prometheus::Registry;
use tokio_util::sync::CancellationToken;
use tracing::warn;
Expand Down Expand Up @@ -28,6 +29,12 @@ pub async fn main(cli: &Cli) -> Result<(), Error> {
));
}

for (&d, &c) in &domains.iter().counts() {
if c > 1 {
return Err(anyhow!("Domain '{d}' specified more than once"));
}
}

// Install crypto-provider
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
Expand Down Expand Up @@ -78,11 +85,14 @@ pub async fn main(cli: &Cli) -> Result<(), Error> {
clickhouse.clone(),
)?;

let metrics = http::server::Metrics::new(&registry);

// Set up HTTP
let http_server = Arc::new(http::Server::new(
cli.http_server.http,
router.clone(),
(&cli.http_server).into(),
metrics.clone(),
None,
));
tasks.add("http_server", http_server);
Expand All @@ -103,6 +113,7 @@ pub async fn main(cli: &Cli) -> Result<(), Error> {
cli.http_server.https,
router,
(&cli.http_server).into(),
metrics.clone(),
Some(rustls_cfg),
));
tasks.add("https_server", https_server);
Expand All @@ -115,12 +126,13 @@ pub async fn main(cli: &Cli) -> Result<(), Error> {
addr,
router,
(&cli.http_server).into(),
metrics.clone(),
None,
));
tasks.add("metrics_server", srv);
}

// Spawn & track runners
// Spawn & track tasks
tasks.start(&token);

warn!("Service is running, waiting for the shutdown signal");
Expand Down
2 changes: 1 addition & 1 deletion src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use http::{HeaderMap, Version};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

pub use client::{Client, ReqwestClient};
pub use server::{ConnInfo, Server};
pub use server::{ConnInfo, Server, TlsInfo};

pub const ALPN_H1: &[u8] = b"http/1.1";
pub const ALPN_H2: &[u8] = b"h2";
Expand Down
199 changes: 177 additions & 22 deletions src/http/server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::{
net::SocketAddr,
str::FromStr,
sync::Arc,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};

Expand All @@ -13,6 +16,10 @@ use hyper_util::{
rt::{TokioExecutor, TokioIo, TokioTimer},
server::conn::auto::Builder,
};
use prometheus::{
register_histogram_vec_with_registry, register_int_counter_vec_with_registry,
register_int_gauge_vec_with_registry, HistogramVec, IntCounterVec, IntGaugeVec, Registry,
};
use rustls::{server::ServerConnection, CipherSuite, ProtocolVersion};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
Expand All @@ -23,13 +30,85 @@ use tokio_rustls::TlsAcceptor;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tower_service::Service;
use tracing::{debug, warn};
use uuid::Uuid;

use super::{is_http_alpn, AsyncCounter, Stats};

// Blanket async read+write trait to box streams
pub const CONN_DURATION_BUCKETS: &[f64] = &[1.0, 8.0, 32.0, 64.0, 256.0, 512.0, 1024.0];
pub const CONN_REQUESTS: &[f64] = &[1.0, 4.0, 8.0, 16.0, 32.0, 64.0, 256.0];

// Blanket async read+write trait for streams Box-ing
trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> AsyncReadWrite for T {}

#[derive(Clone)]
pub struct Metrics {
pub conns_open: IntGaugeVec,
pub requests: IntCounterVec,
pub bytes_sent: IntCounterVec,
pub bytes_rcvd: IntCounterVec,
pub conn_duration: HistogramVec,
pub requests_per_conn: HistogramVec,
}

impl Metrics {
pub fn new(registry: &Registry) -> Self {
const LABELS: &[&str] = &["addr", "tls", "family"];

Self {
conns_open: register_int_gauge_vec_with_registry!(
format!("conn_open"),
format!("Number of currently open connections"),
LABELS,
registry
)
.unwrap(),

requests: register_int_counter_vec_with_registry!(
format!("conn_requests_total"),
format!("Counts the number of requests"),
LABELS,
registry
)
.unwrap(),

bytes_sent: register_int_counter_vec_with_registry!(
format!("conn_bytes_sent_total"),
format!("Counts number of bytes sent"),
LABELS,
registry
)
.unwrap(),

bytes_rcvd: register_int_counter_vec_with_registry!(
format!("conn_bytes_rcvd_total"),
format!("Counts number of bytes received"),
LABELS,
registry
)
.unwrap(),

conn_duration: register_histogram_vec_with_registry!(
format!("conn_duration_sec"),
format!("Records the duration of connection in seconds"),
LABELS,
CONN_DURATION_BUCKETS.to_vec(),
registry
)
.unwrap(),

requests_per_conn: register_histogram_vec_with_registry!(
format!("conn_requests_per_conn"),
format!("Records the number of requests per connection"),
LABELS,
CONN_REQUESTS.to_vec(),
registry
)
.unwrap(),
}
}
}

#[derive(Clone, Copy)]
pub struct Options {
pub backlog: u32,
Expand Down Expand Up @@ -79,11 +158,12 @@ impl TryFrom<&ServerConnection> for TlsInfo {

#[derive(Debug)]
pub struct ConnInfo {
pub id: Uuid,
pub accepted_at: Instant,
pub local_addr: SocketAddr,
pub remote_addr: SocketAddr,
pub stats: Arc<Stats>,
pub tls: Option<TlsInfo>,
pub traffic: Arc<Stats>,
pub req_count: AtomicU64,
}

struct Conn {
Expand All @@ -93,6 +173,7 @@ struct Conn {
builder: Builder<TokioExecutor>,
token: CancellationToken,
options: Options,
metrics: Metrics,
tls_acceptor: Option<TlsAcceptor>,
}

Expand All @@ -108,17 +189,17 @@ impl Conn {

let start = Instant::now();
let stream = self.tls_acceptor.as_ref().unwrap().accept(stream).await?;
let latency = start.elapsed();
let duration = start.elapsed();

let conn = stream.get_ref().1;
let mut tls_info = TlsInfo::try_from(conn)?;
tls_info.handshake = latency;
tls_info.handshake = duration;

debug!(
"Server {}: {}: handshake finished in {}ms (server: {}, proto: {:?}, cipher: {:?}, ALPN: {})",
self.addr,
self.remote_addr,
latency.as_millis(),
duration.as_millis(),
tls_info.sni,
tls_info.protocol,
tls_info.cipher,
Expand All @@ -136,50 +217,116 @@ impl Conn {
self.addr, self.remote_addr
);

// Prepare metric labels
let addr = self.addr.to_string();
let labels = &[
addr.as_str(), // Listening addr
if self.tls_acceptor.is_some() {
"yes"
} else {
"no"
}, // Is TLS
if self.remote_addr.is_ipv4() {
"v4"
} else {
"v6"
}, // IP Family
];

self.metrics.conns_open.with_label_values(labels).inc();

// Disable Nagle's algo
stream.set_nodelay(true)?;

// Wrap with counting
// Wrap with traffic counter
let stats = Arc::new(Stats::new());
let stream = AsyncCounter::new(stream, stats.clone());

let conn_info = Arc::new(ConnInfo {
id: Uuid::now_v7(),
accepted_at,
local_addr: self.addr,
remote_addr: self.remote_addr,
traffic: stats.clone(),
req_count: AtomicU64::new(0),
});

let result = self.handle_inner(stream, conn_info.clone()).await;

// Record connection metrics
let (sent, rcvd) = (stats.sent(), stats.rcvd());
let dur = accepted_at.elapsed().as_secs_f64();
let reqs = conn_info.req_count.load(Ordering::SeqCst);

self.metrics.conns_open.with_label_values(labels).dec();
self.metrics.requests.with_label_values(labels).inc_by(reqs);
self.metrics
.bytes_rcvd
.with_label_values(labels)
.inc_by(rcvd);
self.metrics
.bytes_sent
.with_label_values(labels)
.inc_by(sent);
self.metrics
.conn_duration
.with_label_values(labels)
.observe(dur);
self.metrics
.requests_per_conn
.with_label_values(labels)
.observe(reqs as f64);

debug!(
"Server {}: {} ({}): connection closed (rcvd: {}, sent: {}, reqs: {}, duration: {})",
self.addr, self.remote_addr, conn_info.id, rcvd, sent, reqs, dur,
);

result
}

async fn handle_inner(
&self,
stream: impl AsyncReadWrite + 'static,
conn_info: Arc<ConnInfo>,
) -> Result<(), Error> {
// Perform TLS handshake if we're in TLS mode
let (stream, tls_info): (Box<dyn AsyncReadWrite>, _) = if self.tls_acceptor.is_some() {
let (mut stream, tls_info) = self.tls_handshake(stream).await?;

// Close the connection if agreed ALPN is not HTTP - probably it was an ACME challenge
if !is_http_alpn(tls_info.alpn.as_bytes()) {
debug!("Not HTTP ALPN ('{}') - closing connection", tls_info.alpn);
debug!(
"Server {}: {}: Not HTTP ALPN ('{}') - closing connection",
self.addr, self.remote_addr, tls_info.alpn
);

stream
.shutdown()
.await
.context("unable to shutdown stream")?;

return Ok(());
}

(Box::new(stream), Some(tls_info))
(Box::new(stream), Some(Arc::new(tls_info)))
} else {
(Box::new(stream), None)
};

// Since it will be cloned for each request served over this connection
// it's probably better to wrap it into Arc
let conn_info = ConnInfo {
accepted_at,
local_addr: self.addr,
remote_addr: self.remote_addr,
stats,
tls: tls_info,
};
let conn_info = Arc::new(conn_info);

// Convert stream from Tokio to Hyper
let stream = TokioIo::new(stream);

// Convert router to Hyper service
let service = hyper::service::service_fn(move |mut request: Request<Incoming>| {
conn_info.req_count.fetch_add(1, Ordering::SeqCst);
// Inject connection information
request.extensions_mut().insert(conn_info.clone());
if let Some(v) = &tls_info {
request.extensions_mut().insert(v.clone());
}

// Serve the request
self.router.clone().call(request)
});

Expand All @@ -202,7 +349,11 @@ impl Conn {
select! {
biased;
() = tokio::time::sleep(self.options.grace_period) => {},
_ = conn.as_mut() => {},
v = conn.as_mut() => {
if let Err(e) = v {
return Err(anyhow!("Unable to serve connection: {e}"));
}
},
}
}

Expand All @@ -223,6 +374,7 @@ pub struct Server {
router: Router,
tracker: TaskTracker,
options: Options,
metrics: Metrics,
tls_acceptor: Option<TlsAcceptor>,
}

Expand All @@ -231,12 +383,14 @@ impl Server {
addr: SocketAddr,
router: Router,
options: Options,
metrics: Metrics,
rustls_cfg: Option<rustls::ServerConfig>,
) -> Self {
Self {
addr,
router,
options,
metrics,
tracker: TaskTracker::new(),
tls_acceptor: rustls_cfg.map(|x| TlsAcceptor::from(Arc::new(x))),
}
Expand Down Expand Up @@ -303,6 +457,7 @@ impl Server {
builder: builder.clone(),
token: token.child_token(),
options: self.options,
metrics: self.metrics.clone(), // All metrics have Arc inside
tls_acceptor: self.tls_acceptor.clone(),
};

Expand Down
Loading

0 comments on commit f26caca

Please sign in to comment.