From 7eb0f11ea181639ce456050616c32c972b9296be Mon Sep 17 00:00:00 2001 From: si3nloong Date: Sun, 29 Sep 2024 13:34:37 +0800 Subject: [PATCH] fix: go path for nested pointer of struct field --- codegen/code_generator.go | 32 +++++++--- codegen/codegen.go | 63 +++++++++++-------- codegen/dialect/postgres/column_type.go | 2 +- codegen/types.go | 31 +++++++-- .../struct-field/pointer/generated.go | 26 ++++---- .../struct-field/pointer/generated.go.tpl | 26 ++++---- .../struct-field/pointer/generated_test.go | 62 +++++++++++++++--- .../testcase/struct-field/pointer/model.go | 10 ++- sequel/types/binary.go | 5 +- sequel/types/byte_test.go | 3 + sequel/types/pgtype/bool_array.go | 4 +- sequel/types/pgtype/float_array.go | 5 +- sequel/types/pgtype/string_array.go | 5 +- sequel/types/text.go | 10 +-- sequel/types/time.go | 6 -- 15 files changed, 198 insertions(+), 92 deletions(-) diff --git a/codegen/code_generator.go b/codegen/code_generator.go index b5063e1..52cd884 100644 --- a/codegen/code_generator.go +++ b/codegen/code_generator.go @@ -386,9 +386,23 @@ func (g *Generator) buildValuer(importPkgs *Package, table *tableInfo) { g.L("values := make([]any,", len(table.columns), ")") for _, f := range table.columns { if f.isPtr() { - g.L("if v." + f.goPath + " != nil {") - g.L("values[", f.columnPos, "] = ", g.valuer(importPkgs, "*v."+f.GoPath(), assertAsPtr[types.Pointer](f.GoType()).Elem())) - g.L("}") + paths := f.GoPaths() + goPath := "v." + queue := []string{} + for i := range paths { + if i > 0 { + goPath += "." + paths[i] + } else { + goPath += paths[i] + } + g.L("if " + goPath + " != nil {") + queue = append(queue, "}") + } + g.L("values[", f.columnPos, "] = ", g.valuer(importPkgs, "*"+goPath, assertAsPtr[types.Pointer](f.GoType()).Elem())) + for len(queue) > 0 { + g.L(queue[0]) + queue = queue[1:] + } } else { g.L("values[", f.columnPos, "] = ", g.valuer(importPkgs, "v."+f.GoPath(), f.GoType())) } @@ -400,7 +414,7 @@ func (g *Generator) buildValuer(importPkgs *Package, table *tableInfo) { if i > 0 { g.WriteByte(',') } - g.WriteString(g.valuer(importPkgs, "v."+f.goPath, f.t)) + g.WriteString(g.valuer(importPkgs, "v."+f.GoPath(), f.t)) } g.L("}") } @@ -416,12 +430,12 @@ func (g *Generator) buildScanner(importPkgs *Package, table *tableInfo) { g.L("addrs := make([]any, ", len(table.columns), ")") for _, f := range table.columns { if f.isPtr() { - g.L("if v." + f.goPath + " == nil {") - g.L("v."+f.goPath+" = new(", Expr(strings.TrimPrefix(f.t.String(), "*")).Format(importPkgs, ExprParams{}), ")") + g.L("if v." + f.GoPath() + " == nil {") + g.L("v."+f.GoPath()+" = new(", Expr(strings.TrimPrefix(f.t.String(), "*")).Format(importPkgs, ExprParams{}), ")") g.L("}") - g.L("addrs[", f.columnPos, "] = ", g.scanner(importPkgs, "v."+f.goPath, f.t)) + g.L("addrs[", f.columnPos, "] = ", g.scanner(importPkgs, "v."+f.GoPath(), f.t)) } else { - g.L("addrs[", f.columnPos, "] = ", g.scanner(importPkgs, "&v."+f.goPath, f.t)) + g.L("addrs[", f.columnPos, "] = ", g.scanner(importPkgs, "&v."+f.GoPath(), f.t)) } } g.L("return addrs") @@ -431,7 +445,7 @@ func (g *Generator) buildScanner(importPkgs *Package, table *tableInfo) { if i > 0 { g.WriteByte(',') } - g.WriteString(g.scanner(importPkgs, "&v."+f.goPath, f.t)) + g.WriteString(g.scanner(importPkgs, "&v."+f.GoPath(), f.t)) } g.WriteString("}\n") } diff --git a/codegen/codegen.go b/codegen/codegen.go index a23aeec..deeb1ff 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -64,10 +64,11 @@ const ( ) type typeQueue struct { - path string - idx []int - t *ast.StructType - pkg *packages.Package + // path string + paths []string + idx []int + t *ast.StructType + pkg *packages.Package } func Generate(c *Config) error { @@ -363,7 +364,6 @@ func parseGoPackage( structCaches = append(structCaches, structCache{name: typeSpec.Name, t: v, pkg: importPkg}) } } - return true }) } @@ -408,11 +408,18 @@ func parseGoPackage( // If the field is embedded struct // `Type` can be either *ast.Ident or *ast.SelectorExpr if fi.Names == nil { - t := fi.Type + var ( + t = fi.Type + path string + ) + // If it's an embedded struct with pointer + // we need to get the underlying type if ut := assertAsPtr[ast.StarExpr](fi.Type); ut != nil { + path = "*" t = ut.X } + switch vi := t.(type) { // Local struct case *ast.Ident: @@ -433,12 +440,12 @@ func parseGoPackage( continue } - path := types.ExprString(vi) - if f.path != "" { - path = f.path + "." + path - } + path += types.ExprString(vi) + // if f.path != "" { + // path = f.path + "." + path + // } t := obj.Decl.(*ast.TypeSpec) - q = append(q, typeQueue{path: path, idx: append(f.idx, i), t: t.Type.(*ast.StructType), pkg: f.pkg}) + q = append(q, typeQueue{paths: append(f.paths, path), idx: append(f.idx, i), t: t.Type.(*ast.StructType), pkg: f.pkg}) structFields = append(structFields, &structFieldType{ index: append(f.idx, i), @@ -446,8 +453,9 @@ func parseGoPackage( embedded: true, name: types.ExprString(vi), tag: tag, - path: path, - t: f.pkg.TypesInfo.TypeOf(fi.Type), + // path: path, + paths: append(f.paths, path), + t: f.pkg.TypesInfo.TypeOf(fi.Type), }) continue @@ -478,14 +486,15 @@ func parseGoPackage( continue } - path := types.ExprString(vi.Sel) - if f.path != "" { - path = f.path + "." + path - } + path += types.ExprString(vi.Sel) + // if f.path != "" { + // path = f.path + "." + path + // } // If it's a embedded struct, we continue on next loop if st := assertAsPtr[ast.StructType](decl.Type); st != nil { - q = append(q, typeQueue{path: path, idx: append(f.idx, i), t: st, pkg: importPkg}) + paths := append(f.paths, path) + q = append(q, typeQueue{paths: paths, idx: append(f.idx, i), t: st, pkg: importPkg}) structFields = append(structFields, &structFieldType{ index: append(f.idx, i), @@ -493,8 +502,9 @@ func parseGoPackage( embedded: true, name: types.ExprString(vi.Sel), tag: tag, - path: path, - t: f.pkg.TypesInfo.TypeOf(fi.Type), + // path: path, + paths: paths, + t: f.pkg.TypesInfo.TypeOf(fi.Type), }) } continue @@ -528,9 +538,9 @@ func parseGoPackage( for j, n := range fi.Names { path := types.ExprString(n) - if f.path != "" { - path = f.path + "." + path - } + // if f.path != "" { + // path = f.path + "." + path + // } structFields = append(structFields, &structFieldType{ index: append(f.idx, i+j), @@ -538,8 +548,9 @@ func parseGoPackage( name: types.ExprString(n), tag: tag, enums: goEnum, - path: path, - t: f.pkg.TypesInfo.TypeOf(fi.Type), + // path: path, + paths: append(f.paths, path), + t: f.pkg.TypesInfo.TypeOf(fi.Type), }) } } @@ -592,7 +603,7 @@ func parseGoPackage( for _, f := range s.fields { column := new(columnInfo) column.goName = f.name - column.goPath = f.path + column.goPaths = f.paths column.t = f.t column.columnName = rename(f.name) column.columnPos = pos diff --git a/codegen/dialect/postgres/column_type.go b/codegen/dialect/postgres/column_type.go index d7d94f3..b38ff54 100644 --- a/codegen/dialect/postgres/column_type.go +++ b/codegen/dialect/postgres/column_type.go @@ -61,7 +61,7 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType { "int": { DataType: s.intDataType("int4", int64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", }, "uint8": { DataType: s.intDataType("int2", uint64(0)), diff --git a/codegen/types.go b/codegen/types.go index 78e5d3c..56ce983 100644 --- a/codegen/types.go +++ b/codegen/types.go @@ -17,9 +17,10 @@ type structType struct { } type structFieldType struct { - name string - index []int - path string + name string + index []int + // path string + paths []string t types.Type enums *enum exported bool @@ -131,7 +132,7 @@ type goTag struct { // Some of the default behaviour is not able to override, such as go size, go enum, go tags, go path, go name, go nullable type columnInfo struct { goName string - goPath string + goPaths []string columnName string columnPos int size int @@ -158,7 +159,27 @@ func (c *columnInfo) GoName() string { } func (c *columnInfo) GoPath() string { - return c.goPath + return strings.Join(lo.Map(c.goPaths, func(v string, _ int) string { + if v[0] == '*' { + return v[1:] + } + return v + }), ".") +} + +func (c *columnInfo) GoPaths() []string { + var goPath string + paths := []string{} + for _, path := range c.goPaths { + if path[0] == '*' { + paths = append(paths, goPath+path[1:]) + goPath = "" + continue + } + goPath += path + } + paths = append(paths, goPath) + return paths } func (c *columnInfo) GoType() types.Type { diff --git a/examples/testcase/struct-field/pointer/generated.go b/examples/testcase/struct-field/pointer/generated.go index 2190d61..b0b3314 100755 --- a/examples/testcase/struct-field/pointer/generated.go +++ b/examples/testcase/struct-field/pointer/generated.go @@ -17,7 +17,7 @@ func (v Ptr) PK() (string, int, any) { return "id", 0, (int64)(v.ID) } func (Ptr) Columns() []string { - return []string{"id", "str", "bytes", "bool", "int", "int_8", "int_16", "int_32", "int_64", "uint", "uint_8", "uint_16", "uint_32", "uint_64", "f_32", "f_64", "time", "nested", "embeded_time"} + return []string{"id", "str", "bytes", "bool", "int", "int_8", "int_16", "int_32", "int_64", "uint", "uint_8", "uint_16", "uint_32", "uint_64", "f_32", "f_64", "time", "nested", "embedded_time"} } func (v Ptr) Values() []any { values := make([]any, 19) @@ -73,8 +73,12 @@ func (v Ptr) Values() []any { if v.Nested != nil { values[17] = types.JSONMarshaler(*v.Nested) } - if v.embeded.EmbededTime != nil { - values[18] = (time.Time)(*v.embeded.EmbededTime) + if v.deepNested != nil { + if v.deepNested.embedded != nil { + if v.deepNested.embedded.EmbeddedTime != nil { + values[18] = (time.Time)(*v.deepNested.embedded.EmbeddedTime) + } + } } return values } @@ -149,23 +153,23 @@ func (v *Ptr) Addrs() []any { v.Nested = new(nested) } addrs[17] = types.JSONUnmarshaler(v.Nested) - if v.embeded.EmbededTime == nil { - v.embeded.EmbededTime = new(time.Time) + if v.deepNested.embedded.EmbeddedTime == nil { + v.deepNested.embedded.EmbeddedTime = new(time.Time) } - addrs[18] = types.Time(v.embeded.EmbededTime) + addrs[18] = types.Time(v.deepNested.embedded.EmbeddedTime) return addrs } func (Ptr) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" } func (v Ptr) InsertOneStmt() (string, []any) { - return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embeded_time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.embeded.EmbededTime)} + return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embedded_time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.deepNested.embedded.EmbeddedTime)} } func (v Ptr) FindOneByPKStmt() (string, []any) { - return "SELECT id,str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embeded_time FROM ptr WHERE id = ? LIMIT 1;", []any{(int64)(v.ID)} + return "SELECT id,str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embedded_time FROM ptr WHERE id = ? LIMIT 1;", []any{(int64)(v.ID)} } func (v Ptr) UpdateOneByPKStmt() (string, []any) { - return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ?,nested = ?,embeded_time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.embeded.EmbededTime), (int64)(v.ID)} + return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ?,nested = ?,embedded_time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.deepNested.embedded.EmbeddedTime), (int64)(v.ID)} } func (v Ptr) GetID() sequel.ColumnValuer[int64] { return sequel.Column("id", v.ID, func(val int64) driver.Value { return (int64)(val) }) @@ -221,6 +225,6 @@ func (v Ptr) GetTime() sequel.ColumnValuer[*time.Time] { func (v Ptr) GetNested() sequel.ColumnValuer[*nested] { return sequel.Column("nested", v.Nested, func(val *nested) driver.Value { return types.JSONMarshaler(val) }) } -func (v Ptr) GetEmbededTime() sequel.ColumnValuer[*time.Time] { - return sequel.Column("embeded_time", v.embeded.EmbededTime, func(val *time.Time) driver.Value { return types.Time(val) }) +func (v Ptr) GetEmbeddedTime() sequel.ColumnValuer[*time.Time] { + return sequel.Column("embedded_time", v.deepNested.embedded.EmbeddedTime, func(val *time.Time) driver.Value { return types.Time(val) }) } diff --git a/examples/testcase/struct-field/pointer/generated.go.tpl b/examples/testcase/struct-field/pointer/generated.go.tpl index 2190d61..b0b3314 100644 --- a/examples/testcase/struct-field/pointer/generated.go.tpl +++ b/examples/testcase/struct-field/pointer/generated.go.tpl @@ -17,7 +17,7 @@ func (v Ptr) PK() (string, int, any) { return "id", 0, (int64)(v.ID) } func (Ptr) Columns() []string { - return []string{"id", "str", "bytes", "bool", "int", "int_8", "int_16", "int_32", "int_64", "uint", "uint_8", "uint_16", "uint_32", "uint_64", "f_32", "f_64", "time", "nested", "embeded_time"} + return []string{"id", "str", "bytes", "bool", "int", "int_8", "int_16", "int_32", "int_64", "uint", "uint_8", "uint_16", "uint_32", "uint_64", "f_32", "f_64", "time", "nested", "embedded_time"} } func (v Ptr) Values() []any { values := make([]any, 19) @@ -73,8 +73,12 @@ func (v Ptr) Values() []any { if v.Nested != nil { values[17] = types.JSONMarshaler(*v.Nested) } - if v.embeded.EmbededTime != nil { - values[18] = (time.Time)(*v.embeded.EmbededTime) + if v.deepNested != nil { + if v.deepNested.embedded != nil { + if v.deepNested.embedded.EmbeddedTime != nil { + values[18] = (time.Time)(*v.deepNested.embedded.EmbeddedTime) + } + } } return values } @@ -149,23 +153,23 @@ func (v *Ptr) Addrs() []any { v.Nested = new(nested) } addrs[17] = types.JSONUnmarshaler(v.Nested) - if v.embeded.EmbededTime == nil { - v.embeded.EmbededTime = new(time.Time) + if v.deepNested.embedded.EmbeddedTime == nil { + v.deepNested.embedded.EmbeddedTime = new(time.Time) } - addrs[18] = types.Time(v.embeded.EmbededTime) + addrs[18] = types.Time(v.deepNested.embedded.EmbeddedTime) return addrs } func (Ptr) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" } func (v Ptr) InsertOneStmt() (string, []any) { - return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embeded_time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.embeded.EmbededTime)} + return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embedded_time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.deepNested.embedded.EmbeddedTime)} } func (v Ptr) FindOneByPKStmt() (string, []any) { - return "SELECT id,str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embeded_time FROM ptr WHERE id = ? LIMIT 1;", []any{(int64)(v.ID)} + return "SELECT id,str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time,nested,embedded_time FROM ptr WHERE id = ? LIMIT 1;", []any{(int64)(v.ID)} } func (v Ptr) UpdateOneByPKStmt() (string, []any) { - return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ?,nested = ?,embeded_time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.embeded.EmbededTime), (int64)(v.ID)} + return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ?,nested = ?,embedded_time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), types.JSONMarshaler(v.Nested), types.Time(v.deepNested.embedded.EmbeddedTime), (int64)(v.ID)} } func (v Ptr) GetID() sequel.ColumnValuer[int64] { return sequel.Column("id", v.ID, func(val int64) driver.Value { return (int64)(val) }) @@ -221,6 +225,6 @@ func (v Ptr) GetTime() sequel.ColumnValuer[*time.Time] { func (v Ptr) GetNested() sequel.ColumnValuer[*nested] { return sequel.Column("nested", v.Nested, func(val *nested) driver.Value { return types.JSONMarshaler(val) }) } -func (v Ptr) GetEmbededTime() sequel.ColumnValuer[*time.Time] { - return sequel.Column("embeded_time", v.embeded.EmbededTime, func(val *time.Time) driver.Value { return types.Time(val) }) +func (v Ptr) GetEmbeddedTime() sequel.ColumnValuer[*time.Time] { + return sequel.Column("embedded_time", v.deepNested.embedded.EmbeddedTime, func(val *time.Time) driver.Value { return types.Time(val) }) } diff --git a/examples/testcase/struct-field/pointer/generated_test.go b/examples/testcase/struct-field/pointer/generated_test.go index 65311f5..bc7cdc3 100644 --- a/examples/testcase/struct-field/pointer/generated_test.go +++ b/examples/testcase/struct-field/pointer/generated_test.go @@ -2,17 +2,63 @@ package pointer import ( "testing" + "time" "github.com/stretchr/testify/require" ) func TestPointer(t *testing.T) { - ptr := Ptr{} - ptr.embeded = new(embeded) - values := ptr.Values() - require.Equal(t, 19, len(values)) - require.Equal(t, int64(0), values[0]) - for i := 1; i < 18; i++ { - require.Nil(t, values[i]) - } + t.Run("Every field is nil value", func(t *testing.T) { + ptr := Ptr{} + values := ptr.Values() + require.Equal(t, 19, len(values)) + require.Equal(t, int64(0), values[0]) + for i := 1; i < 19; i++ { + require.Nil(t, values[i]) + } + }) + + t.Run("nested values", func(t *testing.T) { + t.Run("deepNested has value but descendant is nil", func(t *testing.T) { + ptr := Ptr{} + ptr.deepNested = &deepNested{} + values := ptr.Values() + require.Equal(t, 19, len(values)) + require.Equal(t, int64(0), values[0]) + for i := 1; i < 19; i++ { + require.Nil(t, values[i]) + } + }) + + t.Run("embedded has value but descendant is nil", func(t *testing.T) { + ptr := Ptr{} + ptr.deepNested = &deepNested{ + &embedded{}, + } + values := ptr.Values() + require.Equal(t, 19, len(values)) + require.Equal(t, int64(0), values[0]) + for i := 1; i < 19; i++ { + require.Nil(t, values[i]) + } + }) + + t.Run("EmbeddedTime has value", func(t *testing.T) { + ptr := Ptr{} + ts := time.Now() + ptr.deepNested = &deepNested{ + &embedded{ + EmbeddedTime: &ts, + }, + } + values := ptr.Values() + require.Equal(t, 19, len(values)) + require.Equal(t, int64(0), values[0]) + for i := 1; i < 18; i++ { + require.Nil(t, values[i]) + } + require.NotNil(t, values[18]) + require.Equal(t, ts.Format(time.RFC3339), (values[18]).(time.Time).Format(time.RFC3339)) + }) + }) } diff --git a/examples/testcase/struct-field/pointer/model.go b/examples/testcase/struct-field/pointer/model.go index d7c6aea..628fda5 100644 --- a/examples/testcase/struct-field/pointer/model.go +++ b/examples/testcase/struct-field/pointer/model.go @@ -21,13 +21,17 @@ type Ptr struct { F64 *float64 Time *time.Time `sql:",size:6"` Nested *nested - *embeded + *deepNested } type nested struct { ID *int64 } -type embeded struct { - EmbededTime *time.Time +type embedded struct { + EmbeddedTime *time.Time +} + +type deepNested struct { + *embedded } diff --git a/sequel/types/binary.go b/sequel/types/binary.go index 4e255cb..8a161f1 100644 --- a/sequel/types/binary.go +++ b/sequel/types/binary.go @@ -1,6 +1,7 @@ package types import ( + "database/sql" "database/sql/driver" "encoding" "fmt" @@ -14,7 +15,7 @@ type binaryMarshaler[T interface { func BinaryMarshaler[T interface { encoding.BinaryMarshaler -}](addr T) binaryMarshaler[T] { +}](addr T) driver.Valuer { return binaryMarshaler[T]{v: addr} } @@ -32,7 +33,7 @@ type binaryUnmarshaler[T any, Ptr interface { func BinaryUnmarshaler[T any, Ptr interface { *T encoding.BinaryUnmarshaler -}](addr Ptr) binaryUnmarshaler[T, Ptr] { +}](addr Ptr) sql.Scanner { return binaryUnmarshaler[T, Ptr]{v: addr} } diff --git a/sequel/types/byte_test.go b/sequel/types/byte_test.go index 3a80e33..b9058fe 100644 --- a/sequel/types/byte_test.go +++ b/sequel/types/byte_test.go @@ -23,6 +23,7 @@ func TestBytesArray(t *testing.T) { t.Run("Scan with non-overflow & smaller length string", func(t *testing.T) { const size = 10 fsBytes := [size]byte{} + require.Equal(t, fsBytes, [10]byte{}) bytes := FixedSizeBytes(fsBytes[:], size) require.Equal(t, size, bytes.size) @@ -35,6 +36,7 @@ func TestBytesArray(t *testing.T) { t.Run("Scan with non-overflow & smaller length string", func(t *testing.T) { const size = 10 fsBytes := [size]byte{} + require.Equal(t, fsBytes, [10]byte{}) bytes := FixedSizeBytes(fsBytes[:], size) require.Equal(t, size, bytes.size) @@ -59,5 +61,6 @@ func TestBytesArray(t *testing.T) { require.Equal(t, size, bytes.size) require.Error(t, bytes.Scan(`hello world`)) + // require.Equal(t, fsBytes, [10]byte([]byte(`hello world`))) }) } diff --git a/sequel/types/pgtype/bool_array.go b/sequel/types/pgtype/bool_array.go index 4048074..d0acdd8 100644 --- a/sequel/types/pgtype/bool_array.go +++ b/sequel/types/pgtype/bool_array.go @@ -4,6 +4,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "unsafe" ) // BoolArray represents a one-dimensional array of the PostgreSQL boolean type. @@ -36,8 +37,7 @@ func (a BoolArray[T]) Value() (driver.Value, error) { b[0] = '{' b[2*n] = '}' - - return string(b), nil + return unsafe.String(unsafe.SliceData(b), len(b)), nil } return "{}", nil } diff --git a/sequel/types/pgtype/float_array.go b/sequel/types/pgtype/float_array.go index 2137644..ef843c4 100644 --- a/sequel/types/pgtype/float_array.go +++ b/sequel/types/pgtype/float_array.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "fmt" "strconv" + "unsafe" ) // Float32Array represents a one-dimensional array of the PostgreSQL double @@ -27,8 +28,8 @@ func (a Float32Array[T]) Value() (driver.Value, error) { b = append(b, ',') b = strconv.AppendFloat(b, (float64)(a[i]), 'f', -1, 32) } - - return string(append(b, '}')), nil + b = append(b, '}') + return unsafe.String(unsafe.SliceData(b), len(b)), nil } return "{}", nil } diff --git a/sequel/types/pgtype/string_array.go b/sequel/types/pgtype/string_array.go index d8a39fb..627d69d 100644 --- a/sequel/types/pgtype/string_array.go +++ b/sequel/types/pgtype/string_array.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "fmt" + "unsafe" ) // StringArray represents a one-dimensional array of the PostgreSQL character types. @@ -25,8 +26,8 @@ func (a StringArray[T]) Value() (driver.Value, error) { b = append(b, ',') b = appendArrayQuotedBytes(b, []byte(a[i])) } - - return string(append(b, '}')), nil + b = append(b, '}') + return unsafe.String(unsafe.SliceData(b), len(b)), nil } return "{}", nil } diff --git a/sequel/types/text.go b/sequel/types/text.go index 1603742..de4089b 100644 --- a/sequel/types/text.go +++ b/sequel/types/text.go @@ -1,9 +1,11 @@ package types import ( + "database/sql" "database/sql/driver" "encoding" "fmt" + "unsafe" ) type textMarshaler[T interface { @@ -14,7 +16,7 @@ type textMarshaler[T interface { func TextMarshaler[T interface { encoding.TextMarshaler -}](addr T) textMarshaler[T] { +}](addr T) driver.Valuer { return textMarshaler[T]{v: addr} } @@ -32,17 +34,17 @@ type textUnmarshaler[T any, Ptr interface { func TextUnmarshaler[T any, Ptr interface { *T encoding.TextUnmarshaler -}](addr Ptr) textUnmarshaler[T, Ptr] { +}](addr Ptr) sql.Scanner { return textUnmarshaler[T, Ptr]{v: addr} } func (b textUnmarshaler[T, Ptr]) Scan(v any) error { switch vi := v.(type) { case string: - return b.v.UnmarshalText([]byte(vi)) + return b.v.UnmarshalText(unsafe.Slice(unsafe.StringData(vi), len(vi))) case []byte: return b.v.UnmarshalText(vi) default: - return fmt.Errorf(`sequel/types: text must be []byte to unmarshal`) + return fmt.Errorf(`sequel/types: unable to unmarshal %T`, vi) } } diff --git a/sequel/types/time.go b/sequel/types/time.go index bfa6296..91d1890 100644 --- a/sequel/types/time.go +++ b/sequel/types/time.go @@ -1,7 +1,6 @@ package types import ( - "database/sql" "database/sql/driver" "fmt" "regexp" @@ -20,11 +19,6 @@ type timestamp[T time.Time] struct { strictType bool } -var ( - _ sql.Scanner = (*timestamp[time.Time])(nil) - _ driver.Valuer = (*timestamp[time.Time])(nil) -) - func Time[T time.Time](addr *T, strict ...bool) ValueScanner[T] { var strictType bool if len(strict) > 0 {