Skip to content
This repository has been archived by the owner on Oct 18, 2023. It is now read-only.

Add support for queries over HTTP #12

Merged
merged 7 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support query over HTTP
  • Loading branch information
MarinPostma committed Jan 9, 2023
commit e8838afe01dab8079543403c6aa96793b81b482c
9 changes: 9 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ edition = "2021"
anyhow = "1.0.66"
async-lock = "2.6.0"
async-trait = "0.1.58"
base64 = "0.21.0"
bincode = "1.3.3"
byteorder = "1.4.3"
bytes = { version = "1.2.1", features = ["serde"] }
clap = { version = "4.0.23", features = [ "derive" ] }
crossbeam = "0.8.2"
futures = "0.3.25"
hex = "0.4.3"
hyper = { version = "0.14.23", features = ["http2"] }
# Regular mvfs prevents users from enabling WAL mode
mvfs = { git = "https://github.com/psarna/mvsqlite", branch = "mwal", optional = true }
mwal = { git = "https://github.com/psarna/mvsqlite", branch = "mwal", optional = true }
Expand All @@ -26,6 +28,7 @@ prost = "0.11.3"
regex = "1.7.0"
rusqlite = { version = "0.28.0", features = [ "buildtime_bindgen", "column_decltype" ] }
serde = { version = "1.0.149", features = ["derive"] }
serde_json = "1.0.91"
smallvec = "1.10.0"
sqlparser = "0.27.0"
tokio = { version = "1.21.2", features = ["full"] }
Expand Down
1 change: 1 addition & 0 deletions server/src/database/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ where
}
}

#[derive(Clone)]
pub struct DbFactoryService<F> {
factory: F,
}
Expand Down
124 changes: 124 additions & 0 deletions server/src/http/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
use std::collections::HashMap;
use std::future::poll_fn;
use std::{convert::Infallible, net::SocketAddr};

use base64::prelude::BASE64_STANDARD_NO_PAD;
use base64::Engine;
use bytes::{BufMut, Bytes, BytesMut};
use hyper::body::to_bytes;
use hyper::server::conn::AddrStream;
use hyper::service::make_service_fn;
use hyper::{Body, Request, Response};
use serde::Deserialize;
use serde_json::Number;
use tokio::sync::{mpsc, oneshot};
use tower::balance::pool;
use tower::load::Load;
use tower::{service_fn, BoxError, MakeService, Service};

use crate::query::{self, Query, QueryError, QueryResponse, ResultSet};

fn query_response_to_json(rows: QueryResponse) -> anyhow::Result<Bytes> {
let QueryResponse::ResultSet(ResultSet { columns, rows }) = rows;
let mut values = Vec::new();
MarinPostma marked this conversation as resolved.
Show resolved Hide resolved
for row in rows {
let val = row
.values
.into_iter()
.zip(columns.iter().map(|c| &c.name))
.try_fold(
HashMap::new(),
|mut map, (value, name)| -> anyhow::Result<_> {
let value = match value {
query::Value::Null => serde_json::Value::Null,
query::Value::Integer(i) => serde_json::Value::Number(Number::from(i)),
query::Value::Real(x) => serde_json::Value::Number(
Number::from_f64(x).ok_or(anyhow::anyhow!("invalid float value"))?,
),
query::Value::Text(s) => serde_json::Value::String(s),
query::Value::Blob(v) => {
serde_json::Value::String(BASE64_STANDARD_NO_PAD.encode(v))
}
};
MarinPostma marked this conversation as resolved.
Show resolved Hide resolved

map.insert(name.to_string(), value);
Ok(map)
},
)?;

values.push(val);
}

let mut buffer = BytesMut::new().writer();
serde_json::to_writer(&mut buffer, &values)?;

Ok(buffer.into_inner().freeze())
}

async fn handle_request(
mut req: Request<Body>,
sender: mpsc::Sender<(oneshot::Sender<Result<QueryResponse, BoxError>>, Query)>,
) -> anyhow::Result<Response<Body>> {
let bytes = to_bytes(req.body_mut()).await?;
let req: HttpQueryRequest = serde_json::from_slice(&bytes)?;
let (s, resp) = oneshot::channel();
let _ = sender
.send((s, Query::SimpleQuery(req.statements.join(";"), Vec::new())))
.await;

let result = resp.await?;
match result {
Ok(rows) => {
let json = query_response_to_json(rows)?;
Ok(Response::new(Body::from(json)))
}
Err(_) => todo!(),
}
}

