Skip to content

Commit

Permalink
added a nulls.UUID type
Browse files Browse the repository at this point in the history
  • Loading branch information
markbates committed Jun 27, 2017
1 parent 7ec8a9f commit 32203da
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 9 deletions.
7 changes: 2 additions & 5 deletions nulls/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,10 @@ func (ns *String) UnmarshalJSON(text []byte) error {
if string(text) == "null" {
return nil
}
s := ""
err := json.Unmarshal(text, &s)
if err == nil {
ns.String = s
if err := json.Unmarshal(text, &ns.String); err == nil {
ns.Valid = true
}
return err
return nil
}

func (ns *String) UnmarshalText(text []byte) error {
Expand Down
18 changes: 14 additions & 4 deletions nulls/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/jmoiron/sqlx"
. "github.com/markbates/pop/nulls"
_ "github.com/mattn/go-sqlite3"
uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/require"
)

Expand All @@ -25,6 +26,7 @@ type Foo struct {
IntType Int `json:"intType" db:"int_type"`
Int32Type Int32 `json:"int32Type" db:"int32_type"`
UInt32Type UInt32 `json:"uint32Type" db:"uint32_type"`
UID UUID `json:"uid" db:"uid"`
}

const schema = `CREATE TABLE "main"."foos" (
Expand All @@ -37,9 +39,11 @@ const schema = `CREATE TABLE "main"."foos" (
"bytes" blob,
"int_type" integer,
"int32_type" integer,
"uint32_type" integer
"uint32_type" integer,
"uid" uuid
);`

var uid = "3c9228a9-8549-4d52-8261-dfe0ab0ee6d4"
var now = time.Now()

func newValidFoo() Foo {
Expand All @@ -54,6 +58,7 @@ func newValidFoo() Foo {
IntType: NewInt(2),
Int32Type: NewInt32(3),
UInt32Type: NewUInt32(5),
UID: NewUUID(uuid.FromStringOrNil(uid)),
}
}

Expand All @@ -65,7 +70,7 @@ func Test_TypesMarshalProperly(t *testing.T) {

ti, _ := json.Marshal(now)
ba, _ := json.Marshal(f.Bytes)
jsonString := fmt.Sprintf(`{"id":1,"name":"Mark","alive":true,"price":9.99,"birth":%s,"price32":3.33,"bytes":%s,"intType":2,"int32Type":3,"uint32Type":5}`, ti, ba)
jsonString := fmt.Sprintf(`{"id":1,"name":"Mark","alive":true,"price":9.99,"birth":%s,"price32":3.33,"bytes":%s,"intType":2,"int32Type":3,"uint32Type":5,"uid":"%s"}`, ti, ba, uid)

// check marshalling to json works:
data, _ := json.Marshal(f)
Expand All @@ -84,10 +89,11 @@ func Test_TypesMarshalProperly(t *testing.T) {
a.Equal(f.IntType.Int, 2)
a.Equal(f.Int32Type.Int32, int32(3))
a.Equal(f.UInt32Type.UInt32, uint32(5))
a.Equal(uid, f.UID.UUID.String())

// check marshalling nulls works:
f = Foo{}
jsonString = `{"id":null,"name":null,"alive":null,"price":null,"birth":null,"price32":null,"bytes":null,"intType":null,"int32Type":null,"uint32Type":null}`
jsonString = `{"id":null,"name":null,"alive":null,"price":null,"birth":null,"price32":null,"bytes":null,"intType":null,"int32Type":null,"uint32Type":null,"uid":null}`
data, _ = json.Marshal(f)
a.Equal(string(data), jsonString)

Expand All @@ -113,6 +119,8 @@ func Test_TypesMarshalProperly(t *testing.T) {
a.False(f.Int32Type.Valid)
a.Equal(f.UInt32Type.UInt32, uint32(0))
a.False(f.UInt32Type.Valid)
a.Equal(f.UID.UUID, uuid.Nil)
a.False(f.UID.Valid)
}

func Test_TypeSaveAndRetrieveProperly(t *testing.T) {
Expand Down Expand Up @@ -153,7 +161,8 @@ func Test_TypeSaveAndRetrieveProperly(t *testing.T) {
a.NoError(err)

f = newValidFoo()
tx.NamedExec("INSERT INTO foos (id, name, alive, price, birth, price32, bytes, int_type, int32_type, uint32_type) VALUES (:id, :name, :alive, :price, :birth, :price32, :bytes, :int_type, :int32_type, :uint32_type)", &f)
_, err = tx.NamedExec("INSERT INTO foos (id, name, alive, price, birth, price32, bytes, int_type, int32_type, uint32_type, uid) VALUES (:id, :name, :alive, :price, :birth, :price32, :bytes, :int_type, :int32_type, :uint32_type, :uid)", &f)
a.NoError(err)
f = Foo{}
tx.Get(&f, "select * from foos")
a.True(f.Alive.Valid)
Expand All @@ -176,6 +185,7 @@ func Test_TypeSaveAndRetrieveProperly(t *testing.T) {
a.Equal(f.IntType.Int, 2)
a.Equal(f.Int32Type.Int32, int32(3))
a.Equal(f.UInt32Type.UInt32, uint32(5))
a.Equal(f.UID.UUID.String(), uid)

tx.Rollback()
})
Expand Down
81 changes: 81 additions & 0 deletions nulls/uuid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package nulls

import (
"database/sql/driver"
"encoding/json"

"github.com/satori/go.uuid"
)

// UUID can be used with the standard sql package to represent a
// UUID value that can be NULL in the database
type UUID struct {
UUID uuid.UUID
Valid bool
}

func (ns UUID) Interface() interface{} {
if !ns.Valid {
return nil
}
return ns.UUID
}

// NewString returns a new, properly instantiated
// String object.
func NewUUID(u uuid.UUID) UUID {
return UUID{UUID: u, Valid: true}
}

// Value implements the driver.Valuer interface.
func (u UUID) Value() (driver.Value, error) {
if !u.Valid {
return nil, nil
}
// Delegate to UUID Value function
return u.UUID.Value()
}

// Scan implements the sql.Scanner interface.
func (u *UUID) Scan(src interface{}) error {
if src == nil {
u.UUID, u.Valid = uuid.Nil, false
return nil
}

// Delegate to UUID Scan function
u.Valid = true
return u.UUID.Scan(src)
}

// MarshalJSON marshals the underlying value to a
// proper JSON representation.
func (ns UUID) MarshalJSON() ([]byte, error) {
if ns.Valid {
return json.Marshal(ns.UUID.String())
}
return json.Marshal(nil)
}

// UnmarshalJSON will unmarshal a JSON value into
// the propert representation of that value.
func (ns *UUID) UnmarshalJSON(text []byte) error {
ns.Valid = false
ns.UUID = uuid.Nil
if string(text) == "null" {
return nil
}

s := ""
if err := json.Unmarshal(text, &s); err == nil {
if u, err := uuid.FromString(s); err == nil {
ns.UUID = u
ns.Valid = true
}
}
return nil
}

func (ns *UUID) UnmarshalText(text []byte) error {
return ns.UnmarshalJSON(text)
}

0 comments on commit 32203da

Please sign in to comment.