Skip to content
This repository has been archived by the owner on Oct 18, 2023. It is now read-only.

Hrana server authentication #235

Merged
merged 2 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Improve the ergonomy of passing JWT key for Hrana
  • Loading branch information
honzasp committed Feb 22, 2023
commit 560c69584eb80e844ba448872774d879f77c0fc3
5 changes: 1 addition & 4 deletions packages/js/hrana-client/examples/jwt_auth.mjs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import * as hrana from "@libsql/hrana-client";

const jwt = (
"eyJ0eXAiOiJKV1QiLCJhbGciOiJFZERTQSJ9.eyJleHAiOjE2NzY5MDkwOTR9._8Dt3MSN7b5-ykbxM2dCh8CzIPpkqDmPagRXfSO3s1es-6vRN_qMrNGsEUdCFP6tAmCNYd9RJZ9zaUT_wCQ3Bg"
);
const client = hrana.open("ws://localhost:2023", jwt);
const client = hrana.open("ws://localhost:2023", process.env.JWT);
const stream = client.openStream();
console.log(await stream.queryValue("SELECT 1"));
client.close();
11 changes: 7 additions & 4 deletions scripts/gen_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

pubkey_base64 = base64.b64encode(pubkey.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
))
pubkey_base64 = base64.b64encode(
pubkey.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
),
altchars=b"-_",
)
while pubkey_base64[-1] == ord("="):
pubkey_base64 = pubkey_base64[:-1]

Expand Down
32 changes: 18 additions & 14 deletions sqld/src/hrana/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use crate::database::service::DbFactory;
use anyhow::{Context as _, Result, bail};
use anyhow::{bail, Context as _, Result};
use enclose::enclose;
use std::{fs, str};
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;

mod conn;
Expand All @@ -20,7 +18,16 @@ pub async fn serve(
bind_addr: SocketAddr,
jwt_key: Option<jsonwebtoken::DecodingKey>,
) -> Result<()> {
let server = Arc::new(Server { db_factory, jwt_key });
let server = Arc::new(Server {
db_factory,
jwt_key,
});

if server.jwt_key.is_some() {
tracing::info!("Hrana authentication is enabled");
} else {
tracing::warn!("Hrana authentication is disabled, the server is unprotected");
}

let listener = tokio::net::TcpListener::bind(bind_addr)
.await
Expand Down Expand Up @@ -53,19 +60,16 @@ pub async fn serve(
}
}

