Skip to content

Commit

Permalink
fix(bug): generated codes
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Oct 1, 2024
1 parent 72efacb commit 7d60066
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 15 deletions.
17 changes: 13 additions & 4 deletions codegen/templates/db.go.tpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{{- reserveImport "context" }}
{{- reserveImport "database/sql" }}
{{- reserveImport "database/sql/driver" }}
{{- reserveImport "strings" }}
{{- reserveImport "strconv" }}
{{- reserveImport "sync" }}
Expand Down Expand Up @@ -1514,7 +1515,7 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) {
var (
args = make([]any, len(s.args))
idx int
i int
i = 1
)
copy(args, s.args)
Expand All @@ -1523,17 +1524,22 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) {
{{ if isStaticVar -}}
idx = strings.Index(str, "?")
{{ else -}}
idx = strings.Index(str, wrapVar(i))
placeholder := wrapVar(i)
idx = strings.Index(str, placeholder)
{{ end -}}
if idx < 0 {
f.Write(unsafe.Slice(unsafe.StringData(str), len(str)))
break
}
f.Write([]byte(str[:idx]))
v := toStr(args[0])
v := strf(args[0])
f.Write(unsafe.Slice(unsafe.StringData(v), len(v)))
{{ if isStaticVar -}}
str = str[idx+1:]
{{ else -}}
str = str[idx+len(placeholder):]
{{ end -}}
args = args[1:]
i++
}
Expand Down Expand Up @@ -1572,7 +1578,7 @@ func wrapVar(i int) string {
}
{{ end }}
func toStr(v any) string {
func strf(v any) string {
switch vi := v.(type) {
case string:
return strconv.Quote(vi)
Expand All @@ -1588,6 +1594,9 @@ func toStr(v any) string {
return strconv.Quote(vi.Format(time.RFC3339))
case sql.RawBytes:
return unsafe.String(unsafe.SliceData(vi), len(vi))
case driver.Valuer:
val, _ := vi.Value()
return strf(val)
default:
panic("unreachable")
}
Expand Down
10 changes: 7 additions & 3 deletions examples/db/mysql/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mysqldb
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"iter"
Expand Down Expand Up @@ -1097,7 +1098,7 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) {
var (
args = make([]any, len(s.args))
idx int
i int
i = 1
)

copy(args, s.args)
Expand All @@ -1110,7 +1111,7 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) {
}

f.Write([]byte(str[:idx]))
v := toStr(args[0])
v := strf(args[0])
f.Write(unsafe.Slice(unsafe.StringData(v), len(v)))
str = str[idx+1:]
args = args[1:]
Expand Down Expand Up @@ -1145,7 +1146,7 @@ func dbName(model any) string {
return ""
}

func toStr(v any) string {
func strf(v any) string {
switch vi := v.(type) {
case string:
return strconv.Quote(vi)
Expand All @@ -1161,6 +1162,9 @@ func toStr(v any) string {
return strconv.Quote(vi.Format(time.RFC3339))
case sql.RawBytes:
return unsafe.String(unsafe.SliceData(vi), len(vi))
case driver.Valuer:
val, _ := vi.Value()
return strf(val)
default:
panic("unreachable")
}
Expand Down
15 changes: 10 additions & 5 deletions examples/db/postgres/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package postgresdb
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"iter"
"strconv"
Expand Down Expand Up @@ -1215,22 +1216,23 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) {
var (
args = make([]any, len(s.args))
idx int
i int
i = 1
)

copy(args, s.args)

for {
idx = strings.Index(str, wrapVar(i))
placeholder := wrapVar(i)
idx = strings.Index(str, placeholder)
if idx < 0 {
f.Write(unsafe.Slice(unsafe.StringData(str), len(str)))
break
}

f.Write([]byte(str[:idx]))
v := toStr(args[0])
v := strf(args[0])
f.Write(unsafe.Slice(unsafe.StringData(v), len(v)))
str = str[idx+1:]
str = str[idx+len(placeholder):]
args = args[1:]
i++
}
Expand Down Expand Up @@ -1267,7 +1269,7 @@ func wrapVar(i int) string {
return `$` + strconv.Itoa(i)
}

func toStr(v any) string {
func strf(v any) string {
switch vi := v.(type) {
case string:
return strconv.Quote(vi)
Expand All @@ -1283,6 +1285,9 @@ func toStr(v any) string {
return strconv.Quote(vi.Format(time.RFC3339))
case sql.RawBytes:
return unsafe.String(unsafe.SliceData(vi), len(vi))
case driver.Valuer:
val, _ := vi.Value()
return strf(val)
default:
panic("unreachable")
}
Expand Down
10 changes: 7 additions & 3 deletions examples/db/sqlite/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sqlite
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"iter"
Expand Down Expand Up @@ -1097,7 +1098,7 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) {
var (
args = make([]any, len(s.args))
idx int
i int
i = 1
)

copy(args, s.args)
Expand All @@ -1110,7 +1111,7 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) {
}

f.Write([]byte(str[:idx]))
v := toStr(args[0])
v := strf(args[0])
f.Write(unsafe.Slice(unsafe.StringData(v), len(v)))
str = str[idx+1:]
args = args[1:]
Expand Down Expand Up @@ -1145,7 +1146,7 @@ func dbName(model any) string {
return ""
}

func toStr(v any) string {
func strf(v any) string {
switch vi := v.(type) {
case string:
return strconv.Quote(vi)
Expand All @@ -1161,6 +1162,9 @@ func toStr(v any) string {
return strconv.Quote(vi.Format(time.RFC3339))
case sql.RawBytes:
return unsafe.String(unsafe.SliceData(vi), len(vi))
case driver.Valuer:
val, _ := vi.Value()
return strf(val)
default:
panic("unreachable")
}
Expand Down

0 comments on commit 7d60066

Please sign in to comment.