Skip to content
This repository has been archived by the owner on Oct 18, 2023. It is now read-only.

Commit

Permalink
Improve the ergonomy of passing JWT key for Hrana
Browse files Browse the repository at this point in the history
  • Loading branch information
honzasp committed Feb 22, 2023
1 parent c9da363 commit 560c695
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 62 deletions.
5 changes: 1 addition & 4 deletions packages/js/hrana-client/examples/jwt_auth.mjs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import * as hrana from "@libsql/hrana-client";

const jwt = (
"eyJ0eXAiOiJKV1QiLCJhbGciOiJFZERTQSJ9.eyJleHAiOjE2NzY5MDkwOTR9._8Dt3MSN7b5-ykbxM2dCh8CzIPpkqDmPagRXfSO3s1es-6vRN_qMrNGsEUdCFP6tAmCNYd9RJZ9zaUT_wCQ3Bg"
);
const client = hrana.open("ws://localhost:2023", jwt);
const client = hrana.open("ws://localhost:2023", process.env.JWT);
const stream = client.openStream();
console.log(await stream.queryValue("SELECT 1"));
client.close();
11 changes: 7 additions & 4 deletions scripts/gen_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

pubkey_base64 = base64.b64encode(pubkey.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
))
pubkey_base64 = base64.b64encode(
pubkey.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
),
altchars=b"-_",
)
while pubkey_base64[-1] == ord("="):
pubkey_base64 = pubkey_base64[:-1]

Expand Down
32 changes: 18 additions & 14 deletions sqld/src/hrana/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use crate::database::service::DbFactory;
use anyhow::{Context as _, Result, bail};
use anyhow::{bail, Context as _, Result};
use enclose::enclose;
use std::{fs, str};
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;

mod conn;
Expand All @@ -20,7 +18,16 @@ pub async fn serve(
bind_addr: SocketAddr,
jwt_key: Option<jsonwebtoken::DecodingKey>,
) -> Result<()> {
let server = Arc::new(Server { db_factory, jwt_key });
let server = Arc::new(Server {
db_factory,
jwt_key,
});

if server.jwt_key.is_some() {
tracing::info!("Hrana authentication is enabled");
} else {
tracing::warn!("Hrana authentication is disabled, the server is unprotected");
}

let listener = tokio::net::TcpListener::bind(bind_addr)
.await
Expand Down Expand Up @@ -53,19 +60,16 @@ pub async fn serve(
}
}

