Skip to content

Commit

Permalink
chore: update generated codes
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Sep 20, 2024
1 parent 540bec1 commit bd6676d
Show file tree
Hide file tree
Showing 31 changed files with 190 additions and 138 deletions.
36 changes: 18 additions & 18 deletions codegen/code_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (g *Generator) genModels(pkg *packages.Package, dstDir string, typeInferred
g.LogError(fmt.Errorf(`sqlgen: struct %q has function "DatabaseName" but wrong footprint`, t.goName))
} else if method != nil && !wrongType && t.dbName != "" {
g.L("func (" + t.goName + ") DatabaseName() string {")
g.L(`return ` + g.Quote(t.dbName))
g.L(`return ` + g.Quote(g.QuoteIdentifier(t.dbName)))
g.L("}")
}

Expand All @@ -122,7 +122,7 @@ func (g *Generator) genModels(pkg *packages.Package, dstDir string, typeInferred
g.LogError(fmt.Errorf(`sqlgen: struct %q has function "TableName" but wrong footprint`, t.goName))
} else if method != nil && !wrongType {
g.L("func (" + t.goName + ") TableName() string {")
g.L(`return ` + g.Quote(t.tableName))
g.L(`return ` + g.Quote(g.QuoteIdentifier(t.tableName)))
g.L("}")
} else {
// TODO: we need to do something when table name is declare by user
Expand All @@ -139,7 +139,7 @@ func (g *Generator) genModels(pkg *packages.Package, dstDir string, typeInferred
} else {
pk := t.keys[0]
g.L("func (v " + t.goName + ") PK() (string, int, any) {")
g.L(`return `, g.Quote(pk.ColumnName()), ", ", pk.ColumnPos(), ", ", g.valuer(importPkgs, "v."+pk.GoPath(), pk.Type()))
g.L(`return `, g.Quote(g.QuoteIdentifier(pk.ColumnName())), ", ", pk.ColumnPos(), ", ", g.valuer(importPkgs, "v."+pk.GoPath(), pk.Type()))
g.L("}")
}
}
Expand All @@ -154,7 +154,7 @@ func (g *Generator) genModels(pkg *packages.Package, dstDir string, typeInferred
if i > 0 {
g.WriteByte(',')
}
g.WriteString(g.Quote(f.ColumnName()))
g.WriteString(g.Quote(g.QuoteIdentifier(f.ColumnName())))
}
g.WriteString("}\n")
g.L("}")
Expand Down Expand Up @@ -244,16 +244,16 @@ func (g *Generator) genModels(pkg *packages.Package, dstDir string, typeInferred
matches := sqlFuncRegexp.FindStringSubmatch(sqlValuer("{}"))
if len(matches) > 4 {
g.L("func (v "+t.goName+") ", g.config.Getter.Prefix+f.GoName(), "() sequel.SQLColumnValuer[", typeStr, "] {")
g.L(`return sequel.SQLColumn`+specificType+`(`, g.Quote(f.ColumnName()), `, v.`, f.GoPath()+",", fmt.Sprintf(`func(placeholder string) string { return %q+ placeholder + %q}`, matches[1]+matches[2], matches[4]+matches[5]), `, func(val `, typeStr, `) driver.Value { return `, g.valuer(importPkgs, "val", f.Type()), ` })`)
g.L(`return sequel.SQLColumn`+specificType+`(`, g.Quote(g.QuoteIdentifier(f.ColumnName())), `, v.`, f.GoPath()+",", fmt.Sprintf(`func(placeholder string) string { return %q+ placeholder + %q}`, matches[1]+matches[2], matches[4]+matches[5]), `, func(val `, typeStr, `) driver.Value { return `, g.valuer(importPkgs, "val", f.Type()), ` })`)
g.L("}")
} else {
g.L("func (v "+t.goName+") ", g.config.Getter.Prefix+f.GoName(), "() sequel.ColumnValuer[", typeStr, "] {")
g.L(`return sequel.Column`, specificType, `(`, g.Quote(f.ColumnName()), `, v.`, f.GoPath(), `, func(val `, typeStr, `) driver.Value { return `, g.valuer(importPkgs, "val", f.Type()), ` })`)
g.L(`return sequel.Column`, specificType, `(`, g.Quote(g.QuoteIdentifier(f.ColumnName())), `, v.`, f.GoPath(), `, func(val `, typeStr, `) driver.Value { return `, g.valuer(importPkgs, "val", f.Type()), ` })`)
g.L("}")
}
} else {
g.L("func (v "+t.goName+") ", g.config.Getter.Prefix+f.GoName(), "() sequel.ColumnValuer[", typeStr, "] {")
g.L("return sequel.Column", specificType, "(", g.Quote(f.ColumnName()), ", v.", f.GoPath(), ", func(val ", typeStr, `) driver.Value { return `, g.valuer(importPkgs, "val", f.Type()), ` })`)
g.L("return sequel.Column", specificType, "(", g.Quote(g.QuoteIdentifier(f.ColumnName())), ", v.", f.GoPath(), ", func(val ", typeStr, `) driver.Value { return `, g.valuer(importPkgs, "val", f.Type()), ` })`)
g.L("}")
}
}
Expand Down Expand Up @@ -440,7 +440,7 @@ func (g *Generator) buildFindByPK(importPkgs *Package, t *tableInfo) {
if len(t.keys) > 1 {
// Composite primary key
keyCols := lo.Map(t.keys, func(v *columnInfo, _ int) string {
return v.ColumnName()
return g.QuoteIdentifier(v.ColumnName())
})
buf.WriteString("(" + strings.Join(keyCols, ",") + ")" + " = ")
buf.WriteByte('(')
Expand All @@ -456,7 +456,7 @@ func (g *Generator) buildFindByPK(importPkgs *Package, t *tableInfo) {
if i > 0 {
buf.WriteString(" AND ")
}
buf.WriteString(f.ColumnName() + " = " + g.dialect.QuoteVar(i+1))
buf.WriteString(g.QuoteIdentifier(f.ColumnName()) + " = " + g.dialect.QuoteVar(i+1))
}
}
buf.WriteString(" LIMIT 1;")
Expand All @@ -483,7 +483,7 @@ func (g *Generator) buildInsertOne(importPkgs *Package, t *tableInfo) {
if method, wrongType := t.Implements(sqlTabler); wrongType {
g.LogError(fmt.Errorf(`sqlgen: struct %q has function "TableName" but wrong footprint`, t.goName))
} else if method != nil {
buf.WriteString("INSERT INTO " + t.tableName)
buf.WriteString("INSERT INTO " + g.QuoteIdentifier(t.tableName))
} else {
query = g.Quote("INSERT INTO ") + "+ v.TableName() +"
}
Expand All @@ -492,7 +492,7 @@ func (g *Generator) buildInsertOne(importPkgs *Package, t *tableInfo) {
if i > 0 {
buf.WriteByte(',')
}
buf.WriteString(f.ColumnName())
buf.WriteString(g.QuoteIdentifier(f.ColumnName()))
}
buf.WriteString(") VALUES (")
for i, f := range columns {
Expand Down Expand Up @@ -550,12 +550,12 @@ func (g *Generator) buildUpdateByPK(importPkgs *Package, t *tableInfo) {
if i > 0 {
buf.WriteByte(',')
}
buf.WriteString(f.ColumnName() + " = " + g.sqlValuer(f, i))
buf.WriteString(g.QuoteIdentifier(f.ColumnName()) + " = " + g.sqlValuer(f, i))
}
buf.WriteString(" WHERE ")
if len(t.keys) > 1 {
keyCols := lo.Map(t.keys, func(v *columnInfo, _ int) string {
return v.ColumnName()
return g.QuoteIdentifier(v.ColumnName())
})
buf.WriteString("(" + strings.Join(keyCols, ",") + ")" + " = ")
buf.WriteByte('(')
Expand All @@ -571,7 +571,7 @@ func (g *Generator) buildUpdateByPK(importPkgs *Package, t *tableInfo) {
if i > 0 {
buf.WriteString(" AND ")
}
buf.WriteString(k.ColumnName() + " = " + g.sqlValuer(k, i+len(columns)))
buf.WriteString(g.QuoteIdentifier(k.ColumnName()) + " = " + g.sqlValuer(k, i+len(columns)))
}
}
buf.WriteByte(';')
Expand All @@ -591,7 +591,7 @@ func (g *Generator) buildUpdateByPK(importPkgs *Package, t *tableInfo) {

func (g *Generator) valuer(importPkgs *Package, goPath string, t types.Type) string {
utype, isPtr := underlyingType(t)
if columnType, ok := g.columnTypes[t.String()]; ok {
if columnType, ok := g.columnTypes[t.String()]; ok && columnType.Valuer != "" {
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath})
} else if _, wrong := types.MissingMethod(utype, goSqlValuer, true); wrong {
if isPtr {
Expand Down Expand Up @@ -630,12 +630,12 @@ func (g *Generator) sqlScanner(f *columnInfo) string {
if sqlScanner, ok := f.sqlScanner(); ok {
matches := sqlFuncRegexp.FindStringSubmatch(sqlScanner("{}"))
if len(matches) > 4 {
return matches[1] + matches[2] + f.ColumnName() + matches[4] + matches[5]
return matches[1] + matches[2] + g.QuoteIdentifier(f.ColumnName()) + matches[4] + matches[5]
} else {
return f.ColumnName()
return g.QuoteIdentifier(f.ColumnName())
}
}
return f.ColumnName()
return g.QuoteIdentifier(f.ColumnName())
}

func (g *Generator) sqlValuer(f *columnInfo, idx int) string {
Expand Down
1 change: 1 addition & 0 deletions codegen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ func (c *Config) initIfEmpty() {
}

func (c *Config) Merge(mapCfg *Config) *Config {
c.QuoteIdentifier = mapCfg.QuoteIdentifier
if mapCfg.Source != nil {
c.Source = append([]string{}, mapCfg.Source...)
}
Expand Down
30 changes: 15 additions & 15 deletions codegen/dialect/mysql/column_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func (s *mysqlDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
"[]string": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalStringSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.StringList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.StringSlice({{addrOfGoPath}})",
},
"[]byte": {
DataType: s.columnDataType("BLOB"),
Expand All @@ -205,68 +205,68 @@ func (s *mysqlDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
},
"[]bool": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalBoolList({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.BoolList({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalBoolSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.BoolSlice({{addrOfGoPath}})",
},
"[]int": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalIntSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.IntList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.IntSlice({{addrOfGoPath}})",
},
"[]int8": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalIntSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.IntList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.IntSlice({{addrOfGoPath}})",
},
"[]int16": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalIntSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.IntList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.IntSlice({{addrOfGoPath}})",
},
"[]int32": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalIntSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.IntList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.IntSlice({{addrOfGoPath}})",
},
"[]int64": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalIntSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.IntList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.IntSlice({{addrOfGoPath}})",
},
"[]uint": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalUintSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.UintList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.UintSlice({{addrOfGoPath}})",
},
"[]uint8": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalUintSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.UintList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.UintSlice({{addrOfGoPath}})",
},
"[]uint16": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalUintSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.UintList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.UintSlice({{addrOfGoPath}})",
},
"[]uint32": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalUintSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.UintList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.UintSlice({{addrOfGoPath}})",
},
"[]uint64": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalUintSlice({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.UintList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.UintSlice({{addrOfGoPath}})",
},
"[]float32": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalFloatList({{goPath}},-1)",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.FloatList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.FloatSlice({{addrOfGoPath}})",
},
"[]float64": {
DataType: s.columnDataType("JSON"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalFloatList({{goPath}},-1)",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.FloatList({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.FloatSlice({{addrOfGoPath}})",
},
"*": {
DataType: s.columnDataType("JSON"),
Expand Down
2 changes: 1 addition & 1 deletion codegen/sequel.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ type SQLColumnValuer[T any] interface {
SQLValue(placeholder string) string
}

type ColumnOrder interface {
type OrderByClause interface {
ColumnName() string
Asc() bool
}
14 changes: 9 additions & 5 deletions codegen/templates/db.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ type PaginateStmt struct {
Select []string
FromTable string
Where sequel.WhereClause
OrderBy []sequel.ColumnOrder
OrderBy []sequel.OrderByClause
Limit uint16
}
Expand Down Expand Up @@ -1159,7 +1159,7 @@ type SelectStmt struct {
Select []string
FromTable string
Where sequel.WhereClause
OrderBy []sequel.ColumnOrder
OrderBy []sequel.OrderByClause
GroupBy []string
Offset uint64
Limit uint16
Expand Down Expand Up @@ -1267,7 +1267,7 @@ type SelectOneStmt struct {
Select []string
FromTable string
Where sequel.WhereClause
OrderBy []sequel.ColumnOrder
OrderBy []sequel.OrderByClause
GroupBy []string
}
Expand Down Expand Up @@ -1347,14 +1347,14 @@ type UpdateStmt struct {
Table string
Set []sequel.SetClause
Where sequel.WhereClause
OrderBy []sequel.ColumnOrder
OrderBy []sequel.OrderByClause
Limit uint16
}
type DeleteStmt struct {
FromTable string
Where sequel.WhereClause
OrderBy []sequel.ColumnOrder
OrderBy []sequel.OrderByClause
Limit uint16
}
Expand Down Expand Up @@ -1520,7 +1520,11 @@ func (s *sqlStmt) Format(f fmt.State, verb rune) {
copy(args, s.args)
for {
{{ if isStaticVar -}}
idx = strings.Index(str, "?")
{{ else -}}
idx = strings.Index(str, wrapVar(i))
{{ end -}}
if idx < 0 {
f.Write(unsafe.Slice(unsafe.StringData(str), len(str)))
break
Expand Down
4 changes: 2 additions & 2 deletions codegen/templates/operator.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ func Set[T any](f sequel.ColumnValuer[T], value ...T) sequel.SetClause {
}
}

func Asc[T any](f sequel.ColumnValuer[T]) sequel.ColumnOrder {
func Asc[T any](f sequel.ColumnValuer[T]) sequel.OrderByClause {
return sequel.OrderByColumn(f.ColumnName(), true)
}

func Desc[T any](f sequel.ColumnValuer[T]) sequel.ColumnOrder {
func Desc[T any](f sequel.ColumnValuer[T]) sequel.OrderByClause {
return sequel.OrderByColumn(f.ColumnName(), false)
}
10 changes: 5 additions & 5 deletions examples/db/mysql/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ type PaginateStmt struct {
Select []string
FromTable string
Where sequel.WhereClause
OrderBy []sequel.ColumnOrder
OrderBy []sequel.OrderByClause
Limit uint16
}

Expand Down Expand Up @@ -762,7 +762,7 @@ type SelectStmt struct {
Select []string
FromTable string
Where sequel.WhereClause
OrderBy []sequel.ColumnOrder
OrderBy []sequel.OrderByClause
GroupBy []string
Offset uint64
Limit uint16
Expand Down Expand Up @@ -870,7 +870,7 @@ type SelectOneStmt struct {
Select []string
FromTable string
Where sequel.WhereClause
OrderBy []sequel.ColumnOrder
OrderBy []sequel.OrderByClause
GroupBy []string
}

Expand Down Expand Up @@ -950,14 +950,14 @@ type UpdateStmt struct {
Table string
Set []sequel.SetClause
Where sequel.WhereClause
OrderBy []sequel.ColumnOrder
OrderBy []sequel.OrderByClause
Limit uint16
}

type DeleteStmt struct {
FromTable string
Where sequel.WhereClause
OrderBy []sequel.ColumnOrder
OrderBy []sequel.OrderByClause
Limit uint16
}

Expand Down
4 changes: 2 additions & 2 deletions examples/db/mysql/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ func Set[T any](f sequel.ColumnValuer[T], value ...T) sequel.SetClause {
}
}

func Asc[T any](f sequel.ColumnValuer[T]) sequel.ColumnOrder {
func Asc[T any](f sequel.ColumnValuer[T]) sequel.OrderByClause {
return sequel.OrderByColumn(f.ColumnName(), true)
}

func Desc[T any](f sequel.ColumnValuer[T]) sequel.ColumnOrder {
func Desc[T any](f sequel.ColumnValuer[T]) sequel.OrderByClause {
return sequel.OrderByColumn(f.ColumnName(), false)
}
Loading

0 comments on commit bd6676d

Please sign in to comment.