Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 34 additions & 25 deletions src/test_utils/fixture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,35 +84,44 @@ impl FlightSqlService for TestFlightSqlServiceImpl {
) -> Result<Response<FlightInfo>, Status> {
let CommandStatementQuery { query, .. } = query;
let dialect = datafusion::sql::sqlparser::dialect::GenericDialect {};
let statements = DFParser::parse_sql_with_dialect(&query, &dialect).unwrap();
// For testing purposes, we only support a single statement
assert_eq!(statements.len(), 1, "Only single statements are supported");
let statement = statements[0].clone();
let logical_plan = self
.context
.state()
.statement_to_plan(statement)
.await
.unwrap();
let schema = logical_plan.schema();
match DFParser::parse_sql_with_dialect(&query, &dialect) {
Ok(statements) => {
// For testing purposes, we only support a single statement
assert_eq!(statements.len(), 1, "Only single statements are supported");
let statement = statements[0].clone();
match self.context.state().statement_to_plan(statement).await {
Ok(logical_plan) => {
let schema = logical_plan.schema();

let uuid = uuid::Uuid::new_v4();
let ticket = TicketStatementQuery {
statement_handle: uuid.to_string().into(),
};
let mut bytes: Vec<u8> = Vec::new();
ticket.encode(&mut bytes).unwrap();
let uuid = uuid::Uuid::new_v4();
let ticket = TicketStatementQuery {
statement_handle: uuid.to_string().into(),
};
let mut bytes: Vec<u8> = Vec::new();
ticket.encode(&mut bytes).unwrap();

let info = FlightInfo::new()
.try_with_schema(schema.as_arrow())
.unwrap()
.with_endpoint(FlightEndpoint::new().with_ticket(Ticket::new(bytes)))
.with_descriptor(FlightDescriptor::new_cmd(query));
let info = FlightInfo::new()
.try_with_schema(schema.as_arrow())
.unwrap()
.with_endpoint(FlightEndpoint::new().with_ticket(Ticket::new(bytes)))
.with_descriptor(FlightDescriptor::new_cmd(query));

let mut guard = self.requests.lock().unwrap();
guard.insert(uuid, logical_plan);
let mut guard = self.requests.lock().unwrap();
guard.insert(uuid, logical_plan);

Ok(Response::new(info))
Ok(Response::new(info))
}
Err(e) => {
let error = format!("{:?}", e);
Err(Status::internal(error))
}
}
}
Err(e) => {
let error = format!("{:?}", e);
Err(Status::internal(error))
}
}
}

async fn do_get_statement(
Expand Down
49 changes: 49 additions & 0 deletions tests/extension_cases/flightsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,31 @@ pub async fn test_execute() {
fixture.shutdown_and_wait().await;
}

#[tokio::test]
pub async fn test_invalid_sql_command() {
let test_server = TestFlightSqlServiceImpl::new();
let fixture = TestFixture::new(test_server.service(), "127.0.0.1:50051").await;

let assert = tokio::task::spawn_blocking(|| {
Command::cargo_bin("dft")
.unwrap()
.arg("-c")
.arg("SELEC 1;")
.arg("--flightsql")
.timeout(Duration::from_secs(5))
.assert()
.failure()
})
.await
.unwrap();

// I think its implementation specific how they decide to return errors but I believe they will
// all be in the form of an IPC error
let expected = r##"Error: Ipc error"##;
assert.stderr(contains_str(expected));
fixture.shutdown_and_wait().await;
}

#[tokio::test]
pub async fn test_execute_multiple_commands() {
let test_server = TestFlightSqlServiceImpl::new();
Expand Down Expand Up @@ -127,6 +152,30 @@ pub async fn test_command_in_file() {
fixture.shutdown_and_wait().await;
}

#[tokio::test]
pub async fn test_invalid_sql_command_in_file() {
let test_server = TestFlightSqlServiceImpl::new();
let fixture = TestFixture::new(test_server.service(), "127.0.0.1:50051").await;
let file = sql_in_file("SELEC 1");
let assert = tokio::task::spawn_blocking(move || {
Command::cargo_bin("dft")
.unwrap()
.arg("--flightsql")
.arg("-f")
.arg(file.path())
.assert()
.failure()
})
.await
.unwrap();

// I think its implementation specific how they decide to return errors but I believe they will
// all be in the form of an IPC error
let expected = r##"Error: Ipc error"##;
assert.stderr(contains_str(expected));
fixture.shutdown_and_wait().await;
}

#[tokio::test]
pub async fn test_command_multiple_files() {
let test_server = TestFlightSqlServiceImpl::new();
Expand Down