pub fn load_jwt_key(path: &Path) -> Result<jsonwebtoken::DecodingKey> {
let data = fs::read(path)?;
if data.starts_with(b"-----BEGIN PUBLIC KEY-----") {
jsonwebtoken::DecodingKey::from_ed_pem(&data)
pub fn parse_jwt_key(data: &str) -> Result<jsonwebtoken::DecodingKey> {
if data.starts_with("-----BEGIN PUBLIC KEY-----") {
jsonwebtoken::DecodingKey::from_ed_pem(data.as_bytes())
.context("Could not decode Ed25519 public key from PEM")
} else if data.starts_with(b"-----BEGIN PRIVATE KEY-----") {
} else if data.starts_with("-----BEGIN PRIVATE KEY-----") {
bail!("Received a private key, but a public key is expected")
} else if data.starts_with(b"-----BEGIN") {
} else if data.starts_with("-----BEGIN") {
bail!("Key is in unsupported PEM format")
} else if let Ok(data_str) = str::from_utf8(&data) {
jsonwebtoken::DecodingKey::from_ed_components(&data_str)
.context("Could not decode Ed25519 public key from base64")
} else {
bail!("Key is in an unsupported binary format")
jsonwebtoken::DecodingKey::from_ed_components(data)
.context("Could not decode Ed25519 public key from base64")
}
}
4 changes: 1 addition & 3 deletions sqld/src/hrana/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ pub enum ResponseError {
#[error("Authentication using JWT is required")]
AuthJwtRequired,
#[error("Authentication using JWT failed")]
AuthJwtRejected {
source: jsonwebtoken::errors::Error,
},
AuthJwtRejected { source: jsonwebtoken::errors::Error },

#[error("Stream {stream_id} not found")]
StreamNotFound { stream_id: i32 },
Expand Down
10 changes: 6 additions & 4 deletions sqld/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub struct Config {
pub http_auth: Option<String>,
pub enable_http_console: bool,
pub hrana_addr: Option<SocketAddr>,
pub hrana_jwt_key: Option<PathBuf>,
pub hrana_jwt_key: Option<String>,
pub backend: Backend,
#[cfg(feature = "mwal_backend")]
pub mwal_addr: Option<String>,
Expand Down Expand Up @@ -117,10 +117,12 @@ async fn run_service(
}

if let Some(addr) = config.hrana_addr {
let jwt_key = config.hrana_jwt_key.as_deref()
.map(hrana::load_jwt_key)
let jwt_key = config
.hrana_jwt_key
.as_deref()
.map(hrana::parse_jwt_key)
.transpose()
.context("Could not load JWT decoding key for Hrana")?;
.context("Could not parse JWT decoding key for Hrana")?;

join_set.spawn(async move {
hrana::serve(service.factory, addr, jwt_key)
Expand Down
81 changes: 48 additions & 33 deletions sqld/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{net::SocketAddr, path::PathBuf, time::Duration};
use std::{env, fs, net::SocketAddr, path::PathBuf, time::Duration};

use anyhow::Result;
use anyhow::{bail, Context as _, Result};
use clap::Parser;
use sqld::Config;
use tracing_subscriber::filter::LevelFilter;
Expand All @@ -24,7 +24,9 @@ struct Cli {
hrana_listen_addr: Option<SocketAddr>,
/// Path to a file with a JWT decoding key used to authenticate Hrana connections. If you do
/// not specify a key, Hrana authentication is not required. The key is either PKCS#8-encoded
/// Ed25519 public key, or just plain bytes of the Ed25519 public key in base64.
/// Ed25519 public key, or just plain bytes of the Ed25519 public key in URL-safe base64.
///
/// You can also pass the key directly in the env variable SQLD_HRANA_JWT_KEY.
#[clap(long, env = "SQLD_HRANA_JWT_KEY_FILE")]
hrana_jwt_key_file: Option<PathBuf>,

Expand Down Expand Up @@ -141,39 +143,51 @@ impl Cli {
if let Some(ref addr) = self.pg_listen_addr {
eprintln!("\t- listening for PostgreSQL wire on: {addr}");
}
eprintln!("\t- gprc_tls: {}", if self.grpc_tls { "yes" } else { "no" });
eprintln!("\t- grpc_tls: {}", if self.grpc_tls { "yes" } else { "no" });
}
}

impl From<Cli> for Config {
fn from(cli: Cli) -> Self {
Self {
db_path: cli.db_path,
tcp_addr: cli.pg_listen_addr,
ws_addr: cli.ws_listen_addr,
http_addr: Some(cli.http_listen_addr),
http_auth: cli.http_auth,
enable_http_console: cli.enable_http_console,
hrana_addr: cli.hrana_listen_addr,
hrana_jwt_key: cli.hrana_jwt_key_file,
backend: cli.backend,
writer_rpc_addr: cli.primary_grpc_url,
writer_rpc_tls: cli.primary_grpc_tls,
writer_rpc_cert: cli.primary_grpc_cert_file,
writer_rpc_key: cli.primary_grpc_key_file,
writer_rpc_ca_cert: cli.primary_grpc_ca_cert_file,
rpc_server_addr: cli.grpc_listen_addr,
rpc_server_tls: cli.grpc_tls,
rpc_server_cert: cli.grpc_cert_file,
rpc_server_key: cli.grpc_key_file,
rpc_server_ca_cert: cli.grpc_ca_cert_file,
#[cfg(feature = "mwal_backend")]
mwal_addr: cli.mwal_addr,
enable_bottomless_replication: cli.enable_bottomless_replication,
create_local_http_tunnel: cli.create_local_http_tunnel,
idle_shutdown_timeout: cli.idle_shutdown_timeout_s.map(Duration::from_secs),
fn config_from_args(args: Cli) -> Result<Config> {
let hrana_jwt_key = if let Some(file_path) = args.hrana_jwt_key_file {
let data =
fs::read_to_string(file_path).context("Could not read file with Hrana JWT key")?;
Some(data)
} else {
match env::var("SQLD_HRANA_JWT_KEY") {
Ok(key) => Some(key),
Err(env::VarError::NotPresent) => None,
Err(env::VarError::NotUnicode(_)) => {
bail!("Env variable SQLD_HRANA_JWT_KEY does not contain a valid Unicode value")
}
}
}
};

Ok(Config {
db_path: args.db_path,
tcp_addr: args.pg_listen_addr,
ws_addr: args.ws_listen_addr,
http_addr: Some(args.http_listen_addr),
http_auth: args.http_auth,
enable_http_console: args.enable_http_console,
hrana_addr: args.hrana_listen_addr,
hrana_jwt_key,
backend: args.backend,
writer_rpc_addr: args.primary_grpc_url,
writer_rpc_tls: args.primary_grpc_tls,
writer_rpc_cert: args.primary_grpc_cert_file,
writer_rpc_key: args.primary_grpc_key_file,
writer_rpc_ca_cert: args.primary_grpc_ca_cert_file,
rpc_server_addr: args.grpc_listen_addr,
rpc_server_tls: args.grpc_tls,
rpc_server_cert: args.grpc_cert_file,
rpc_server_key: args.grpc_key_file,
rpc_server_ca_cert: args.grpc_ca_cert_file,
#[cfg(feature = "mwal_backend")]
mwal_addr: args.mwal_addr,
enable_bottomless_replication: args.enable_bottomless_replication,
create_local_http_tunnel: args.create_local_http_tunnel,
idle_shutdown_timeout: args.idle_shutdown_timeout_s.map(Duration::from_secs),
})
}

#[tokio::main]
Expand Down Expand Up @@ -203,7 +217,8 @@ async fn main() -> Result<()> {
_ => (),
}

sqld::run_server(args.into()).await?;
let config = config_from_args(args)?;
sqld::run_server(config).await?;

Ok(())
}

0 comments on commit 560c695

Please sign in to comment.