Skip to content

Commit

Permalink
metrics module
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-oracle committed Apr 30, 2024
1 parent b72a60b commit f20542c
Show file tree
Hide file tree
Showing 11 changed files with 407 additions and 56 deletions.
10 changes: 10 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ pub struct Cli {
#[command(flatten, next_help_heading = "Policy")]
pub policy: Policy,

#[command(flatten, next_help_heading = "Metrics")]
pub metrics: Metrics,

#[command(flatten, next_help_heading = "Misc")]
pub misc: Misc,
}
Expand Down Expand Up @@ -175,6 +178,13 @@ pub struct Policy {
pub denylist_poll_interval: Duration,
}

#[derive(Args)]
pub struct Metrics {
/// Where to listen for Prometheus metrics scraping
#[clap(long = "metrics-listen")]
pub listen: Option<SocketAddr>,
}

#[derive(Args)]
pub struct Misc {
/// Path to a GeoIP database
Expand Down
23 changes: 18 additions & 5 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ use anyhow::Error;
use async_trait::async_trait;
use prometheus::Registry;
use rustls::sign::CertifiedKey;
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tracing::{error, warn};

use crate::{
cli::Cli,
http::{server, ReqwestClient, Server},
metrics,
routing::{
self,
canister::{CanisterResolver, ResolvesCanister},
Expand All @@ -30,11 +31,10 @@ pub trait Run: Send + Sync {
}

pub async fn main(cli: &Cli) -> Result<(), Error> {
// Prepare some general stuff
let token = CancellationToken::new();
let tracker = TaskTracker::new();

let registry = Registry::new();

let http_client = Arc::new(ReqwestClient::new(cli)?);

// Handle SIGTERM/SIGHUP and Ctrl+C
Expand All @@ -46,14 +46,17 @@ pub async fn main(cli: &Cli) -> Result<(), Error> {
let mut domains = cli.domain.domains_system.clone();
domains.extend(cli.domain.domains_app.clone());

// Prepare certificate storage
let storage = Arc::new(Storage::new());

// Prepare canister resolver to infer canister_id from requests
let canister_resolver = CanisterResolver::new(
domains,
cli.domain.canister_aliases.clone(),
storage.clone() as Arc<dyn LooksupCustomDomain>,
)?;

// List of cancellable tasks to execute & watch
// List of cancellable tasks to execute & track
let mut runners: Vec<(String, Arc<dyn Run>)> = vec![];

// Create a router
Expand All @@ -68,6 +71,7 @@ pub async fn main(cli: &Cli) -> Result<(), Error> {
}

let server_options = server::Options::from(&cli.http_server);

// Set up HTTP
let http_server = Arc::new(Server::new(
cli.http_server.http,
Expand All @@ -94,7 +98,16 @@ pub async fn main(cli: &Cli) -> Result<(), Error> {
)) as Arc<dyn Run>;
runners.push(("https_server".into(), https_server));

// Spawn runners
// Setup metrics
if let Some(addr) = cli.metrics.listen {
let (router, runner) = metrics::setup(&registry);
runners.push(("metrics_runner".into(), runner));

let srv = Arc::new(Server::new(addr, router, server_options, None));
runners.push(("metrics_server".into(), srv as Arc<dyn Run>));
}

// Spawn & track runners
for (name, obj) in runners {
let token = token.child_token();
tracker.spawn(async move {
Expand Down
20 changes: 20 additions & 0 deletions src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,25 @@ pub mod client;
pub mod dns;
pub mod server;

use http::{HeaderMap, Version};

pub use client::{Client, ReqwestClient};
pub use server::{ConnInfo, Server};

// Calculate very approximate HTTP request/response headers size in bytes.
// More or less accurate only for http/1.1 since in h2 headers are in HPACK-compressed.
// But it seems there's no better way.
pub fn calc_headers_size(h: &HeaderMap) -> usize {
h.iter().map(|(k, v)| k.as_str().len() + v.len() + 2).sum()
}

pub const fn http_version(v: Version) -> &'static str {
match v {
Version::HTTP_09 => "0.9",
Version::HTTP_10 => "1.0",
Version::HTTP_11 => "1.1",
Version::HTTP_2 => "2.0",
Version::HTTP_3 => "3.0",
_ => "-",
}
}
4 changes: 4 additions & 0 deletions src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ impl TryFrom<&ServerConnection> for TlsInfo {

#[derive(Clone, Debug)]
pub struct ConnInfo {
pub accepted_at: Instant,
pub local_addr: SocketAddr,
pub remote_addr: SocketAddr,
pub tls: Option<TlsInfo>,
Expand Down Expand Up @@ -137,6 +138,8 @@ impl Conn {
}

pub async fn handle(&self, stream: TcpStream) -> Result<(), Error> {
let accepted_at = Instant::now();

debug!(
"Server {}: {}: got a new connection",
self.addr, self.remote_addr
Expand Down Expand Up @@ -164,6 +167,7 @@ impl Conn {
// Since it will be cloned for each request served over this connection
// it's probably better to wrap it into Arc
let conn_info = ConnInfo {
accepted_at,
local_addr: self.addr,
remote_addr: self.remote_addr,
tls: tls_info,
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async fn main() -> Result<(), Error> {
let cli = Cli::parse();

let subscriber = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(tracing::Level::INFO)
.with_max_level(tracing::Level::DEBUG)
.finish();
tracing::subscriber::set_global_default(subscriber)?;

Expand Down
27 changes: 17 additions & 10 deletions src/metrics/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::{
task::{Context, Poll},
};

use crate::http::calc_headers_size;

// Body that counts the bytes streamed
pub struct CountingBody<D, E> {
inner: Pin<Box<dyn Body<Data = D, Error = E> + Send + 'static>>,
Expand Down Expand Up @@ -74,21 +76,22 @@ where
// There is still some data available
Poll::Ready(Some(v)) => match v {
Ok(buf) => {
// Ignore if it's not a data frame for now.
// It can also be trailers that are uncommon
// Normal data frame
if buf.is_data() {
self.bytes_sent += buf.data_ref().unwrap().remaining() as u64;
} else if buf.is_trailers() {
// Trailers are very uncommon, for the sake of completeness
self.bytes_sent += calc_headers_size(buf.trailers_ref().unwrap()) as u64;
}

// Check if we already got what was expected
if Some(self.bytes_sent) >= self.expected_size {
self.do_callback(Ok(()));
}
// Check if we already got what was expected
if Some(self.bytes_sent) >= self.expected_size {
self.do_callback(Ok(()));
}
}

// Error occured, execute callback
Err(e) => {
// Error is not Copy/Clone so use string instead
self.do_callback(Err(e.to_string()));
}
},
Expand Down Expand Up @@ -117,8 +120,13 @@ mod test {

#[tokio::test]
async fn test_body_stream() {
let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblah";
let mut stream = tokio_util::io::ReaderStream::new(&data[..]);
let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbl\
ahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahbla\
hfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoob\
arblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbla\
blahfoobarblahblah";

let stream = tokio_util::io::ReaderStream::new(&data[..]);
let body = axum::body::Body::from_stream(stream);

let (tx, rx) = std::sync::mpsc::channel();
Expand All @@ -141,7 +149,6 @@ mod test {
#[tokio::test]
async fn test_body_full() {
let data = vec![0; 512];

let buf = bytes::Bytes::from_iter(data.clone());
let body = http_body_util::Full::new(buf);

Expand Down
Loading

0 comments on commit f20542c

Please sign in to comment.