From a8e78a88fd9704a8f4faa77c98d1adb934574021 Mon Sep 17 00:00:00 2001 From: si3nloong Date: Thu, 26 Sep 2024 16:47:18 +0800 Subject: [PATCH] refactor: sql types --- sequel/types/bool_slice.go | 20 +++++++++---------- sequel/types/bool_slice_test.go | 17 +++++++++++++++- sequel/types/float_slice.go | 21 +++++++++++++++++++- sequel/types/int.go | 15 +++++++------- sequel/types/int_slice.go | 13 +++++++++++- sequel/types/int_test.go | 18 +++++++++++++++++ sequel/types/str_slice.go | 20 ++++++++++++++++++- sequel/types/string.go | 5 ++++- sequel/types/time.go | 35 ++++++++++++++++----------------- sequel/types/uint_slice.go | 4 +++- 10 files changed, 126 insertions(+), 42 deletions(-) diff --git a/sequel/types/bool_slice.go b/sequel/types/bool_slice.go index 1cdb1f3..015c73b 100644 --- a/sequel/types/bool_slice.go +++ b/sequel/types/bool_slice.go @@ -18,8 +18,8 @@ type boolList[T ~bool] struct { } var ( - _ driver.Valuer = boolList[bool]{} - _ sql.Scanner = boolList[bool]{} + _ driver.Valuer = (*boolList[bool])(nil) + _ sql.Scanner = (*boolList[bool])(nil) ) func BoolSlice[T ~bool](v *[]T) boolList[T] { @@ -27,15 +27,13 @@ func BoolSlice[T ~bool](v *[]T) boolList[T] { } func (s boolList[T]) Value() (driver.Value, error) { - if (*s.v) == nil { - val := make([]byte, len(nullBytes)) - copy(val, nullBytes) - return val, nil + if s.v == nil || *s.v == nil { + return nil, nil } - return encoding.MarshalBoolSlice(*(s.v)), nil + return encoding.MarshalBoolSlice(*s.v), nil } -func (s boolList[T]) Scan(v any) error { +func (s *boolList[T]) Scan(v any) error { switch vi := v.(type) { case []byte: if bytes.Equal(vi, nullBytes) { @@ -80,17 +78,17 @@ func (s boolList[T]) Scan(v any) error { var ( paths = strings.Split(vi, ",") values = make([]T, len(paths)) - b string ) for i := range paths { - b = strings.TrimSpace(paths[i]) - flag, err := strconv.ParseBool(b) + flag, err := strconv.ParseBool(strings.TrimSpace(paths[i])) if err != nil { return err } values[i] = T(flag) } *s.v = values + case nil: + *s.v = nil default: return fmt.Errorf(`sequel/types: unsupported scan type %T for []~bool`, vi) } diff --git a/sequel/types/bool_slice_test.go b/sequel/types/bool_slice_test.go index d197efb..3e19e08 100644 --- a/sequel/types/bool_slice_test.go +++ b/sequel/types/bool_slice_test.go @@ -13,10 +13,25 @@ func TestBoolSlice(t *testing.T) { require.NoError(t, v.Scan(nullBytes)) value, err := v.Value() require.NoError(t, err) - require.Equal(t, nullBytes, value) + require.Nil(t, value) }) t.Run("custom bool type", func(t *testing.T) { + type Bool bool + var bList = []Bool{true, false} + var v = BoolSlice(&bList) + require.NoError(t, v.Scan(`[false, false, true]`)) + value, err := v.Value() + require.NoError(t, err) + require.Equal(t, `[false,false,true]`, value) + }) + t.Run("Scan with nil value", func(t *testing.T) { + var bList = []bool{true, false, true} + var v = BoolSlice(&bList) + require.NoError(t, v.Scan(nil)) + value, err := v.Value() + require.NoError(t, err) + require.Nil(t, value) }) } diff --git a/sequel/types/float_slice.go b/sequel/types/float_slice.go index ebc22d6..eb4ba5f 100644 --- a/sequel/types/float_slice.go +++ b/sequel/types/float_slice.go @@ -2,11 +2,14 @@ package types import ( "bytes" + "database/sql" + "database/sql/driver" "fmt" "strconv" "strings" "unsafe" + "github.com/si3nloong/sqlgen/sequel/encoding" "golang.org/x/exp/constraints" ) @@ -14,11 +17,25 @@ type floatList[T constraints.Float] struct { v *[]T } +var ( + _ driver.Valuer = (*floatList[float32])(nil) + _ sql.Scanner = (*floatList[float32])(nil) + _ driver.Valuer = (*floatList[float64])(nil) + _ sql.Scanner = (*floatList[float64])(nil) +) + func FloatSlice[T constraints.Float](v *[]T) floatList[T] { return floatList[T]{v: v} } -func (s floatList[T]) Scan(v any) error { +func (s floatList[T]) Value() (driver.Value, error) { + if s.v == nil || *s.v == nil { + return nil, nil + } + return encoding.MarshalFloatList(*s.v, 64), nil +} + +func (s *floatList[T]) Scan(v any) error { switch vi := v.(type) { case []byte: if bytes.Equal(vi, nullBytes) { @@ -74,6 +91,8 @@ func (s floatList[T]) Scan(v any) error { values[i] = T(f64) } *s.v = values + case nil: + *s.v = nil default: return fmt.Errorf(`sequel/types: unsupported scan type %T for []~float`, vi) } diff --git a/sequel/types/int.go b/sequel/types/int.go index 7a3ae9c..1f1c669 100644 --- a/sequel/types/int.go +++ b/sequel/types/int.go @@ -42,18 +42,13 @@ func (i intLike[T]) Value() (driver.Value, error) { return int64(*i.addr), nil } -func (i intLike[T]) Scan(v any) error { +func (i *intLike[T]) Scan(v any) error { var val T switch vi := v.(type) { - case []byte: - m, err := strconv.ParseInt(unsafe.String(unsafe.SliceData(vi), len(vi)), 10, 64) - if err != nil { - return err - } - val = T(m) case int64: val = T(vi) case nil: + i.addr = nil return nil default: @@ -62,6 +57,12 @@ func (i intLike[T]) Scan(v any) error { } switch vi := v.(type) { + case []byte: + m, err := strconv.ParseInt(unsafe.String(unsafe.SliceData(vi), len(vi)), 10, 64) + if err != nil { + return err + } + val = T(m) case string: m, err := strconv.ParseInt(string(vi), 10, 64) if err != nil { diff --git a/sequel/types/int_slice.go b/sequel/types/int_slice.go index f00c3b8..8632c37 100644 --- a/sequel/types/int_slice.go +++ b/sequel/types/int_slice.go @@ -2,11 +2,13 @@ package types import ( "bytes" + "database/sql/driver" "fmt" "strconv" "strings" "unsafe" + "github.com/si3nloong/sqlgen/sequel/encoding" "golang.org/x/exp/constraints" ) @@ -18,7 +20,14 @@ func IntSlice[T constraints.Signed](v *[]T) intList[T] { return intList[T]{v: v} } -func (s intList[T]) Scan(v any) error { +func (s intList[T]) Value() (driver.Value, error) { + if s.v == nil || *s.v == nil { + return nil, nil + } + return encoding.MarshalIntSlice(*s.v), nil +} + +func (s *intList[T]) Scan(v any) error { switch vi := v.(type) { case []byte: if bytes.Equal(vi, nullBytes) { @@ -74,6 +83,8 @@ func (s intList[T]) Scan(v any) error { values[i] = T(i64) } *s.v = values + case nil: + *s.v = nil default: return fmt.Errorf(`sequel/types: unsupported scan type %T for []~int`, vi) } diff --git a/sequel/types/int_test.go b/sequel/types/int_test.go index 4b27712..10ee6e3 100644 --- a/sequel/types/int_test.go +++ b/sequel/types/int_test.go @@ -44,4 +44,22 @@ func TestInteger(t *testing.T) { require.NoError(t, err) require.Equal(t, int64(88), value) }) + + t.Run("Integer with new(int)", func(t *testing.T) { + var ptr = new(int) + v := Integer(ptr) + + t.Run("Value", func(t *testing.T) { + value, err := v.Value() + require.NoError(t, err) + require.Empty(t, value) + }) + + t.Run("Scan", func(t *testing.T) { + require.NoError(t, v.Scan(nil)) + value, err := v.Value() + require.NoError(t, err) + require.Nil(t, value) + }) + }) } diff --git a/sequel/types/str_slice.go b/sequel/types/str_slice.go index d61e9e7..b345eb3 100644 --- a/sequel/types/str_slice.go +++ b/sequel/types/str_slice.go @@ -2,19 +2,35 @@ package types import ( "bytes" + "database/sql" + "database/sql/driver" "fmt" "strings" + + "github.com/si3nloong/sqlgen/sequel/encoding" ) type strSlice[T ~string] struct { v *[]T } +var ( + _ driver.Valuer = (*strSlice[string])(nil) + _ sql.Scanner = (*strSlice[string])(nil) +) + func StringSlice[T ~string](v *[]T) strSlice[T] { return strSlice[T]{v: v} } -func (s strSlice[T]) Scan(v any) error { +func (s strSlice[T]) Value() (driver.Value, error) { + if s.v == nil || *s.v == nil { + return nil, nil + } + return encoding.MarshalStringSlice(*s.v), nil +} + +func (s *strSlice[T]) Scan(v any) error { switch vi := v.(type) { case []byte: if bytes.Equal(vi, nullBytes) { @@ -54,6 +70,8 @@ func (s strSlice[T]) Scan(v any) error { values[i] = (T)(strings.Trim(b[i], `"`)) } *s.v = values + case nil: + *s.v = nil default: return fmt.Errorf(`sequel/types: unsupported scan type %T for []~string`, vi) } diff --git a/sequel/types/string.go b/sequel/types/string.go index fe86ccb..339953a 100644 --- a/sequel/types/string.go +++ b/sequel/types/string.go @@ -43,13 +43,16 @@ func (s strLike[T]) Value() (driver.Value, error) { return string(*s.addr), nil } -func (s strLike[T]) Scan(v any) error { +func (s *strLike[T]) Scan(v any) error { var val T switch vi := v.(type) { case string: val = T(vi) case []byte: val = T(vi) + case nil: + s.addr = nil + return nil default: if s.strictType { return fmt.Errorf(`sequel/types: unable to scan %T to ~string`, vi) diff --git a/sequel/types/time.go b/sequel/types/time.go index d35b7f8..e16e2ff 100644 --- a/sequel/types/time.go +++ b/sequel/types/time.go @@ -15,39 +15,39 @@ var ( ddmmyyyyhhmmsstz = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}\s\d{2}\:\d{2}:\d{2}\.\d+$`) ) -type datetime[T time.Time] struct { +type timestamp[T time.Time] struct { addr *T strictType bool } var ( - _ sql.Scanner = (*datetime[time.Time])(nil) - _ driver.Valuer = (*datetime[time.Time])(nil) + _ sql.Scanner = (*timestamp[time.Time])(nil) + _ driver.Valuer = (*timestamp[time.Time])(nil) ) -func Time[T time.Time](addr *T, strict ...bool) datetime[T] { +func Time[T time.Time](addr *T, strict ...bool) timestamp[T] { var strictType bool if len(strict) > 0 { strictType = strict[0] } - return datetime[T]{addr: addr, strictType: strictType} + return timestamp[T]{addr: addr, strictType: strictType} } -func (dt datetime[T]) Interface() T { - if dt.addr == nil { +func (t timestamp[T]) Interface() T { + if t.addr == nil { return *new(T) } - return *dt.addr + return *t.addr } -func (dt datetime[T]) Value() (driver.Value, error) { - if dt.addr == nil { +func (t timestamp[T]) Value() (driver.Value, error) { + if t.addr == nil { return nil, nil } - return time.Time(*dt.addr), nil + return time.Time(*t.addr), nil } -func (s datetime[T]) Scan(v any) error { +func (t *timestamp[T]) Scan(v any) error { var val T switch vi := v.(type) { case []byte: @@ -66,18 +66,17 @@ func (s datetime[T]) Scan(v any) error { val = T(vi) case int64: val = T(time.Unix(vi, 0)) + case nil: + t.addr = nil + return nil default: return fmt.Errorf(`sequel/types: unsupported scan type %T for time.Time`, vi) } - *s.addr = val + *t.addr = val return nil } -func parseTime(str string) (time.Time, error) { - var ( - t time.Time - err error - ) +func parseTime(str string) (t time.Time, err error) { switch { case ddmmyyyy.MatchString(str): t, err = time.Parse("2006-01-02", str) diff --git a/sequel/types/uint_slice.go b/sequel/types/uint_slice.go index e363dd5..df33632 100644 --- a/sequel/types/uint_slice.go +++ b/sequel/types/uint_slice.go @@ -18,7 +18,7 @@ func UintSlice[T constraints.Unsigned](v *[]T) uintList[T] { return uintList[T]{v: v} } -func (s uintList[T]) Scan(v any) error { +func (s *uintList[T]) Scan(v any) error { switch vi := v.(type) { case []byte: if bytes.Equal(vi, nullBytes) { @@ -74,6 +74,8 @@ func (s uintList[T]) Scan(v any) error { values[i] = T(u64) } *s.v = values + case nil: + *s.v = nil default: return fmt.Errorf(`sequel/types: unsupported scan type %T for []~uint`, vi) }