Skip to content

Commit

Permalink
pgwire: update all callsites to use new send implementation
Browse files Browse the repository at this point in the history
In some cases I had to explicitly clone some objects to please the
borrow checker, as the new `send` has to get a mutable ref to
coord_client/session.
  • Loading branch information
andrioni committed Feb 16, 2022
1 parent 2eae0b8 commit 57c64d9
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 78 deletions.
139 changes: 61 additions & 78 deletions src/pgwire/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,10 @@ where
if let Some(relation_desc) = &stmt_desc.relation_desc {
if !stmt_desc.is_copy {
let formats = vec![mz_pgrepr::Format::Text; stmt_desc.arity()];
self.conn
.send(BackendMessage::RowDescription(
message::encode_row_description(relation_desc, &formats),
))
.await?;
self.send(BackendMessage::RowDescription(
message::encode_row_description(relation_desc, &formats),
))
.await?;
}
}

Expand Down Expand Up @@ -485,7 +484,7 @@ where
}

if num_stmts == 0 {
self.conn.send(BackendMessage::EmptyQueryResponse).await?;
self.send(BackendMessage::EmptyQueryResponse).await?;
}

self.ready().await
Expand Down Expand Up @@ -550,7 +549,7 @@ where
.await
{
Ok(()) => {
self.conn.send(BackendMessage::ParseComplete).await?;
self.send(BackendMessage::ParseComplete).await?;
Ok(State::Ready)
}
Err(e) => {
Expand All @@ -574,12 +573,11 @@ where
async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
let resp = self.coord_client.end_transaction(action).await;
if let Err(err) = resp {
self.conn
.send(BackendMessage::ErrorResponse(ErrorResponse::from_coord(
Severity::Error,
err,
)))
.await?;
self.send(BackendMessage::ErrorResponse(ErrorResponse::from_coord(
Severity::Error,
err,
)))
.await?;
}
Ok(())
}
Expand Down Expand Up @@ -704,7 +702,7 @@ where
.await;
}

self.conn.send(BackendMessage::BindComplete).await?;
self.send(BackendMessage::BindComplete).await?;
Ok(State::Ready)
}

Expand Down Expand Up @@ -788,11 +786,8 @@ where
// we must remember the number of rows that were returned. Use this tag to
// remember that information and return it.
PortalState::Completed(Some(tag)) => {
self.conn
.send(BackendMessage::CommandComplete {
tag: tag.to_string(),
})
.await?;
let tag = tag.to_string();
self.send(BackendMessage::CommandComplete { tag }).await?;
Ok(State::Ready)
}
PortalState::Completed(None) => {
Expand Down Expand Up @@ -822,20 +817,20 @@ where
.await
}
};
self.conn
.send(BackendMessage::ParameterDescription(
stmt.desc()
.param_types
.iter()
.map(mz_pgrepr::Type::from)
.collect(),
))
.await?;
// Cloning to avoid a mutable borrow issue because `send` also uses `coord_client`
let parameter_desc = BackendMessage::ParameterDescription(
stmt.desc()
.param_types
.iter()
.map(mz_pgrepr::Type::from)
.collect(),
);
// Claim that all results will be output in text format, even
// though the true result formats are not yet known. A bit
// weird, but this is the behavior that PostgreSQL specifies.
let formats = vec![mz_pgrepr::Format::Text; stmt.desc().arity()];
self.conn.send(describe_rows(stmt.desc(), &formats)).await?;
let row_desc = describe_rows(&stmt.desc(), &formats);
self.send_all([parameter_desc, row_desc]).await?;
Ok(State::Ready)
}

Expand All @@ -849,7 +844,7 @@ where
.map(|portal| describe_rows(&portal.desc, &portal.result_formats));
match row_desc {
Some(row_desc) => {
self.conn.send(row_desc).await?;
self.send(row_desc).await?;
Ok(State::Ready)
}
None => {
Expand All @@ -864,13 +859,13 @@ where

async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
self.coord_client.session().remove_prepared_statement(&name);
self.conn.send(BackendMessage::CloseComplete).await?;
self.send(BackendMessage::CloseComplete).await?;
Ok(State::Ready)
}

async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
self.coord_client.session().remove_portal(&name);
self.conn.send(BackendMessage::CloseComplete).await?;
self.send(BackendMessage::CloseComplete).await?;
Ok(State::Ready)
}

