diff --git a/codegen/templates/db.go.tpl b/codegen/templates/db.go.tpl index 964c9a3..fa11bc8 100644 --- a/codegen/templates/db.go.tpl +++ b/codegen/templates/db.go.tpl @@ -1,5 +1,6 @@ {{- reserveImport "context" }} {{- reserveImport "database/sql" }} +{{- reserveImport "database/sql/driver" }} {{- reserveImport "strings" }} {{- reserveImport "strconv" }} {{- reserveImport "sync" }} @@ -1514,7 +1515,7 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) { var ( args = make([]any, len(s.args)) idx int - i int + i = 1 ) copy(args, s.args) @@ -1523,7 +1524,8 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) { {{ if isStaticVar -}} idx = strings.Index(str, "?") {{ else -}} - idx = strings.Index(str, wrapVar(i)) + placeholder := wrapVar(i) + idx = strings.Index(str, placeholder) {{ end -}} if idx < 0 { f.Write(unsafe.Slice(unsafe.StringData(str), len(str))) @@ -1531,9 +1533,13 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) { } f.Write([]byte(str[:idx])) - v := toStr(args[0]) + v := strf(args[0]) f.Write(unsafe.Slice(unsafe.StringData(v), len(v))) + {{ if isStaticVar -}} str = str[idx+1:] + {{ else -}} + str = str[idx+len(placeholder):] + {{ end -}} args = args[1:] i++ } @@ -1572,7 +1578,7 @@ func wrapVar(i int) string { } {{ end }} -func toStr(v any) string { +func strf(v any) string { switch vi := v.(type) { case string: return strconv.Quote(vi) @@ -1588,6 +1594,9 @@ func toStr(v any) string { return strconv.Quote(vi.Format(time.RFC3339)) case sql.RawBytes: return unsafe.String(unsafe.SliceData(vi), len(vi)) + case driver.Valuer: + val, _ := vi.Value() + return strf(val) default: panic("unreachable") } diff --git a/examples/db/mysql/db.go b/examples/db/mysql/db.go index d3bcae9..fc343db 100755 --- a/examples/db/mysql/db.go +++ b/examples/db/mysql/db.go @@ -3,6 +3,7 @@ package mysqldb import ( "context" "database/sql" + "database/sql/driver" "errors" "fmt" "iter" @@ -1097,7 +1098,7 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) { var ( args = make([]any, len(s.args)) idx int - i int + i = 1 ) copy(args, s.args) @@ -1110,7 +1111,7 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) { } f.Write([]byte(str[:idx])) - v := toStr(args[0]) + v := strf(args[0]) f.Write(unsafe.Slice(unsafe.StringData(v), len(v))) str = str[idx+1:] args = args[1:] @@ -1145,7 +1146,7 @@ func dbName(model any) string { return "" } -func toStr(v any) string { +func strf(v any) string { switch vi := v.(type) { case string: return strconv.Quote(vi) @@ -1161,6 +1162,9 @@ func toStr(v any) string { return strconv.Quote(vi.Format(time.RFC3339)) case sql.RawBytes: return unsafe.String(unsafe.SliceData(vi), len(vi)) + case driver.Valuer: + val, _ := vi.Value() + return strf(val) default: panic("unreachable") } diff --git a/examples/db/postgres/db.go b/examples/db/postgres/db.go index 49a110c..70492d5 100755 --- a/examples/db/postgres/db.go +++ b/examples/db/postgres/db.go @@ -3,6 +3,7 @@ package postgresdb import ( "context" "database/sql" + "database/sql/driver" "fmt" "iter" "strconv" @@ -1215,22 +1216,23 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) { var ( args = make([]any, len(s.args)) idx int - i int + i = 1 ) copy(args, s.args) for { - idx = strings.Index(str, wrapVar(i)) + placeholder := wrapVar(i) + idx = strings.Index(str, placeholder) if idx < 0 { f.Write(unsafe.Slice(unsafe.StringData(str), len(str))) break } f.Write([]byte(str[:idx])) - v := toStr(args[0]) + v := strf(args[0]) f.Write(unsafe.Slice(unsafe.StringData(v), len(v))) - str = str[idx+1:] + str = str[idx+len(placeholder):] args = args[1:] i++ } @@ -1267,7 +1269,7 @@ func wrapVar(i int) string { return `$` + strconv.Itoa(i) } -func toStr(v any) string { +func strf(v any) string { switch vi := v.(type) { case string: return strconv.Quote(vi) @@ -1283,6 +1285,9 @@ func toStr(v any) string { return strconv.Quote(vi.Format(time.RFC3339)) case sql.RawBytes: return unsafe.String(unsafe.SliceData(vi), len(vi)) + case driver.Valuer: + val, _ := vi.Value() + return strf(val) default: panic("unreachable") } diff --git a/examples/db/sqlite/db.go b/examples/db/sqlite/db.go index 992b82f..efcff29 100755 --- a/examples/db/sqlite/db.go +++ b/examples/db/sqlite/db.go @@ -3,6 +3,7 @@ package sqlite import ( "context" "database/sql" + "database/sql/driver" "errors" "fmt" "iter" @@ -1097,7 +1098,7 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) { var ( args = make([]any, len(s.args)) idx int - i int + i = 1 ) copy(args, s.args) @@ -1110,7 +1111,7 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) { } f.Write([]byte(str[:idx])) - v := toStr(args[0]) + v := strf(args[0]) f.Write(unsafe.Slice(unsafe.StringData(v), len(v))) str = str[idx+1:] args = args[1:] @@ -1145,7 +1146,7 @@ func dbName(model any) string { return "" } -func toStr(v any) string { +func strf(v any) string { switch vi := v.(type) { case string: return strconv.Quote(vi) @@ -1161,6 +1162,9 @@ func toStr(v any) string { return strconv.Quote(vi.Format(time.RFC3339)) case sql.RawBytes: return unsafe.String(unsafe.SliceData(vi), len(vi)) + case driver.Valuer: + val, _ := vi.Value() + return strf(val) default: panic("unreachable") }