From bed94b0d50aeae5cd17d41510ad0d5083a40dbd2 Mon Sep 17 00:00:00 2001 From: si3nloong Date: Sat, 28 Sep 2024 08:49:30 +0800 Subject: [PATCH] fix: pgtype and types --- codegen/dialect/mysql/column_type.go | 16 +- codegen/dialect/postgres/column_type.go | 142 ++++++++++-------- examples/testcase/main/generated.go | 2 +- examples/testcase/main/generated.go.tpl | 2 +- .../testcase/struct-field/alias/generated.go | 2 +- .../struct-field/alias/generated.go.tpl | 2 +- .../struct-field/pointer/generated.go | 14 +- .../struct-field/pointer/generated.go.tpl | 14 +- .../struct-field/primitive/generated.go | 2 +- .../struct-field/primitive/generated.go.tpl | 2 +- .../testcase/struct-field/slice/generated.go | 2 +- .../struct-field/slice/generated.go.tpl | 2 +- sequel/encoding/marshaler.go | 14 +- sequel/types/bool.go | 12 +- sequel/types/float.go | 88 ++++++++--- sequel/types/float_slice.go | 19 ++- sequel/types/float_test.go | 17 ++- sequel/types/int.go | 111 +++++++++++--- sequel/types/pgtype/bool_array.go | 109 ++++---------- sequel/types/pgtype/byte_array.go | 79 +++++----- sequel/types/pgtype/float_array.go | 85 ++++++----- sequel/types/pgtype/int_array.go | 35 +---- sequel/types/pgtype/string_array.go | 83 +++------- sequel/types/pgtype/uint_array.go | 45 ++---- sequel/types/uint.go | 91 +++++++++++ 25 files changed, 544 insertions(+), 446 deletions(-) create mode 100644 sequel/types/uint.go diff --git a/codegen/dialect/mysql/column_type.go b/codegen/dialect/mysql/column_type.go index bf822f2..a811de6 100644 --- a/codegen/dialect/mysql/column_type.go +++ b/codegen/dialect/mysql/column_type.go @@ -34,12 +34,12 @@ func (s *mysqlDriver) ColumnDataTypes() map[string]*dialect.ColumnType { "float32": { DataType: s.columnDataType("FLOAT", int64(0)), Valuer: "(float64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float32({{addrOfGoPath}})", }, "float64": { DataType: s.columnDataType("FLOAT", int64(0)), Valuer: "(float64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float64({{addrOfGoPath}})", }, "time.Time": { DataType: s.columnDataType("TIMESTAMP", sql.RawBytes("CURRENT_TIMESTAMP")), @@ -63,13 +63,13 @@ func (s *mysqlDriver) ColumnDataTypes() map[string]*dialect.ColumnType { }, "*float32": { DataType: s.columnDataType("FLOAT"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Float({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Float32({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float32({{addrOfGoPath}})", }, "*float64": { DataType: s.columnDataType("FLOAT"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Float({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Float64({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float64({{addrOfGoPath}})", }, "*time.Time": { DataType: s.columnDataType("TIMESTAMP"), @@ -168,12 +168,12 @@ func (s *mysqlDriver) ColumnDataTypes() map[string]*dialect.ColumnType { "[]float32": { DataType: s.columnDataType("JSON"), Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalFloatList({{goPath}},-1)", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.FloatSlice({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float32Slice({{addrOfGoPath}})", }, "[]float64": { DataType: s.columnDataType("JSON"), Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalFloatList({{goPath}},-1)", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.FloatSlice({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float64Slice({{addrOfGoPath}})", }, "*": { DataType: s.columnDataType("JSON"), diff --git a/codegen/dialect/postgres/column_type.go b/codegen/dialect/postgres/column_type.go index 657b5fa..0d681ba 100644 --- a/codegen/dialect/postgres/column_type.go +++ b/codegen/dialect/postgres/column_type.go @@ -41,62 +41,62 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType { "int8": { DataType: s.intDataType("int2", int64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int8({{addrOfGoPath}})", }, "int16": { DataType: s.intDataType("int2", int64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int16({{addrOfGoPath}})", }, "int32": { DataType: s.intDataType("int4", int64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int32({{addrOfGoPath}})", }, "int64": { DataType: s.columnDataType("bigint", int64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int64({{addrOfGoPath}})", }, "int": { DataType: s.intDataType("int4", int64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int({{addrOfGoPath}})", }, "uint8": { DataType: s.intDataType("int2", uint64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Uint8({{addrOfGoPath}})", }, "uint16": { DataType: s.intDataType("int2", uint64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Uint16({{addrOfGoPath}})", }, "uint32": { DataType: s.intDataType("int4", uint64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Uint32({{addrOfGoPath}})", }, "uint64": { DataType: s.intDataType("bigint", uint64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Uint64({{addrOfGoPath}})", }, "uint": { DataType: s.intDataType("int4", uint64(0)), Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Uint({{addrOfGoPath}})", }, "float32": { DataType: s.columnDataType("real"), Valuer: "(float64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float32({{addrOfGoPath}})", }, "float64": { DataType: s.columnDataType("double precision"), Valuer: "(float64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float64({{addrOfGoPath}})", }, "time.Time": { DataType: s.columnDataType("timestamptz(6)", sql.RawBytes(`NOW()`)), @@ -119,15 +119,65 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType { Valuer: "github.com/si3nloong/sqlgen/sequel/types.Bool({{addrOfGoPath}})", Scanner: "github.com/si3nloong/sqlgen/sequel/types.Bool({{addrOfGoPath}})", }, + "*int8": { + DataType: s.intDataType("int2", int64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int8({{addrOfGoPath}})", + }, + "*int16": { + DataType: s.intDataType("int2", int64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int16({{addrOfGoPath}})", + }, + "*int32": { + DataType: s.intDataType("int4", int64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int32({{addrOfGoPath}})", + }, + "*int64": { + DataType: s.columnDataType("bigint", int64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Int64({{addrOfGoPath}})", + }, + "*int": { + DataType: s.intDataType("int4", int64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", + }, + "*uint8": { + DataType: s.intDataType("int2", uint64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Uint8({{addrOfGoPath}})", + }, + "*uint16": { + DataType: s.intDataType("int2", uint64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Uint16({{addrOfGoPath}})", + }, + "*uint32": { + DataType: s.intDataType("int4", uint64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Uint32({{addrOfGoPath}})", + }, + "*uint64": { + DataType: s.intDataType("bigint", uint64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Uint64({{addrOfGoPath}})", + }, + "*uint": { + DataType: s.intDataType("int4", uint64(0)), + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Uint({{addrOfGoPath}})", + }, "*float32": { DataType: s.columnDataType("real"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Float32({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float32({{addrOfGoPath}})", }, "*float64": { DataType: s.columnDataType("double precision"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types.Float64({{addrOfGoPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float64({{addrOfGoPath}})", }, "*time.Time": { DataType: s.columnDataType("timestamptz(6)"), @@ -161,7 +211,7 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType { "[]string": { DataType: s.columnDataType("text[]"), Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArray[{{elemType}}]({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArrayScanner({{addrOfGoPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArray[{{elemType}}])({{addrOfGoPath}})", }, "[]rune": { DataType: s.columnDataType("text"), @@ -176,7 +226,7 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType { "[]bool": { DataType: s.columnDataType("bool[]"), Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.BoolArray[{{elemType}}]({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.BoolArrayScanner({{addrOfGoPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.BoolArray[{{elemType}}])({{addrOfGoPath}})", }, "[][]byte": { DataType: s.columnDataType("bytea"), @@ -196,52 +246,52 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType { "[]int": { DataType: s.columnDataType("int4[]"), Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayValue({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArray[{{elemType}}])({{addrOfGoPath}})", }, "[]int8": { DataType: s.columnDataType("int2[]"), Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayValue({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArray[{{elemType}}])({{addrOfGoPath}})", }, "[]int16": { DataType: s.columnDataType("int2[]"), Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Int16Array[{{elemType}}]({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArray[{{elemType}}])({{addrOfGoPath}})", }, "[]int32": { DataType: s.columnDataType("int4[]"), Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Int32Array[{{elemType}}]({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArray[{{elemType}}])({{addrOfGoPath}})", }, "[]int64": { DataType: s.columnDataType("int8[]"), Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Int64Array[{{elemType}}]({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArray[{{elemType}}])({{addrOfGoPath}})", }, "[]uint": { DataType: s.columnDataType("int4[]"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})", + Valuer: "(github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArray[{{elemType}}])({{goPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArray[{{elemType}}])({{addrOfGoPath}})", }, "[]uint8": { DataType: s.columnDataType("int2[]"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})", + Valuer: "(github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArray[{{elemType}}])({{goPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.Uint8Array[{{elemType}}])({{addrOfGoPath}})", }, "[]uint16": { DataType: s.columnDataType("int2[]"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})", + Valuer: "(github.com/si3nloong/sqlgen/sequel/types/pgtype.Uint16Array[{{elemType}}])({{goPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.Uint16Array[{{elemType}}])({{addrOfGoPath}})", }, "[]uint32": { DataType: s.columnDataType("int4[]"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})", + Valuer: "(github.com/si3nloong/sqlgen/sequel/types/pgtype.Uint32Array[{{elemType}}])({{goPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.Uint32Array[{{elemType}}])({{addrOfGoPath}})", }, "[]uint64": { DataType: s.columnDataType("int8[]"), - Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})", + Valuer: "(github.com/si3nloong/sqlgen/sequel/types/pgtype.Uint64Array[{{elemType}}])({{goPath}})", + Scanner: "(*github.com/si3nloong/sqlgen/sequel/types/pgtype.Uint64Array[{{elemType}}])({{addrOfGoPath}})", }, "*": { DataType: s.columnDataType("json"), @@ -249,35 +299,9 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType { Scanner: "github.com/si3nloong/sqlgen/sequel/types.JSONUnmarshaler({{addrOfGoPath}})", }, } - s.mapIntegers(dataTypes) return dataTypes } -func (s *postgresDriver) mapIntegers(dict map[string]*dialect.ColumnType) { - types := [][2]string{ - {"int", "int4"}, {"int8", "int2"}, {"int16", "int2"}, {"int32", "int4"}, {"int64", "bigint"}, - {"uint", "int4"}, {"uint8", "int2"}, {"uint16", "int2"}, {"uint32", "int4"}, {"uint64", "bigint"}, - } - for _, t := range types { - dict[t[0]] = &dialect.ColumnType{ - DataType: s.intDataType(t[1], int64(0)), - Valuer: "(int64)({{goPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - } - } - types = [][2]string{ - {"*int", "int4"}, {"*int8", "int2"}, {"*int16", "int2"}, {"*int32", "int4"}, {"*int64", "bigint"}, - {"*uint", "int4"}, {"*uint8", "int2"}, {"*uint16", "int2"}, {"*uint32", "int4"}, {"*uint64", "bigint"}, - } - for _, t := range types { - dict[t[0]] = &dialect.ColumnType{ - DataType: s.intDataType(t[1], int64(0)), - Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})", - } - } -} - func (s *postgresDriver) intDataType(dataType string, defaultValue ...any) func(dialect.GoColumn) string { return func(column dialect.GoColumn) string { str := dataType diff --git a/examples/testcase/main/generated.go b/examples/testcase/main/generated.go index 6867609..4a9c45d 100755 --- a/examples/testcase/main/generated.go +++ b/examples/testcase/main/generated.go @@ -48,7 +48,7 @@ func (v HouseUnit) Values() []any { return []any{(int64)(v.No), (time.Time)(v.BuildTime), types.JSONMarshaler(v.Address), (int64)(v.Kind), (int64)(v.Type), (int64)(v.Chan), types.JSONMarshaler(v.Inner), types.JSONMarshaler(v.Arr), encoding.MarshalFloatList(v.Slice, -1), types.JSONMarshaler(v.Map)} } func (v *HouseUnit) Addrs() []any { - return []any{types.Integer(&v.No), (*time.Time)(&v.BuildTime), types.JSONUnmarshaler(&v.Address), types.Integer(&v.Kind), types.Integer(&v.Type), types.Integer(&v.Chan), types.JSONUnmarshaler(&v.Inner), types.JSONUnmarshaler(&v.Arr), types.FloatSlice(&v.Slice), types.JSONUnmarshaler(&v.Map)} + return []any{types.Integer(&v.No), (*time.Time)(&v.BuildTime), types.JSONUnmarshaler(&v.Address), types.Integer(&v.Kind), types.Integer(&v.Type), types.Integer(&v.Chan), types.JSONUnmarshaler(&v.Inner), types.JSONUnmarshaler(&v.Arr), types.Float64Slice(&v.Slice), types.JSONUnmarshaler(&v.Map)} } func (HouseUnit) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?,?,?)" diff --git a/examples/testcase/main/generated.go.tpl b/examples/testcase/main/generated.go.tpl index 6867609..4a9c45d 100644 --- a/examples/testcase/main/generated.go.tpl +++ b/examples/testcase/main/generated.go.tpl @@ -48,7 +48,7 @@ func (v HouseUnit) Values() []any { return []any{(int64)(v.No), (time.Time)(v.BuildTime), types.JSONMarshaler(v.Address), (int64)(v.Kind), (int64)(v.Type), (int64)(v.Chan), types.JSONMarshaler(v.Inner), types.JSONMarshaler(v.Arr), encoding.MarshalFloatList(v.Slice, -1), types.JSONMarshaler(v.Map)} } func (v *HouseUnit) Addrs() []any { - return []any{types.Integer(&v.No), (*time.Time)(&v.BuildTime), types.JSONUnmarshaler(&v.Address), types.Integer(&v.Kind), types.Integer(&v.Type), types.Integer(&v.Chan), types.JSONUnmarshaler(&v.Inner), types.JSONUnmarshaler(&v.Arr), types.FloatSlice(&v.Slice), types.JSONUnmarshaler(&v.Map)} + return []any{types.Integer(&v.No), (*time.Time)(&v.BuildTime), types.JSONUnmarshaler(&v.Address), types.Integer(&v.Kind), types.Integer(&v.Type), types.Integer(&v.Chan), types.JSONUnmarshaler(&v.Inner), types.JSONUnmarshaler(&v.Arr), types.Float64Slice(&v.Slice), types.JSONUnmarshaler(&v.Map)} } func (HouseUnit) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?,?,?)" diff --git a/examples/testcase/struct-field/alias/generated.go b/examples/testcase/struct-field/alias/generated.go index cc338e6..3eb6965 100755 --- a/examples/testcase/struct-field/alias/generated.go +++ b/examples/testcase/struct-field/alias/generated.go @@ -23,7 +23,7 @@ func (v AliasStruct) Values() []any { return []any{(float64)(v.B), (int64)(v.pk.ID), (string)(v.Header), string(v.Raw), (string)(v.Text), (driver.Valuer)(v.NullStr), (time.Time)(v.model.Created), (time.Time)(v.model.Updated)} } func (v *AliasStruct) Addrs() []any { - return []any{types.Float(&v.B), types.Integer(&v.pk.ID), types.String(&v.Header), types.String(&v.Raw), types.String(&v.Text), (sql.Scanner)(&v.NullStr), (*time.Time)(&v.model.Created), (*time.Time)(&v.model.Updated)} + return []any{types.Float64(&v.B), types.Integer(&v.pk.ID), types.String(&v.Header), types.String(&v.Raw), types.String(&v.Text), (sql.Scanner)(&v.NullStr), (*time.Time)(&v.model.Created), (*time.Time)(&v.model.Updated)} } func (AliasStruct) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?)" diff --git a/examples/testcase/struct-field/alias/generated.go.tpl b/examples/testcase/struct-field/alias/generated.go.tpl index cc338e6..3eb6965 100644 --- a/examples/testcase/struct-field/alias/generated.go.tpl +++ b/examples/testcase/struct-field/alias/generated.go.tpl @@ -23,7 +23,7 @@ func (v AliasStruct) Values() []any { return []any{(float64)(v.B), (int64)(v.pk.ID), (string)(v.Header), string(v.Raw), (string)(v.Text), (driver.Valuer)(v.NullStr), (time.Time)(v.model.Created), (time.Time)(v.model.Updated)} } func (v *AliasStruct) Addrs() []any { - return []any{types.Float(&v.B), types.Integer(&v.pk.ID), types.String(&v.Header), types.String(&v.Raw), types.String(&v.Text), (sql.Scanner)(&v.NullStr), (*time.Time)(&v.model.Created), (*time.Time)(&v.model.Updated)} + return []any{types.Float64(&v.B), types.Integer(&v.pk.ID), types.String(&v.Header), types.String(&v.Raw), types.String(&v.Text), (sql.Scanner)(&v.NullStr), (*time.Time)(&v.model.Created), (*time.Time)(&v.model.Updated)} } func (AliasStruct) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?)" diff --git a/examples/testcase/struct-field/pointer/generated.go b/examples/testcase/struct-field/pointer/generated.go index 2752f1f..b8a46df 100755 --- a/examples/testcase/struct-field/pointer/generated.go +++ b/examples/testcase/struct-field/pointer/generated.go @@ -20,7 +20,7 @@ func (Ptr) Columns() []string { return []string{"id", "str", "bytes", "bool", "int", "int_8", "int_16", "int_32", "int_64", "uint", "uint_8", "uint_16", "uint_32", "uint_64", "f_32", "f_64", "time"} } func (v Ptr) Values() []any { - return []any{(int64)(v.ID), types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float(v.F32), types.Float(v.F64), types.Time(v.Time)} + return []any{(int64)(v.ID), types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time)} } func (v *Ptr) Addrs() []any { addrs := make([]any, 17) @@ -80,11 +80,11 @@ func (v *Ptr) Addrs() []any { if v.F32 == nil { v.F32 = new(float32) } - addrs[14] = types.Float(v.F32) + addrs[14] = types.Float32(v.F32) if v.F64 == nil { v.F64 = new(float64) } - addrs[15] = types.Float(v.F64) + addrs[15] = types.Float64(v.F64) if v.Time == nil { v.Time = new(time.Time) } @@ -95,13 +95,13 @@ func (Ptr) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" } func (v Ptr) InsertOneStmt() (string, []any) { - return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float(v.F32), types.Float(v.F64), types.Time(v.Time)} + return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time)} } func (v Ptr) FindOneByPKStmt() (string, []any) { return "SELECT id,str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time FROM ptr WHERE id = ? LIMIT 1;", []any{(int64)(v.ID)} } func (v Ptr) UpdateOneByPKStmt() (string, []any) { - return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float(v.F32), types.Float(v.F64), types.Time(v.Time), (int64)(v.ID)} + return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), (int64)(v.ID)} } func (v Ptr) GetID() sequel.ColumnValuer[int64] { return sequel.Column("id", v.ID, func(val int64) driver.Value { return (int64)(val) }) @@ -146,10 +146,10 @@ func (v Ptr) GetUint64() sequel.ColumnValuer[*uint64] { return sequel.Column("uint_64", v.Uint64, func(val *uint64) driver.Value { return types.Integer(val) }) } func (v Ptr) GetF32() sequel.ColumnValuer[*float32] { - return sequel.Column("f_32", v.F32, func(val *float32) driver.Value { return types.Float(val) }) + return sequel.Column("f_32", v.F32, func(val *float32) driver.Value { return types.Float32(val) }) } func (v Ptr) GetF64() sequel.ColumnValuer[*float64] { - return sequel.Column("f_64", v.F64, func(val *float64) driver.Value { return types.Float(val) }) + return sequel.Column("f_64", v.F64, func(val *float64) driver.Value { return types.Float64(val) }) } func (v Ptr) GetTime() sequel.ColumnValuer[*time.Time] { return sequel.Column("time", v.Time, func(val *time.Time) driver.Value { return types.Time(val) }) diff --git a/examples/testcase/struct-field/pointer/generated.go.tpl b/examples/testcase/struct-field/pointer/generated.go.tpl index 2752f1f..b8a46df 100644 --- a/examples/testcase/struct-field/pointer/generated.go.tpl +++ b/examples/testcase/struct-field/pointer/generated.go.tpl @@ -20,7 +20,7 @@ func (Ptr) Columns() []string { return []string{"id", "str", "bytes", "bool", "int", "int_8", "int_16", "int_32", "int_64", "uint", "uint_8", "uint_16", "uint_32", "uint_64", "f_32", "f_64", "time"} } func (v Ptr) Values() []any { - return []any{(int64)(v.ID), types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float(v.F32), types.Float(v.F64), types.Time(v.Time)} + return []any{(int64)(v.ID), types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time)} } func (v *Ptr) Addrs() []any { addrs := make([]any, 17) @@ -80,11 +80,11 @@ func (v *Ptr) Addrs() []any { if v.F32 == nil { v.F32 = new(float32) } - addrs[14] = types.Float(v.F32) + addrs[14] = types.Float32(v.F32) if v.F64 == nil { v.F64 = new(float64) } - addrs[15] = types.Float(v.F64) + addrs[15] = types.Float64(v.F64) if v.Time == nil { v.Time = new(time.Time) } @@ -95,13 +95,13 @@ func (Ptr) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" } func (v Ptr) InsertOneStmt() (string, []any) { - return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float(v.F32), types.Float(v.F64), types.Time(v.Time)} + return "INSERT INTO ptr (str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time)} } func (v Ptr) FindOneByPKStmt() (string, []any) { return "SELECT id,str,bytes,bool,int,int_8,int_16,int_32,int_64,uint,uint_8,uint_16,uint_32,uint_64,f_32,f_64,time FROM ptr WHERE id = ? LIMIT 1;", []any{(int64)(v.ID)} } func (v Ptr) UpdateOneByPKStmt() (string, []any) { - return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float(v.F32), types.Float(v.F64), types.Time(v.Time), (int64)(v.ID)} + return "UPDATE ptr SET str = ?,bytes = ?,bool = ?,int = ?,int_8 = ?,int_16 = ?,int_32 = ?,int_64 = ?,uint = ?,uint_8 = ?,uint_16 = ?,uint_32 = ?,uint_64 = ?,f_32 = ?,f_64 = ?,time = ? WHERE id = ?;", []any{types.String(v.Str), types.String(v.Bytes), types.Bool(v.Bool), types.Integer(v.Int), types.Integer(v.Int8), types.Integer(v.Int16), types.Integer(v.Int32), types.Integer(v.Int64), types.Integer(v.Uint), types.Integer(v.Uint8), types.Integer(v.Uint16), types.Integer(v.Uint32), types.Integer(v.Uint64), types.Float32(v.F32), types.Float64(v.F64), types.Time(v.Time), (int64)(v.ID)} } func (v Ptr) GetID() sequel.ColumnValuer[int64] { return sequel.Column("id", v.ID, func(val int64) driver.Value { return (int64)(val) }) @@ -146,10 +146,10 @@ func (v Ptr) GetUint64() sequel.ColumnValuer[*uint64] { return sequel.Column("uint_64", v.Uint64, func(val *uint64) driver.Value { return types.Integer(val) }) } func (v Ptr) GetF32() sequel.ColumnValuer[*float32] { - return sequel.Column("f_32", v.F32, func(val *float32) driver.Value { return types.Float(val) }) + return sequel.Column("f_32", v.F32, func(val *float32) driver.Value { return types.Float32(val) }) } func (v Ptr) GetF64() sequel.ColumnValuer[*float64] { - return sequel.Column("f_64", v.F64, func(val *float64) driver.Value { return types.Float(val) }) + return sequel.Column("f_64", v.F64, func(val *float64) driver.Value { return types.Float64(val) }) } func (v Ptr) GetTime() sequel.ColumnValuer[*time.Time] { return sequel.Column("time", v.Time, func(val *time.Time) driver.Value { return types.Time(val) }) diff --git a/examples/testcase/struct-field/primitive/generated.go b/examples/testcase/struct-field/primitive/generated.go index 3ca26e1..9f14bd8 100755 --- a/examples/testcase/struct-field/primitive/generated.go +++ b/examples/testcase/struct-field/primitive/generated.go @@ -18,7 +18,7 @@ func (v Primitive) Values() []any { return []any{(string)(v.Str), string(v.Bytes), (bool)(v.Bool), (int64)(v.Int), (int64)(v.Int8), (int64)(v.Int16), (int64)(v.Int32), (int64)(v.Int64), (int64)(v.Uint), (int64)(v.Uint8), (int64)(v.Uint16), (int64)(v.Uint32), (int64)(v.Uint64), (float64)(v.F32), (float64)(v.F64), (time.Time)(v.Time)} } func (v *Primitive) Addrs() []any { - return []any{types.String(&v.Str), types.String(&v.Bytes), types.Bool(&v.Bool), types.Integer(&v.Int), types.Integer(&v.Int8), types.Integer(&v.Int16), types.Integer(&v.Int32), types.Integer(&v.Int64), types.Integer(&v.Uint), types.Integer(&v.Uint8), types.Integer(&v.Uint16), types.Integer(&v.Uint32), types.Integer(&v.Uint64), types.Float(&v.F32), types.Float(&v.F64), (*time.Time)(&v.Time)} + return []any{types.String(&v.Str), types.String(&v.Bytes), types.Bool(&v.Bool), types.Integer(&v.Int), types.Integer(&v.Int8), types.Integer(&v.Int16), types.Integer(&v.Int32), types.Integer(&v.Int64), types.Integer(&v.Uint), types.Integer(&v.Uint8), types.Integer(&v.Uint16), types.Integer(&v.Uint32), types.Integer(&v.Uint64), types.Float32(&v.F32), types.Float64(&v.F64), (*time.Time)(&v.Time)} } func (Primitive) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" diff --git a/examples/testcase/struct-field/primitive/generated.go.tpl b/examples/testcase/struct-field/primitive/generated.go.tpl index 3ca26e1..9f14bd8 100644 --- a/examples/testcase/struct-field/primitive/generated.go.tpl +++ b/examples/testcase/struct-field/primitive/generated.go.tpl @@ -18,7 +18,7 @@ func (v Primitive) Values() []any { return []any{(string)(v.Str), string(v.Bytes), (bool)(v.Bool), (int64)(v.Int), (int64)(v.Int8), (int64)(v.Int16), (int64)(v.Int32), (int64)(v.Int64), (int64)(v.Uint), (int64)(v.Uint8), (int64)(v.Uint16), (int64)(v.Uint32), (int64)(v.Uint64), (float64)(v.F32), (float64)(v.F64), (time.Time)(v.Time)} } func (v *Primitive) Addrs() []any { - return []any{types.String(&v.Str), types.String(&v.Bytes), types.Bool(&v.Bool), types.Integer(&v.Int), types.Integer(&v.Int8), types.Integer(&v.Int16), types.Integer(&v.Int32), types.Integer(&v.Int64), types.Integer(&v.Uint), types.Integer(&v.Uint8), types.Integer(&v.Uint16), types.Integer(&v.Uint32), types.Integer(&v.Uint64), types.Float(&v.F32), types.Float(&v.F64), (*time.Time)(&v.Time)} + return []any{types.String(&v.Str), types.String(&v.Bytes), types.Bool(&v.Bool), types.Integer(&v.Int), types.Integer(&v.Int8), types.Integer(&v.Int16), types.Integer(&v.Int32), types.Integer(&v.Int64), types.Integer(&v.Uint), types.Integer(&v.Uint8), types.Integer(&v.Uint16), types.Integer(&v.Uint32), types.Integer(&v.Uint64), types.Float32(&v.F32), types.Float64(&v.F64), (*time.Time)(&v.Time)} } func (Primitive) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" diff --git a/examples/testcase/struct-field/slice/generated.go b/examples/testcase/struct-field/slice/generated.go index 93b8a79..1fb5581 100755 --- a/examples/testcase/struct-field/slice/generated.go +++ b/examples/testcase/struct-field/slice/generated.go @@ -23,7 +23,7 @@ func (v Slice) Values() []any { return []any{(int64)(v.ID), encoding.MarshalBoolSlice(v.BoolList), encoding.MarshalStringSlice(v.StrList), encoding.MarshalStringSlice(v.CustomStrList), encoding.MarshalIntSlice(v.IntList), encoding.MarshalIntSlice(v.Int8List), encoding.MarshalIntSlice(v.Int16List), encoding.MarshalIntSlice(v.Int32List), encoding.MarshalIntSlice(v.Int64List), encoding.MarshalUintSlice(v.UintList), encoding.MarshalUintSlice(v.Uint8List), encoding.MarshalUintSlice(v.Uint16List), encoding.MarshalUintSlice(v.Uint32List), encoding.MarshalUintSlice(v.Uint64List), encoding.MarshalFloatList(v.F32List, -1), encoding.MarshalFloatList(v.F64List, -1)} } func (v *Slice) Addrs() []any { - return []any{types.Integer(&v.ID), types.BoolSlice(&v.BoolList), types.StringSlice(&v.StrList), types.StringSlice(&v.CustomStrList), types.IntSlice(&v.IntList), types.IntSlice(&v.Int8List), types.IntSlice(&v.Int16List), types.IntSlice(&v.Int32List), types.IntSlice(&v.Int64List), types.UintSlice(&v.UintList), types.UintSlice(&v.Uint8List), types.UintSlice(&v.Uint16List), types.UintSlice(&v.Uint32List), types.UintSlice(&v.Uint64List), types.FloatSlice(&v.F32List), types.FloatSlice(&v.F64List)} + return []any{types.Integer(&v.ID), types.BoolSlice(&v.BoolList), types.StringSlice(&v.StrList), types.StringSlice(&v.CustomStrList), types.IntSlice(&v.IntList), types.IntSlice(&v.Int8List), types.IntSlice(&v.Int16List), types.IntSlice(&v.Int32List), types.IntSlice(&v.Int64List), types.UintSlice(&v.UintList), types.UintSlice(&v.Uint8List), types.UintSlice(&v.Uint16List), types.UintSlice(&v.Uint32List), types.UintSlice(&v.Uint64List), types.Float32Slice(&v.F32List), types.Float64Slice(&v.F64List)} } func (Slice) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" diff --git a/examples/testcase/struct-field/slice/generated.go.tpl b/examples/testcase/struct-field/slice/generated.go.tpl index 93b8a79..1fb5581 100644 --- a/examples/testcase/struct-field/slice/generated.go.tpl +++ b/examples/testcase/struct-field/slice/generated.go.tpl @@ -23,7 +23,7 @@ func (v Slice) Values() []any { return []any{(int64)(v.ID), encoding.MarshalBoolSlice(v.BoolList), encoding.MarshalStringSlice(v.StrList), encoding.MarshalStringSlice(v.CustomStrList), encoding.MarshalIntSlice(v.IntList), encoding.MarshalIntSlice(v.Int8List), encoding.MarshalIntSlice(v.Int16List), encoding.MarshalIntSlice(v.Int32List), encoding.MarshalIntSlice(v.Int64List), encoding.MarshalUintSlice(v.UintList), encoding.MarshalUintSlice(v.Uint8List), encoding.MarshalUintSlice(v.Uint16List), encoding.MarshalUintSlice(v.Uint32List), encoding.MarshalUintSlice(v.Uint64List), encoding.MarshalFloatList(v.F32List, -1), encoding.MarshalFloatList(v.F64List, -1)} } func (v *Slice) Addrs() []any { - return []any{types.Integer(&v.ID), types.BoolSlice(&v.BoolList), types.StringSlice(&v.StrList), types.StringSlice(&v.CustomStrList), types.IntSlice(&v.IntList), types.IntSlice(&v.Int8List), types.IntSlice(&v.Int16List), types.IntSlice(&v.Int32List), types.IntSlice(&v.Int64List), types.UintSlice(&v.UintList), types.UintSlice(&v.Uint8List), types.UintSlice(&v.Uint16List), types.UintSlice(&v.Uint32List), types.UintSlice(&v.Uint64List), types.FloatSlice(&v.F32List), types.FloatSlice(&v.F64List)} + return []any{types.Integer(&v.ID), types.BoolSlice(&v.BoolList), types.StringSlice(&v.StrList), types.StringSlice(&v.CustomStrList), types.IntSlice(&v.IntList), types.IntSlice(&v.Int8List), types.IntSlice(&v.Int16List), types.IntSlice(&v.Int32List), types.IntSlice(&v.Int64List), types.UintSlice(&v.UintList), types.UintSlice(&v.Uint8List), types.UintSlice(&v.Uint16List), types.UintSlice(&v.Uint32List), types.UintSlice(&v.Uint64List), types.Float32Slice(&v.F32List), types.Float64Slice(&v.F64List)} } func (Slice) InsertPlaceholders(row int) string { return "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" diff --git a/sequel/encoding/marshaler.go b/sequel/encoding/marshaler.go index 63847c1..eb252e3 100644 --- a/sequel/encoding/marshaler.go +++ b/sequel/encoding/marshaler.go @@ -30,14 +30,10 @@ func MarshalStringSlice[V ~[]byte | ~string](list []V, enclose ...[2]byte) strin return blr.String() } -func MarshalIntSlice[V constraints.Signed](list []V, enclose ...[2]byte) string { +func MarshalIntSlice[V constraints.Signed](list []V) string { blr := strpool.AcquireString() defer strpool.ReleaseString(blr) - if len(enclose) > 0 { - blr.WriteByte(enclose[0][0]) - } else { - blr.WriteByte('[') - } + blr.WriteByte('[') for i := range list { if i > 0 { blr.WriteString("," + strconv.FormatInt((int64)(list[i]), 10)) @@ -45,11 +41,7 @@ func MarshalIntSlice[V constraints.Signed](list []V, enclose ...[2]byte) string blr.WriteString(strconv.FormatInt((int64)(list[i]), 10)) } } - if len(enclose) > 0 { - blr.WriteByte(enclose[0][1]) - } else { - blr.WriteByte(']') - } + blr.WriteByte(']') return blr.String() } diff --git a/sequel/types/bool.go b/sequel/types/bool.go index 5d39193..1265b82 100644 --- a/sequel/types/bool.go +++ b/sequel/types/bool.go @@ -1,7 +1,6 @@ package types import ( - "database/sql" "database/sql/driver" "fmt" "strconv" @@ -13,13 +12,8 @@ type boolLike[T ~bool] struct { strictType bool } -var ( - _ sql.Scanner = (*boolLike[bool])(nil) - _ driver.Valuer = (*boolLike[bool])(nil) -) - // Bool returns a sql.Scanner -func Bool[T ~bool](addr *T, strict ...bool) *boolLike[T] { +func Bool[T ~bool](addr *T, strict ...bool) ValueScanner[T] { var strictType bool if len(strict) > 0 { strictType = strict[0] @@ -34,13 +28,15 @@ func (b boolLike[T]) Interface() T { return *b.addr } +// Value implements the driver.Valuer interface. func (b boolLike[T]) Value() (driver.Value, error) { if b.addr == nil { return nil, nil } - return bool(*b.addr), nil + return (bool)(*b.addr), nil } +// Scan implements the sql.Scanner interface. func (b *boolLike[T]) Scan(v any) error { var val T switch vi := v.(type) { diff --git a/sequel/types/float.go b/sequel/types/float.go index 25cdb2a..1e64f83 100644 --- a/sequel/types/float.go +++ b/sequel/types/float.go @@ -1,63 +1,117 @@ package types import ( - "database/sql" "database/sql/driver" "fmt" "strconv" "unsafe" - - "golang.org/x/exp/constraints" ) -type floatLike[T constraints.Float] struct { +type float32Like[T ~float32] struct { addr *T strictType bool } -var ( - _ sql.Scanner = (*floatLike[float32])(nil) - _ driver.Valuer = (*floatLike[float32])(nil) -) - // Float returns a sql.Scanner -func Float[T constraints.Float](addr *T, strict ...bool) *floatLike[T] { +func Float32[T ~float32](addr *T, strict ...bool) ValueScanner[T] { var strictType bool if len(strict) > 0 { strictType = strict[0] } - return &floatLike[T]{addr: addr, strictType: strictType} + return &float32Like[T]{addr: addr, strictType: strictType} } -func (f floatLike[T]) Interface() T { +func (f float32Like[T]) Interface() T { if f.addr == nil { return *new(T) } return *f.addr } -func (f floatLike[T]) Value() (driver.Value, error) { +func (f float32Like[T]) Value() (driver.Value, error) { if f.addr == nil { return nil, nil } return float64(*f.addr), nil } -func (f *floatLike[T]) Scan(v any) error { +func (f *float32Like[T]) Scan(v any) error { var val T switch vi := v.(type) { case float64: val = T(vi) case int64: val = T(vi) - case uint64: + case nil: + f.addr = nil + return nil + default: + if f.strictType { + return fmt.Errorf(`sequel/types: unable to scan %T to ~float32`, vi) + } + + switch vi := v.(type) { + case []byte: + f, err := strconv.ParseFloat(unsafe.String(unsafe.SliceData(vi), len(vi)), 64) + if err != nil { + return err + } + val = T(f) + case string: + f, err := strconv.ParseFloat(vi, 32) + if err != nil { + return err + } + val = T(f) + default: + return fmt.Errorf(`sequel/types: unable to scan %T to ~float32`, vi) + } + } + *f.addr = val + return nil +} + +type float64Like[T ~float64] struct { + addr *T + strictType bool +} + +// Float returns a sql.Scanner +func Float64[T ~float64](addr *T, strict ...bool) ValueScanner[T] { + var strictType bool + if len(strict) > 0 { + strictType = strict[0] + } + return &float64Like[T]{addr: addr, strictType: strictType} +} + +func (f float64Like[T]) Interface() T { + if f.addr == nil { + return *new(T) + } + return *f.addr +} + +func (f float64Like[T]) Value() (driver.Value, error) { + if f.addr == nil { + return nil, nil + } + return float64(*f.addr), nil +} + +func (f *float64Like[T]) Scan(v any) error { + var val T + switch vi := v.(type) { + case float64: + val = T(vi) + case int64: val = T(vi) case nil: f.addr = nil return nil default: if f.strictType { - return fmt.Errorf(`sequel/types: unable to scan %T to ~float`, vi) + return fmt.Errorf(`sequel/types: unable to scan %T to ~float64`, vi) } switch vi := v.(type) { @@ -74,7 +128,7 @@ func (f *floatLike[T]) Scan(v any) error { } val = T(f) default: - return fmt.Errorf(`sequel/types: unable to scan %T to ~float`, vi) + return fmt.Errorf(`sequel/types: unable to scan %T to ~float64`, vi) } } *f.addr = val diff --git a/sequel/types/float_slice.go b/sequel/types/float_slice.go index eb4ba5f..c0a3aae 100644 --- a/sequel/types/float_slice.go +++ b/sequel/types/float_slice.go @@ -14,7 +14,8 @@ import ( ) type floatList[T constraints.Float] struct { - v *[]T + v *[]T + prec int } var ( @@ -24,8 +25,12 @@ var ( _ sql.Scanner = (*floatList[float64])(nil) ) -func FloatSlice[T constraints.Float](v *[]T) floatList[T] { - return floatList[T]{v: v} +func Float32Slice[T constraints.Float](v *[]T) floatList[T] { + return floatList[T]{v: v, prec: 32} +} + +func Float64Slice[T constraints.Float](v *[]T) floatList[T] { + return floatList[T]{v: v, prec: 64} } func (s floatList[T]) Value() (driver.Value, error) { @@ -57,11 +62,11 @@ func (s *floatList[T]) Scan(v any) error { ) for i := range paths { b = bytes.TrimSpace(paths[i]) - f64, err := strconv.ParseFloat(unsafe.String(unsafe.SliceData(b), len(b)), 64) + f, err := strconv.ParseFloat(unsafe.String(unsafe.SliceData(b), len(b)), s.prec) if err != nil { return err } - values[i] = T(f64) + values[i] = T(f) } *s.v = values case string: @@ -84,11 +89,11 @@ func (s *floatList[T]) Scan(v any) error { ) for i := range paths { b = strings.TrimSpace(paths[i]) - f64, err := strconv.ParseFloat(b, 64) + f, err := strconv.ParseFloat(b, s.prec) if err != nil { return err } - values[i] = T(f64) + values[i] = T(f) } *s.v = values case nil: diff --git a/sequel/types/float_test.go b/sequel/types/float_test.go index 2113872..8ee3a4e 100644 --- a/sequel/types/float_test.go +++ b/sequel/types/float_test.go @@ -10,9 +10,24 @@ func TestFloat(t *testing.T) { t.Run("Scan with primitive types", func(t *testing.T) { t.Run("float32", func(t *testing.T) { var f32 float32 - v := Float(&f32) + v := Float32(&f32) require.NoError(t, v.Scan(float64(81.20022))) require.Equal(t, float32(81.20022), v.Interface()) }) + + t.Run("~float32", func(t *testing.T) { + type F32 float32 + var f F32 + v := Float32(&f) + require.NoError(t, v.Scan(float64(81.20022))) + require.Equal(t, F32(81.20022), v.Interface()) + }) + + t.Run("float64", func(t *testing.T) { + var f64 float64 + v := Float64(&f64) + require.NoError(t, v.Scan(float64(81.20022))) + require.Equal(t, float64(81.20022), v.Interface()) + }) }) } diff --git a/sequel/types/int.go b/sequel/types/int.go index 18732a3..471fc37 100644 --- a/sequel/types/int.go +++ b/sequel/types/int.go @@ -14,6 +14,57 @@ type intLike[T constraints.Integer] struct { strictType bool } +func (i intLike[T]) Interface() T { + if i.addr == nil { + return *new(T) + } + return *i.addr +} + +func (i intLike[T]) Value() (driver.Value, error) { + if i.addr == nil { + return nil, nil + } + return int64(*i.addr), nil +} + +func (i *intLike[T]) Scan(v any) error { + var val T + switch vi := v.(type) { + case []byte: + m, err := strconv.Atoi(unsafe.String(unsafe.SliceData(vi), len(vi))) + if err != nil { + return err + } + val = T(m) + case int64: + val = T(vi) + case nil: + i.addr = nil + return nil + + default: + if i.strictType { + return fmt.Errorf(`sequel/types: unable to scan %T to ~int`, vi) + } + + switch vi := v.(type) { + case string: + m, err := strconv.Atoi(vi) + if err != nil { + return err + } + val = T(m) + case float64: + val = T(vi) + default: + return fmt.Errorf(`sequel/types: unable to scan %T to ~int`, vi) + } + } + *i.addr = val + return nil +} + func Integer[T constraints.Integer](addr *T, strict ...bool) ValueScanner[T] { var strictType bool if len(strict) > 0 { @@ -22,25 +73,55 @@ func Integer[T constraints.Integer](addr *T, strict ...bool) ValueScanner[T] { return &intLike[T]{addr: addr, strictType: strictType} } -func (i intLike[T]) Interface() T { +func Int8[T ~int8](addr *T, strict ...bool) ValueScanner[T] { + return newFixedSizeInt(addr, 8, strict...) +} + +func Int16[T ~int16](addr *T, strict ...bool) ValueScanner[T] { + return newFixedSizeInt(addr, 16, strict...) +} + +func Int32[T ~int32](addr *T, strict ...bool) ValueScanner[T] { + return newFixedSizeInt(addr, 32, strict...) +} + +func Int64[T ~int64](addr *T, strict ...bool) ValueScanner[T] { + return newFixedSizeInt(addr, 64, strict...) +} + +func newFixedSizeInt[T constraints.Signed](addr *T, bitSize int, strict ...bool) ValueScanner[T] { + var strictType bool + if len(strict) > 0 { + strictType = strict[0] + } + return &fixedSizeIntLike[T]{addr: addr, bitSize: bitSize, strictType: strictType} +} + +type fixedSizeIntLike[T constraints.Signed] struct { + addr *T + bitSize int + strictType bool +} + +func (i fixedSizeIntLike[T]) Interface() T { if i.addr == nil { return *new(T) } return *i.addr } -func (i intLike[T]) Value() (driver.Value, error) { +func (i fixedSizeIntLike[T]) Value() (driver.Value, error) { if i.addr == nil { return nil, nil } return int64(*i.addr), nil } -func (i *intLike[T]) Scan(v any) error { +func (i *fixedSizeIntLike[T]) Scan(v any) error { var val T switch vi := v.(type) { case []byte: - m, err := strconv.ParseInt(unsafe.String(unsafe.SliceData(vi), len(vi)), 10, 64) + m, err := strconv.ParseInt(unsafe.String(unsafe.SliceData(vi), len(vi)), 10, i.bitSize) if err != nil { return err } @@ -58,31 +139,11 @@ func (i *intLike[T]) Scan(v any) error { switch vi := v.(type) { case string: - m, err := strconv.ParseInt(string(vi), 10, 64) + m, err := strconv.ParseInt(vi, 10, i.bitSize) if err != nil { return err } val = T(m) - case uint64: - val = T(vi) - case uint32: - val = T(vi) - case uint16: - val = T(vi) - case uint8: - val = T(vi) - case uint: - val = T(vi) - case int32: - val = T(vi) - case int16: - val = T(vi) - case int8: - val = T(vi) - case int: - val = T(vi) - case float32: - val = T(vi) case float64: val = T(vi) default: diff --git a/sequel/types/pgtype/bool_array.go b/sequel/types/pgtype/bool_array.go index 9c6a3c8..4048074 100644 --- a/sequel/types/pgtype/bool_array.go +++ b/sequel/types/pgtype/bool_array.go @@ -4,70 +4,44 @@ import ( "database/sql" "database/sql/driver" "fmt" - "strconv" - "unsafe" ) -func BoolArrayValue[T ~bool](b []T) driver.Valuer { - return BoolArray[T](b) -} - -func BoolArrayScanner[T ~bool](b *[]T) sql.Scanner { - return &boolArray[T]{b: b} -} +// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. +type BoolArray[T ~bool] []T -type boolArray[T ~bool] struct { - b *[]T -} +var ( + _ sql.Scanner = (*BoolArray[bool])(nil) + _ driver.Valuer = (*BoolArray[bool])(nil) +) -// Scan implements the sql.Scanner interface. -func (a *boolArray[T]) Scan(src any) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a.b = nil - return nil +// Value implements the driver.Valuer interface. +func (a BoolArray[T]) Value() (driver.Value, error) { + if a == nil { + return nil, nil } - return fmt.Errorf("pgtype: cannot convert %T to BoolArray", src) -} -func (a boolArray[T]) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "BoolArray") - if err != nil { - return err - } - if *a.b != nil && len(elems) == 0 { - *a.b = (*a.b)[:0] - } else { - b := make(BoolArray[T], len(elems)) - for i, v := range elems { - if len(v) != 1 { - return fmt.Errorf("pgtype: could not parse boolean array index %d: invalid boolean %q", i, v) - } - switch v[0] { - case 't': - b[i] = true - case 'f': - b[i] = false - default: - f, err := strconv.ParseBool(unsafe.String(unsafe.SliceData(v), len(v))) - if err != nil { - return fmt.Errorf("pgtype: could not parse boolean array index %d: invalid boolean %q", i, v) - } - b[i] = (T)(f) + if n := len(a); n > 0 { + // There will be exactly two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1+2*n) + + for i := 0; i < n; i++ { + b[2*i] = ',' + if a[i] { + b[1+2*i] = 't' + } else { + b[1+2*i] = 'f' } } - *a.b = b + + b[0] = '{' + b[2*n] = '}' + + return string(b), nil } - return nil + return "{}", nil } -// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. -type BoolArray[T ~bool] []T - // Scan implements the sql.Scanner interface. func (a *BoolArray[T]) Scan(src interface{}) error { switch src := src.(type) { @@ -108,32 +82,3 @@ func (a *BoolArray[T]) scanBytes(src []byte) error { } return nil } - -// Value implements the driver.Valuer interface. -func (a BoolArray[T]) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be exactly two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1+2*n) - - for i := 0; i < n; i++ { - b[2*i] = ',' - if a[i] { - b[1+2*i] = 't' - } else { - b[1+2*i] = 'f' - } - } - - b[0] = '{' - b[2*n] = '}' - - return string(b), nil - } - - return "{}", nil -} diff --git a/sequel/types/pgtype/byte_array.go b/sequel/types/pgtype/byte_array.go index 77bbdf1..dca87fc 100644 --- a/sequel/types/pgtype/byte_array.go +++ b/sequel/types/pgtype/byte_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql" "database/sql/driver" "encoding/hex" "fmt" @@ -9,46 +10,16 @@ import ( ) // ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. -type ByteaArray [][]byte +type ByteaArray[T ~[]byte] []T -// Scan implements the sql.Scanner interface. -func (a *ByteaArray) Scan(src interface{}) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) -} - -func (a *ByteaArray) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(ByteaArray, len(elems)) - for i, v := range elems { - b[i], err = parseBytea(v) - if err != nil { - return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) - } - } - *a = b - } - return nil -} +var ( + _ sql.Scanner = (*ByteaArray[[]byte])(nil) + _ driver.Valuer = (*ByteaArray[[]byte])(nil) +) // Value implements the driver.Valuer interface. It uses the "hex" format which // is only supported on PostgreSQL 9.0 or newer. -func (a ByteaArray) Value() (driver.Value, error) { +func (a ByteaArray[T]) Value() (driver.Value, error) { if a == nil { return nil, nil } @@ -75,10 +46,44 @@ func (a ByteaArray) Value() (driver.Value, error) { return string(b), nil } - return "{}", nil } +// Scan implements the sql.Scanner interface. +func (a *ByteaArray[T]) Scan(src any) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) +} + +func (a *ByteaArray[T]) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(ByteaArray[T], len(elems)) + for i, v := range elems { + b[i], err = parseBytea(v) + if err != nil { + return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) + } + } + *a = b + } + return nil +} + // Parse a bytea value received from the server. Both "hex" and the legacy // "escape" format are supported. func parseBytea(s []byte) (result []byte, err error) { diff --git a/sequel/types/pgtype/float_array.go b/sequel/types/pgtype/float_array.go index d0fa265..2137644 100644 --- a/sequel/types/pgtype/float_array.go +++ b/sequel/types/pgtype/float_array.go @@ -6,12 +6,35 @@ import ( "strconv" ) -// Float64Array represents a one-dimensional array of the PostgreSQL double +// Float32Array represents a one-dimensional array of the PostgreSQL double // precision type. -type Float64Array[T ~float64] []T +type Float32Array[T ~float32] []T + +// Value implements the driver.Valuer interface. +func (a Float32Array[T]) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, (float64)(a[0]), 'f', -1, 32) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, (float64)(a[i]), 'f', -1, 32) + } + + return string(append(b, '}')), nil + } + return "{}", nil +} // Scan implements the sql.Scanner interface. -func (a *Float64Array[T]) Scan(src interface{}) error { +func (a *Float32Array[T]) Scan(src interface{}) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -21,21 +44,20 @@ func (a *Float64Array[T]) Scan(src interface{}) error { *a = nil return nil } - - return fmt.Errorf("pgtype: cannot convert %T to Float64Array", src) + return fmt.Errorf("pgtype: cannot convert %T to Float32Array", src) } -func (a *Float64Array[T]) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Float64Array") +func (a *Float32Array[T]) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float32Array") if err != nil { return err } if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { - b := make(Float64Array[T], len(elems)) + b := make(Float32Array[T], len(elems)) for i, v := range elems { - f, err := strconv.ParseFloat(string(v), 64) + f, err := strconv.ParseFloat(string(v), 32) if err != nil { return fmt.Errorf("pgtype: parsing array element index %d: %v", i, err) } @@ -46,6 +68,10 @@ func (a *Float64Array[T]) scanBytes(src []byte) error { return nil } +// Float64Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float64Array[T ~float64] []T + // Value implements the driver.Valuer interface. func (a Float64Array[T]) Value() (driver.Value, error) { if a == nil { @@ -69,12 +95,8 @@ func (a Float64Array[T]) Value() (driver.Value, error) { return "{}", nil } -// Float32Array represents a one-dimensional array of the PostgreSQL double -// precision type. -type Float32Array[T ~float32] []T - // Scan implements the sql.Scanner interface. -func (a *Float32Array[T]) Scan(src interface{}) error { +func (a *Float64Array[T]) Scan(src interface{}) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -84,20 +106,21 @@ func (a *Float32Array[T]) Scan(src interface{}) error { *a = nil return nil } - return fmt.Errorf("pgtype: cannot convert %T to Float32Array", src) + + return fmt.Errorf("pgtype: cannot convert %T to Float64Array", src) } -func (a *Float32Array[T]) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Float32Array") +func (a *Float64Array[T]) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float64Array") if err != nil { return err } if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { - b := make(Float32Array[T], len(elems)) + b := make(Float64Array[T], len(elems)) for i, v := range elems { - f, err := strconv.ParseFloat(string(v), 32) + f, err := strconv.ParseFloat(string(v), 64) if err != nil { return fmt.Errorf("pgtype: parsing array element index %d: %v", i, err) } @@ -107,27 +130,3 @@ func (a *Float32Array[T]) scanBytes(src []byte) error { } return nil } - -// Value implements the driver.Valuer interface. -func (a Float32Array[T]) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+2*n) - b[0] = '{' - - b = strconv.AppendFloat(b, (float64)(a[0]), 'f', -1, 32) - for i := 1; i < n; i++ { - b = append(b, ',') - b = strconv.AppendFloat(b, (float64)(a[i]), 'f', -1, 32) - } - - return string(append(b, '}')), nil - } - - return "{}", nil -} diff --git a/sequel/types/pgtype/int_array.go b/sequel/types/pgtype/int_array.go index 4498332..e91379a 100644 --- a/sequel/types/pgtype/int_array.go +++ b/sequel/types/pgtype/int_array.go @@ -1,7 +1,6 @@ package pgtype import ( - "database/sql" "database/sql/driver" "fmt" "strconv" @@ -19,26 +18,9 @@ type ( Int64Array[T ~int64] []T ) -func IntArrayValue[T constraints.Signed](a []T) driver.Valuer { - return intArray[T]{v: &a} -} - -func IntArrayScanner[T constraints.Signed](a *[]T) sql.Scanner { - return intArray[T]{v: a} -} - -type intArray[T constraints.Signed] struct { - v *[]T -} - // Value implements the driver.Valuer interface. -func (a intArray[T]) Value() (driver.Value, error) { - return intArrayValue(*a.v) -} - -// Scan implements the sql.Scanner interface. -func (a intArray[T]) Scan(src any) error { - return arrayScan(a.v, src, "IntArray") +func (a IntArray[T]) Value() (driver.Value, error) { + return intArrayValue(a) } // Scan implements the sql.Scanner interface. @@ -47,7 +29,7 @@ func (a *IntArray[T]) Scan(src any) error { } // Value implements the driver.Valuer interface. -func (a IntArray[T]) Value() (driver.Value, error) { +func (a Int8Array[T]) Value() (driver.Value, error) { return intArrayValue(a) } @@ -57,7 +39,7 @@ func (a *Int8Array[T]) Scan(src any) error { } // Value implements the driver.Valuer interface. -func (a Int8Array[T]) Value() (driver.Value, error) { +func (a Int16Array[T]) Value() (driver.Value, error) { return intArrayValue(a) } @@ -67,7 +49,7 @@ func (a *Int16Array[T]) Scan(src any) error { } // Value implements the driver.Valuer interface. -func (a Int16Array[T]) Value() (driver.Value, error) { +func (a Int32Array[T]) Value() (driver.Value, error) { return intArrayValue(a) } @@ -77,7 +59,7 @@ func (a *Int32Array[T]) Scan(src any) error { } // Value implements the driver.Valuer interface. -func (a Int32Array[T]) Value() (driver.Value, error) { +func (a Int64Array[T]) Value() (driver.Value, error) { return intArrayValue(a) } @@ -86,11 +68,6 @@ func (a *Int64Array[T]) Scan(src any) error { return arrayScan(a, src, "Int64Array") } -// Value implements the driver.Valuer interface. -func (a Int64Array[T]) Value() (driver.Value, error) { - return intArrayValue(a) -} - func scanBytes[T constraints.Integer, Arr interface{ ~[]T }](a *Arr, src []byte, t string) error { elems, err := scanLinearArray(src, []byte{','}, t) if err != nil { diff --git a/sequel/types/pgtype/string_array.go b/sequel/types/pgtype/string_array.go index 6da109b..d8a39fb 100644 --- a/sequel/types/pgtype/string_array.go +++ b/sequel/types/pgtype/string_array.go @@ -1,58 +1,38 @@ package pgtype import ( - "database/sql" "database/sql/driver" "fmt" ) -func StringArrayScanner[T ~string](a *[]T) sql.Scanner { - return &stringArray[T]{v: a} -} - -type stringArray[T ~string] struct { - v *[]T -} +// StringArray represents a one-dimensional array of the PostgreSQL character types. +type StringArray[T ~string] []T -// Scan implements the sql.Scanner interface. -func (a *stringArray[T]) Scan(src interface{}) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a.v = nil - return nil +// Value implements the driver.Valuer interface. +func (a StringArray[T]) Value() (driver.Value, error) { + if a == nil { + return nil, nil } - return fmt.Errorf("pgtype: cannot convert %T to StringArray", src) -} -func (a stringArray[T]) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "StringArray") - if err != nil { - return err - } - if *a.v != nil && len(elems) == 0 { - *a.v = (*a.v)[:0] - } else { - b := make(StringArray[T], len(elems)) - for i, v := range elems { - if v == nil { - return fmt.Errorf("pgtype: parsing array element index %d: cannot convert nil to string", i) - } - b[i] = T(v) + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+3*n) + b[0] = '{' + + b = appendArrayQuotedBytes(b, []byte(a[0])) + for i := 1; i < n; i++ { + b = append(b, ',') + b = appendArrayQuotedBytes(b, []byte(a[i])) } - *a.v = b + + return string(append(b, '}')), nil } - return nil + return "{}", nil } -// StringArray represents a one-dimensional array of the PostgreSQL character types. -type StringArray[T ~string] []T - // Scan implements the sql.Scanner interface. -func (a *StringArray[T]) Scan(src interface{}) error { +func (a *StringArray[T]) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -84,26 +64,3 @@ func (a *StringArray[T]) scanBytes(src []byte) error { } return nil } - -// Value implements the driver.Valuer interface. -func (a StringArray[T]) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, 2*N bytes of quotes, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+3*n) - b[0] = '{' - - b = appendArrayQuotedBytes(b, []byte(a[0])) - for i := 1; i < n; i++ { - b = append(b, ',') - b = appendArrayQuotedBytes(b, []byte(a[i])) - } - - return string(append(b, '}')), nil - } - return "{}", nil -} diff --git a/sequel/types/pgtype/uint_array.go b/sequel/types/pgtype/uint_array.go index d22a769..01363cd 100644 --- a/sequel/types/pgtype/uint_array.go +++ b/sequel/types/pgtype/uint_array.go @@ -1,7 +1,6 @@ package pgtype import ( - "database/sql" "database/sql/driver" "golang.org/x/exp/constraints" @@ -9,33 +8,16 @@ import ( // Uint64Array represents a one-dimensional array of the PostgreSQL integer types. type ( - UintArray[T ~uint] []T - Uint8Array[T ~uint8] []T - Uint16Array[T ~uint16] []T - Uint32Array[T ~uint32] []T - Uint64Array[T ~uint64] []T + UintArray[T constraints.Unsigned] []T + Uint8Array[T ~uint8] []T + Uint16Array[T ~uint16] []T + Uint32Array[T ~uint32] []T + Uint64Array[T ~uint64] []T ) -func UintArrayValue[T constraints.Unsigned](a []T) driver.Valuer { - return uintArray[T]{v: &a} -} - -func UintArrayScanner[T constraints.Unsigned](a *[]T) sql.Scanner { - return uintArray[T]{v: a} -} - -type uintArray[T constraints.Unsigned] struct { - v *[]T -} - // Value implements the driver.Valuer interface. -func (a uintArray[T]) Value() (driver.Value, error) { - return uintArrayValue(*a.v) -} - -// Scan implements the sql.Scanner interface. -func (a uintArray[T]) Scan(src any) error { - return arrayScan(a.v, src, "UintArray") +func (a UintArray[T]) Value() (driver.Value, error) { + return uintArrayValue(a) } // Scan implements the sql.Scanner interface. @@ -44,7 +26,7 @@ func (a *UintArray[T]) Scan(src any) error { } // Value implements the driver.Valuer interface. -func (a UintArray[T]) Value() (driver.Value, error) { +func (a Uint8Array[T]) Value() (driver.Value, error) { return uintArrayValue(a) } @@ -54,7 +36,7 @@ func (a *Uint8Array[T]) Scan(src any) error { } // Value implements the driver.Valuer interface. -func (a Uint8Array[T]) Value() (driver.Value, error) { +func (a Uint16Array[T]) Value() (driver.Value, error) { return uintArrayValue(a) } @@ -64,7 +46,7 @@ func (a *Uint16Array[T]) Scan(src any) error { } // Value implements the driver.Valuer interface. -func (a Uint16Array[T]) Value() (driver.Value, error) { +func (a Uint32Array[T]) Value() (driver.Value, error) { return uintArrayValue(a) } @@ -74,7 +56,7 @@ func (a *Uint32Array[T]) Scan(src any) error { } // Value implements the driver.Valuer interface. -func (a Uint32Array[T]) Value() (driver.Value, error) { +func (a Uint64Array[T]) Value() (driver.Value, error) { return uintArrayValue(a) } @@ -82,8 +64,3 @@ func (a Uint32Array[T]) Value() (driver.Value, error) { func (a *Uint64Array[T]) Scan(src any) error { return arrayScan(a, src, "Uint64Array") } - -// Value implements the driver.Valuer interface. -func (a Uint64Array[T]) Value() (driver.Value, error) { - return uintArrayValue(a) -} diff --git a/sequel/types/uint.go b/sequel/types/uint.go new file mode 100644 index 0000000..3d073e2 --- /dev/null +++ b/sequel/types/uint.go @@ -0,0 +1,91 @@ +package types + +import ( + "database/sql/driver" + "fmt" + "strconv" + "unsafe" + + "golang.org/x/exp/constraints" +) + +func Uint8[T ~uint8](addr *T, strict ...bool) ValueScanner[T] { + return newFixedSizeUint(addr, 8, strict...) +} + +func Uint16[T ~uint16](addr *T, strict ...bool) ValueScanner[T] { + return newFixedSizeUint(addr, 16, strict...) +} + +func Uint32[T ~uint32](addr *T, strict ...bool) ValueScanner[T] { + return newFixedSizeUint(addr, 32, strict...) +} + +func Uint64[T ~uint64](addr *T, strict ...bool) ValueScanner[T] { + return newFixedSizeUint(addr, 64, strict...) +} + +func newFixedSizeUint[T constraints.Unsigned](addr *T, bitSize int, strict ...bool) ValueScanner[T] { + var strictType bool + if len(strict) > 0 { + strictType = strict[0] + } + return &fixedSizeUintLike[T]{addr: addr, bitSize: bitSize, strictType: strictType} +} + +type fixedSizeUintLike[T constraints.Unsigned] struct { + addr *T + bitSize int + strictType bool +} + +func (i fixedSizeUintLike[T]) Interface() T { + if i.addr == nil { + return *new(T) + } + return *i.addr +} + +func (i fixedSizeUintLike[T]) Value() (driver.Value, error) { + if i.addr == nil { + return nil, nil + } + return int64(*i.addr), nil +} + +func (i *fixedSizeUintLike[T]) Scan(v any) error { + var val T + switch vi := v.(type) { + case []byte: + m, err := strconv.ParseUint(unsafe.String(unsafe.SliceData(vi), len(vi)), 10, i.bitSize) + if err != nil { + return err + } + val = T(m) + case int64: + val = T(vi) + case nil: + i.addr = nil + return nil + + default: + if i.strictType { + return fmt.Errorf(`sequel/types: unable to scan %T to ~uint`, vi) + } + + switch vi := v.(type) { + case string: + m, err := strconv.ParseUint(vi, 10, i.bitSize) + if err != nil { + return err + } + val = T(m) + case float64: + val = T(vi) + default: + return fmt.Errorf(`sequel/types: unable to scan %T to ~uint`, vi) + } + } + *i.addr = val + return nil +}