Skip to content

Commit

Permalink
Merge pull request #109 from si3nloong/feat/array
Browse files Browse the repository at this point in the history
feat: support array and refactor the expression template string
  • Loading branch information
si3nloong committed Jul 6, 2024
2 parents dafc147 + 66ef049 commit f0c5f12
Show file tree
Hide file tree
Showing 39 changed files with 592 additions and 189 deletions.
4 changes: 2 additions & 2 deletions codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (i columnInfo) SQLValuer() sequel.QueryFunc {
return nil
}
return func(placeholder string) string {
return strings.Replace(i.model.SQLValuer, "{placeholder}", placeholder, 1)
return strings.Replace(i.model.SQLValuer, "{{.}}", placeholder, 1)
}
}

Expand All @@ -246,7 +246,7 @@ func (i columnInfo) SQLScanner() sequel.QueryFunc {
return nil
}
return func(column string) string {
return strings.Replace(i.model.SQLScanner, "{column}", column, 1)
return strings.Replace(i.model.SQLScanner, "{{.}}", column, 1)
}
}

Expand Down
55 changes: 52 additions & 3 deletions codegen/expr.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
package codegen

import (
"fmt"
"bytes"
"go/types"
"path/filepath"
"regexp"
"strings"
"text/template"
)

var (
pkgRegexp = regexp.MustCompile(`(?i)((?:[a-z][a-z0-9_.-]*/)*[a-z][a-z0-9_.-]*)\.[a-z]\w*`)
)

// Currently this template is support following expression:
//
// - {{.}} - current go path
// - {{goPath}} - go path
// - {{addrOfGoPath}} - address of go path
// - {{len}} - go array size
type Expr string

// Possible values:
Expand All @@ -20,8 +27,50 @@ type Expr string
// time.Time(v)
// string(v)

