Skip to content

Commit 1044326

Browse files
authored
Merge pull request #258 from cipherstash/error-handling-tests
Refactor error handling and expand tests
2 parents ed66b08 + a9d2319 commit 1044326

File tree

4 files changed

+93
-37
lines changed

4 files changed

+93
-37
lines changed

packages/cipherstash-proxy-integration/src/select/unmappable.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ mod tests {
1010
/// Test ensures that unmappable SQL statements return an error
1111
///
1212
#[tokio::test]
13-
async fn unmappable_error() {
13+
async fn unmappable_table_not_found() {
1414
let client = connect_with_tls(PROXY).await;
1515

1616
let sql = "SELECT blah FROM vtha";
@@ -21,4 +21,43 @@ mod tests {
2121
"Expected unmappble SQL statement to return an error",
2222
);
2323
}
24+
25+
#[tokio::test]
26+
async fn unmappable_column_not_found() {
27+
let client = connect_with_tls(PROXY).await;
28+
29+
let sql = "SELECT blah FROM encrypted";
30+
let result = client.query(sql, &[]).await;
31+
32+
assert!(
33+
result.is_err(),
34+
"Expected unmappble SQL statement to return an error",
35+
);
36+
}
37+
38+
#[tokio::test]
39+
async fn unmappable_native_cannot_be_unified_with_encrypted() {
40+
let client = connect_with_tls(PROXY).await;
41+
42+
let sql = "SELECT * FROM encrypted WHERE plaintext = encrypted_text";
43+
let result = client.query(sql, &[]).await;
44+
45+
assert!(
46+
result.is_err(),
47+
"Expected unmappble SQL statement to return an error",
48+
);
49+
}
50+
51+
#[tokio::test]
52+
async fn unmappable_syntax_error() {
53+
let client = connect_with_tls(PROXY).await;
54+
55+
let sql = "SELECT *, FROM encrypted";
56+
let result = client.query(sql, &[]).await;
57+
58+
assert!(
59+
result.is_err(),
60+
"Expected unmappble SQL statement to return an error",
61+
);
62+
}
2463
}

packages/cipherstash-proxy/src/postgresql/frontend.rs

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,12 @@ where
116116
// No mapping needed, don't change the bytes
117117
Ok(None) => (),
118118
Err(err) => {
119-
bytes = self.to_database_exception(err)?;
119+
warn!(
120+
client_id = self.context.client_id,
121+
msg = "Query Error",
122+
error = ?err.to_string(),
123+
);
124+
self.error_handler(err).await?;
120125
}
121126
}
122127
}
@@ -131,39 +136,14 @@ where
131136
Ok(Some(mapped)) => bytes = mapped,
132137
// No mapping needed, don't change the bytes
133138
Ok(None) => (),
134-
Err(err) => match err {
135-
Error::Mapping(MappingError::InvalidSqlStatement(_)) => {
136-
warn!(target: PROTOCOL,
137-
client_id = self.context.client_id,
138-
msg = "MappingError::SqlParse",
139-
error = ?err,
140-
);
141-
142-
let error_response =
143-
ErrorResponse::invalid_sql_statement(err.to_string());
144-
145-
self.send_error_response(error_response)?;
146-
}
147-
Error::Encrypt(EncryptError::UnknownColumn {
148-
ref table,
149-
ref column,
150-
}) => {
151-
warn!(target: PROTOCOL,
152-
client_id = self.context.client_id,
153-
msg = "EncryptError::UnknownColumn",
154-
);
155-
let error_response =
156-
ErrorResponse::unknown_column(err.to_string(), table, column);
157-
self.send_error_response(error_response)?;
158-
}
159-
_ => {
160-
warn!(target: PROTOCOL,
139+
Err(err) => {
140+
warn!(
161141
client_id = self.context.client_id,
162-
msg = "build_frontend_exception",
163-
);
164-
bytes = self.to_database_exception(err)?;
165-
}
166-
},
142+
msg = "Parse Error",
143+
error = ?err.to_string(),
144+
);
145+
self.error_handler(err).await?;
146+
}
167147
}
168148
}
169149
Code::Bind => {
@@ -216,6 +196,19 @@ where
216196
Ok(())
217197
}
218198

199+
pub async fn error_handler(&mut self, err: Error) -> Result<(), Error> {
200+
let error_response = match err {
201+
Error::Mapping(err) => ErrorResponse::invalid_sql_statement(err.to_string()),
202+
Error::Encrypt(EncryptError::UnknownColumn {
203+
ref table,
204+
ref column,
205+
}) => ErrorResponse::unknown_column(err.to_string(), table, column),
206+
_ => ErrorResponse::system_error(err.to_string()),
207+
};
208+
self.send_error_response(error_response)?;
209+
Ok(())
210+
}
211+
219212
pub async fn write_to_server(&mut self, bytes: BytesMut) -> Result<(), Error> {
220213
let sent: u64 = bytes.len() as u64;
221214
counter!(SERVER_BYTES_SENT_TOTAL).increment(sent);

packages/cipherstash-proxy/src/postgresql/handler.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ pub async fn handler(
138138
let message = ProtocolError::ClientAuthenticationFailed.to_string();
139139
error!(msg = message);
140140

141-
let message = ErrorResponse::invalid_password(&message);
141+
let message = ErrorResponse::invalid_password(message);
142142
let bytes = BytesMut::try_from(message)?;
143143
client_stream.write_all(&bytes).await?;
144144
}

packages/cipherstash-proxy/src/postgresql/messages/error_response.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pub const CODE_RAISE_EXCEPTION: &str = "P0001";
1717
pub const CODE_SYNTAX_ERROR: &str = "42601";
1818
pub const CODE_INVALID_TEXT_REPRESENTATION: &str = "22P02";
1919
pub const CODE_IDLE_SESSION_TIMEOUT: &str = "57P05";
20+
pub const CODE_SYSTEM_ERROR: &str = "58000";
2021

2122
///
2223
/// ErrorResponse (B)
@@ -82,7 +83,7 @@ impl ErrorResponse {
8283
}
8384
}
8485

85-
pub fn invalid_password(message: &str) -> Self {
86+
pub fn invalid_password(message: String) -> Self {
8687
Self {
8788
fields: vec![
8889
Field {
@@ -99,7 +100,7 @@ impl ErrorResponse {
99100
},
100101
Field {
101102
code: ErrorResponseCode::Message,
102-
value: message.to_string(),
103+
value: message,
103104
},
104105
],
105106
}
@@ -238,6 +239,29 @@ impl ErrorResponse {
238239
}
239240
}
240241

242+
pub fn system_error(message: String) -> Self {
243+
Self {
244+
fields: vec![
245+
Field {
246+
code: ErrorResponseCode::Severity,
247+
value: "FATAL".to_string(),
248+
},
249+
Field {
250+
code: ErrorResponseCode::SeverityLegacy,
251+
value: "FATAL".to_string(),
252+
},
253+
Field {
254+
code: ErrorResponseCode::Code,
255+
value: CODE_SYSTEM_ERROR.to_string(),
256+
},
257+
Field {
258+
code: ErrorResponseCode::Message,
259+
value: message,
260+
},
261+
],
262+
}
263+
}
264+
241265
pub fn tls_required() -> Self {
242266
Self {
243267
fields: vec![

0 commit comments

Comments
 (0)