Skip to content

Commit

Permalink
fix: scanner and valuer
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Sep 18, 2024
1 parent 090944b commit c0f2da9
Show file tree
Hide file tree
Showing 15 changed files with 416 additions and 295 deletions.
67 changes: 57 additions & 10 deletions codegen/code_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,25 @@ func (g *Generator) buildCompositeKeys(importPkgs *Package, table *tableInfo) {
}

func (g *Generator) buildValuer(importPkgs *Package, table *tableInfo) {
// ptrs := lo.Filter(table.columns, func(v *columnInfo, _ int) bool {
// return v.isPtr()
// })
g.L("func (v " + table.goName + ") Values() []any {")
// if len(ptrs) > 0 {
// g.L("values := make([]any, ", len(table.columns), ")")
// for _, f := range table.columns {
// if f.isPtr() {
// g.L("if v." + f.goPath + " == nil {")
// log.Println(f.t.Underlying().String(), f.columnPos)
// g.L("values[", f.columnPos, "] = nil")
// g.L("}")
// } else {
// g.L("values[", f.columnPos, "] = ", g.valuer(importPkgs, "v."+f.GoPath(), f.Type()))
// }
// }
// g.L("return values")
// } else {
// }
g.WriteString("return []any{")
for i, f := range table.columns {
if i > 0 {
Expand All @@ -369,15 +387,33 @@ func (g *Generator) buildValuer(importPkgs *Package, table *tableInfo) {
}

func (g *Generator) buildScanner(importPkgs *Package, table *tableInfo) {
ptrs := lo.Filter(table.columns, func(v *columnInfo, _ int) bool {
return v.isPtr()
})
g.L("func (v *" + table.goName + ") Addrs() []any {")
g.WriteString("return []any{")
for i, f := range table.columns {
if i > 0 {
g.WriteByte(',')
if len(ptrs) > 0 {
g.L("addrs := make([]any, ", len(table.columns), ")")
for _, f := range table.columns {
if f.isPtr() {
g.L("if v." + f.goPath + " == nil {")
g.L("v."+f.goPath+" = new(", Expr(strings.TrimPrefix(f.t.String(), "*")).Format(importPkgs, ExprParams{}), ")")
g.L("}")
g.L("addrs[", f.columnPos, "] = ", g.scanner(importPkgs, "v."+f.GoPath(), f.Type()))
} else {
g.L("addrs[", f.columnPos, "] = ", g.scanner(importPkgs, "&v."+f.GoPath(), f.Type()))
}
}
g.WriteString(g.scanner(importPkgs, "&v."+f.GoPath(), f.Type()))
g.L("return addrs")
} else {
g.WriteString("return []any{")
for i, f := range table.columns {
if i > 0 {
g.WriteByte(',')
}
g.WriteString(g.scanner(importPkgs, "&v."+f.GoPath(), f.Type()))
}
g.WriteString("}\n")
}
g.WriteString("}\n")
g.L("}")
}

Expand Down Expand Up @@ -554,26 +590,37 @@ func (g *Generator) buildUpdateByPK(importPkgs *Package, t *tableInfo) {
}

func (g *Generator) valuer(importPkgs *Package, goPath string, t types.Type) string {
utype, isPtr := underlyingType(t)
if columnType, ok := g.columnTypes[t.String()]; ok {
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath})
} else if _, wrong := types.MissingMethod(t, goSqlValuer, true); wrong {
} else if _, wrong := types.MissingMethod(utype, goSqlValuer, true); wrong {
if isPtr {
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
}
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
} else if columnType, ok := g.columnDataType(t); ok && columnType.Valuer != "" {
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, Len: arraySize(t)})
} else if isImplemented(t, textMarshaler) {
} else if isImplemented(utype, textMarshaler) {
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextMarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
}
return Expr(g.defaultColumnTypes["*"].Valuer).Format(importPkgs, ExprParams{GoPath: goPath})
}

func (g *Generator) scanner(importPkgs *Package, goPath string, t types.Type) string {
ptr, isPtr := pointerType(t)
if columnType, ok := g.columnTypes[t.String()]; ok && columnType.Scanner != "" {
return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath})
} else if types.Implements(newPointer(t), goSqlScanner) {
} else if isImplemented(ptr, goSqlScanner) {
if isPtr {
return Expr("(database/sql.Scanner)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
}
return Expr("(database/sql.Scanner)({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
} else if columnType, ok := g.columnDataType(t); ok && columnType.Scanner != "" {
return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, Len: arraySize(t)})
} else if isImplemented(newPointer(t), textUnmarshaler) {
} else if isImplemented(ptr, textUnmarshaler) {
if isPtr {
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
}
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
}
return Expr(g.defaultColumnTypes["*"].Scanner).Format(importPkgs, ExprParams{GoPath: goPath})
Expand Down
4 changes: 4 additions & 0 deletions codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,10 @@ func parseGoPackage(
}

if gen.config.Migration != nil {
if err := os.MkdirAll(gen.config.Migration.Dir, os.ModePerm); err != nil {
return err
}

if err := gen.genMigrations(schemas); err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions codegen/dialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type GoColumn interface {
Size() int
GoPath() string
GoType() string
AutoIncr() bool
// Type() types.Type
Nullable() bool

Expand Down
Loading

0 comments on commit c0f2da9

Please sign in to comment.