Skip to content

Commit

Permalink
refactor: sql types
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Sep 26, 2024
1 parent 4f2ab89 commit a8e78a8
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 42 deletions.
20 changes: 9 additions & 11 deletions sequel/types/bool_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,22 @@ type boolList[T ~bool] struct {
}

var (
_ driver.Valuer = boolList[bool]{}
_ sql.Scanner = boolList[bool]{}
_ driver.Valuer = (*boolList[bool])(nil)
_ sql.Scanner = (*boolList[bool])(nil)
)

func BoolSlice[T ~bool](v *[]T) boolList[T] {
return boolList[T]{v: v}
}

func (s boolList[T]) Value() (driver.Value, error) {
if (*s.v) == nil {
val := make([]byte, len(nullBytes))
copy(val, nullBytes)
return val, nil
if s.v == nil || *s.v == nil {
return nil, nil
}
return encoding.MarshalBoolSlice(*(s.v)), nil
return encoding.MarshalBoolSlice(*s.v), nil
}

func (s boolList[T]) Scan(v any) error {
func (s *boolList[T]) Scan(v any) error {
switch vi := v.(type) {
case []byte:
if bytes.Equal(vi, nullBytes) {
Expand Down Expand Up @@ -80,17 +78,17 @@ func (s boolList[T]) Scan(v any) error {
var (
paths = strings.Split(vi, ",")
values = make([]T, len(paths))
b string
)
for i := range paths {
b = strings.TrimSpace(paths[i])
flag, err := strconv.ParseBool(b)
flag, err := strconv.ParseBool(strings.TrimSpace(paths[i]))
if err != nil {
return err
}
values[i] = T(flag)
}
*s.v = values
case nil:
*s.v = nil
default:
return fmt.Errorf(`sequel/types: unsupported scan type %T for []~bool`, vi)
}
Expand Down
17 changes: 16 additions & 1 deletion sequel/types/bool_slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,25 @@ func TestBoolSlice(t *testing.T) {
require.NoError(t, v.Scan(nullBytes))
value, err := v.Value()
require.NoError(t, err)
require.Equal(t, nullBytes, value)
require.Nil(t, value)
})

t.Run("custom bool type", func(t *testing.T) {
type Bool bool
var bList = []Bool{true, false}
var v = BoolSlice(&bList)
require.NoError(t, v.Scan(`[false, false, true]`))
value, err := v.Value()
require.NoError(t, err)
require.Equal(t, `[false,false,true]`, value)
})

t.Run("Scan with nil value", func(t *testing.T) {
var bList = []bool{true, false, true}
var v = BoolSlice(&bList)
require.NoError(t, v.Scan(nil))
value, err := v.Value()
require.NoError(t, err)
require.Nil(t, value)
})
}
21 changes: 20 additions & 1 deletion sequel/types/float_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,40 @@ package types

import (
"bytes"
"database/sql"
"database/sql/driver"
"fmt"
"strconv"
"strings"
"unsafe"

"github.com/si3nloong/sqlgen/sequel/encoding"
"golang.org/x/exp/constraints"
)

type floatList[T constraints.Float] struct {
v *[]T
}

var (
_ driver.Valuer = (*floatList[float32])(nil)
_ sql.Scanner = (*floatList[float32])(nil)
_ driver.Valuer = (*floatList[float64])(nil)
_ sql.Scanner = (*floatList[float64])(nil)
)

func FloatSlice[T constraints.Float](v *[]T) floatList[T] {
return floatList[T]{v: v}
}

func (s floatList[T]) Scan(v any) error {
func (s floatList[T]) Value() (driver.Value, error) {
if s.v == nil || *s.v == nil {
return nil, nil
}
return encoding.MarshalFloatList(*s.v, 64), nil
}

func (s *floatList[T]) Scan(v any) error {
switch vi := v.(type) {
case []byte:
if bytes.Equal(vi, nullBytes) {
Expand Down Expand Up @@ -74,6 +91,8 @@ func (s floatList[T]) Scan(v any) error {
values[i] = T(f64)
}
*s.v = values
case nil:
*s.v = nil
default:
return fmt.Errorf(`sequel/types: unsupported scan type %T for []~float`, vi)
}
Expand Down
15 changes: 8 additions & 7 deletions sequel/types/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,13 @@ func (i intLike[T]) Value() (driver.Value, error) {
return int64(*i.addr), nil
}

func (i intLike[T]) Scan(v any) error {
func (i *intLike[T]) Scan(v any) error {
var val T
switch vi := v.(type) {
case []byte:
m, err := strconv.ParseInt(unsafe.String(unsafe.SliceData(vi), len(vi)), 10, 64)
if err != nil {
return err
}
val = T(m)
case int64:
val = T(vi)
case nil:
i.addr = nil
return nil

default:
Expand All @@ -62,6 +57,12 @@ func (i intLike[T]) Scan(v any) error {
}

switch vi := v.(type) {
case []byte:
m, err := strconv.ParseInt(unsafe.String(unsafe.SliceData(vi), len(vi)), 10, 64)
if err != nil {
return err
}
val = T(m)
case string:
m, err := strconv.ParseInt(string(vi), 10, 64)
if err != nil {
Expand Down
13 changes: 12 additions & 1 deletion sequel/types/int_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package types

import (
"bytes"
"database/sql/driver"
"fmt"
"strconv"
"strings"
"unsafe"

"github.com/si3nloong/sqlgen/sequel/encoding"
"golang.org/x/exp/constraints"
)

Expand All @@ -18,7 +20,14 @@ func IntSlice[T constraints.Signed](v *[]T) intList[T] {
return intList[T]{v: v}
}

func (s intList[T]) Scan(v any) error {
func (s intList[T]) Value() (driver.Value, error) {
if s.v == nil || *s.v == nil {
return nil, nil
}
return encoding.MarshalIntSlice(*s.v), nil
}

func (s *intList[T]) Scan(v any) error {
switch vi := v.(type) {
case []byte:
if bytes.Equal(vi, nullBytes) {
Expand Down Expand Up @@ -74,6 +83,8 @@ func (s intList[T]) Scan(v any) error {
values[i] = T(i64)
}
*s.v = values
case nil:
*s.v = nil
default:
return fmt.Errorf(`sequel/types: unsupported scan type %T for []~int`, vi)
}
Expand Down
18 changes: 18 additions & 0 deletions sequel/types/int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,22 @@ func TestInteger(t *testing.T) {
require.NoError(t, err)
require.Equal(t, int64(88), value)
})

t.Run("Integer with new(int)", func(t *testing.T) {
var ptr = new(int)
v := Integer(ptr)

t.Run("Value", func(t *testing.T) {
value, err := v.Value()
require.NoError(t, err)
require.Empty(t, value)
})

t.Run("Scan", func(t *testing.T) {
require.NoError(t, v.Scan(nil))
value, err := v.Value()
require.NoError(t, err)
require.Nil(t, value)
})
})
}
20 changes: 19 additions & 1 deletion sequel/types/str_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,35 @@ package types

import (
"bytes"
"database/sql"
"database/sql/driver"
"fmt"
"strings"

"github.com/si3nloong/sqlgen/sequel/encoding"
)

type strSlice[T ~string] struct {
v *[]T
}

var (
_ driver.Valuer = (*strSlice[string])(nil)
_ sql.Scanner = (*strSlice[string])(nil)
)

func StringSlice[T ~string](v *[]T) strSlice[T] {
return strSlice[T]{v: v}
}

func (s strSlice[T]) Scan(v any) error {
func (s strSlice[T]) Value() (driver.Value, error) {
if s.v == nil || *s.v == nil {
return nil, nil
}
return encoding.MarshalStringSlice(*s.v), nil
}

func (s *strSlice[T]) Scan(v any) error {
switch vi := v.(type) {
case []byte:
if bytes.Equal(vi, nullBytes) {
Expand Down Expand Up @@ -54,6 +70,8 @@ func (s strSlice[T]) Scan(v any) error {
values[i] = (T)(strings.Trim(b[i], `"`))
}
*s.v = values
case nil:
*s.v = nil
default:
return fmt.Errorf(`sequel/types: unsupported scan type %T for []~string`, vi)
}
Expand Down
5 changes: 4 additions & 1 deletion sequel/types/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ func (s strLike[T]) Value() (driver.Value, error) {
return string(*s.addr), nil
}

func (s strLike[T]) Scan(v any) error {
func (s *strLike[T]) Scan(v any) error {
var val T
switch vi := v.(type) {
case string:
val = T(vi)
case []byte:
val = T(vi)
case nil:
s.addr = nil
return nil
default:
if s.strictType {
return fmt.Errorf(`sequel/types: unable to scan %T to ~string`, vi)
Expand Down
35 changes: 17 additions & 18 deletions sequel/types/time.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,39 @@ var (
ddmmyyyyhhmmsstz = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}\s\d{2}\:\d{2}:\d{2}\.\d+$`)
)

type datetime[T time.Time] struct {
type timestamp[T time.Time] struct {
addr *T
strictType bool
}

var (
_ sql.Scanner = (*datetime[time.Time])(nil)
_ driver.Valuer = (*datetime[time.Time])(nil)
_ sql.Scanner = (*timestamp[time.Time])(nil)
_ driver.Valuer = (*timestamp[time.Time])(nil)
)

func Time[T time.Time](addr *T, strict ...bool) datetime[T] {
func Time[T time.Time](addr *T, strict ...bool) timestamp[T] {
var strictType bool
if len(strict) > 0 {
strictType = strict[0]
}
return datetime[T]{addr: addr, strictType: strictType}
return timestamp[T]{addr: addr, strictType: strictType}
}

func (dt datetime[T]) Interface() T {
if dt.addr == nil {
func (t timestamp[T]) Interface() T {
if t.addr == nil {
return *new(T)
}
return *dt.addr
return *t.addr
}

func (dt datetime[T]) Value() (driver.Value, error) {
if dt.addr == nil {
func (t timestamp[T]) Value() (driver.Value, error) {
if t.addr == nil {
return nil, nil
}
return time.Time(*dt.addr), nil
return time.Time(*t.addr), nil
}

func (s datetime[T]) Scan(v any) error {
func (t *timestamp[T]) Scan(v any) error {
var val T
switch vi := v.(type) {
case []byte:
Expand All @@ -66,18 +66,17 @@ func (s datetime[T]) Scan(v any) error {
val = T(vi)
case int64:
val = T(time.Unix(vi, 0))
case nil:
t.addr = nil
return nil
default:
return fmt.Errorf(`sequel/types: unsupported scan type %T for time.Time`, vi)
}
*s.addr = val
*t.addr = val
return nil
}

func parseTime(str string) (time.Time, error) {
var (
t time.Time
err error
)
func parseTime(str string) (t time.Time, err error) {
switch {
case ddmmyyyy.MatchString(str):
t, err = time.Parse("2006-01-02", str)
Expand Down
4 changes: 3 additions & 1 deletion sequel/types/uint_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func UintSlice[T constraints.Unsigned](v *[]T) uintList[T] {
return uintList[T]{v: v}
}

func (s uintList[T]) Scan(v any) error {
func (s *uintList[T]) Scan(v any) error {
switch vi := v.(type) {
case []byte:
if bytes.Equal(vi, nullBytes) {
Expand Down Expand Up @@ -74,6 +74,8 @@ func (s uintList[T]) Scan(v any) error {
values[i] = T(u64)
}
*s.v = values
case nil:
*s.v = nil
default:
return fmt.Errorf(`sequel/types: unsupported scan type %T for []~uint`, vi)
}
Expand Down

0 comments on commit a8e78a8

Please sign in to comment.