Skip to content

parse literals more like Postgres #1807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 26, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions sql/parser/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,22 @@ var (
)

func encodeSQLString(buf []byte, in []byte) []byte {
buf = append(buf, '\'')
for _, ch := range in {
if encodedChar := encodeMap[ch]; encodedChar == dontEscape {
buf = append(buf, ch)
} else {
buf = append(buf, '\\')
buf = append(buf, encodedChar)
// See http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html
start := 0
for i, ch := range in {
if encodedChar := encodeMap[ch]; encodedChar != dontEscape {
if start == 0 {
buf = append(buf, 'e', '\'') // begin e'xxx' string
}
buf = append(buf, in[start:i]...)
buf = append(buf, '\\', encodedChar)
start = i + 1
}
}
if start == 0 {
buf = append(buf, '\'') // begin 'xxx' string if nothing was escaped
}
buf = append(buf, in[start:]...)
buf = append(buf, '\'')
return buf
}
Expand All @@ -59,18 +66,18 @@ func encodeSQLIdent(buf *bytes.Buffer, s string) {
return
}

// The only characters we need to escape are '"' and '\\'.
// The only character that requires escaping is a double quote.
_ = buf.WriteByte('"')
start := 0
for i, n := 0, len(s); i < n; i++ {
ch := s[i]
if ch == '"' || ch == '\\' {
if ch == '"' {
if start != i {
_, _ = buf.WriteString(s[start:i])
}
start = i + 1
_ = buf.WriteByte('\\')
_ = buf.WriteByte(ch)
_ = buf.WriteByte(ch) // add extra copy of ch
}
}
if start < len(s) {
Expand Down Expand Up @@ -99,14 +106,13 @@ func encodeSQLBytes(buf []byte, v []byte) []byte {
func init() {
encodeRef := map[byte]byte{
'\x00': '0',
'\'': '\'',
'"': '"',
'\b': 'b',
'\f': 'f',
'\n': 'n',
'\r': 'r',
'\t': 't',
'\\': '\\',
'\'': '\'',
}

for i := range encodeMap {
Expand Down
6 changes: 3 additions & 3 deletions sql/parser/expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func TestQualifiedNameString(t *testing.T) {
// and is then followed by [a-zA-Z0-9$_] or extended ascii.
{"foo$09", "foo$09"},
{"_Ab10", "_Ab10"},
// Everything else quotes the string and escapes '"' and '\\'.
// Everything else quotes the string and escapes double quotes.
{".foobar", `".foobar"`},
{`".foobar"`, `"\".foobar\""`},
{`\".foobar\"`, `"\\\".foobar\\\""`},
{`".foobar"`, `""".foobar"""`},
{`\".foobar\"`, `"\"".foobar\"""`},
}

for _, tc := range testCases {
Expand Down
34 changes: 25 additions & 9 deletions sql/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,21 @@ func TestParse(t *testing.T) {
{`SELECT a FROM t`},
{`SELECT a.b FROM t`},
{`SELECT 'a' FROM t`},
{`SELECT 'a\'a' FROM t`},

{`SELECT 'a' AS "12345"`},
{`SELECT 'a' AS clnm`},

{`SELECT 'a\\na' FROM t`},
{`SELECT '\\n' FROM t`},
// Escaping may change since the scanning process loses information
// (you can write e'\'' or ''''), but these are the idempotent cases.
// Generally, anything that needs to escape plus \ and ' leads to an
// escaped string.
{`SELECT e'a\'a' FROM t`},
{`SELECT e'a\\\\na' FROM t`},
{`SELECT e'\\\\n' FROM t`},
{`SELECT "a""a" FROM t`},
{`SELECT a FROM "t\n"`}, // no escaping in sql identifiers
{`SELECT a FROM "t"""`}, // no escaping in sql identifiers

{`SELECT "FROM" FROM t`},
{`SELECT CAST(1 AS TEXT)`},
{`SELECT FROM t AS bar`},
Expand Down Expand Up @@ -246,11 +254,19 @@ func TestParse2(t *testing.T) {
// {`SELECT 010 FROM t`, ``},
// {`SELECT 0xf0 FROM t`, ``},
// {`SELECT 0xF0 FROM t`, ``},
// Escaped string literals are not always escaped the same.
{`SELECT 'a''a' FROM t`,
`SELECT 'a\'a' FROM t`},
{`SELECT "a""a" FROM t`,
`SELECT "a\"a" FROM t`},
// Escaped string literals are not always escaped the same because
// '''' and e'\'' scan to the same token. It's more convenient to
// prefer escaping ' and \, so we do that.
{`SELECT 'a''a'`,
`SELECT e'a\'a'`},
{`SELECT 'a\a'`,
`SELECT e'a\\a'`},
{`SELECT 'a\n'`,
`SELECT e'a\\n'`},
{"SELECT '\n'",
`SELECT e'\n'`},
{"SELECT '\n\\'",
`SELECT e'\n\\'`},
{`SELECT "a'a" FROM t`,
`SELECT "a'a" FROM t`},
// Comments are stripped.
Expand Down Expand Up @@ -319,7 +335,7 @@ func TestParseSyntax(t *testing.T) {
{`SELECT 1 FROM t FOR SHARE`},
{`SELECT 1 FROM t FOR KEY SHARE`},
{`SELECT ((1)) FROM t WHERE ((a)) IN (((1))) AND ((a, b)) IN ((((1, 1))), ((2, 2)))`},
{`SELECT '\'\"\b\n\r\t\\' FROM t`},
{`SELECT e'\'\"\b\n\r\t\\' FROM t`},
{`SELECT '\x' FROM t`},
{`SELECT 1 FROM t GROUP BY a`},
{`SELECT 1 FROM t ORDER BY a`},
Expand Down
37 changes: 25 additions & 12 deletions sql/parser/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
)

const eof = -1
const errUnterminated = "unterminated string"
const errUnsupportedEscape = "octal, hex and unicode escape not supported"

type scanner struct {
in string
Expand Down Expand Up @@ -498,29 +500,40 @@ func (s *scanner) scanString(lval *sqlSymType, ch int, allowEscapes bool) bool {

case '\\':
t := s.peek()
// We always allow the quote character and "\" to be escaped.
if t == ch || t == '\\' {
lval.str += s.in[start : s.pos-1]
start = s.pos
s.pos++
continue
}
if allowEscapes {
lval.str += s.in[start : s.pos-1]
if t == ch || t == '\\' {
start = s.pos
s.pos++
continue
}

switch t {
case 'b', 'f', 'n', 'r', 't', '\'', '"':
lval.str += s.in[start : s.pos-1]
// TODO(pmattis): Handle other back-slash escapes? Octal? Hexadecimal?
// Unicode?
case 'b', 'f', 'n', 'r', 't', '\'':
lval.str += string(decodeMap[byte(t)])
s.pos++
start = s.pos
continue
case 'x', 'u', 'U':
fallthrough
case '0', '1', '2', '3', '4', '5', '6', '7':
lval.id = ERROR
lval.str = errUnsupportedEscape
return false
}
// TODO(pmattis): Handle other back-slash escapes? Octal? Hexadecimal?
// Unicode?

// If we end up here, it's a redundant escape - simply drop the
// backslash. For example, e'\"' is equivalent to e'"', and
// e'\a\b' to e'a\b'. This is what Postgres does:
// http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS-ESCAPE
start = s.pos
}

case eof:
lval.id = ERROR
lval.str = "unterminated string"
lval.str = errUnterminated
return false
}
}
Expand Down
24 changes: 19 additions & 5 deletions sql/parser/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package parser

import (
"reflect"
"strings"
"testing"
)

Expand Down Expand Up @@ -216,10 +217,11 @@ func TestScanString(t *testing.T) {
{`'a''b'`, `a'b`},
{`"a" "b"`, `a`},
{`'a' 'b'`, `a`},
{`'\n'`, "\\n"},
{`"\""`, `"`},
{`'\''`, `'`},
{`'\0\'\"\b\f\n\r\t\\'`, `\0'\"\b\f\n\r\t\`},
{`'\n'`, `\n`},
{`e'\n'`, "\n"},
{`'\\n'`, `\\n`},
{`'\'''`, `\'`},
{`'\0\'`, `\0\`},
{`"a"
"b"`, `ab`},
{`"a"
Expand All @@ -228,7 +230,19 @@ func TestScanString(t *testing.T) {
'b'`, `ab`},
{`'a'
"b"`, `a`},
{`e'foo\"\'\\\b\f\n\r\tbar'`, "foo\"'\\\b\f\n\r\tbar"},
{`e'\"'`, `"`}, // redundant escape
{`e'\a'`, `a`}, // redundant escape
{"'\n\\'", "\n\\"},
{`e'foo\"\'\\\b\f\n\r\tbar'`,
strings.Join([]string{`foo"'\`, "\b\f\n\r\t", `bar`}, "")},
{`e'\\0'`, `\0`},
{`'\0'`, `\0`},
{`e'\0'`, errUnsupportedEscape},
{`"''"`, `''`},
{`'""'''`, `""'`},
{`""""`, `"`},
{`''''`, `'`},
{`''''''`, `''`},
}
for _, d := range testData {
s := newScanner(d.sql)
Expand Down