diff --git a/codegen/dialect/mysql/column_type.go b/codegen/dialect/mysql/column_type.go index e11da77..af4b799 100644 --- a/codegen/dialect/mysql/column_type.go +++ b/codegen/dialect/mysql/column_type.go @@ -10,7 +10,7 @@ import ( ) func (s *mysqlDriver) ColumnDataTypes() map[string]*dialect.ColumnType { - return map[string]*dialect.ColumnType{ + dataTypes := map[string]*dialect.ColumnType{ "rune": { DataType: s.columnDataType("CHAR(1)"), Valuer: "(string)({{goPath}})", @@ -26,56 +26,6 @@ func (s *mysqlDriver) ColumnDataTypes() map[string]*dialect.ColumnType { Valuer: "(bool)({{goPath}})", Scanner: "github.com/si3nloong/sqlgen/sequel/types.Bool({{addrOfGoPath}})", }, - "int": { - DataType: s.columnDataType("INTEGER", int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - }, - "int8": { - DataType: s.columnDataType("TINYINT", int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - }, - "int16": { - DataType: s.columnDataType("SMALLINT", int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - }, - "int32": { - DataType: s.columnDataType("MEDIUMINT", int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - }, - "int64": { - DataType: s.columnDataType("BIGINT", int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - }, - "uint": { - DataType: s.columnDataType("INTEGER UNSIGNED", int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - }, - "uint8": { - DataType: s.columnDataType("TINYINT UNSIGNED", int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - }, - "uint16": { - DataType: s.columnDataType("SMALLINT UNSIGNED", int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - }, - "uint32": { - DataType: s.columnDataType("MEDIUMINT UNSIGNED", int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - }, - "uint64": { - DataType: s.columnDataType("BIGINT UNSIGNED", int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - }, "float32": { DataType: s.columnDataType("FLOAT", int64(0)), Valuer: "(float64)({{goPath}})", @@ -94,80 +44,32 @@ func (s *mysqlDriver) ColumnDataTypes() map[string]*dialect.ColumnType { "*string": { DataType: s.columnDataType("VARCHAR(255)"), Valuer: "github.com/si3nloong/sqlgen/sequel/types.String({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfString({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.String({{addrOfGoPath}})", }, "*[]byte": { DataType: s.columnDataType("BLOB"), Valuer: "github.com/si3nloong/sqlgen/sequel/types.String({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfString({{addrOfGoPath}})"}, + Scanner: "github.com/si3nloong/sqlgen/sequel/types.String({{addrOfGoPath}})", + }, "*bool": { DataType: s.columnDataType("BOOL"), Valuer: "github.com/si3nloong/sqlgen/sequel/types.Bool({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfBool({{addrOfGoPath}})"}, - "*uint": { - DataType: s.columnDataType("INTEGER UNSIGNED"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfInt({{addrOfGoPath}})", - }, - "*uint8": { - DataType: s.columnDataType("TINYINT UNSIGNED"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfInt({{addrOfGoPath}})", - }, - "*uint16": { - DataType: s.columnDataType("SMALLINT UNSIGNED"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfInt({{addrOfGoPath}})", - }, - "*uint32": { - DataType: s.columnDataType("MEDIUMINT UNSIGNED"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfInt({{addrOfGoPath}})", - }, - "*uint64": { - DataType: s.columnDataType("BIGINT UNSIGNED"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfInt({{addrOfGoPath}})", - }, - "*int": { - DataType: s.columnDataType("INTEGER"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfInt({{addrOfGoPath}})", - }, - "*int8": { - DataType: s.columnDataType("TINYINT"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfInt({{addrOfGoPath}})", - }, - "*int16": { - DataType: s.columnDataType("SMALLINT"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfInt({{addrOfGoPath}})", - }, - "*int32": { - DataType: s.columnDataType("MEDIUMINT"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfInt({{addrOfGoPath}})", - }, - "*int64": { - DataType: s.columnDataType("BIGINT"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfInt({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Bool({{addrOfGoPath}})", }, "*float32": { DataType: s.columnDataType("FLOAT"), Valuer: "github.com/si3nloong/sqlgen/sequel/types.Float({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfFloat({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", }, "*float64": { DataType: s.columnDataType("FLOAT"), Valuer: "github.com/si3nloong/sqlgen/sequel/types.Float({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfFloat({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", }, "*time.Time": { DataType: s.columnDataType("TIMESTAMP"), Valuer: "github.com/si3nloong/sqlgen/sequel/types.Time({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfTime({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Time({{addrOfGoPath}})", }, "sql.RawBytes": { DataType: s.columnDataType("TEXT"), @@ -274,6 +176,33 @@ func (s *mysqlDriver) ColumnDataTypes() map[string]*dialect.ColumnType { Scanner: "github.com/si3nloong/sqlgen/sequel/types.JSONUnmarshaler({{addrOfGoPath}})", }, } + s.mapIntegers(dataTypes) + return dataTypes +} + +func (s *mysqlDriver) mapIntegers(dict map[string]*dialect.ColumnType) { + types := [][2]string{ + {"int", "INTEGER"}, {"int8", "TINYINT"}, {"int16", "SMALLINT"}, {"int32", "MEDIUMINT"}, {"int64", "BIGINT"}, + {"uint", "INTEGER UNSIGNED"}, {"uint8", "TINYINT UNSIGNED"}, {"uint16", "SMALLINT UNSIGNED"}, {"uint32", "MEDIUMINT UNSIGNED"}, {"uint64", "BIGINT UNSIGNED"}, + } + for _, t := range types { + dict[t[0]] = &dialect.ColumnType{ + DataType: s.columnDataType(t[1], int64(0)), + Valuer: "(int64)({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + } + } + types = [][2]string{ + {"*int", "INTEGER"}, {"*int8", "TINYINT"}, {"*int16", "SMALLINT"}, {"*int32", "MEDIUMINT"}, {"*int64", "BIGINT"}, + {"*uint", "INTEGER UNSIGNED"}, {"*uint8", "TINYINT UNSIGNED"}, {"*uint16", "SMALLINT UNSIGNED"}, {"*uint32", "MEDIUMINT UNSIGNED"}, {"*uint64", "BIGINT UNSIGNED"}, + } + for _, t := range types { + dict[t[0]] = &dialect.ColumnType{ + DataType: s.columnDataType(t[1], int64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + } + } } func (*mysqlDriver) columnDataType(dataType string, defaultValue ...any) func(dialect.GoColumn) string { diff --git a/examples/go.mod b/examples/go.mod index ff06d36..da98837 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -8,7 +8,7 @@ require ( github.com/gofrs/uuid/v5 v5.0.0 github.com/google/uuid v1.6.0 github.com/jaswdr/faker v1.16.0 - github.com/lib/pq v1.10.7 + github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.16 github.com/paulmach/orb v0.11.1 github.com/si3nloong/sqlgen v1.0.0-alpha.3.0.20231118095154-390f9683bb93 diff --git a/examples/go.sum b/examples/go.sum index a960793..60e21ef 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -56,8 +56,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= -github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= diff --git a/examples/testcase/struct-field/pointer/generated.go b/examples/testcase/struct-field/pointer/generated.go index a4f63d1..2752f1f 100755 --- a/examples/testcase/struct-field/pointer/generated.go +++ b/examples/testcase/struct-field/pointer/generated.go @@ -28,67 +28,67 @@ func (v *Ptr) Addrs() []any { if v.Str == nil { v.Str = new(string) } - addrs[1] = types.PtrOfString(&v.Str) + addrs[1] = types.String(v.Str) if v.Bytes == nil { v.Bytes = new([]byte) } - addrs[2] = types.PtrOfString(&v.Bytes) + addrs[2] = types.String(v.Bytes) if v.Bool == nil { v.Bool = new(bool) } - addrs[3] = types.PtrOfBool(&v.Bool) + addrs[3] = types.Bool(v.Bool) if v.Int == nil { v.Int = new(int) } - addrs[4] = types.PtrOfInt(&v.Int) + addrs[4] = types.Integer(v.Int) if v.Int8 == nil { v.Int8 = new(int8) } - addrs[5] = types.PtrOfInt(&v.Int8) + addrs[5] = types.Integer(v.Int8) if v.Int16 == nil { v.Int16 = new(int16) } - addrs[6] = types.PtrOfInt(&v.Int16) + addrs[6] = types.Integer(v.Int16) if v.Int32 == nil { v.Int32 = new(int32) } - addrs[7] = types.PtrOfInt(&v.Int32) + addrs[7] = types.Integer(v.Int32) if v.Int64 == nil { v.Int64 = new(int64) } - addrs[8] = types.PtrOfInt(&v.Int64) + addrs[8] = types.Integer(v.Int64) if v.Uint == nil { v.Uint = new(uint) } - addrs[9] = types.PtrOfInt(&v.Uint) + addrs[9] = types.Integer(v.Uint) if v.Uint8 == nil { v.Uint8 = new(uint8) } - addrs[10] = types.PtrOfInt(&v.Uint8) + addrs[10] = types.Integer(v.Uint8) if v.Uint16 == nil { v.Uint16 = new(uint16) } - addrs[11] = types.PtrOfInt(&v.Uint16) + addrs[11] = types.Integer(v.Uint16) if v.Uint32 == nil { v.Uint32 = new(uint32) } - addrs[12] = types.PtrOfInt(&v.Uint32) + addrs[12] = types.Integer(v.Uint32) if v.Uint64 == nil { v.Uint64 = new(uint64) } - addrs[13] = types.PtrOfInt(&v.Uint64) + addrs[13] = types.Integer(v.Uint64) if v.F32 == nil { v.F32 = new(float32) } - addrs[14] = types.PtrOfFloat(&v.F32) + addrs[14] = types.Float(v.F32) if v.F64 == nil { v.F64 = new(float64) } - addrs[15] = types.PtrOfFloat(&v.F64) + addrs[15] = types.Float(v.F64) if v.Time == nil { v.Time = new(time.Time) } - addrs[16] = types.PtrOfTime(&v.Time) + addrs[16] = types.Time(v.Time) return addrs } func (Ptr) InsertPlaceholders(row int) string { diff --git a/examples/testcase/struct-field/pointer/generated.go.tpl b/examples/testcase/struct-field/pointer/generated.go.tpl index a4f63d1..2752f1f 100644 --- a/examples/testcase/struct-field/pointer/generated.go.tpl +++ b/examples/testcase/struct-field/pointer/generated.go.tpl @@ -28,67 +28,67 @@ func (v *Ptr) Addrs() []any { if v.Str == nil { v.Str = new(string) } - addrs[1] = types.PtrOfString(&v.Str) + addrs[1] = types.String(v.Str) if v.Bytes == nil { v.Bytes = new([]byte) } - addrs[2] = types.PtrOfString(&v.Bytes) + addrs[2] = types.String(v.Bytes) if v.Bool == nil { v.Bool = new(bool) } - addrs[3] = types.PtrOfBool(&v.Bool) + addrs[3] = types.Bool(v.Bool) if v.Int == nil { v.Int = new(int) } - addrs[4] = types.PtrOfInt(&v.Int) + addrs[4] = types.Integer(v.Int) if v.Int8 == nil { v.Int8 = new(int8) } - addrs[5] = types.PtrOfInt(&v.Int8) + addrs[5] = types.Integer(v.Int8) if v.Int16 == nil { v.Int16 = new(int16) } - addrs[6] = types.PtrOfInt(&v.Int16) + addrs[6] = types.Integer(v.Int16) if v.Int32 == nil { v.Int32 = new(int32) } - addrs[7] = types.PtrOfInt(&v.Int32) + addrs[7] = types.Integer(v.Int32) if v.Int64 == nil { v.Int64 = new(int64) } - addrs[8] = types.PtrOfInt(&v.Int64) + addrs[8] = types.Integer(v.Int64) if v.Uint == nil { v.Uint = new(uint) } - addrs[9] = types.PtrOfInt(&v.Uint) + addrs[9] = types.Integer(v.Uint) if v.Uint8 == nil { v.Uint8 = new(uint8) } - addrs[10] = types.PtrOfInt(&v.Uint8) + addrs[10] = types.Integer(v.Uint8) if v.Uint16 == nil { v.Uint16 = new(uint16) } - addrs[11] = types.PtrOfInt(&v.Uint16) + addrs[11] = types.Integer(v.Uint16) if v.Uint32 == nil { v.Uint32 = new(uint32) } - addrs[12] = types.PtrOfInt(&v.Uint32) + addrs[12] = types.Integer(v.Uint32) if v.Uint64 == nil { v.Uint64 = new(uint64) } - addrs[13] = types.PtrOfInt(&v.Uint64) + addrs[13] = types.Integer(v.Uint64) if v.F32 == nil { v.F32 = new(float32) } - addrs[14] = types.PtrOfFloat(&v.F32) + addrs[14] = types.Float(v.F32) if v.F64 == nil { v.F64 = new(float64) } - addrs[15] = types.PtrOfFloat(&v.F64) + addrs[15] = types.Float(v.F64) if v.Time == nil { v.Time = new(time.Time) } - addrs[16] = types.PtrOfTime(&v.Time) + addrs[16] = types.Time(v.Time) return addrs } func (Ptr) InsertPlaceholders(row int) string { diff --git a/examples/testcase/struct-field/sql/generated.go b/examples/testcase/struct-field/sql/generated.go index 5a5d321..0916977 100755 --- a/examples/testcase/struct-field/sql/generated.go +++ b/examples/testcase/struct-field/sql/generated.go @@ -36,7 +36,7 @@ func (v *AutoPkLocation) Addrs() []any { if v.PtrGeoPoint == nil { v.PtrGeoPoint = new(orb.Point) } - addrs[2] = types.JSONUnmarshaler(&v.PtrGeoPoint) + addrs[2] = types.JSONUnmarshaler(v.PtrGeoPoint) if v.PtrUUID == nil { v.PtrUUID = new(uuid.UUID) } diff --git a/examples/testcase/struct-field/sql/generated.go.tpl b/examples/testcase/struct-field/sql/generated.go.tpl index 5a5d321..0916977 100644 --- a/examples/testcase/struct-field/sql/generated.go.tpl +++ b/examples/testcase/struct-field/sql/generated.go.tpl @@ -36,7 +36,7 @@ func (v *AutoPkLocation) Addrs() []any { if v.PtrGeoPoint == nil { v.PtrGeoPoint = new(orb.Point) } - addrs[2] = types.JSONUnmarshaler(&v.PtrGeoPoint) + addrs[2] = types.JSONUnmarshaler(v.PtrGeoPoint) if v.PtrUUID == nil { v.PtrUUID = new(uuid.UUID) } diff --git a/examples/testcase/struct-field/valuer/generated.go b/examples/testcase/struct-field/valuer/generated.go index a66011a..0db9979 100755 --- a/examples/testcase/struct-field/valuer/generated.go +++ b/examples/testcase/struct-field/valuer/generated.go @@ -23,7 +23,7 @@ func (v *B) Addrs() []any { if v.PtrValue == nil { v.PtrValue = new(anyType) } - addrs[2] = types.JSONUnmarshaler(&v.PtrValue) + addrs[2] = types.JSONUnmarshaler(v.PtrValue) addrs[3] = types.String(&v.N) return addrs } diff --git a/examples/testcase/struct-field/valuer/generated.go.tpl b/examples/testcase/struct-field/valuer/generated.go.tpl index a66011a..0db9979 100644 --- a/examples/testcase/struct-field/valuer/generated.go.tpl +++ b/examples/testcase/struct-field/valuer/generated.go.tpl @@ -23,7 +23,7 @@ func (v *B) Addrs() []any { if v.PtrValue == nil { v.PtrValue = new(anyType) } - addrs[2] = types.JSONUnmarshaler(&v.PtrValue) + addrs[2] = types.JSONUnmarshaler(v.PtrValue) addrs[3] = types.String(&v.N) return addrs } diff --git a/sequel/types/ptr_bool.go b/sequel/types/ptr_bool.go deleted file mode 100644 index 3a0d7ab..0000000 --- a/sequel/types/ptr_bool.go +++ /dev/null @@ -1,55 +0,0 @@ -package types - -import ( - "fmt" - "strconv" - "unsafe" -) - -type ptrOfBoolLike[T ~bool] struct { - addr **T -} - -func PtrOfBool[T ~bool](v **T) ptrOfBoolLike[T] { - return ptrOfBoolLike[T]{addr: v} -} - -func (p ptrOfBoolLike[T]) Interface() *T { - if p.addr == nil { - return nil - } - return *p.addr -} - -func (p ptrOfBoolLike[T]) Scan(v any) error { - if v == nil { - (*p.addr) = nil - return nil - } - - switch vi := v.(type) { - case []byte: - b, err := strconv.ParseBool(unsafe.String(unsafe.SliceData(vi), len(vi))) - if err != nil { - return err - } - val := T(b) - *p.addr = &val - case string: - b, err := strconv.ParseBool(vi) - if err != nil { - return err - } - val := T(b) - *p.addr = &val - case bool: - val := T(vi) - *p.addr = &val - case int64: - val := T(vi != 0) - *p.addr = &val - default: - return fmt.Errorf(`sequel/types: unable to scan %T to *bool`, vi) - } - return nil -} diff --git a/sequel/types/ptr_float.go b/sequel/types/ptr_float.go deleted file mode 100644 index ace9e4f..0000000 --- a/sequel/types/ptr_float.go +++ /dev/null @@ -1,57 +0,0 @@ -package types - -import ( - "fmt" - "strconv" - "unsafe" - - "golang.org/x/exp/constraints" -) - -type ptrOfFloatLike[T constraints.Float] struct { - addr **T -} - -func PtrOfFloat[T constraints.Float](v **T) ptrOfFloatLike[T] { - return ptrOfFloatLike[T]{addr: v} -} - -func (p ptrOfFloatLike[T]) Interface() *T { - if p.addr == nil { - return nil - } - return *p.addr -} - -func (p ptrOfFloatLike[T]) Scan(v any) error { - if v == nil { - (*p.addr) = nil - return nil - } - - switch vi := v.(type) { - case []byte: - f, err := strconv.ParseFloat(unsafe.String(unsafe.SliceData(vi), len(vi)), 64) - if err != nil { - return err - } - val := T(f) - *p.addr = &val - case string: - f, err := strconv.ParseFloat(vi, 64) - if err != nil { - return err - } - val := T(f) - *p.addr = &val - case float64: - val := T(vi) - *p.addr = &val - case int64: - val := T(vi) - *p.addr = &val - default: - return fmt.Errorf(`sequel/types: unable to scan %T to *float`, vi) - } - return nil -} diff --git a/sequel/types/ptr_int.go b/sequel/types/ptr_int.go deleted file mode 100644 index 07a92eb..0000000 --- a/sequel/types/ptr_int.go +++ /dev/null @@ -1,47 +0,0 @@ -package types - -import ( - "fmt" - "strconv" - "unsafe" - - "golang.org/x/exp/constraints" -) - -type ptrOfIntLike[T constraints.Integer] struct { - addr **T -} - -func PtrOfInt[T constraints.Integer](v **T) ptrOfIntLike[T] { - return ptrOfIntLike[T]{addr: v} -} - -func (p ptrOfIntLike[T]) Interface() *T { - if p.addr == nil { - return nil - } - return *p.addr -} - -func (p ptrOfIntLike[T]) Scan(v any) error { - if v == nil { - (*p.addr) = nil - return nil - } - - switch vi := v.(type) { - case []byte: - i, err := strconv.ParseInt(unsafe.String(unsafe.SliceData(vi), len(vi)), 10, 64) - if err != nil { - return err - } - val := T(i) - *p.addr = &val - case int64: - val := T(vi) - *p.addr = &val - default: - return fmt.Errorf(`sequel/types: unable to scan %T to *int`, vi) - } - return nil -} diff --git a/sequel/types/ptr_string.go b/sequel/types/ptr_string.go deleted file mode 100644 index 7c9ac25..0000000 --- a/sequel/types/ptr_string.go +++ /dev/null @@ -1,37 +0,0 @@ -package types - -import "fmt" - -type ptrOfStrLike[T StringLikeType] struct { - addr **T -} - -func PtrOfString[T StringLikeType](v **T) ptrOfStrLike[T] { - return ptrOfStrLike[T]{addr: v} -} - -func (p ptrOfStrLike[T]) Interface() *T { - if p.addr == nil { - return nil - } - return *p.addr -} - -func (p ptrOfStrLike[T]) Scan(v any) error { - if v == nil { - (*p.addr) = nil - return nil - } - - switch vi := v.(type) { - case string: - val := T(vi) - *p.addr = &val - case []byte: - val := T(vi) - *p.addr = &val - default: - return fmt.Errorf(`sequel/types: unable to scan %T to *string`, vi) - } - return nil -} diff --git a/sequel/types/ptr_time.go b/sequel/types/ptr_time.go deleted file mode 100644 index b6fb12d..0000000 --- a/sequel/types/ptr_time.go +++ /dev/null @@ -1,77 +0,0 @@ -package types - -import ( - "fmt" - "regexp" - "time" - "unsafe" -) - -var ( - ddmmyyyy = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}$`) - ddmmyyyyhhmmss = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}\s\d{2}\:\d{2}:\d{2}$`) - ddmmyyyyhhmmsstz = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}\s\d{2}\:\d{2}:\d{2}\.\d+$`) -) - -type ptrOfTime[T time.Time] struct { - addr **T -} - -func PtrOfTime[T time.Time](v **T) ptrOfTime[T] { - return ptrOfTime[T]{addr: v} -} - -func (p ptrOfTime[T]) Interface() *T { - if p.addr == nil { - return nil - } - return *p.addr -} - -func (p ptrOfTime[T]) Scan(v any) error { - if v == nil { - (*p.addr) = nil - return nil - } - - switch vi := v.(type) { - case []byte: - t, err := parseTime(unsafe.String(unsafe.SliceData(vi), len(vi))) - if err != nil { - return err - } - val := T(t) - *p.addr = &val - case string: - t, err := parseTime(vi) - if err != nil { - return err - } - val := T(t) - *p.addr = &val - case time.Time: - val := T(vi) - *p.addr = &val - default: - return fmt.Errorf(`sequel/types: unable to scan %T to *time.Time`, vi) - } - return nil -} - -func parseTime(str string) (time.Time, error) { - var ( - t time.Time - err error - ) - switch { - case ddmmyyyy.MatchString(str): - t, err = time.Parse("2006-01-02", str) - case ddmmyyyyhhmmss.MatchString(str): - t, err = time.Parse("2006-01-02 15:04:05", str) - case ddmmyyyyhhmmsstz.MatchString(str): - t, err = time.Parse("2006-01-02 15:04:05.999999", str) - default: - t, err = time.Parse(time.RFC3339Nano, str) - } - return t, err -} diff --git a/sequel/types/time.go b/sequel/types/time.go index 74e5dd2..d35b7f8 100644 --- a/sequel/types/time.go +++ b/sequel/types/time.go @@ -4,10 +4,17 @@ import ( "database/sql" "database/sql/driver" "fmt" + "regexp" "time" "unsafe" ) +var ( + ddmmyyyy = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}$`) + ddmmyyyyhhmmss = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}\s\d{2}\:\d{2}:\d{2}$`) + ddmmyyyyhhmmsstz = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}\s\d{2}\:\d{2}:\d{2}\.\d+$`) +) + type datetime[T time.Time] struct { addr *T strictType bool @@ -65,3 +72,21 @@ func (s datetime[T]) Scan(v any) error { *s.addr = val return nil } + +func parseTime(str string) (time.Time, error) { + var ( + t time.Time + err error + ) + switch { + case ddmmyyyy.MatchString(str): + t, err = time.Parse("2006-01-02", str) + case ddmmyyyyhhmmss.MatchString(str): + t, err = time.Parse("2006-01-02 15:04:05", str) + case ddmmyyyyhhmmsstz.MatchString(str): + t, err = time.Parse("2006-01-02 15:04:05.999999", str) + default: + t, err = time.Parse(time.RFC3339Nano, str) + } + return t, err +}