Skip to content

Commit

Permalink
fix: go path for nested pointer of struct field
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Sep 29, 2024
1 parent 0ad6992 commit 7eb0f11
Show file tree
Hide file tree
Showing 15 changed files with 198 additions and 92 deletions.
32 changes: 23 additions & 9 deletions codegen/code_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,23 @@ func (g *Generator) buildValuer(importPkgs *Package, table *tableInfo) {
g.L("values := make([]any,", len(table.columns), ")")
for _, f := range table.columns {
if f.isPtr() {
g.L("if v." + f.goPath + " != nil {")
g.L("values[", f.columnPos, "] = ", g.valuer(importPkgs, "*v."+f.GoPath(), assertAsPtr[types.Pointer](f.GoType()).Elem()))
g.L("}")
paths := f.GoPaths()
goPath := "v."
queue := []string{}
for i := range paths {
if i > 0 {
goPath += "." + paths[i]
} else {
goPath += paths[i]
}
g.L("if " + goPath + " != nil {")
queue = append(queue, "}")
}
g.L("values[", f.columnPos, "] = ", g.valuer(importPkgs, "*"+goPath, assertAsPtr[types.Pointer](f.GoType()).Elem()))
for len(queue) > 0 {
g.L(queue[0])
queue = queue[1:]
}
} else {
g.L("values[", f.columnPos, "] = ", g.valuer(importPkgs, "v."+f.GoPath(), f.GoType()))
}
Expand All @@ -400,7 +414,7 @@ func (g *Generator) buildValuer(importPkgs *Package, table *tableInfo) {
if i > 0 {
g.WriteByte(',')
}
g.WriteString(g.valuer(importPkgs, "v."+f.goPath, f.t))
g.WriteString(g.valuer(importPkgs, "v."+f.GoPath(), f.t))
}
g.L("}")
}
Expand All @@ -416,12 +430,12 @@ func (g *Generator) buildScanner(importPkgs *Package, table *tableInfo) {
g.L("addrs := make([]any, ", len(table.columns), ")")
for _, f := range table.columns {
if f.isPtr() {
g.L("if v." + f.goPath + " == nil {")
g.L("v."+f.goPath+" = new(", Expr(strings.TrimPrefix(f.t.String(), "*")).Format(importPkgs, ExprParams{}), ")")
g.L("if v." + f.GoPath() + " == nil {")
g.L("v."+f.GoPath()+" = new(", Expr(strings.TrimPrefix(f.t.String(), "*")).Format(importPkgs, ExprParams{}), ")")
g.L("}")
g.L("addrs[", f.columnPos, "] = ", g.scanner(importPkgs, "v."+f.goPath, f.t))
g.L("addrs[", f.columnPos, "] = ", g.scanner(importPkgs, "v."+f.GoPath(), f.t))
} else {
g.L("addrs[", f.columnPos, "] = ", g.scanner(importPkgs, "&v."+f.goPath, f.t))
g.L("addrs[", f.columnPos, "] = ", g.scanner(importPkgs, "&v."+f.GoPath(), f.t))
}
}
g.L("return addrs")
Expand All @@ -431,7 +445,7 @@ func (g *Generator) buildScanner(importPkgs *Package, table *tableInfo) {
if i > 0 {
g.WriteByte(',')
}
g.WriteString(g.scanner(importPkgs, "&v."+f.goPath, f.t))
g.WriteString(g.scanner(importPkgs, "&v."+f.GoPath(), f.t))
}
g.WriteString("}\n")
}
Expand Down
63 changes: 37 additions & 26 deletions codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ const (
)

type typeQueue struct {
path string
idx []int
t *ast.StructType
pkg *packages.Package
// path string
paths []string
idx []int
t *ast.StructType
pkg *packages.Package
}

func Generate(c *Config) error {
Expand Down Expand Up @@ -363,7 +364,6 @@ func parseGoPackage(
structCaches = append(structCaches, structCache{name: typeSpec.Name, t: v, pkg: importPkg})
}
}

return true
})
}
Expand Down Expand Up @@ -408,11 +408,18 @@ func parseGoPackage(
// If the field is embedded struct
// `Type` can be either *ast.Ident or *ast.SelectorExpr
if fi.Names == nil {
t := fi.Type
var (
t = fi.Type
path string
)

// If it's an embedded struct with pointer
// we need to get the underlying type
if ut := assertAsPtr[ast.StarExpr](fi.Type); ut != nil {
path = "*"
t = ut.X
}

switch vi := t.(type) {
// Local struct
case *ast.Ident:
Expand All @@ -433,21 +440,22 @@ func parseGoPackage(
continue
}

path := types.ExprString(vi)
if f.path != "" {
path = f.path + "." + path
}
path += types.ExprString(vi)
// if f.path != "" {
// path = f.path + "." + path
// }
t := obj.Decl.(*ast.TypeSpec)
q = append(q, typeQueue{path: path, idx: append(f.idx, i), t: t.Type.(*ast.StructType), pkg: f.pkg})
q = append(q, typeQueue{paths: append(f.paths, path), idx: append(f.idx, i), t: t.Type.(*ast.StructType), pkg: f.pkg})

structFields = append(structFields, &structFieldType{
index: append(f.idx, i),
exported: vi.IsExported(),
embedded: true,
name: types.ExprString(vi),
tag: tag,
path: path,
t: f.pkg.TypesInfo.TypeOf(fi.Type),
// path: path,
paths: append(f.paths, path),
t: f.pkg.TypesInfo.TypeOf(fi.Type),
})
continue

Expand Down Expand Up @@ -478,23 +486,25 @@ func parseGoPackage(
continue
}

path := types.ExprString(vi.Sel)
if f.path != "" {
path = f.path + "." + path
}
path += types.ExprString(vi.Sel)
// if f.path != "" {
// path = f.path + "." + path
// }

// If it's a embedded struct, we continue on next loop
if st := assertAsPtr[ast.StructType](decl.Type); st != nil {
q = append(q, typeQueue{path: path, idx: append(f.idx, i), t: st, pkg: importPkg})
paths := append(f.paths, path)
q = append(q, typeQueue{paths: paths, idx: append(f.idx, i), t: st, pkg: importPkg})

structFields = append(structFields, &structFieldType{
index: append(f.idx, i),
exported: vi.Sel.IsExported(),
embedded: true,
name: types.ExprString(vi.Sel),
tag: tag,
path: path,
t: f.pkg.TypesInfo.TypeOf(fi.Type),
// path: path,
paths: paths,
t: f.pkg.TypesInfo.TypeOf(fi.Type),
})
}
continue
Expand Down Expand Up @@ -528,18 +538,19 @@ func parseGoPackage(

for j, n := range fi.Names {
path := types.ExprString(n)
if f.path != "" {
path = f.path + "." + path
}
// if f.path != "" {
// path = f.path + "." + path
// }

structFields = append(structFields, &structFieldType{
index: append(f.idx, i+j),
exported: n.IsExported(),
name: types.ExprString(n),
tag: tag,
enums: goEnum,
path: path,
t: f.pkg.TypesInfo.TypeOf(fi.Type),
// path: path,
paths: append(f.paths, path),
t: f.pkg.TypesInfo.TypeOf(fi.Type),
})
}
}
Expand Down Expand Up @@ -592,7 +603,7 @@ func parseGoPackage(
for _, f := range s.fields {
column := new(columnInfo)
column.goName = f.name
column.goPath = f.path
column.goPaths = f.paths
column.t = f.t
column.columnName = rename(f.name)
column.columnPos = pos
Expand Down
2 changes: 1 addition & 1 deletion codegen/dialect/postgres/column_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
"int": {
DataType: s.intDataType("int4", int64(0)),
Valuer: "(int64)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
"uint8": {
DataType: s.intDataType("int2", uint64(0)),
Expand Down
31 changes: 26 additions & 5 deletions codegen/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ type structType struct {
}

type structFieldType struct {
name string
index []int
path string
name string
index []int
// path string
paths []string
t types.Type
enums *enum
exported bool
Expand Down Expand Up @@ -131,7 +132,7 @@ type goTag struct {
// Some of the default behaviour is not able to override, such as go size, go enum, go tags, go path, go name, go nullable
type columnInfo struct {
goName string
goPath string
goPaths []string
columnName string
columnPos int
size int
Expand All @@ -158,7 +159,27 @@ func (c *columnInfo) GoName() string {
}

func (c *columnInfo) GoPath() string {
return c.goPath
return strings.Join(lo.Map(c.goPaths, func(v string, _ int) string {
if v[0] == '*' {
return v[1:]
}
return v
}), ".")
}

func (c *columnInfo) GoPaths() []string {
var goPath string
paths := []string{}
for _, path := range c.goPaths {
if path[0] == '*' {
paths = append(paths, goPath+path[1:])
goPath = ""
continue
}
goPath += path
}
paths = append(paths, goPath)
return paths
}

func (c *columnInfo) GoType() types.Type {
Expand Down
26 changes: 15 additions & 11 deletions examples/testcase/struct-field/pointer/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (v Ptr) PK() (string, int, any) {
return "id", 0, (int64)(v.ID)
}
func (Ptr) Columns() []string {
return []string{"id", "str", "bytes", "bool", "int", "int_8", "int_16", "int_32", "int_64", "uint", "uint_8", "uint_16", "uint_32", "uint_64", "f_32", "f_64", "time", "nested", "embeded_time"}
return []string{"id", "str", "bytes", "bool", "int", "int_8", "int_16", "int_32", "int_64", "uint", "uint_8", "uint_16", "uint_32", "uint_64", "f_32", "f_64", "time", "nested", "embedded_time"}
}
func (v Ptr) Values() []any {
values := make([]any, 19)
Expand Down Expand Up @@ -73,8 +73,12 @@ func (v Ptr) Values() []any {
if v.Nested != nil {
values[17] = types.JSONMarshaler(*v.Nested)
}
if v.embeded.EmbededTime != nil {
values[18] = (time.Time)(*v.embeded.EmbededTime)
if v.deepNested != nil {
if v.deepNested.embedded != nil {
if v.deepNested.embedded.EmbeddedTime != nil {
values[18] = (time.Time)(*v.deepNested.embedded.EmbeddedTime)
}
}
}
return values
}
Expand Down Expand Up @@ -149,23 +153,23 @@ func (v *Ptr) Addrs() []any {
v.Nested = new(nested)
}
addrs[17] = types.JSONUnmarshaler(v.Nested)
if v.embeded.EmbededTime == nil {
v.embeded.EmbededTime = new(time.Time)
if v.deepNested.embedded.EmbeddedTime == nil {
v.deepNested.embedded.EmbeddedTime = new(time.Time)
}
addrs[18] = types.Time(v.embeded.EmbededTime)
addrs[18] = types.Time(v.deepNested.embedded.EmbeddedTime)
return addrs
}
func (Ptr) InsertPlaceholders(row int) string {
return "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"
}
func (v Ptr) InsertOneStmt() (string, []any) {
return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embeded_time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.embeded.EmbededTime)}
return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embedded_time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.deepNested.embedded.EmbeddedTime)}
}
func (v Ptr) FindOneByPKStmt() (string, []any) {
return "SELECT id,str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embeded_time FROM ptr WHERE id = ? LIMIT 1;", []any{(int64)(v.ID)}
return "SELECT id,str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embedded_time FROM ptr WHERE id = ? LIMIT 1;", []any{(int64)(v.ID)}
}
func (v Ptr) UpdateOneByPKStmt() (string, []any) {
return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ?,nested = ?,embeded_time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.embeded.EmbededTime), (int64)(v.ID)}
return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ?,nested = ?,embedded_time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.deepNested.embedded.EmbeddedTime), (int64)(v.ID)}
}
func (v Ptr) GetID() sequel.ColumnValuer[int64] {
return sequel.Column("id", v.ID, func(val int64) driver.Value { return (int64)(val) })
Expand Down Expand Up @@ -221,6 +225,6 @@ func (v Ptr) GetTime() sequel.ColumnValuer[*time.Time] {
func (v Ptr) GetNested() sequel.ColumnValuer[*nested] {
return sequel.Column("nested", v.Nested, func(val *nested) driver.Value { return types.JSONMarshaler(val) })
}
func (v Ptr) GetEmbededTime() sequel.ColumnValuer[*time.Time] {
return sequel.Column("embeded_time", v.embeded.EmbededTime, func(val *time.Time) driver.Value { return types.Time(val) })
func (v Ptr) GetEmbeddedTime() sequel.ColumnValuer[*time.Time] {
return sequel.Column("embedded_time", v.deepNested.embedded.EmbeddedTime, func(val *time.Time) driver.Value { return types.Time(val) })
}
26 changes: 15 additions & 11 deletions examples/testcase/struct-field/pointer/generated.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (v Ptr) PK() (string, int, any) {
return "id", 0, (int64)(v.ID)
}
func (Ptr) Columns() []string {
return []string{"id", "str", "bytes", "bool", "int", "int_8", "int_16", "int_32", "int_64", "uint", "uint_8", "uint_16", "uint_32", "uint_64", "f_32", "f_64", "time", "nested", "embeded_time"}
return []string{"id", "str", "bytes", "bool", "int", "int_8", "int_16", "int_32", "int_64", "uint", "uint_8", "uint_16", "uint_32", "uint_64", "f_32", "f_64", "time", "nested", "embedded_time"}
}
func (v Ptr) Values() []any {
values := make([]any, 19)
Expand Down Expand Up @@ -73,8 +73,12 @@ func (v Ptr) Values() []any {
if v.Nested != nil {
values[17] = types.JSONMarshaler(*v.Nested)
}
if v.embeded.EmbededTime != nil {
values[18] = (time.Time)(*v.embeded.EmbededTime)
if v.deepNested != nil {
if v.deepNested.embedded != nil {
if v.deepNested.embedded.EmbeddedTime != nil {
values[18] = (time.Time)(*v.deepNested.embedded.EmbeddedTime)
}
}
}
return values
}
Expand Down Expand Up @@ -149,23 +153,23 @@ func (v *Ptr) Addrs() []any {
v.Nested = new(nested)
}
addrs[17] = types.JSONUnmarshaler(v.Nested)
if v.embeded.EmbededTime == nil {
v.embeded.EmbededTime = new(time.Time)
if v.deepNested.embedded.EmbeddedTime == nil {
v.deepNested.embedded.EmbeddedTime = new(time.Time)
}
addrs[18] = types.Time(v.embeded.EmbededTime)
addrs[18] = types.Time(v.deepNested.embedded.EmbeddedTime)
return addrs
}
func (Ptr) InsertPlaceholders(row int) string {
return "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"
}
func (v Ptr) InsertOneStmt() (string, []any) {
return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embeded_time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.embeded.EmbededTime)}
return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embedded_time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.deepNested.embedded.EmbeddedTime)}
}
func (v Ptr) FindOneByPKStmt() (string, []any) {
return "SELECT id,str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embeded_time FROM ptr WHERE id = ? LIMIT 1;", []any{(int64)(v.ID)}
return "SELECT id,str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embedded_time FROM ptr WHERE id = ? LIMIT 1;", []any{(int64)(v.ID)}
}
func (v Ptr) UpdateOneByPKStmt() (string, []any) {
return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ?,nested = ?,embeded_time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.embeded.EmbededTime), (int64)(v.ID)}
return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ?,nested = ?,embedded_time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.deepNested.embedded.EmbeddedTime), (int64)(v.ID)}
}
func (v Ptr) GetID() sequel.ColumnValuer[int64] {
return sequel.Column("id", v.ID, func(val int64) driver.Value { return (int64)(val) })
Expand Down Expand Up @@ -221,6 +225,6 @@ func (v Ptr) GetTime() sequel.ColumnValuer[*time.Time] {
func (v Ptr) GetNested() sequel.ColumnValuer[*nested] {
return sequel.Column("nested", v.Nested, func(val *nested) driver.Value { return types.JSONMarshaler(val) })
}
func (v Ptr) GetEmbededTime() sequel.ColumnValuer[*time.Time] {
return sequel.Column("embeded_time", v.embeded.EmbededTime, func(val *time.Time) driver.Value { return types.Time(val) })
func (v Ptr) GetEmbeddedTime() sequel.ColumnValuer[*time.Time] {
return sequel.Column("embedded_time", v.deepNested.embedded.EmbeddedTime, func(val *time.Time) driver.Value { return types.Time(val) })
}
Loading

0 comments on commit 7eb0f11

Please sign in to comment.