pub fn load_jwt_key(path: &Path) -> Result<jsonwebtoken::DecodingKey> {
let data = fs::read(path)?;
if data.starts_with(b"-----BEGIN PUBLIC KEY-----") {
jsonwebtoken::DecodingKey::from_ed_pem(&data)
pub fn parse_jwt_key(data: &str) -> Result<jsonwebtoken::DecodingKey> {
if data.starts_with("-----BEGIN PUBLIC KEY-----") {
jsonwebtoken::DecodingKey::from_ed_pem(data.as_bytes())
.context("Could not decode Ed25519 public key from PEM")
} else if data.starts_with(b"-----BEGIN PRIVATE KEY-----") {
} else if data.starts_with("-----BEGIN PRIVATE KEY-----") {
bail!("Received a private key, but a public key is expected")
} else if data.starts_with(b"-----BEGIN") {
} else if data.starts_with("-----BEGIN") {
bail!("Key is in unsupported PEM format")
} else if let Ok(data_str) = str::from_utf8(&data) {
jsonwebtoken::DecodingKey::from_ed_components(&data_str)
.context("Could not decode Ed25519 public key from base64")
} else {
bail!("Key is in an unsupported binary format")
jsonwebtoken::DecodingKey::from_ed_components(data)
.context("Could not decode Ed25519 public key from base64")
}
}
4 changes: 1 addition & 3 deletions sqld/src/hrana/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ pub enum ResponseError {
#[error("Authentication using JWT is required")]
AuthJwtRequired,
#[error("Authentication using JWT failed")]
AuthJwtRejected {
source: jsonwebtoken::errors::Error,
},
AuthJwtRejected { source: jsonwebtoken::errors::Error },

#[error("Stream {stream_id} not found")]
StreamNotFound { stream_id: i32 },
Expand Down
10 changes: 6 additions & 4 deletions sqld/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub struct Config {
pub http_auth: Option<String>,
pub enable_http_console: bool,
pub hrana_addr: Option<SocketAddr>,
pub hrana_jwt_key: Option<PathBuf>,
pub hrana_jwt_key: Option<String>,
pub backend: Backend,
#[cfg(feature = "mwal_backend")]
pub mwal_addr: Option<String>,
Expand Down Expand Up @@ -117,10 +117,12 @@ async fn run_service(
}

if let Some(addr) = config.hrana_addr {
let jwt_key = config.hrana_jwt_key.as_deref()
.map(hrana::load_jwt_key)
let jwt_key = config
.hrana_jwt_key
.as_deref()
.map(hrana::parse_jwt_key)
.transpose()
.context("Could not load JWT decoding key for Hrana")?;
.context("Could not parse JWT decoding key for Hrana")?;

join_set.spawn(async move {
hrana::serve(service.factory, addr, jwt_key)
Expand Down
81 changes: 48 additions & 33 deletions sqld/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{net::SocketAddr, path::PathBuf, time::Duration};
use std::{env, fs, net::SocketAddr, path::PathBuf, time::Duration};

use anyhow::Result;
use anyhow::{bail, Context as _, Result};
use clap::Parser;
use sqld::Config;
use tracing_subscriber::filter::LevelFilter;
Expand All @@ -24,7 +24,9 @@ struct Cli {
hrana_listen_addr: Option<SocketAddr>,
/// Path to a file with a JWT decoding key used to authenticate Hrana connections. If you do
/// not specify a key, Hrana authentication is not required. The key is either PKCS#8-encoded
/// Ed25519 public key, or just plain bytes of the Ed25519 public key in base64.
/// Ed25519 public key, or just plain bytes of the Ed25519 public key in URL-safe base64.
///
/// You can also pass the key directly in the env variable SQLD_HRANA_JWT_KEY.
#[clap(long, env = "SQLD_HRANA_JWT_KEY_FILE")]
hrana_jwt_key_file: Option<PathBuf>,

Expand Down Expand Up @@ -141,39 +143,51 @@ impl Cli {
if let Some(ref addr) = self.pg_listen_addr {
eprintln!("\t- listening for PostgreSQL wire on: {addr}");
}
eprintln!("\t- gprc_tls: {}", if self.grpc_tls { "yes" } else { "no" });
eprintln!("\t- grpc_tls: {}", if self.grpc_tls { "yes" } else { "no" });
}
}

impl From<Cli> for Config {
fn from(cli: Cli) -> Self {
Self {
db_path: cli.db_path,
tcp_addr: cli.pg_listen_addr,
ws_addr: cli.ws_listen_addr,
http_addr: Some(cli.http_listen_addr),
http_auth: cli.http_auth,
enable_http_console: cli.enable_http_console,
hrana_addr: cli.hrana_listen_addr,
hrana_jwt_key: cli.hrana_jwt_key_file,
backend: cli.backend,
writer_rpc_addr: cli.primary_grpc_url,
writer_rpc_tls: cli.primary_grpc_tls,
writer_rpc_cert: cli.primary_grpc_cert_file,
writer_rpc_key: cli.primary_grpc_key_file,
writer_rpc_ca_cert: cli.primary_grpc_ca_cert_file,
rpc_server_addr: cli.grpc_listen_addr,
rpc_server_tls: cli.grpc_tls,
rpc_server_cert: cli.grpc_cert_file,
rpc_server_key: cli.grpc_key_file,
rpc_server_ca_cert: cli.grpc_ca_cert_file,
#[cfg(feature = "mwal_backend")]
mwal_addr: cli.mwal_addr,
enable_bottomless_replication: cli.enable_bottomless_replication,
create_local_http_tunnel: cli.create_local_http_tunnel,
idle_shutdown_timeout: cli.idle_shutdown_timeout_s.map(Duration::from_secs),
fn config_from_args(args: Cli) -> Result<Config> {
let hrana_jwt_key = if let Some(file_path) = args.hrana_jwt_key_file {
let data =
fs::read_to_string(file_path).context("Could not read file with Hrana JWT key")?;
Some(data)
} else {
match env::var("SQLD_HRANA_JWT_KEY") {
Ok(key) => Some(key),
Err(env::VarError::NotPresent) => None,
Err(env::VarError::NotUnicode(_)) => {
bail!("Env variable SQLD_HRANA_JWT_KEY does not contain a valid Unicode value")
}
}
}
};

Ok(Config {
db_path: args.db_path,
tcp_addr: args.pg_listen_addr,
ws_addr: args.ws_listen_addr,
http_addr: Some(args.http_listen_addr),
http_auth: args.http_auth,
enable_http_console: args.enable_http_console,
hrana_addr: args.hrana_listen_addr,
hrana_jwt_key,
backend: args.backend,
writer_rpc_addr: args.primary_grpc_url,
writer_rpc_tls: args.primary_grpc_tls,
writer_rpc_cert: args.primary_grpc_cert_file,
writer_rpc_key: args.primary_grpc_key_file,
writer_rpc_ca_cert: args.primary_grpc_ca_cert_file,
rpc_server_addr: args.grpc_listen_addr,
rpc_server_tls: args.grpc_tls,
rpc_server_cert: args.grpc_cert_file,
rpc_server_key: args.grpc_key_file,
rpc_server_ca_cert: args.grpc_ca_cert_file,
#[cfg(feature = "mwal_backend")]
mwal_addr: args.mwal_addr,
enable_bottomless_replication: args.enable_bottomless_replication,
create_local_http_tunnel: args.create_local_http_tunnel,
idle_shutdown_timeout: args.idle_shutdown_timeout_s.map(Duration::from_secs),
})
}

#[tokio::main]
Expand Down Expand Up @@ -203,7 +217,8 @@ async fn main() -> Result<()> {
_ => (),
}

sqld::run_server(args.into()).await?;
let config = config_from_args(args)?;
sqld::run_server(config).await?;

Ok(())
}