diff --git a/session/session.go b/session/session.go index bf0ef97128bd1..0def5e85355f0 100644 --- a/session/session.go +++ b/session/session.go @@ -1215,8 +1215,10 @@ func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) } // ParseWithParams parses a query string, with arguments, to raw ast.StmtNode. +// Note that it will not do escaping if no variable arguments are passed. func (s *session) ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) { - sql, err := EscapeSQL(sql, args...) + var err error + sql, err = sqlexec.EscapeSQL(sql, args...) if err != nil { return nil, err } diff --git a/session/session_test.go b/session/session_test.go index 7d2ab385a0f4b..9a953a3438484 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -3843,25 +3843,35 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { defer func() { se.GetSessionVars().InRestrictedSQL = origin }() - _, err := exec.ParseWithParams(context.Background(), "SELECT 4") + _, err := exec.ParseWithParams(context.TODO(), "SELECT 4") c.Assert(err, IsNil) // test charset attack - stmts, err := exec.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") + stmt, err := exec.ParseWithParams(context.TODO(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") c.Assert(err, IsNil) var sb strings.Builder ctx := format.NewRestoreCtx(0, &sb) - err = stmts.Restore(ctx) + err = stmt.Restore(ctx) c.Assert(err, IsNil) // FIXME: well... so the restore function is vulnerable... c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\xbf' OR 1=1 /* LIMIT 1") // test invalid sql - _, err = exec.ParseWithParams(context.Background(), "SELECT") + _, err = exec.ParseWithParams(context.TODO(), "SELECT") c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*") // test invalid arguments to escape - _, err = exec.ParseWithParams(context.Background(), "SELECT %?") + _, err = exec.ParseWithParams(context.TODO(), "SELECT %?, %?", 3) c.Assert(err, ErrorMatches, "missing arguments.*") + + // test noescape + stmt, err = exec.ParseWithParams(context.TODO(), "SELECT 3") + c.Assert(err, IsNil) + + sb.Reset() + ctx = format.NewRestoreCtx(0, &sb) + err = stmt.Restore(ctx) + c.Assert(err, IsNil) + c.Assert(sb.String(), Equals, "SELECT 3") } diff --git a/util/sqlexec/utils.go b/util/sqlexec/utils.go new file mode 100644 index 0000000000000..1ffc29b72d8e0 --- /dev/null +++ b/util/sqlexec/utils.go @@ -0,0 +1,260 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlexec + +import ( + "encoding/json" + "io" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/util/hack" +) + +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} + +// escapeBytesBackslash will escape []byte into the buffer, with backslash. +func escapeBytesBackslash(buf []byte, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringBackslash will escape string into the buffer, with backslash. +func escapeStringBackslash(buf []byte, v string) []byte { + return escapeBytesBackslash(buf, hack.Slice(v)) +} + +// escapeSQL is the internal impl of EscapeSQL and FormatSQL. +func escapeSQL(sql string, args ...interface{}) ([]byte, error) { + buf := make([]byte, 0, len(sql)) + argPos := 0 + for i := 0; i < len(sql); i++ { + q := strings.IndexByte(sql[i:], '%') + if q == -1 { + buf = append(buf, sql[i:]...) + break + } + buf = append(buf, sql[i:i+q]...) + i += q + + ch := byte(0) + if i+1 < len(sql) { + ch = sql[i+1] // get the specifier + } + switch ch { + case 'n': + if argPos >= len(args) { + return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + v, ok := arg.(string) + if !ok { + return nil, errors.Errorf("expect a string identifier, got %v", arg) + } + buf = append(buf, '`') + buf = append(buf, strings.Replace(v, "`", "``", -1)...) + buf = append(buf, '`') + i++ // skip specifier + case '?': + if argPos >= len(args) { + return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + } else { + switch v := arg.(type) { + case int: + buf = strconv.AppendInt(buf, int64(v), 10) + case int8: + buf = strconv.AppendInt(buf, int64(v), 10) + case int16: + buf = strconv.AppendInt(buf, int64(v), 10) + case int32: + buf = strconv.AppendInt(buf, int64(v), 10) + case int64: + buf = strconv.AppendInt(buf, v, 10) + case uint: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint8: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint16: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint32: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint64: + buf = strconv.AppendUint(buf, v, 10) + case float32: + buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + buf = append(buf, '\'') + buf = v.AppendFormat(buf, "2006-01-02 15:04:05.999999") + buf = append(buf, '\'') + } + case json.RawMessage: + buf = append(buf, '\'') + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "_binary'"...) + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, v) + buf = append(buf, '\'') + case []string: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, k) + buf = append(buf, '\'') + } + case []float32: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = strconv.AppendFloat(buf, float64(k), 'g', -1, 32) + } + case []float64: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = strconv.AppendFloat(buf, k, 'g', -1, 64) + } + default: + return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg) + } + } + i++ // skip specifier + case '%': + buf = append(buf, '%') + i++ // skip specifier + default: + buf = append(buf, '%') + } + } + return buf, nil +} + +// EscapeSQL will escape input arguments into the sql string, doing necessary processing. +// It works like printf() in c, there are following format specifiers: +// 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) +// 2. %%: output % +// 3. %n: for identifiers, for example ("use %n", db) +// But it does not prevent you from doing EscapeSQL("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". +// It is still your responsibility to write safe SQL. +func EscapeSQL(sql string, args ...interface{}) (string, error) { + str, err := escapeSQL(sql, args...) + return string(str), err +} + +// MustEscapeSQL is an helper around EscapeSQL. The error returned from escapeSQL can be avoided statically if you do not pass interface{}. +func MustEscapeSQL(sql string, args ...interface{}) string { + r, err := EscapeSQL(sql, args...) + if err != nil { + panic(err) + } + return r +} + +// FormatSQL is the io.Writer version of EscapeSQL. Please refer to EscapeSQL for details. +func FormatSQL(w io.Writer, sql string, args ...interface{}) error { + buf, err := escapeSQL(sql, args...) + if err != nil { + return err + } + _, err = w.Write(buf) + return err +} + +// MustFormatSQL is an helper around FormatSQL, like MustEscapeSQL. But it asks that the writer must be strings.Builder, +// which will not return error when w.Write(...). +func MustFormatSQL(w *strings.Builder, sql string, args ...interface{}) { + err := FormatSQL(w, sql, args...) + if err != nil { + panic(err) + } +} diff --git a/util/sqlexec/utils_test.go b/util/sqlexec/utils_test.go new file mode 100644 index 0000000000000..a8a912a33978f --- /dev/null +++ b/util/sqlexec/utils_test.go @@ -0,0 +1,430 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlexec + +import ( + "encoding/json" + "strings" + "testing" + "time" + + . "github.com/pingcap/check" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testUtilsSuite{}) + +type testUtilsSuite struct{} + +func (s *testUtilsSuite) TestReserveBuffer(c *C) { + res0 := reserveBuffer(nil, 0) + c.Assert(res0, HasLen, 0) + + res1 := reserveBuffer(res0, 3) + c.Assert(res1, HasLen, 3) + res1[1] = 3 + + res2 := reserveBuffer(res1, 9) + c.Assert(res2, HasLen, 12) + c.Assert(cap(res2), Equals, 15) + c.Assert(res2[:3], DeepEquals, res1) +} + +func (s *testUtilsSuite) TestEscapeBackslash(c *C) { + type TestCase struct { + name string + input []byte + output []byte + } + tests := []TestCase{ + { + name: "normal", + input: []byte("hello"), + output: []byte("hello"), + }, + { + name: "0", + input: []byte("he\x00lo"), + output: []byte("he\\0lo"), + }, + { + name: "break line", + input: []byte("he\nlo"), + output: []byte("he\\nlo"), + }, + { + name: "carry", + input: []byte("he\rlo"), + output: []byte("he\\rlo"), + }, + { + name: "substitute", + input: []byte("he\x1alo"), + output: []byte("he\\Zlo"), + }, + { + name: "single quote", + input: []byte("he'lo"), + output: []byte("he\\'lo"), + }, + { + name: "double quote", + input: []byte("he\"lo"), + output: []byte("he\\\"lo"), + }, + { + name: "back slash", + input: []byte("he\\lo"), + output: []byte("he\\\\lo"), + }, + { + name: "double escape", + input: []byte("he\x00lo\""), + output: []byte("he\\0lo\\\""), + }, + { + name: "chinese", + input: []byte("中文?"), + output: []byte("中文?"), + }, + } + for _, t := range tests { + commentf := Commentf("%s", t.name) + c.Assert(escapeBytesBackslash(nil, t.input), DeepEquals, t.output, commentf) + c.Assert(escapeStringBackslash(nil, string(t.input)), DeepEquals, t.output, commentf) + } +} + +func (s *testUtilsSuite) TestEscapeSQL(c *C) { + type TestCase struct { + name string + input string + params []interface{} + output string + err string + } + time2, err := time.Parse("2006-01-02 15:04:05", "2018-01-23 04:03:05") + c.Assert(err, IsNil) + tests := []TestCase{ + { + name: "normal 1", + input: "select * from 1", + params: []interface{}{}, + output: "select * from 1", + err: "", + }, + { + name: "normal 2", + input: "WHERE source != 'builtin'", + params: []interface{}{}, + output: "WHERE source != 'builtin'", + err: "", + }, + { + name: "discard extra arguments", + input: "select * from 1", + params: []interface{}{4, 5, "rt"}, + output: "select * from 1", + err: "", + }, + { + name: "%? missing arguments", + input: "select %? from %?", + params: []interface{}{4}, + err: "missing arguments.*", + }, + { + name: "nil", + input: "select %?", + params: []interface{}{nil}, + output: "select NULL", + err: "", + }, + { + name: "int", + input: "select %?", + params: []interface{}{int(3)}, + output: "select 3", + err: "", + }, + { + name: "int8", + input: "select %?", + params: []interface{}{int8(4)}, + output: "select 4", + err: "", + }, + { + name: "int16", + input: "select %?", + params: []interface{}{int16(5)}, + output: "select 5", + err: "", + }, + { + name: "int32", + input: "select %?", + params: []interface{}{int32(6)}, + output: "select 6", + err: "", + }, + { + name: "int64", + input: "select %?", + params: []interface{}{int64(7)}, + output: "select 7", + err: "", + }, + { + name: "uint", + input: "select %?", + params: []interface{}{uint(8)}, + output: "select 8", + err: "", + }, + { + name: "uint8", + input: "select %?", + params: []interface{}{uint8(9)}, + output: "select 9", + err: "", + }, + { + name: "uint16", + input: "select %?", + params: []interface{}{uint16(10)}, + output: "select 10", + err: "", + }, + { + name: "uint32", + input: "select %?", + params: []interface{}{uint32(11)}, + output: "select 11", + err: "", + }, + { + name: "uint64", + input: "select %?", + params: []interface{}{uint64(12)}, + output: "select 12", + err: "", + }, + { + name: "float32", + input: "select %?", + params: []interface{}{float32(0.13)}, + output: "select 0.13", + err: "", + }, + { + name: "float64", + input: "select %?", + params: []interface{}{float64(0.14)}, + output: "select 0.14", + err: "", + }, + { + name: "bool on", + input: "select %?", + params: []interface{}{true}, + output: "select 1", + err: "", + }, + { + name: "bool off", + input: "select %?", + params: []interface{}{false}, + output: "select 0", + err: "", + }, + { + name: "time 0", + input: "select %?", + params: []interface{}{time.Time{}}, + output: "select '0000-00-00'", + err: "", + }, + { + name: "time 1", + input: "select %?", + params: []interface{}{time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC)}, + output: "select '2019-01-01 00:00:00'", + err: "", + }, + { + name: "time 2", + input: "select %?", + params: []interface{}{time2}, + output: "select '2018-01-23 04:03:05'", + err: "", + }, + { + name: "time 3", + input: "select %?", + params: []interface{}{time.Unix(0, 888888888).UTC()}, + output: "select '1970-01-01 00:00:00.888888'", + err: "", + }, + { + name: "empty byte slice1", + input: "select %?", + params: []interface{}{[]byte(nil)}, + output: "select NULL", + err: "", + }, + { + name: "empty byte slice2", + input: "select %?", + params: []interface{}{[]byte{}}, + output: "select _binary''", + err: "", + }, + { + name: "byte slice", + input: "select %?", + params: []interface{}{[]byte{2, 3}}, + output: "select _binary'\x02\x03'", + err: "", + }, + { + name: "string", + input: "select %?", + params: []interface{}{"33"}, + output: "select '33'", + }, + { + name: "string slice", + input: "select %?", + params: []interface{}{[]string{"33", "44"}}, + output: "select '33','44'", + }, + { + name: "raw json", + input: "select %?", + params: []interface{}{json.RawMessage(`{"h": "hello"}`)}, + output: "select '{\\\"h\\\": \\\"hello\\\"}'", + }, + { + name: "unsupported args", + input: "select %?", + params: []interface{}{make(chan byte)}, + err: "unsupported 1-th argument.*", + }, + { + name: "mixed arguments", + input: "select %?, %?, %?", + params: []interface{}{"33", 44, time.Time{}}, + output: "select '33', 44, '0000-00-00'", + }, + { + name: "simple injection", + input: "select %?", + params: []interface{}{"0; drop database"}, + output: "select '0; drop database'", + }, + { + name: "identifier, wrong arg", + input: "use %n", + params: []interface{}{3}, + err: "expect a string identifier.*", + }, + { + name: "identifier", + input: "use %n", + params: []interface{}{"table`"}, + output: "use `table```", + err: "", + }, + { + name: "%n missing arguments", + input: "use %n", + params: []interface{}{}, + err: "missing arguments.*", + }, + { + name: "% escape", + input: "select * from t where val = '%%?'", + params: []interface{}{}, + output: "select * from t where val = '%?'", + err: "", + }, + { + name: "unknown specifier", + input: "%v", + params: []interface{}{}, + output: "%v", + err: "", + }, + { + name: "truncated specifier ", + input: "rv %", + params: []interface{}{}, + output: "rv %", + err: "", + }, + { + name: "float32 slice", + input: "select %?", + params: []interface{}{[]float32{33.1, 0.44}}, + output: "select 33.1,0.44", + }, + { + name: "float64 slice", + input: "select %?", + params: []interface{}{[]float64{55.2, 0.66}}, + output: "select 55.2,0.66", + }, + } + for _, t := range tests { + comment := Commentf("%s", t.name) + r3 := new(strings.Builder) + r1, e1 := escapeSQL(t.input, t.params...) + r2, e2 := EscapeSQL(t.input, t.params...) + e3 := FormatSQL(r3, t.input, t.params...) + if t.err == "" { + c.Assert(e1, IsNil, comment) + c.Assert(string(r1), Equals, t.output, comment) + c.Assert(e2, IsNil, comment) + c.Assert(r2, Equals, t.output, comment) + c.Assert(e3, IsNil, comment) + c.Assert(r3.String(), Equals, t.output, comment) + } else { + c.Assert(e1, NotNil, comment) + c.Assert(e1, ErrorMatches, t.err, comment) + c.Assert(e2, NotNil, comment) + c.Assert(e2, ErrorMatches, t.err, comment) + c.Assert(e3, NotNil, comment) + c.Assert(e3, ErrorMatches, t.err, comment) + } + } +} + +func (s *testUtilsSuite) TestMustUtils(c *C) { + c.Assert(func() { + MustEscapeSQL("%?") + }, PanicMatches, "missing arguments.*") + + c.Assert(func() { + sql := new(strings.Builder) + MustFormatSQL(sql, "%?") + }, PanicMatches, "missing arguments.*") + + sql := new(strings.Builder) + MustFormatSQL(sql, "t") + MustEscapeSQL("tt") +}