Skip to content

Commit

Permalink
support oracle
Browse files Browse the repository at this point in the history
  • Loading branch information
tianfw committed Aug 27, 2023
1 parent 2f05e90 commit 38f9bc0
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 17 deletions.
2 changes: 2 additions & 0 deletions args.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ func (args *Args) compileArg(buf *stringBuilder, flavor Flavor, values []interfa
fmt.Fprintf(buf, "$%d", len(values)+1)
case SQLServer:
fmt.Fprintf(buf, "@p%d", len(values)+1)
case Oracle:
fmt.Fprintf(buf, ":%d", len(values)+1)
default:
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor)))
}
Expand Down
9 changes: 7 additions & 2 deletions flavor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const (
CQL
ClickHouse
Presto
Oracle
)

var (
Expand Down Expand Up @@ -58,6 +59,8 @@ func (f Flavor) String() string {
return "ClickHouse"
case Presto:
return "Presto"
case Oracle:
return "Oracle"
}

return "<invalid>"
Expand All @@ -84,6 +87,8 @@ func (f Flavor) Interpolate(sql string, args []interface{}) (string, error) {
return clickhouseInterpolate(sql, args...)
case Presto:
return prestoInterpolate(sql, args...)
case Oracle:
return oracleInterpolate(sql, args...)
}

return "", ErrInterpolateNotImplemented
Expand Down Expand Up @@ -140,7 +145,7 @@ func (f Flavor) Quote(name string) string {
switch f {
case MySQL, ClickHouse:
return fmt.Sprintf("`%s`", name)
case PostgreSQL, SQLServer, SQLite, Presto:
case PostgreSQL, SQLServer, SQLite, Presto, Oracle:
return fmt.Sprintf(`"%s"`, name)
case CQL:
return fmt.Sprintf("'%s'", name)
Expand All @@ -152,7 +157,7 @@ func (f Flavor) Quote(name string) string {
// PrepareInsertIgnore prepares the insert builder to build insert ignore SQL statement based on the sql flavor
func (f Flavor) PrepareInsertIgnore(table string, ib *InsertBuilder) {
switch ib.args.Flavor {
case MySQL:
case MySQL, Oracle:
ib.verb = "INSERT IGNORE"

case PostgreSQL:
Expand Down
18 changes: 18 additions & 0 deletions flavor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,21 @@ func ExampleFlavor_Interpolate_cql() {
// SELECT name FROM user WHERE id = 1234 AND name = 'Charmy Liu'
// <nil>
}

func ExampleFlavor_Interpolate_oracle() {
sb := Oracle.NewSelectBuilder()
sb.Select("name").From("user").Where(
sb.E("id", 1234),
sb.E("name", "Charmy Liu"),
sb.E("enabled", true),
)
sql, args := sb.Build()
query, err := Oracle.Interpolate(sql, args)

fmt.Println(query)
fmt.Println(err)

// Output:
// SELECT name FROM user WHERE id = 1234 AND name = 'Charmy Liu' AND enabled = 1
// <nil>
}
209 changes: 195 additions & 14 deletions interpolate.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,170 @@ func prestoInterpolate(query string, args ...interface{}) (string, error) {
return mysqlLikeInterpolate(Presto, query, args...)
}

// oraclelInterpolate parses query and replace all ":*" with encoded args.
// If there are more ":*" than len(args), returns ErrMissingArgs.
// Otherwise, if there are less ":*" than len(args), the redundant args are omitted.
func oracleInterpolate(query string, args ...interface{}) (string, error) {
// Roughly estimate the size to avoid useless memory allocation and copy.
buf := make([]byte, 0, len(query)+len(args)*20)

var quote rune
var dollarQuote string
var err error
var idx int64
max := len(args)
escaping := false
offset := 0
target := query
r, sz := utf8.DecodeRuneInString(target)

for ; sz != 0; r, sz = utf8.DecodeRuneInString(target) {
offset += sz
target = query[offset:]

if escaping {
escaping = false
continue
}

switch r {
case ':':
if quote != 0 {
if quote != ':' {
continue
}

// Try to find the end of dollar quote.
pos := offset

for r, sz = utf8.DecodeRuneInString(target); sz != 0 && r != ':'; r, sz = utf8.DecodeRuneInString(target) {
pos += sz
target = query[pos:]
}

if sz == 0 {
break
}

if r == ':' {
dq := query[offset : pos+sz]
offset = pos
target = query[offset:]

if dq == dollarQuote {
quote = 0
dollarQuote = ""
offset += sz
target = query[offset:]
}

continue
}

continue
}

oldSz := sz
pos := offset
r, sz = utf8.DecodeRuneInString(target)

if '1' <= r && r <= '9' {
// A placeholder is found.
pos += sz
target = query[pos:]

for r, sz = utf8.DecodeRuneInString(target); sz != 0 && '0' <= r && r <= '9'; r, sz = utf8.DecodeRuneInString(target) {
pos += sz
target = query[pos:]
}

idx, err = strconv.ParseInt(query[offset:pos], 10, strconv.IntSize)

if err != nil {
return "", err
}

if int(idx) >= max+1 {
return "", ErrInterpolateMissingArgs
}

buf = append(buf, query[:offset-oldSz]...)
buf, err = encodeValue(buf, args[idx-1], Oracle)

if err != nil {
return "", err
}

query = target
offset = 0

if sz == 0 {
break
}

continue
}

// Try to find the beginning of dollar quote.
for ; sz != 0 && r != ':' && unicode.IsLetter(r); r, sz = utf8.DecodeRuneInString(target) {
pos += sz
target = query[pos:]
}

if sz == 0 {
break
}

if !unicode.IsLetter(r) && r != ':' {
continue
}

pos += sz
quote = ':'
dollarQuote = query[offset:pos]
offset = pos
target = query[offset:]

case '\'':
if quote == '\'' {
// PostgreSQL uses two single quotes to represent one single quote.
r, sz = utf8.DecodeRuneInString(target)

if r == '\'' {
offset += sz
target = query[offset:]
continue
}

quote = 0
continue
}

if quote == 0 {
quote = '\''
}

case '"':
if quote == '"' {
quote = 0
continue
}

if quote == 0 {
quote = '"'
}

case '\\':
if quote == '\'' || quote == '"' {
escaping = true
}
}
}

buf = append(buf, query...)
return *(*string)(unsafe.Pointer(&buf)), nil
}

func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {
switch v := arg.(type) {
case nil:
Expand All @@ -420,32 +584,35 @@ func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {
// In SQL standard, the precision of fractional seconds in time literal is up to 6 digits.
// Round up v.
v = v.Add(500 * time.Nanosecond)
buf = append(buf, '\'')

switch flavor {
case MySQL:
buf = append(buf, v.Format("2006-01-02 15:04:05.999999")...)
buf = append(buf, v.Format("'2006-01-02 15:04:05.999999'")...)

case PostgreSQL:
buf = append(buf, v.Format("2006-01-02 15:04:05.999999 MST")...)
buf = append(buf, v.Format("'2006-01-02 15:04:05.999999 MST'")...)

case SQLite:
buf = append(buf, v.Format("2006-01-02 15:04:05.000")...)
buf = append(buf, v.Format("'2006-01-02 15:04:05.000'")...)

case SQLServer:
buf = append(buf, v.Format("2006-01-02 15:04:05.999999 Z07:00")...)
buf = append(buf, v.Format("'2006-01-02 15:04:05.999999 Z07:00'")...)

case CQL:
buf = append(buf, v.Format("2006-01-02 15:04:05.999999Z0700")...)
buf = append(buf, v.Format("'2006-01-02 15:04:05.999999Z0700'")...)

case ClickHouse:
buf = append(buf, v.Format("2006-01-02 15:04:05.999999")...)
buf = append(buf, v.Format("'2006-01-02 15:04:05.999999'")...)

case Presto:
buf = append(buf, v.Format("2006-01-02 15:04:05.000")...)
}
buf = append(buf, v.Format("'2006-01-02 15:04:05.000'")...)

buf = append(buf, '\'')
case Oracle:
buf = append(buf, "to_timestamp('"...)
buf = append(buf, v.Format("2006-01-02 15:04:05.999999")...)
buf = append(buf, "', 'YYYY-MM-DD HH24:MI:SS.FF')"...)

}

case fmt.Stringer:
buf = quoteStringValue(buf, v.String(), flavor)
Expand All @@ -455,10 +622,19 @@ func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {

switch k := primative.Kind(); k {
case reflect.Bool:
if primative.Bool() {
buf = append(buf, "TRUE"...)
} else {
buf = append(buf, "FALSE"...)
switch flavor {
case Oracle:
if primative.Bool() {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
default:
if primative.Bool() {
buf = append(buf, "TRUE"...)
} else {
buf = append(buf, "FALSE"...)
}
}

case reflect.Int:
Expand Down Expand Up @@ -553,6 +729,11 @@ func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {
buf = append(buf, "from_hex('"...)
buf = appendHex(buf, data)
buf = append(buf, "')"...)

case Oracle:
buf = append(buf, "hextoraw('"...)
buf = appendHex(buf, data)
buf = append(buf, "')"...)
}

default:
Expand Down
41 changes: 41 additions & 0 deletions interpolate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,47 @@ func TestFlavorInterpolate(t *testing.T) {
"SELECT ?", []interface{}{errorValuer(1)},
"", ErrErrorValuer,
},

{
Oracle,
"SELECT * FROM a WHERE name = :3 AND state IN (:2, :4, :1, :6, :5)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)},
"SELECT * FROM a WHERE name = 8 AND state IN (42, -16, 'I\\'m fine', 64, 32)", nil,
},
{
Oracle,
"SELECT * FROM :abc::1:abc:1:1 WHERE name = \":1\" AND state IN (:2, ':1', :3, :6, :5, :4, :2) :3", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"},
"SELECT * FROM :abc::1:abc:1'\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'' WHERE name = \":1\" AND state IN (42, ':1', 8, 64, 32, 16, 42) 8", nil,
},
{
Oracle,
"SELECT :1, :2, :3, :4, :5, :6, :7, :8, :9, :11, :a", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil, 10, 11, 12},
"SELECT 1, 0, 1.234567, 9.87654321, NULL, hextoraw('49276D206279746573'), to_timestamp('2019-04-24 12:23:34.123457', 'YYYY-MM-DD HH24:MI:SS.FF'), '0000-00-00', NULL, 11, :a", nil,
},
{
Oracle,
"SELECT '\\':1', \"\\\":1\", `:1`, \\:1a, ::1::, :a :b: :a : :1:b:1:1 :a: :", []interface{}{Oracle},
"SELECT '\\':1', \"\\\":1\", `'Oracle'`, \\'Oracle'a, ::1::, :a :b: :a : :1:b:1'Oracle' :a: :", nil,
},
{
Oracle,
"SELECT * FROM a WHERE name = 'Huan''Du'':1' AND desc = :1", []interface{}{"c'mon"},
"SELECT * FROM a WHERE name = 'Huan''Du'':1' AND desc = 'c\\'mon'", nil,
},
{
Oracle,
"SELECT :1", nil,
"", ErrInterpolateMissingArgs,
},
{
Oracle,
"SELECT :1", []interface{}{complex(1, 2)},
"", ErrInterpolateUnsupportedArgs,
},
{
Oracle,
"SELECT :12345678901234567890", nil,
"", errOutOfRange,
},
}

for idx, c := range cases {
Expand Down
3 changes: 2 additions & 1 deletion struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"reflect"
"regexp"
"sort"
"strings"
)

var (
Expand Down Expand Up @@ -308,7 +309,7 @@ func (s *Struct) selectFromWithTags(table string, with, without []string) (sb *S
cols := make([]string, 0, len(tagged.ForRead))

for _, sf := range tagged.ForRead {
if s.Flavor != CQL {
if s.Flavor != CQL && !strings.ContainsRune(sf.Alias, '.') {
buf.WriteString(table)
buf.WriteRune('.')
}
Expand Down

0 comments on commit 38f9bc0

Please sign in to comment.