Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions ktls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ libc = { version = "0.2.155", features = ["const-extern-fn"] }
thiserror = "2"
tracing = "0.1.40"
tokio-rustls = { default-features = false, version = "0.26.0" }
rustls = { version = "0.23.12", default-features = false }
rustls = { version = "0.23.27", default-features = false }
smallvec = "1.13.2"
memoffset = "0.9.1"
pin-project-lite = "0.2.14"
tokio = { version = "1.39.2", features = ["net", "macros", "io-util"] }
tokio = { version = "1.39.2", features = ["net", "macros", "io-util", "sync"] }
ktls-sys = "1.0.1"
num_enum = "0.7.3"
futures-util = "0.3.30"
nix = { version = "0.29.0", features = ["socket", "uio", "net"] }
bitflags = "2.9.1"

[dev-dependencies]
lazy_static = "1.5.0"
Expand Down
137 changes: 137 additions & 0 deletions ktls/src/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use std::io;
use std::os::fd::{AsRawFd, RawFd};
use std::pin::Pin;
use std::task::{Context, Poll};

use rustls::client::{ClientConnectionData, UnbufferedClientConnection};
use rustls::kernel::KernelConnection;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use crate::ffi::Direction;
use crate::stream::KTlsStreamImpl;
use crate::CryptoInfo;
use crate::{ConnectError, TryConnectError};

pin_project_lite::pin_project! {
/// The client half of a kTLS stream.
pub struct KTlsClientStream<IO> {
#[pin]
pub(crate) stream: KTlsStreamImpl<IO, KernelConnection<ClientConnectionData>>
}
}

impl<IO> KTlsClientStream<IO>
where
IO: AsyncWrite + AsyncRead + AsRawFd,
{
pub fn from_unbuffered_connnection(
socket: IO,
conn: UnbufferedClientConnection,
) -> Result<Self, TryConnectError<IO, UnbufferedClientConnection>> {
// We attempt to set up the TLS ULP before doing anything else so that
// we can indicate that the kernel doesn't support kTLS before returning
// any other error.
if let Err(e) = crate::ffi::setup_ulp(socket.as_raw_fd()) {
let error = if e.raw_os_error() == Some(libc::ENOENT) {
ConnectError::KTlsUnsupported
} else {
ConnectError::IO(e)
};

return Err(TryConnectError {
error,
socket: Some(socket),
conn: Some(conn),
});
}

// TODO: Validate that the negotiated connection is actually
// supported by kTLS on the current machine.

Ok(Self::from_unbuffered_connnection_with_tls_ulp_enabled(
socket, conn,
)?)
}

/// Create a new `KTlsClientStream` from a socket that already has had the TLS ULP
/// enabled on it.
fn from_unbuffered_connnection_with_tls_ulp_enabled(
socket: IO,
conn: UnbufferedClientConnection,
) -> Result<Self, ConnectError> {
let (secrets, kconn) = match conn.dangerous_into_kernel_connection() {
Ok(secrets) => secrets,
Err(e) => return Err(ConnectError::ExtractSecrets(e)),
};

let suite = kconn.negotiated_cipher_suite();
let tx = CryptoInfo::from_rustls(suite, secrets.tx)
.map_err(|_| ConnectError::UnsupportedCipherSuite(suite))?;
let rx = CryptoInfo::from_rustls(suite, secrets.rx)
.map_err(|_| ConnectError::UnsupportedCipherSuite(suite))?;

crate::ffi::setup_tls_info(socket.as_raw_fd(), Direction::Tx, tx)
.map_err(ConnectError::IO)?;
crate::ffi::setup_tls_info(socket.as_raw_fd(), Direction::Rx, rx)
.map_err(ConnectError::IO)?;

Ok(Self {
stream: KTlsStreamImpl::new(socket, Vec::new(), kconn),
})
}
}

impl<IO> AsyncRead for KTlsClientStream<IO>
where
IO: AsyncWrite + AsyncRead + AsRawFd,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().stream.poll_read(cx, buf)
}
}

impl<IO> AsyncWrite for KTlsClientStream<IO>
where
IO: AsyncWrite + AsyncRead + AsRawFd,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().stream.poll_write(cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_shutdown(cx)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
self.project().stream.poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.stream.is_write_vectored()
}
}

impl<IO> AsRawFd for KTlsClientStream<IO>
where
IO: AsRawFd,
{
fn as_raw_fd(&self) -> RawFd {
self.stream.as_raw_fd()
}
}
77 changes: 77 additions & 0 deletions ktls/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use std::{fmt, io};

use rustls::SupportedCipherSuite;

#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ConnectError {
/// kTLS is not supported by the current kernel.
#[error("kTLS is not supported by the current kernel")]
KTlsUnsupported,

#[error("the negotiated cipher suite is not supported by kTLS")]
UnsupportedCipherSuite(SupportedCipherSuite),

#[error("the peer closed the connection before the TLS handshake could be completed")]
ConnectionClosedBeforeHandshakeCompleted,

#[error("{0}")]
IO(#[source] io::Error),

#[error("failed to create rustls connection: {0}")]
Config(#[source] rustls::Error),

#[error("an error occurred during the handshake: {0}")]
Handshake(#[source] rustls::Error),

#[error("unable to extract connection secrets from rustls connection: {0}")]
ExtractSecrets(#[source] rustls::Error),
}

impl From<ConnectError> for io::Error {
fn from(error: ConnectError) -> Self {
match error {
ConnectError::IO(error) => error,
_ => io::Error::other(error),
}
}
}

#[derive(thiserror::Error)]
#[error("{error}")]
pub struct TryConnectError<IO, Conn> {
#[source]
pub error: ConnectError,
pub socket: Option<IO>,
pub conn: Option<Conn>,
}

impl<IO, Conn> fmt::Debug for TryConnectError<IO, Conn> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TryConnectError")
.field("error", &self.error)
.finish_non_exhaustive()
}
}

impl<IO, Conn> From<ConnectError> for TryConnectError<IO, Conn> {
fn from(error: ConnectError) -> Self {
Self {
error,
socket: None,
conn: None,
}
}
}

impl<IO, Conn> From<TryConnectError<IO, Conn>> for ConnectError {
fn from(value: TryConnectError<IO, Conn>) -> Self {
value.error
}
}

impl<IO, Conn> From<TryConnectError<IO, Conn>> for io::Error {
fn from(error: TryConnectError<IO, Conn>) -> Self {
error.error.into()
}
}
Loading
Loading