Expand Down Expand Up @@ -995,9 +990,7 @@ where

async fn ready(&mut self) -> Result<State, io::Error> {
let txn_state = self.coord_client.session().transaction().into();
self.conn
.send(BackendMessage::ReadyForQuery(txn_state))
.await?;
self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
self.flush().await
}

Expand All @@ -1018,7 +1011,7 @@ where
// variable, or rustc barfs out a completely inscrutable
// error: https://github.com/rust-lang/rust/issues/64960.
let tag = format!($($arg)*);
self.conn.send(BackendMessage::CommandComplete { tag }).await?;
self.send(BackendMessage::CommandComplete { tag }).await?;
Ok(State::Ready)
}};
}
Expand All @@ -1028,7 +1021,7 @@ where
if $existed {
let msg =
ErrorResponse::notice($code, concat!($type, " already exists, skipping"));
self.conn.send(msg).await?;
self.send(msg).await?;
}
command_complete!("CREATE {}", $type.to_uppercase())
}};
Expand Down Expand Up @@ -1091,7 +1084,7 @@ where
ExecuteResponse::DroppedView => command_complete!("DROP VIEW"),
ExecuteResponse::DroppedType => command_complete!("DROP TYPE"),
ExecuteResponse::EmptyQuery => {
self.conn.send(BackendMessage::EmptyQueryResponse).await?;
self.send(BackendMessage::EmptyQueryResponse).await?;
Ok(State::Ready)
}
ExecuteResponse::Fetch {
Expand Down Expand Up @@ -1163,7 +1156,7 @@ where
None
};
if let Some(msg) = msg {
self.conn.send(msg).await?;
self.send(msg).await?;
}
command_complete!("SET")
}
Expand All @@ -1173,7 +1166,7 @@ where
SqlState::ACTIVE_SQL_TRANSACTION,
"there is already a transaction in progress",
);
self.conn.send(msg).await?;
self.send(msg).await?;
}
command_complete!("BEGIN")
}
Expand All @@ -1186,7 +1179,7 @@ where
SqlState::NO_ACTIVE_SQL_TRANSACTION,
"there is no transaction in progress",
);
self.conn.send(msg).await?;
self.send(msg).await?;
}
command_complete!("{}", tag)
}
Expand All @@ -1200,7 +1193,7 @@ where
msg.hint =
Some("Wrap your TAIL statement in `COPY (TAIL ...) TO STDOUT`.".into())
}
self.conn.send(msg).await?;
self.send(msg).await?;
self.conn.flush().await?;
}
let row_desc =
Expand Down Expand Up @@ -1386,11 +1379,10 @@ where
// let mut batch_rows = batch_rows;
// Drain panics if it's > len, so cap it.
let drain_rows = cmp::min(want_rows, batch_rows.len());
self.conn
.send_all(batch_rows.drain(..drain_rows).map(|row| {
BackendMessage::DataRow(mz_pgrepr::values_from_row(row, row_desc.typ()))
}))
.await?;
self.send_all(batch_rows.drain(..drain_rows).map(|row| {
BackendMessage::DataRow(mz_pgrepr::values_from_row(row, row_desc.typ()))
}))
.await?;
total_sent_rows += drain_rows;
want_rows -= drain_rows;
// If we have sent the number of requested rows, put the remainder of the batch
Expand Down Expand Up @@ -1435,7 +1427,7 @@ where
.expect("valid fetch portal")
});
let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
self.conn.send(response_message).await?;
self.send(response_message).await?;
Ok(State::Ready)
}

