@@ -58,27 +58,18 @@ pub struct DfSessionService {
58
58
session_context : Arc < SessionContext > ,
59
59
parser : Arc < Parser > ,
60
60
timezone : Arc < Mutex < String > > ,
61
- catalog_name : String ,
62
61
}
63
62
64
63
impl DfSessionService {
65
- pub fn new ( session_context : SessionContext , catalog_name : Option < String > ) -> DfSessionService {
64
+ pub fn new ( session_context : SessionContext ) -> DfSessionService {
66
65
let session_context = Arc :: new ( session_context) ;
67
66
let parser = Arc :: new ( Parser {
68
67
session_context : session_context. clone ( ) ,
69
68
} ) ;
70
- let catalog_name = catalog_name. unwrap_or_else ( || {
71
- session_context
72
- . catalog_names ( )
73
- . first ( )
74
- . cloned ( )
75
- . unwrap_or_else ( || "datafusion" . to_string ( ) )
76
- } ) ;
77
69
DfSessionService {
78
70
session_context,
79
71
parser,
80
72
timezone : Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) ,
81
- catalog_name,
82
73
}
83
74
}
84
75
@@ -103,35 +94,40 @@ impl DfSessionService {
103
94
104
95
// Mock pg_namespace response
105
96
async fn mock_pg_namespace < ' a > ( & self ) -> PgWireResult < QueryResponse < ' a > > {
106
- let fields = vec ! [ FieldInfo :: new(
97
+ let fields = Arc :: new ( vec ! [ FieldInfo :: new(
107
98
"nspname" . to_string( ) ,
108
99
None ,
109
100
None ,
110
101
Type :: VARCHAR ,
111
102
FieldFormat :: Text ,
112
- ) ] ;
103
+ ) ] ) ;
113
104
114
- let row = {
115
- let mut encoder = pgwire:: api:: results:: DataRowEncoder :: new ( Arc :: new ( fields. clone ( ) ) ) ;
116
- encoder. encode_field ( & Some ( & self . catalog_name ) ) ?; // Return catalog_name as a schema
117
- encoder. finish ( )
118
- } ;
119
-
120
- let row_stream = futures:: stream:: once ( async move { row } ) ;
121
- Ok ( QueryResponse :: new ( Arc :: new ( fields) , Box :: pin ( row_stream) ) )
105
+ let fields_ref = fields. clone ( ) ;
106
+ let rows = self
107
+ . session_context
108
+ . catalog_names ( )
109
+ . into_iter ( )
110
+ . map ( move |name| {
111
+ let mut encoder = pgwire:: api:: results:: DataRowEncoder :: new ( fields_ref. clone ( ) ) ;
112
+ encoder. encode_field ( & Some ( & name) ) ?; // Return catalog_name as a schema
113
+ encoder. finish ( )
114
+ } ) ;
115
+
116
+ let row_stream = futures:: stream:: iter ( rows) ;
117
+ Ok ( QueryResponse :: new ( fields. clone ( ) , Box :: pin ( row_stream) ) )
122
118
}
123
119
124
120
async fn try_respond_set_time_zone < ' a > (
125
121
& self ,
126
122
query_lower : & str ,
127
- ) -> PgWireResult < Option < Vec < Response < ' a > > > > {
123
+ ) -> PgWireResult < Option < Response < ' a > > > {
128
124
if query_lower. starts_with ( "set time zone" ) {
129
125
let parts: Vec < & str > = query_lower. split_whitespace ( ) . collect ( ) ;
130
126
if parts. len ( ) >= 4 {
131
127
let tz = parts[ 3 ] . trim_matches ( '"' ) ;
132
128
let mut timezone = self . timezone . lock ( ) . await ;
133
129
* timezone = tz. to_string ( ) ;
134
- Ok ( Some ( vec ! [ Response :: Execution ( Tag :: new( "SET" ) ) ] ) )
130
+ Ok ( Some ( Response :: Execution ( Tag :: new ( "SET" ) ) ) )
135
131
} else {
136
132
Err ( PgWireError :: UserError ( Box :: new (
137
133
pgwire:: error:: ErrorInfo :: new (
@@ -149,32 +145,33 @@ impl DfSessionService {
149
145
async fn try_respond_show_statements < ' a > (
150
146
& self ,
151
147
query_lower : & str ,
152
- ) -> PgWireResult < Option < Vec < Response < ' a > > > > {
148
+ ) -> PgWireResult < Option < Response < ' a > > > {
153
149
if query_lower. starts_with ( "show " ) {
154
- match query_lower {
150
+ match query_lower. strip_suffix ( ";" ) . unwrap_or ( query_lower ) {
155
151
"show time zone" => {
156
152
let timezone = self . timezone . lock ( ) . await . clone ( ) ;
157
153
let resp = Self :: mock_show_response ( "TimeZone" , & timezone) ?;
158
- Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) )
154
+ Ok ( Some ( Response :: Query ( resp) ) )
159
155
}
160
156
"show server_version" => {
161
157
let resp = Self :: mock_show_response ( "server_version" , "15.0 (DataFusion)" ) ?;
162
- Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) )
158
+ Ok ( Some ( Response :: Query ( resp) ) )
163
159
}
164
160
"show transaction_isolation" => {
165
161
let resp =
166
162
Self :: mock_show_response ( "transaction_isolation" , "read uncommitted" ) ?;
167
- Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) )
163
+ Ok ( Some ( Response :: Query ( resp) ) )
168
164
}
169
165
"show catalogs" => {
170
166
let catalogs = self . session_context . catalog_names ( ) ;
171
167
let value = catalogs. join ( ", " ) ;
172
168
let resp = Self :: mock_show_response ( "Catalogs" , & value) ?;
173
- Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) )
169
+ Ok ( Some ( Response :: Query ( resp) ) )
174
170
}
175
171
"show search_path" => {
176
- let resp = Self :: mock_show_response ( "search_path" , & self . catalog_name ) ?;
177
- Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) )
172
+ let default_catalog = "datafusion" ;
173
+ let resp = Self :: mock_show_response ( "search_path" , default_catalog) ?;
174
+ Ok ( Some ( Response :: Query ( resp) ) )
178
175
}
179
176
_ => Err ( PgWireError :: UserError ( Box :: new (
180
177
pgwire:: error:: ErrorInfo :: new (
@@ -192,31 +189,31 @@ impl DfSessionService {
192
189
async fn try_respond_information_schema < ' a > (
193
190
& self ,
194
191
query_lower : & str ,
195
- ) -> PgWireResult < Option < Vec < Response < ' a > > > > {
192
+ ) -> PgWireResult < Option < Response < ' a > > > {
196
193
if query_lower. contains ( "information_schema.schemata" ) {
197
194
let df = schemata_df ( & self . session_context )
198
195
. await
199
196
. map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
200
197
let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
201
- return Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) ) ;
198
+ return Ok ( Some ( Response :: Query ( resp) ) ) ;
202
199
} else if query_lower. contains ( "information_schema.tables" ) {
203
200
let df = tables_df ( & self . session_context )
204
201
. await
205
202
. map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
206
203
let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
207
- return Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) ) ;
204
+ return Ok ( Some ( Response :: Query ( resp) ) ) ;
208
205
} else if query_lower. contains ( "information_schema.columns" ) {
209
206
let df = columns_df ( & self . session_context )
210
207
. await
211
208
. map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
212
209
let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
213
- return Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) ) ;
210
+ return Ok ( Some ( Response :: Query ( resp) ) ) ;
214
211
}
215
212
216
213
// Handle pg_catalog.pg_namespace for pgcli compatibility
217
214
if query_lower. contains ( "pg_catalog.pg_namespace" ) {
218
215
let resp = self . mock_pg_namespace ( ) . await ?;
219
- return Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) ) ;
216
+ return Ok ( Some ( Response :: Query ( resp) ) ) ;
220
217
}
221
218
222
219
Ok ( None )
@@ -233,15 +230,15 @@ impl SimpleQueryHandler for DfSessionService {
233
230
log:: debug!( "Received query: {}" , query) ; // Log the query for debugging
234
231
235
232
if let Some ( resp) = self . try_respond_set_time_zone ( & query_lower) . await ? {
236
- return Ok ( resp) ;
233
+ return Ok ( vec ! [ resp] ) ;
237
234
}
238
235
239
236
if let Some ( resp) = self . try_respond_show_statements ( & query_lower) . await ? {
240
- return Ok ( resp) ;
237
+ return Ok ( vec ! [ resp] ) ;
241
238
}
242
239
243
240
if let Some ( resp) = self . try_respond_information_schema ( & query_lower) . await ? {
244
- return Ok ( resp) ;
241
+ return Ok ( vec ! [ resp] ) ;
245
242
}
246
243
247
244
let df = self
@@ -352,67 +349,12 @@ impl ExtendedQueryHandler for DfSessionService {
352
349
. to_string ( ) ;
353
350
log:: debug!( "Received extended query: {}" , query) ; // Log for debugging
354
351
355
- if query. starts_with ( "show " ) {
356
- match query. as_str ( ) {
357
- "show time zone" => {
358
- let timezone = self . timezone . lock ( ) . await . clone ( ) ;
359
- let resp = Self :: mock_show_response ( "TimeZone" , & timezone) ?;
360
- return Ok ( Response :: Query ( resp) ) ;
361
- }
362
- "show server_version" => {
363
- let resp = Self :: mock_show_response ( "server_version" , "15.0 (DataFusion)" ) ?;
364
- return Ok ( Response :: Query ( resp) ) ;
365
- }
366
- "show transaction_isolation" => {
367
- let resp =
368
- Self :: mock_show_response ( "transaction_isolation" , "read uncommitted" ) ?;
369
- return Ok ( Response :: Query ( resp) ) ;
370
- }
371
- "show catalogs" => {
372
- let catalogs = self . session_context . catalog_names ( ) ;
373
- let value = catalogs. join ( ", " ) ;
374
- let resp = Self :: mock_show_response ( "Catalogs" , & value) ?;
375
- return Ok ( Response :: Query ( resp) ) ;
376
- }
377
- "show search_path" => {
378
- let resp = Self :: mock_show_response ( "search_path" , & self . catalog_name ) ?;
379
- return Ok ( Response :: Query ( resp) ) ;
380
- }
381
- _ => {
382
- return Err ( PgWireError :: UserError ( Box :: new (
383
- pgwire:: error:: ErrorInfo :: new (
384
- "ERROR" . to_string ( ) ,
385
- "42704" . to_string ( ) ,
386
- format ! ( "Unrecognized SHOW command: {}" , query) ,
387
- ) ,
388
- ) ) ) ;
389
- }
390
- }
391
- }
392
-
393
- if query. contains ( "information_schema.schemata" ) {
394
- let df = schemata_df ( & self . session_context )
395
- . await
396
- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
397
- let resp = datatypes:: encode_dataframe ( df, & portal. result_column_format ) . await ?;
398
- return Ok ( Response :: Query ( resp) ) ;
399
- } else if query. contains ( "information_schema.tables" ) {
400
- let df = tables_df ( & self . session_context )
401
- . await
402
- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
403
- let resp = datatypes:: encode_dataframe ( df, & portal. result_column_format ) . await ?;
404
- return Ok ( Response :: Query ( resp) ) ;
405
- } else if query. contains ( "information_schema.columns" ) {
406
- let df = columns_df ( & self . session_context )
407
- . await
408
- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
409
- let resp = datatypes:: encode_dataframe ( df, & portal. result_column_format ) . await ?;
410
- return Ok ( Response :: Query ( resp) ) ;
352
+ if let Some ( resp) = self . try_respond_show_statements ( & query) . await ? {
353
+ return Ok ( resp) ;
411
354
}
412
355
413
- if query. contains ( "pg_catalog.pg_namespace" ) {
414
- let resp = self . mock_pg_namespace ( ) . await ?;
415
- return Ok ( Response :: Query ( resp) ) ;
356
+ if let Some ( resp) = self . try_respond_information_schema ( & query) . await ? {
357
+ return Ok ( resp) ;
416
358
}
417
359
418
360
let plan = & portal. statement . statement ;
0 commit comments