func (e Expr) Format(pkg *Package, args ...any) string {
type ExprParams struct {
// You may pass `&v.Path` or `v.Path` or any relevant go path,
// we will check whether it's addr of the go path
GoPath string
Len int
}

func (e Expr) Format(pkg *Package, args ...ExprParams) string {
params := ExprParams{}
if len(args) > 0 {
params = args[0]
}

actualGoPath := params.GoPath
// If the Go path is an address, we trim it out
// This will ease the use of `addrOfGoPath` function
if len(params.GoPath) > 0 && params.GoPath[0] == '&' {
params.GoPath = params.GoPath[1:]
}

funcMap := template.FuncMap{
"goPath": func() string {
return params.GoPath
},
"addrOfGoPath": func() string {
return "&" + params.GoPath
},
}
if params.Len > 0 {
funcMap["len"] = func() int {
return params.Len
}
}

str := string(e)
tmpl, err := template.New("expression").Funcs(funcMap).Parse(str)
if err != nil {
panic(err)
}
buf := new(bytes.Buffer)
if err := tmpl.Execute(buf, actualGoPath); err != nil {
panic(err)
}
str = buf.String()
matches := pkgRegexp.FindStringSubmatch(str)
if len(matches) > 0 {
p, _ := pkg.Import(types.NewPackage(matches[1], filepath.Base(matches[1])))
Expand All @@ -31,5 +80,5 @@ func (e Expr) Format(pkg *Package, args ...any) string {
str = strings.Replace(str, matches[1]+".", "", -1)
}
}
return fmt.Sprintf(str, args...)
return str
}
8 changes: 4 additions & 4 deletions codegen/expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (
func TestExpr(t *testing.T) {
pkg := NewPackage("", "")

require.Equal(t, "string(v)", Expr(`string(%s)`).Format(pkg, "v"))
require.Equal(t, `time.Time(v)`, Expr(`time.Time(%s)`).Format(pkg, "v"))
require.Equal(t, `(*time.Time)(v)`, Expr(`(*time.Time)(%s)`).Format(pkg, "v"))
require.Equal(t, `(driver.Valuer)(v)`, Expr(`(database/sql/driver.Valuer)(%s)`).Format(pkg, "v"))
require.Equal(t, "string(v)", Expr(`string({{goPath}})`).Format(pkg, ExprParams{GoPath: "v"}))
require.Equal(t, `time.Time(v)`, Expr(`time.Time({{goPath}})`).Format(pkg, ExprParams{GoPath: "v"}))
require.Equal(t, `(*time.Time)(&v)`, Expr(`(*time.Time)({{addrOfGoPath}})`).Format(pkg, ExprParams{GoPath: "v"}))
require.Equal(t, `(driver.Valuer)(v)`, Expr(`(database/sql/driver.Valuer)({{goPath}})`).Format(pkg, ExprParams{GoPath: "v"}))

require.ElementsMatch(t, []*types.Package{
types.NewPackage("time", "time"),
Expand Down
38 changes: 23 additions & 15 deletions codegen/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,33 +414,41 @@ func (g *Generator) buildScanner(importPkgs *Package, table *tableInfo) {

func (g *Generator) valuer(importPkgs *Package, goPath string, t types.Type) string {
if model, ok := g.config.Models[t.String()]; ok && model.Valuer != "" {
// TODO: do it better
return Expr(strings.Replace(model.Valuer, "{field}", goPath, -1)).Format(importPkgs)
return Expr(model.Valuer).Format(importPkgs, ExprParams{GoPath: goPath})
} else if _, wrong := types.MissingMethod(t, goSqlValuer, true); wrong {
return Expr("(database/sql/driver.Valuer)(%s)").Format(importPkgs, goPath)
} else if codec, _ := UnderlyingType(t); codec != nil {
return codec.Encoder.Format(importPkgs, goPath)
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
} else if codec, goType := UnderlyingType(t); codec != nil {
switch vi := goType.(type) {
case GoArray:
return codec.Encoder.Format(importPkgs, ExprParams{GoPath: goPath, Len: vi.Len()})
default:
return codec.Encoder.Format(importPkgs, ExprParams{GoPath: goPath})
}
} else if isImplemented(t, textMarshaler) {
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextMarshaler(%s)").Format(importPkgs, goPath)
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextMarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
} else if isImplemented(t, binaryMarshaler) {
return Expr("github.com/si3nloong/sqlgen/sequel/types.BinaryMarshaler(%s)").Format(importPkgs, goPath)
return Expr("github.com/si3nloong/sqlgen/sequel/types.BinaryMarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
} else {
return Expr("github.com/si3nloong/sqlgen/sequel/types.JSONMarshaler(%s)").Format(importPkgs, goPath)
return Expr("github.com/si3nloong/sqlgen/sequel/types.JSONMarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
}
}

func (g *Generator) scanner(importPkgs *Package, goPath string, t types.Type) string {
if model, ok := g.config.Models[t.String()]; ok && model.Scanner != "" {
// TODO: do it better
return Expr(strings.Replace(model.Scanner, "{field}", goPath, -1)).Format(importPkgs)
return Expr(model.Scanner).Format(importPkgs, ExprParams{GoPath: goPath})
} else if types.Implements(newPointer(t), goSqlScanner) {
return Expr("(database/sql.Scanner)(%s)").Format(importPkgs, goPath)
} else if codec, _ := UnderlyingType(t); codec != nil {
return codec.Decoder.Format(importPkgs, goPath)
return Expr("(database/sql.Scanner)({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
} else if codec, goType := UnderlyingType(t); codec != nil {
switch vi := goType.(type) {
case GoArray:
return codec.Decoder.Format(importPkgs, ExprParams{GoPath: goPath, Len: vi.Len()})
default:
return codec.Decoder.Format(importPkgs, ExprParams{GoPath: goPath})
}
} else if isImplemented(types.NewPointer(t), textMarshaler) {
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler(%s)").Format(importPkgs, goPath)
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
}
return Expr("github.com/si3nloong/sqlgen/sequel/types.JSONUnmarshaler(%s)").Format(importPkgs, goPath)
return Expr("github.com/si3nloong/sqlgen/sequel/types.JSONUnmarshaler({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
}

func (g *Generator) buildFindByPK(importPkgs *Package, table *tableInfo) {
Expand Down
52 changes: 47 additions & 5 deletions codegen/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,46 @@ package codegen
import (
"fmt"
"go/types"
"regexp"
"strconv"
)

func UnderlyingType(t types.Type) (codec *Mapping, typeStr string) {
type GoType interface {
String() string
}

type GoArray interface {
Len() int
}

type nNode struct {
t string
}

func (n nNode) String() string {
return n.t
}

type arrNode struct {
t string
size int
}

func (n arrNode) String() string {
return n.t
}
func (n arrNode) Len() int {
return n.size
}

var (
arrayRegexp = regexp.MustCompile(`^\[(\d+)\](rune|string|byte)$`)
)

func UnderlyingType(t types.Type) (*Mapping, GoType) {
var (
prev = t
prev = t
typeStr string
)

loop:
Expand Down Expand Up @@ -35,17 +70,24 @@ loop:
break loop
}
if v, ok := typeMap[typeStr]; ok {
return v, typeStr
return v, &nNode{t: typeStr}
}
if prev == t {
break loop
}
t = prev
}
if v, ok := typeMap[typeStr]; ok {
return v, typeStr
return v, &nNode{t: typeStr}
}
// Find fixed size array mapper
if matches := arrayRegexp.FindStringSubmatch(typeStr); len(matches) > 0 {
if v, ok := typeMap["[...]"+matches[2]]; ok {
len, _ := strconv.Atoi(matches[1])
return v, &arrNode{t: typeStr, size: len}
}
}
return nil, typeStr
return nil, &nNode{t: typeStr}
}

func assertAsPtr[T any](v any) *T {
Expand Down
Loading

0 comments on commit f0c5f12

Please sign in to comment.