Skip to content

Commit e5c55b0

Browse files
authored
Query parser: encode ?? as ? (#154)
1 parent 75ce343 commit e5c55b0

File tree

2 files changed

+31
-16
lines changed

2 files changed

+31
-16
lines changed

src/query.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ impl Query {
4545
/// during query execution (`execute()`, `fetch()` etc).
4646
///
4747
/// WARNING: This means that the query must not have any extra `?`, even if
48-
/// they are in a string literal!
48+
/// they are in a string literal! Use `??` to have plain `?` in query.
4949
///
5050
/// [`Serialize`]: serde::Serialize
5151
/// [`Identifier`]: crate::sql::Identifier

src/sql/mod.rs

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub(crate) enum SqlBuilder {
2121
pub(crate) enum Part {
2222
Arg,
2323
Fields,
24+
Str(&'static str),
2425
Text(String),
2526
}
2627

@@ -45,20 +46,28 @@ impl fmt::Display for SqlBuilder {
4546

4647
impl SqlBuilder {
4748
pub(crate) fn new(template: &str) -> Self {
48-
let mut iter = template.split('?');
49-
let prefix = String::from(iter.next().unwrap());
50-
let mut parts = vec![Part::Text(prefix)];
49+
let mut parts = Vec::new();
50+
let mut rest = template;
51+
while let Some(idx) = rest.find('?') {
52+
if rest[idx + 1..].starts_with('?') {
53+
parts.push(Part::Text(rest[..idx + 1].to_string()));
54+
rest = &rest[idx + 2..];
55+
continue;
56+
} else if idx != 0 {
57+
parts.push(Part::Text(rest[..idx].to_string()));
58+
}
5159

52-
for s in iter {
53-
let text = if let Some(text) = s.strip_prefix("fields") {
60+
rest = &rest[idx + 1..];
61+
if let Some(restfields) = rest.strip_prefix("fields") {
5462
parts.push(Part::Fields);
55-
text
63+
rest = restfields;
5664
} else {
5765
parts.push(Part::Arg);
58-
s
59-
};
66+
}
67+
}
6068

61-
parts.push(Part::Text(text.into()));
69+
if !rest.is_empty() {
70+
parts.push(Part::Text(rest.to_string()));
6271
}
6372

6473
SqlBuilder::InProgress(parts)
@@ -96,16 +105,12 @@ impl SqlBuilder {
96105
}
97106
}
98107

99-
pub(crate) fn append(&mut self, suffix: &str) {
108+
pub(crate) fn append(&mut self, suffix: &'static str) {
100109
let Self::InProgress(parts) = self else {
101110
return;
102111
};
103112

104-
if let Some(Part::Text(text)) = parts.last_mut() {
105-
text.push_str(suffix);
106-
} else {
107-
// Do nothing, it will fail in `finish()`.
108-
}
113+
parts.push(Part::Str(suffix));
109114
}
110115

111116
pub(crate) fn finish(mut self) -> Result<String> {
@@ -114,6 +119,7 @@ impl SqlBuilder {
114119
if let Self::InProgress(parts) = &self {
115120
for part in parts {
116121
match part {
122+
Part::Str(text) => sql.push_str(text),
117123
Part::Text(text) => sql.push_str(text),
118124
Part::Arg => {
119125
self.error("unbound query argument");
@@ -223,6 +229,15 @@ mod tests {
223229
);
224230
}
225231

232+
#[test]
233+
fn question_escape() {
234+
let sql = SqlBuilder::new("SELECT 1 FROM test WHERE a IN 'a??b'");
235+
assert_eq!(
236+
sql.finish().unwrap(),
237+
r"SELECT 1 FROM test WHERE a IN 'a?b'"
238+
);
239+
}
240+
226241
#[test]
227242
fn option_as_null() {
228243
let mut sql = SqlBuilder::new("SELECT 1 FROM test WHERE a = ?");

0 commit comments

Comments
 (0)