Expand Down Expand Up @@ -1465,12 +1457,11 @@ where
let column_formats = iter::repeat(encode_format)
.take(typ.column_types.len())
.collect();
self.conn
.send(BackendMessage::CopyOutResponse {
overall_format: encode_format,
column_formats,
})
.await?;
self.send(BackendMessage::CopyOutResponse {
overall_format: encode_format,
column_formats,
})
.await?;

// In Postgres, binary copy has a header that is followed (in the same
// CopyData) by the first row. In order to replicate their behavior, use a
Expand Down Expand Up @@ -1526,8 +1517,7 @@ where
count += rows.len();
for row in rows {
encode_fn(row, typ, &mut out)?;
self.conn
.send(BackendMessage::CopyData(mem::take(&mut out)))
self.send(BackendMessage::CopyData(mem::take(&mut out)))
.await?;
}
}
Expand All @@ -1540,16 +1530,13 @@ where
if let CopyFormat::Binary = format {
let trailer: i16 = -1;
out.extend(&trailer.to_be_bytes());
self.conn
.send(BackendMessage::CopyData(mem::take(&mut out)))
self.send(BackendMessage::CopyData(mem::take(&mut out)))
.await?;
}

let tag = format!("COPY {}", count);
self.conn.send(BackendMessage::CopyDone).await?;
self.conn
.send(BackendMessage::CommandComplete { tag })
.await?;
self.send(BackendMessage::CopyDone).await?;
self.send(BackendMessage::CommandComplete { tag }).await?;
Ok(State::Ready)
}

Expand Down Expand Up @@ -1582,12 +1569,11 @@ where

let typ = row_desc.typ();
let column_formats = vec![mz_pgrepr::Format::Text; typ.column_types.len()];
self.conn
.send(BackendMessage::CopyInResponse {
overall_format: mz_pgrepr::Format::Text,
column_formats,
})
.await?;
self.send(BackendMessage::CopyInResponse {
overall_format: mz_pgrepr::Format::Text,
column_formats,
})
.await?;
self.conn.flush().await?;

let mut data = Vec::new();
Expand Down Expand Up @@ -1650,9 +1636,7 @@ where
}

let tag = format!("COPY {}", count);
self.conn
.send(BackendMessage::CommandComplete { tag })
.await?;
self.send(BackendMessage::CommandComplete { tag }).await?;
}

Ok(next_state)
Expand All @@ -1667,7 +1651,7 @@ where
err.message
);
let is_fatal = err.severity.is_fatal();
self.conn.send(BackendMessage::ErrorResponse(err)).await?;
self.send(BackendMessage::ErrorResponse(err)).await?;
let txn = self.coord_client.session().transaction();
match txn {
// Error can be called from describe and parse and so might not be in an active
Expand All @@ -1694,12 +1678,11 @@ where
}

async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
self.conn
.send(BackendMessage::ErrorResponse(ErrorResponse::error(
SqlState::IN_FAILED_SQL_TRANSACTION,
"current transaction is aborted, commands ignored until end of transaction block",
)))
.await?;
self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
SqlState::IN_FAILED_SQL_TRANSACTION,
"current transaction is aborted, commands ignored until end of transaction block",
)))
.await?;
Ok(State::Drain)
}

Expand Down
29 changes: 29 additions & 0 deletions test/pgtest/client_min_messages.pt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Test logic related to filtering which messages are sent to clients based on severity levels

# Check default behavior
send
Query {"query": "COMMIT"}
----

until
ReadyForQuery
----
NoticeResponse {"fields":[{"typ":"S","value":"WARNING"},{"typ":"C","value":"25P01"},{"typ":"M","value":"there is no transaction in progress"}]}
CommandComplete {"tag":"COMMIT"}
ReadyForQuery {"status":"I"}


# Change client_min_messages and see that NoticeResponse is missing
send
Query {"query": "SET client_min_messages = ERROR"}
Query {"query": "COMMIT"}
----

until
ReadyForQuery
ReadyForQuery
----
CommandComplete {"tag":"SET"}
ReadyForQuery {"status":"I"}
CommandComplete {"tag":"COMMIT"}
ReadyForQuery {"status":"I"}

0 comments on commit 57c64d9

Please sign in to comment.