Skip to content

feat: add logging to streams #6924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions src/imap/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ use tokio::io::BufWriter;

use super::capabilities::Capabilities;
use crate::context::Context;
use crate::log::LoggingStream;
use crate::login_param::{ConnectionCandidate, ConnectionSecurity};
use crate::net::dns::{lookup_host_with_cache, update_connect_timestamp};
use crate::net::proxy::ProxyConfig;
use crate::net::session::SessionStream;
use crate::net::tls::wrap_tls;
use crate::net::{
connect_tcp_inner, connect_tls_inner, run_connection_attempts, update_connection_history,
connect_tcp_inner, run_connection_attempts, update_connection_history,
};
use crate::tools::time;

Expand Down Expand Up @@ -125,12 +126,12 @@ impl Client {
);
let res = match security {
ConnectionSecurity::Tls => {
Client::connect_secure(resolved_addr, host, strict_tls).await
Client::connect_secure(&context, resolved_addr, host, strict_tls).await
}
ConnectionSecurity::Starttls => {
Client::connect_starttls(resolved_addr, host, strict_tls).await
Client::connect_starttls(&context, resolved_addr, host, strict_tls).await
}
ConnectionSecurity::Plain => Client::connect_insecure(resolved_addr).await,
ConnectionSecurity::Plain => Client::connect_insecure(&context, resolved_addr).await,
};
match res {
Ok(client) => {
Expand Down Expand Up @@ -201,8 +202,22 @@ impl Client {
}
}

async fn connect_secure(addr: SocketAddr, hostname: &str, strict_tls: bool) -> Result<Self> {
let tls_stream = connect_tls_inner(addr, hostname, strict_tls, alpn(addr.port())).await?;
async fn connect_secure(
context: &Context,
addr: SocketAddr,
hostname: &str,
strict_tls: bool,
) -> Result<Self> {
let tcp_stream = connect_tcp_inner(addr).await?;
let account_id = context.get_id();
let events = context.events.clone();
let logging_stream = LoggingStream::new(
tcp_stream,
format!("TLS IMAP stream {hostname} ({addr})"),
account_id,
events,
);
let tls_stream = wrap_tls(strict_tls, hostname, alpn(addr.port()), logging_stream).await?;
let buffered_stream = BufWriter::new(tls_stream);
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
let mut client = Client::new(session_stream);
Expand All @@ -213,9 +228,17 @@ impl Client {
Ok(client)
}

async fn connect_insecure(addr: SocketAddr) -> Result<Self> {
async fn connect_insecure(context: &Context, addr: SocketAddr) -> Result<Self> {
let tcp_stream = connect_tcp_inner(addr).await?;
let buffered_stream = BufWriter::new(tcp_stream);
let account_id = context.get_id();
let events = context.events.clone();
let logging_stream = LoggingStream::new(
tcp_stream,
"some IMAP insecure TLS stream".to_string(),
account_id,
events,
);
let buffered_stream = BufWriter::new(logging_stream);
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
let mut client = Client::new(session_stream);
let _greeting = client
Expand All @@ -225,9 +248,18 @@ impl Client {
Ok(client)
}

async fn connect_starttls(addr: SocketAddr, host: &str, strict_tls: bool) -> Result<Self> {
async fn connect_starttls(context: &Context, addr: SocketAddr, host: &str, strict_tls: bool) -> Result<Self> {
let tcp_stream = connect_tcp_inner(addr).await?;

let account_id = context.get_id();
let events = context.events.clone();
let tcp_stream = LoggingStream::new(
tcp_stream,
format!("STARTTLS IMAP stream {host} ({addr})"),
account_id,
events,
);

// Run STARTTLS command and convert the client back into a stream.
let buffered_tcp_stream = BufWriter::new(tcp_stream);
let mut client = async_imap::Client::new(buffered_tcp_stream);
Expand All @@ -245,7 +277,6 @@ impl Client {
let tls_stream = wrap_tls(strict_tls, host, &[], tcp_stream)
.await
.context("STARTTLS upgrade failed")?;

let buffered_stream = BufWriter::new(tls_stream);
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
let client = Client::new(session_stream);
Expand Down
4 changes: 4 additions & 0 deletions src/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

use crate::context::Context;

mod logging_stream;

pub(crate) use logging_stream::LoggingStream;

#[macro_export]
macro_rules! info {
($ctx:expr, $msg:expr) => {
Expand Down
227 changes: 227 additions & 0 deletions src/log/logging_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
//! Stream that logs errors as events.
//!
//! This stream can be used to wrap IMAP,
//! SMTP and HTTP streams so errors
//! that occur are logged before
//! they are processed.

use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};

use pin_project::pin_project;

use crate::events::{Event, EventType, Events};
use crate::net::session::SessionStream;

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

#[derive(Debug)]
struct ThroughputStats {
/// Total number of bytes read.
pub total_read: usize,

/// Number of bytes read since the last flush.
pub span_read: usize,

/// First timestamp of successful non-zero read.
///
/// Reset on flush.
pub first_read_timestamp: Option<Instant>,

/// Last non-zero read.
pub last_read_timestamp: Instant,

pub total_duration: Duration,

/// Whether to collect throughput statistics or not.
///
/// Disabled when read timeout is disabled,
/// i.e. when we are in IMAP IDLE.
pub enabled: bool,
}

impl ThroughputStats {
fn new() -> Self {
Self {
total_read: 0,
span_read: 0,
first_read_timestamp: None,
last_read_timestamp: Instant::now(),
total_duration: Duration::ZERO,
enabled: false,
}
}

/// Returns throughput in bps.
pub fn throughput(&self) -> Option<f64> {
let total_duration_secs = self.total_duration.as_secs_f64();
if total_duration_secs > 0.0 {
Some((self.total_read as f64) / total_duration_secs)
} else {
None
}
}
}

/// Stream that logs errors to the event channel.
#[derive(Debug)]
#[pin_project]
pub(crate) struct LoggingStream<S: SessionStream> {
#[pin]
inner: S,

/// Name of the stream to distinguish log messages produced by it.
tag: String,

/// Account ID for logging.
account_id: u32,

/// Event channel.
events: Events,

throughput: ThroughputStats,
}

impl<S: SessionStream> LoggingStream<S> {
pub fn new(inner: S, tag: String, account_id: u32, events: Events) -> Self {
Self {
inner,
tag,
account_id,
events,
throughput: ThroughputStats::new(),
}
}
}

impl<S: SessionStream> AsyncRead for LoggingStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let projected = self.project();
let old_remaining = buf.remaining();

let now = Instant::now();
let res = projected.inner.poll_read(cx, buf);

if projected.throughput.enabled {
let first_read_timestamp =
if let Some(first_read_timestamp) = projected.throughput.first_read_timestamp {
first_read_timestamp
} else {
projected.throughput.first_read_timestamp = Some(now);
now
};

let n = old_remaining - buf.remaining();
if n > 0 {
projected.throughput.last_read_timestamp = now;
projected.throughput.span_read = projected.throughput.span_read.saturating_add(n);
}

let duration = projected
.throughput
.last_read_timestamp
.duration_since(first_read_timestamp);

let log_message = format!("{}: SPAN: {} {}", projected.tag, duration.as_secs_f64(), projected.throughput.span_read);
projected.events.emit(Event {
id: 0,
typ: EventType::Info(log_message),
});

let log_message = format!("{}: READING {}", projected.tag, n);
projected.events.emit(Event {
id: 0,
typ: EventType::Info(log_message),
});
}

res
}
}

impl<S: SessionStream> AsyncWrite for LoggingStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}

fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
let projected = self.project();
if let Some(first_read_timestamp) = projected.throughput.first_read_timestamp.take() {
let duration = projected
.throughput
.last_read_timestamp
.duration_since(first_read_timestamp);

// Only measure when more than about 2 MTU is transferred.
// We cannot measure throughput on small responses
// like `A1000 OK`.
if projected.throughput.span_read > 3000 {
projected.throughput.total_read = projected
.throughput
.total_read
.saturating_add(projected.throughput.span_read);
projected.throughput.total_duration =
projected.throughput.total_duration.saturating_add(duration);
}

projected.throughput.span_read = 0;
}

if let Some(throughput) = projected.throughput.throughput() {
let log_message = format!("{}: FLUSH: {} kbps", projected.tag, throughput * 8e-3);

projected.events.emit(Event {
id: 0,
typ: EventType::Info(log_message),
});
} else {
let log_message = format!("{}: FLUSH: unknown throughput", projected.tag);

projected.events.emit(Event {
id: 0,
typ: EventType::Info(log_message),
});
}

projected.inner.poll_flush(cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_shutdown(cx)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

impl<S: SessionStream> SessionStream for LoggingStream<S> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.throughput.enabled = timeout.is_some();

self.inner.set_read_timeout(timeout)
}
}
Loading