Skip to content

Commit 4ff620d

Browse files
authored
refactor: catalog names, error handling and set statements (#80)
* refactor: do not cache catalog name * fix: allow ; in show statements * chore: print error for accept
1 parent b1b5f62 commit 4ff620d

File tree

2 files changed

+53
-113
lines changed

2 files changed

+53
-113
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 39 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,18 @@ pub struct DfSessionService {
5858
session_context: Arc<SessionContext>,
5959
parser: Arc<Parser>,
6060
timezone: Arc<Mutex<String>>,
61-
catalog_name: String,
6261
}
6362

6463
impl DfSessionService {
65-
pub fn new(session_context: SessionContext, catalog_name: Option<String>) -> DfSessionService {
64+
pub fn new(session_context: SessionContext) -> DfSessionService {
6665
let session_context = Arc::new(session_context);
6766
let parser = Arc::new(Parser {
6867
session_context: session_context.clone(),
6968
});
70-
let catalog_name = catalog_name.unwrap_or_else(|| {
71-
session_context
72-
.catalog_names()
73-
.first()
74-
.cloned()
75-
.unwrap_or_else(|| "datafusion".to_string())
76-
});
7769
DfSessionService {
7870
session_context,
7971
parser,
8072
timezone: Arc::new(Mutex::new("UTC".to_string())),
81-
catalog_name,
8273
}
8374
}
8475

@@ -103,35 +94,40 @@ impl DfSessionService {
10394

10495
// Mock pg_namespace response
10596
async fn mock_pg_namespace<'a>(&self) -> PgWireResult<QueryResponse<'a>> {
106-
let fields = vec![FieldInfo::new(
97+
let fields = Arc::new(vec![FieldInfo::new(
10798
"nspname".to_string(),
10899
None,
109100
None,
110101
Type::VARCHAR,
111102
FieldFormat::Text,
112-
)];
103+
)]);
113104

114-
let row = {
115-
let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
116-
encoder.encode_field(&Some(&self.catalog_name))?; // Return catalog_name as a schema
117-
encoder.finish()
118-
};
119-
120-
let row_stream = futures::stream::once(async move { row });
121-
Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
105+
let fields_ref = fields.clone();
106+
let rows = self
107+
.session_context
108+
.catalog_names()
109+
.into_iter()
110+
.map(move |name| {
111+
let mut encoder = pgwire::api::results::DataRowEncoder::new(fields_ref.clone());
112+
encoder.encode_field(&Some(&name))?; // Return catalog_name as a schema
113+
encoder.finish()
114+
});
115+
116+
let row_stream = futures::stream::iter(rows);
117+
Ok(QueryResponse::new(fields.clone(), Box::pin(row_stream)))
122118
}
123119

124120
async fn try_respond_set_time_zone<'a>(
125121
&self,
126122
query_lower: &str,
127-
) -> PgWireResult<Option<Vec<Response<'a>>>> {
123+
) -> PgWireResult<Option<Response<'a>>> {
128124
if query_lower.starts_with("set time zone") {
129125
let parts: Vec<&str> = query_lower.split_whitespace().collect();
130126
if parts.len() >= 4 {
131127
let tz = parts[3].trim_matches('"');
132128
let mut timezone = self.timezone.lock().await;
133129
*timezone = tz.to_string();
134-
Ok(Some(vec![Response::Execution(Tag::new("SET"))]))
130+
Ok(Some(Response::Execution(Tag::new("SET"))))
135131
} else {
136132
Err(PgWireError::UserError(Box::new(
137133
pgwire::error::ErrorInfo::new(
@@ -149,32 +145,33 @@ impl DfSessionService {
149145
async fn try_respond_show_statements<'a>(
150146
&self,
151147
query_lower: &str,
152-
) -> PgWireResult<Option<Vec<Response<'a>>>> {
148+
) -> PgWireResult<Option<Response<'a>>> {
153149
if query_lower.starts_with("show ") {
154-
match query_lower {
150+
match query_lower.strip_suffix(";").unwrap_or(query_lower) {
155151
"show time zone" => {
156152
let timezone = self.timezone.lock().await.clone();
157153
let resp = Self::mock_show_response("TimeZone", &timezone)?;
158-
Ok(Some(vec![Response::Query(resp)]))
154+
Ok(Some(Response::Query(resp)))
159155
}
160156
"show server_version" => {
161157
let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
162-
Ok(Some(vec![Response::Query(resp)]))
158+
Ok(Some(Response::Query(resp)))
163159
}
164160
"show transaction_isolation" => {
165161
let resp =
166162
Self::mock_show_response("transaction_isolation", "read uncommitted")?;
167-
Ok(Some(vec![Response::Query(resp)]))
163+
Ok(Some(Response::Query(resp)))
168164
}
169165
"show catalogs" => {
170166
let catalogs = self.session_context.catalog_names();
171167
let value = catalogs.join(", ");
172168
let resp = Self::mock_show_response("Catalogs", &value)?;
173-
Ok(Some(vec![Response::Query(resp)]))
169+
Ok(Some(Response::Query(resp)))
174170
}
175171
"show search_path" => {
176-
let resp = Self::mock_show_response("search_path", &self.catalog_name)?;
177-
Ok(Some(vec![Response::Query(resp)]))
172+
let default_catalog = "datafusion";
173+
let resp = Self::mock_show_response("search_path", default_catalog)?;
174+
Ok(Some(Response::Query(resp)))
178175
}
179176
_ => Err(PgWireError::UserError(Box::new(
180177
pgwire::error::ErrorInfo::new(
@@ -192,31 +189,31 @@ impl DfSessionService {
192189
async fn try_respond_information_schema<'a>(
193190
&self,
194191
query_lower: &str,
195-
) -> PgWireResult<Option<Vec<Response<'a>>>> {
192+
) -> PgWireResult<Option<Response<'a>>> {
196193
if query_lower.contains("information_schema.schemata") {
197194
let df = schemata_df(&self.session_context)
198195
.await
199196
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
200197
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
201-
return Ok(Some(vec![Response::Query(resp)]));
198+
return Ok(Some(Response::Query(resp)));
202199
} else if query_lower.contains("information_schema.tables") {
203200
let df = tables_df(&self.session_context)
204201
.await
205202
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
206203
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
207-
return Ok(Some(vec![Response::Query(resp)]));
204+
return Ok(Some(Response::Query(resp)));
208205
} else if query_lower.contains("information_schema.columns") {
209206
let df = columns_df(&self.session_context)
210207
.await
211208
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
212209
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
213-
return Ok(Some(vec![Response::Query(resp)]));
210+
return Ok(Some(Response::Query(resp)));
214211
}
215212

216213
// Handle pg_catalog.pg_namespace for pgcli compatibility
217214
if query_lower.contains("pg_catalog.pg_namespace") {
218215
let resp = self.mock_pg_namespace().await?;
219-
return Ok(Some(vec![Response::Query(resp)]));
216+
return Ok(Some(Response::Query(resp)));
220217
}
221218

222219
Ok(None)
@@ -233,15 +230,15 @@ impl SimpleQueryHandler for DfSessionService {
233230
log::debug!("Received query: {}", query); // Log the query for debugging
234231

235232
if let Some(resp) = self.try_respond_set_time_zone(&query_lower).await? {
236-
return Ok(resp);
233+
return Ok(vec![resp]);
237234
}
238235

239236
if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
240-
return Ok(resp);
237+
return Ok(vec![resp]);
241238
}
242239

243240
if let Some(resp) = self.try_respond_information_schema(&query_lower).await? {
244-
return Ok(resp);
241+
return Ok(vec![resp]);
245242
}
246243

247244
let df = self
@@ -352,67 +349,12 @@ impl ExtendedQueryHandler for DfSessionService {
352349
.to_string();
353350
log::debug!("Received extended query: {}", query); // Log for debugging
354351

355-
if query.starts_with("show ") {
356-
match query.as_str() {
357-
"show time zone" => {
358-
let timezone = self.timezone.lock().await.clone();
359-
let resp = Self::mock_show_response("TimeZone", &timezone)?;
360-
return Ok(Response::Query(resp));
361-
}
362-
"show server_version" => {
363-
let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
364-
return Ok(Response::Query(resp));
365-
}
366-
"show transaction_isolation" => {
367-
let resp =
368-
Self::mock_show_response("transaction_isolation", "read uncommitted")?;
369-
return Ok(Response::Query(resp));
370-
}
371-
"show catalogs" => {
372-
let catalogs = self.session_context.catalog_names();
373-
let value = catalogs.join(", ");
374-
let resp = Self::mock_show_response("Catalogs", &value)?;
375-
return Ok(Response::Query(resp));
376-
}
377-
"show search_path" => {
378-
let resp = Self::mock_show_response("search_path", &self.catalog_name)?;
379-
return Ok(Response::Query(resp));
380-
}
381-
_ => {
382-
return Err(PgWireError::UserError(Box::new(
383-
pgwire::error::ErrorInfo::new(
384-
"ERROR".to_string(),
385-
"42704".to_string(),
386-
format!("Unrecognized SHOW command: {}", query),
387-
),
388-
)));
389-
}
390-
}
391-
}
392-
393-
if query.contains("information_schema.schemata") {
394-
let df = schemata_df(&self.session_context)
395-
.await
396-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
397-
let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?;
398-
return Ok(Response::Query(resp));
399-
} else if query.contains("information_schema.tables") {
400-
let df = tables_df(&self.session_context)
401-
.await
402-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
403-
let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?;
404-
return Ok(Response::Query(resp));
405-
} else if query.contains("information_schema.columns") {
406-
let df = columns_df(&self.session_context)
407-
.await
408-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
409-
let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?;
410-
return Ok(Response::Query(resp));
352+
if let Some(resp) = self.try_respond_show_statements(&query).await? {
353+
return Ok(resp);
411354
}
412355

413-
if query.contains("pg_catalog.pg_namespace") {
414-
let resp = self.mock_pg_namespace().await?;
415-
return Ok(Response::Query(resp));
356+
if let Some(resp) = self.try_respond_information_schema(&query).await? {
357+
return Ok(resp);
416358
}
417359

418360
let plan = &portal.statement.statement;

datafusion-postgres/src/lib.rs

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,9 @@ pub async fn serve(
3939
session_context: SessionContext,
4040
opts: &ServerOptions,
4141
) -> Result<(), std::io::Error> {
42-
// Get the first catalog name from the session context
43-
let catalog_name = session_context
44-
.catalog_names() // Fixed: Removed .catalog_list()
45-
.first()
46-
.cloned();
47-
4842
// Create the handler factory with the session context and catalog name
4943
let factory = Arc::new(HandlerFactory(Arc::new(DfSessionService::new(
5044
session_context,
51-
catalog_name,
5245
))));
5346

5447
// Bind to the specified host and port
@@ -58,15 +51,20 @@ pub async fn serve(
5851

5952
// Accept incoming connections
6053
loop {
61-
if let Ok((socket, addr)) = listener.accept().await {
62-
let factory_ref = factory.clone();
63-
println!("Accepted connection from {}", addr);
54+
match listener.accept().await {
55+
Ok((socket, addr)) => {
56+
let factory_ref = factory.clone();
57+
println!("Accepted connection from {}", addr);
6458

65-
tokio::spawn(async move {
66-
if let Err(e) = process_socket(socket, None, factory_ref).await {
67-
eprintln!("Error processing socket: {}", e);
68-
}
69-
});
70-
};
59+
tokio::spawn(async move {
60+
if let Err(e) = process_socket(socket, None, factory_ref).await {
61+
eprintln!("Error processing socket: {}", e);
62+
}
63+
});
64+
}
65+
Err(e) => {
66+
eprintln!("Error accept socket: {}", e);
67+
}
68+
}
7169
}
7270
}

0 commit comments

Comments
 (0)