Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions rust/cubesql/cubesql/e2e/tests/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,109 @@ impl PostgresIntegrationTestSuite {
Ok(())
}

async fn test_fetch_directions(&self) -> RunResult<()> {
self.test_simple_query(
r#"DECLARE test_fetch_directions CURSOR WITH HOLD FOR SELECT generate_series(1, 100);"#
.to_string(),
|_| {},
)
.await?;

// Test FETCH FORWARD 1 - should return row "1"
self.test_simple_query(
r#"FETCH FORWARD 1 IN test_fetch_directions;"#.to_string(),
|messages| {
assert_eq!(messages.len(), 2); // 1 row + completion
if let SimpleQueryMessage::Row(row) = &messages[0] {
assert_eq!(row.get(0), Some("1"));
} else {
panic!("Expected Row for FETCH FORWARD 1");
}
},
)
.await?;

// Test FETCH NEXT - should return row "2"
self.test_simple_query(
r#"FETCH NEXT IN test_fetch_directions;"#.to_string(),
|messages| {
assert_eq!(messages.len(), 2); // 1 row + completion
if let SimpleQueryMessage::Row(row) = &messages[0] {
assert_eq!(row.get(0), Some("2"));
} else {
panic!("Expected Row for FETCH NEXT");
}
},
)
.await?;

// Test FETCH FORWARD 5 - should return rows 3-7
self.test_simple_query(
r#"FETCH FORWARD 5 IN test_fetch_directions;"#.to_string(),
|messages| {
assert_eq!(messages.len(), 6); // 5 rows + completion
if let SimpleQueryMessage::Row(row) = &messages[0] {
assert_eq!(row.get(0), Some("3"));
} else {
panic!("Expected Row for FETCH FORWARD 5, first row");
}
if let SimpleQueryMessage::Row(row) = &messages[4] {
assert_eq!(row.get(0), Some("7"));
} else {
panic!("Expected Row for FETCH FORWARD 5, last row");
}
},
)
.await?;

// Test FETCH ALL - should return remaining rows (8-100 = 93 rows)
self.test_simple_query(
r#"FETCH ALL IN test_fetch_directions;"#.to_string(),
|messages| {
// 93 rows + 1 completion
assert_eq!(messages.len(), 94);
if let SimpleQueryMessage::Row(row) = &messages[0] {
assert_eq!(row.get(0), Some("8"));
} else {
panic!("Expected Row for FETCH ALL, first row");
}
if let SimpleQueryMessage::Row(row) = &messages[92] {
assert_eq!(row.get(0), Some("100"));
} else {
panic!("Expected Row for FETCH ALL, last row");
}
},
)
.await?;

self.test_simple_query(r#"CLOSE test_fetch_directions;"#.to_string(), |_| {})
.await?;

Ok(())
}

async fn test_fetch_forward_all(&self) -> RunResult<()> {
self.test_simple_query(
r#"DECLARE test_forward_all CURSOR WITH HOLD FOR SELECT generate_series(1, 10);"#
.to_string(),
|_| {},
)
.await?;

self.test_simple_query(
r#"FETCH FORWARD ALL IN test_forward_all;"#.to_string(),
|messages| {
assert_eq!(messages.len(), 11); // 10 rows + 1 completion
},
)
.await?;

self.test_simple_query(r#"CLOSE test_forward_all;"#.to_string(), |_| {})
.await?;

Ok(())
}

// Tableau Desktop uses it
async fn test_simple_cursors_without_hold(&self) -> RunResult<()> {
// without hold is default behaviour
Expand Down Expand Up @@ -1175,6 +1278,8 @@ impl AsyncTestSuite for PostgresIntegrationTestSuite {
self.test_stream_single().await?;
self.test_portal_pagination().await?;
self.test_simple_cursors().await?;
self.test_fetch_directions().await?;
self.test_fetch_forward_all().await?;
self.test_simple_cursors_without_hold().await?;
self.test_simple_cursors_close_specific().await?;
self.test_simple_cursors_close_all().await?;
Expand Down
45 changes: 45 additions & 0 deletions rust/cubesql/cubesql/src/sql/postgres/ast_helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use std::sync::Arc;

use pg_srv::protocol;
use sqlparser::ast::Value;

use super::error::ConnectionError;
use crate::transport::SpanId;

pub fn parse_fetch_limit(
limit: &Value,
span_id: &Option<Arc<SpanId>>,
) -> Result<usize, ConnectionError> {
match limit {
Value::Number(v, negative) => {
if *negative {
return Err(ConnectionError::Protocol(
protocol::ErrorResponse::error(
protocol::ErrorCode::ObjectNotInPrerequisiteState,
"cursor can only scan forward".to_string(),
)
.into(),
span_id.clone(),
));
}
v.parse::<usize>().map_err(|err| {
ConnectionError::Protocol(
protocol::ErrorResponse::error(
protocol::ErrorCode::ProtocolViolation,
format!(r#"Unable to parse number "{}" for fetch limit: {}"#, v, err),
)
.into(),
span_id.clone(),
)
})
}
other => Err(ConnectionError::Protocol(
protocol::ErrorResponse::error(
protocol::ErrorCode::ProtocolViolation,
format!("FETCH limit must be a number, got: {}", other),
)
.into(),
span_id.clone(),
)),
}
}
188 changes: 188 additions & 0 deletions rust/cubesql/cubesql/src/sql/postgres/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
use std::{backtrace::Backtrace, sync::Arc};

use datafusion::{arrow::error::ArrowError, error::DataFusionError};
use pg_srv::{
protocol::{self, ErrorResponse},
ProtocolError,
};

use crate::{compile::CompilationError, transport::SpanId, CubeError};

#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
#[error("CubeError: {0}")]
Cube(CubeError, Option<Arc<SpanId>>),
#[error("DataFusionError: {0}")]
DataFusion(DataFusionError, Option<Arc<SpanId>>),
#[error("ArrowError: {0}")]
Arrow(ArrowError, Option<Arc<SpanId>>),
#[error("CompilationError: {0}")]
CompilationError(CompilationError, Option<Arc<SpanId>>),
#[error("ProtocolError: {0}")]
Protocol(ProtocolError, Option<Arc<SpanId>>),
}

impl ConnectionError {
/// Return Backtrace from any variant of Enum
pub fn backtrace(&self) -> Option<&Backtrace> {
match &self {
ConnectionError::Cube(e, _) => e.backtrace(),
ConnectionError::CompilationError(e, _) => e.backtrace(),
ConnectionError::Protocol(e, _) => e.backtrace(),
ConnectionError::DataFusion(_, _) | ConnectionError::Arrow(_, _) => None,
}
}

/// Converts Error to protocol::ErrorResponse which is usefully for writing response to the client
pub fn to_error_response(self) -> protocol::ErrorResponse {
match self {
ConnectionError::Cube(e, _) => Self::cube_to_error_response(&e),
ConnectionError::DataFusion(e, _) => Self::df_to_error_response(&e),
ConnectionError::Arrow(e, _) => Self::arrow_to_error_response(&e),
ConnectionError::CompilationError(e, _) => {
fn to_error_response(e: CompilationError) -> protocol::ErrorResponse {
match e {
CompilationError::Internal(_, _, _) => protocol::ErrorResponse::error(
protocol::ErrorCode::InternalError,
e.to_string(),
),
CompilationError::User(_, _) => protocol::ErrorResponse::error(
protocol::ErrorCode::InvalidSqlStatement,
e.to_string(),
),
CompilationError::Unsupported(_, _) => protocol::ErrorResponse::error(
protocol::ErrorCode::FeatureNotSupported,
e.to_string(),
),
CompilationError::Fatal(_, _) => protocol::ErrorResponse::fatal(
protocol::ErrorCode::InternalError,
e.to_string(),
),
}
}

to_error_response(e)
}
ConnectionError::Protocol(e, _) => e.to_error_response(),
}
}

pub fn with_span_id(self, span_id: Option<Arc<SpanId>>) -> Self {
match self {
ConnectionError::Cube(e, _) => ConnectionError::Cube(e, span_id),
ConnectionError::DataFusion(e, _) => ConnectionError::DataFusion(e, span_id),
ConnectionError::Arrow(e, _) => ConnectionError::Arrow(e, span_id),
ConnectionError::CompilationError(e, _) => {
ConnectionError::CompilationError(e, span_id)
}
ConnectionError::Protocol(e, _) => ConnectionError::Protocol(e, span_id),
}
}

pub fn span_id(&self) -> Option<Arc<SpanId>> {
match self {
ConnectionError::Cube(_, span_id) => span_id.clone(),
ConnectionError::DataFusion(_, span_id) => span_id.clone(),
ConnectionError::Arrow(_, span_id) => span_id.clone(),
ConnectionError::CompilationError(_, span_id) => span_id.clone(),
ConnectionError::Protocol(_, span_id) => span_id.clone(),
}
}

fn cube_to_error_response(e: &CubeError) -> protocol::ErrorResponse {
let message = e.to_string();
// Remove `Error: ` prefix that can come from JS
let message = if let Some(message) = message.strip_prefix("Error: ") {
message.to_string()
} else {
message
};
protocol::ErrorResponse::error(protocol::ErrorCode::InternalError, message)
}

fn df_to_error_response(e: &DataFusionError) -> protocol::ErrorResponse {
match e {
DataFusionError::ArrowError(arrow_err) => {
return Self::arrow_to_error_response(arrow_err);
}
DataFusionError::External(err) => {
if let Some(cube_err) = err.downcast_ref::<CubeError>() {
return Self::cube_to_error_response(cube_err);
}
}
_ => {}
}
protocol::ErrorResponse::error(
protocol::ErrorCode::InternalError,
format!("Post-processing Error: {}", e),
)
}

fn arrow_to_error_response(e: &ArrowError) -> protocol::ErrorResponse {
match e {
ArrowError::ExternalError(err) => {
if let Some(df_err) = err.downcast_ref::<DataFusionError>() {
return Self::df_to_error_response(df_err);
}
if let Some(cube_err) = err.downcast_ref::<CubeError>() {
return Self::cube_to_error_response(cube_err);
}
}
_ => {}
}
protocol::ErrorResponse::error(
protocol::ErrorCode::InternalError,
format!("Post-processing Error: {}", e),
)
}
}

impl From<CubeError> for ConnectionError {
fn from(e: CubeError) -> Self {
ConnectionError::Cube(e, None)
}
}

impl From<CompilationError> for ConnectionError {
fn from(e: CompilationError) -> Self {
ConnectionError::CompilationError(e, None)
}
}

impl From<ProtocolError> for ConnectionError {
fn from(e: ProtocolError) -> Self {
ConnectionError::Protocol(e, None)
}
}

impl From<tokio::task::JoinError> for ConnectionError {
fn from(e: tokio::task::JoinError) -> Self {
ConnectionError::Cube(e.into(), None)
}
}

impl From<DataFusionError> for ConnectionError {
fn from(e: DataFusionError) -> Self {
ConnectionError::DataFusion(e, None)
}
}

impl From<ArrowError> for ConnectionError {
fn from(e: ArrowError) -> Self {
ConnectionError::Arrow(e, None)
}
}

/// Auto converting for all kind of io:Error to ConnectionError, sugar
impl From<std::io::Error> for ConnectionError {
fn from(e: std::io::Error) -> Self {
ConnectionError::Protocol(e.into(), None)
}
}

/// Auto converting for all kind of io:Error to ConnectionError, sugar
impl From<ErrorResponse> for ConnectionError {
fn from(e: ErrorResponse) -> Self {
ConnectionError::Protocol(e.into(), None)
}
}
4 changes: 2 additions & 2 deletions rust/cubesql/cubesql/src/sql/postgres/extended.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use pg_srv::{protocol, BindValue, PgTypeId, ProtocolError};
use sqlparser::ast;
use std::{fmt, pin::Pin, sync::Arc};

use crate::sql::shim::{ConnectionError, QueryPlanExt};
use super::{shim::QueryPlanExt, ConnectionError};
use datafusion::{
arrow::array::Array, dataframe::DataFrame as DFDataFrame,
physical_plan::SendableRecordBatchStream,
Expand Down Expand Up @@ -599,7 +599,7 @@ mod tests {
};
use pg_srv::protocol::Format;

use crate::sql::{extended::PortalFrom, shim::ConnectionError};
use crate::sql::{error::ConnectionError, extended::PortalFrom};
use datafusion::{
arrow::{
array::{ArrayRef, StringArray},
Expand Down
3 changes: 3 additions & 0 deletions rust/cubesql/cubesql/src/sql/postgres/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
pub(crate) mod ast_helpers;
pub(crate) mod error;
pub(crate) mod extended;
pub mod pg_auth_service;
pub(crate) mod pg_type;
pub(crate) mod service;
pub(crate) mod shim;
pub(crate) mod writer;

pub use error::ConnectionError;
pub use pg_type::*;
pub use service::*;
Loading
Loading