Skip to content

Commit

Permalink
feat: add deltalake support again
Browse files Browse the repository at this point in the history
  • Loading branch information
timvw committed Mar 28, 2024
1 parent f81ef62 commit 6ae7076
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub struct Args {

/// Query to execute
#[clap(short, long, default_value_t = String::from("select * from tbl"), group = "sql")]
query: String,
pub query: String,

/// When provided the schema is shown
#[clap(short, long, group = "sql")]
Expand Down
115 changes: 61 additions & 54 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
mod args;
mod globbing_path;
mod globbing_table;
mod object_store_util;
use std::collections::HashMap;
use std::env;
use std::sync::Arc;

use aws_config::BehaviorVersion;
use aws_credential_types::provider::ProvideCredentials;
use aws_sdk_glue::types::StorageDescriptor;
use aws_sdk_glue::Client;
use aws_types::SdkConfig;
use clap::Parser;
use datafusion::catalog::TableReference;
use std::collections::HashMap;
use std::env;
use std::sync::Arc;

use aws_types::SdkConfig;

use datafusion::common::{DataFusionError, Result};
use datafusion::datasource::file_format::avro::AvroFormat;
use datafusion::datasource::file_format::csv::CsvFormat;
Expand All @@ -24,7 +18,9 @@ use datafusion::datasource::file_format::FileFormat;
use datafusion::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
};
use datafusion::datasource::TableProvider;
use datafusion::prelude::*;
use deltalake::open_table;
use object_store::aws::{AmazonS3, AmazonS3Builder};
use object_store::path::Path;
use object_store::ObjectStore;
Expand All @@ -33,37 +29,10 @@ use url::Url;

use crate::args::Args;

async fn build_s3(url: &Url, sdk_config: &SdkConfig) -> Result<AmazonS3> {
let cp = sdk_config.credentials_provider().unwrap();
let creds = cp
.provide_credentials()
.await
.map_err(|e| DataFusionError::Execution(format!("Failed to get credentials: {e}")))?;

let bucket_name = url.host_str().unwrap();

let builder = AmazonS3Builder::from_env()
.with_bucket_name(bucket_name)
.with_access_key_id(creds.access_key_id())
.with_secret_access_key(creds.secret_access_key());

let builder = if let Some(session_token) = creds.session_token() {
builder.with_token(session_token)
} else {
builder
};

//https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html
let builder = if let Ok(aws_endpoint_url) = env::var("AWS_ENDPOINT_URL") {
builder.with_endpoint(aws_endpoint_url)
} else {
builder
};

let s3 = builder.build()?;

Ok(s3)
}
mod args;
mod globbing_path;
mod globbing_table;
mod object_store_util;

#[tokio::main]
async fn main() -> Result<()> {
Expand All @@ -87,6 +56,8 @@ async fn main() -> Result<()> {
ctx.runtime_env()
.register_object_store(&s3_url, s3_arc.clone());

deltalake::aws::register_handlers(None);

// add trailing slash to folder
if !data_path.ends_with('/') {
let path = Path::parse(s3_url.path())?;
Expand All @@ -102,23 +73,27 @@ async fn main() -> Result<()> {
data_path
};

let table_path = ListingTableUrl::parse(data_path)?;
let mut config = ListingTableConfig::new(table_path);

config = if let Some(format) = file_format {
config.with_listing_options(ListingOptions::new(format))
let table: Arc<dyn TableProvider> = if let Ok(mut delta_table) = open_table(&data_path).await {
if let Some(at) = args.at {
delta_table.load_with_datetime(at).await?;
}
Arc::new(delta_table)
} else {
config.infer_options(&ctx.state()).await?
};
let table_path = ListingTableUrl::parse(&data_path)?;
let mut config = ListingTableConfig::new(table_path);

config = config.infer_schema(&ctx.state()).await?;
config = if let Some(format) = file_format {
config.with_listing_options(ListingOptions::new(format))
} else {
config.infer_options(&ctx.state()).await?
};

let table = ListingTable::try_new(config)?;
config = config.infer_schema(&ctx.state()).await?;
let table = ListingTable::try_new(config)?;
Arc::new(table)
};

ctx.register_table(
TableReference::from("datafusion.public.tbl"),
Arc::new(table),
)?;
ctx.register_table(TableReference::from("datafusion.public.tbl"), table)?;

let query = &args.get_query();
let df = ctx.sql(query).await?;
Expand Down Expand Up @@ -349,3 +324,35 @@ fn lookup_file_format(sd: &StorageDescriptor) -> Result<Arc<dyn FileFormat>> {
let format = format_result?;
Ok(format)
}

async fn build_s3(url: &Url, sdk_config: &SdkConfig) -> Result<AmazonS3> {
let cp = sdk_config.credentials_provider().unwrap();
let creds = cp
.provide_credentials()
.await
.map_err(|e| DataFusionError::Execution(format!("Failed to get credentials: {e}")))?;

let bucket_name = url.host_str().unwrap();

let builder = AmazonS3Builder::from_env()
.with_bucket_name(bucket_name)
.with_access_key_id(creds.access_key_id())
.with_secret_access_key(creds.secret_access_key());

let builder = if let Some(session_token) = creds.session_token() {
builder.with_token(session_token)
} else {
builder
};

//https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html
let builder = if let Ok(aws_endpoint_url) = env::var("AWS_ENDPOINT_URL") {
builder.with_endpoint(aws_endpoint_url)
} else {
builder
};

let s3 = builder.build()?;

Ok(s3)
}
60 changes: 60 additions & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,63 @@ async fn run_with_s3_parquet_files_in_folder_no_trailing_slash() -> datafusion::
.stdout(data_predicate);
Ok(())
}

#[tokio::test]
async fn run_with_local_deltalake() -> datafusion::common::Result<()> {
let mut cmd = get_qv_cmd()?;
let cmd = cmd
.arg(get_qv_testing_path("data/delta/COVID-19_NYT"))
.arg("--at")
.arg("2022-01-13T16:39:00+01:00")
.arg("-q")
.arg("select * from tbl order by date, county, state, fips, cases, deaths");

let header_predicate =
build_row_regex_predicate(vec!["date", "county", "state", "fips", "case", "deaths"]);

let data_predicate = build_row_regex_predicate(vec![
"2020-01-21",
"Snohomish",
"Washington",
"53061",
"1",
"0",
]);

cmd.assert()
.success()
.stdout(header_predicate)
.stdout(data_predicate);
Ok(())
}

#[tokio::test]
async fn run_with_s3_deltalake() -> datafusion::common::Result<()> {
configure_minio();

let mut cmd = get_qv_cmd()?;
let cmd = cmd
.arg("s3://data/delta/COVID-19_NYT")
.arg("--at")
.arg("2022-01-13T16:39:00+01:00")
.arg("-q")
.arg("select * from tbl order by date, county, state, fips, cases, deaths");

let header_predicate =
build_row_regex_predicate(vec!["date", "county", "state", "fips", "case", "deaths"]);

let data_predicate = build_row_regex_predicate(vec![
"2020-01-21",
"Snohomish",
"Washington",
"53061",
"1",
"0X",
]);

cmd.assert()
.success()
.stdout(header_predicate)
.stdout(data_predicate);
Ok(())
}

0 comments on commit 6ae7076

Please sign in to comment.