pub async fn run_http<F>(addr: SocketAddr, db_factory: F) -> anyhow::Result<()>
where
F: MakeService<(), Query> + Send + 'static,
F::Service: Load + Service<Query, Response = QueryResponse, Error = QueryError>,
<F::Service as Load>::Metric: std::fmt::Debug,
F::MakeError: Into<BoxError>,
F::Error: Into<BoxError>,
<F as MakeService<(), Query>>::Service: Send,
<F as MakeService<(), Query>>::Future: Send,
<<F as MakeService<(), Query>>::Service as Service<Query>>::Future: Send,
{
tracing::info!("listening for HTTP requests on {addr}");

let (sender, mut receiver) = mpsc::channel(1024);
let server =
hyper::server::Server::bind(&addr).serve(make_service_fn(move |_: &AddrStream| {
let sender = sender.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| handle_request(req, sender.clone())))
}
}));

tokio::spawn(async move {
let mut pool = pool::Builder::new().build(db_factory, ());
while let Some((resp, query)) = receiver.recv().await {
if let Err(e) = poll_fn(|c| pool.poll_ready(c)).await {
tracing::error!("Connection pool error: {e}");
continue;
MarinPostma marked this conversation as resolved.
Show resolved Hide resolved
}

let fut = pool.call(query);
tokio::spawn(async move {
let _ = resp.send(fut.await);
});
}
});

server.await?;

Ok(())
}

#[derive(Debug, Deserialize)]
pub struct HttpQueryRequest {
statements: Vec<String>,
}
12 changes: 11 additions & 1 deletion server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ use database::libsql::LibSqlDb;
use database::service::DbFactoryService;
use database::write_proxy::WriteProxyDbFactory;
use rpc::run_rpc_server;
use tower::load::Constant;
use tower::ServiceExt;
use wal_logger::{WalLogger, WalLoggerHook};

use crate::postgres::service::PgConnectionFactory;
use crate::server::Server;

mod database;
mod http;
mod libsql;
mod postgres;
mod query;
Expand All @@ -34,6 +37,7 @@ pub async fn run_server(
db_path: PathBuf,
tcp_addr: SocketAddr,
ws_addr: Option<SocketAddr>,
http_addr: Option<SocketAddr>,
backend: Backend,
#[cfg(feature = "mwal_backend")] mwal_addr: Option<String>,
writer_rpc_addr: Option<String>,
Expand Down Expand Up @@ -87,7 +91,13 @@ pub async fn run_server(
}
};
let service = DbFactoryService::new(db_factory.clone());
let factory = PgConnectionFactory::new(service);
let factory = PgConnectionFactory::new(service.clone());
if let Some(addr) = http_addr {
tokio::spawn(http::run_http(
addr,
service.map_response(|s| Constant::new(s, 1)),
));
}
if let Some(addr) = rpc_server_addr {
tokio::spawn(run_rpc_server(addr, db_factory, logger_clone));
}
Expand Down
4 changes: 4 additions & 0 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ struct Cli {
#[cfg(feature = "mwal_backend")]
#[clap(long, short)]
mwal_addr: Option<String>,

#[clap(long)]
http_addr: Option<SocketAddr>,
}

#[tokio::main]
Expand All @@ -53,6 +56,7 @@ async fn main() -> Result<()> {
args.db_path,
args.pg_listen_addr,
args.ws_listen_addr,
args.http_addr,
args.backend,
#[cfg(feature = "mwal_backend")]
args.mwal_addr,
Expand Down
9 changes: 9 additions & 0 deletions server/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::convert::Infallible;
use std::fmt;
use std::str::FromStr;

use futures::stream;
Expand Down Expand Up @@ -254,6 +255,14 @@ pub struct QueryError {
pub msg: String,
}

impl std::error::Error for QueryError {}

impl fmt::Display for QueryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.msg)
}
}

impl From<RpcError> for QueryError {
fn from(other: RpcError) -> Self {
let code = match other.code() {
Expand Down