Skip to content

Commit

Permalink
Fix proto.RedisError in slices
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Feb 22, 2018
1 parent 71ed499 commit 56dea1f
Show file tree
Hide file tree
Showing 14 changed files with 122 additions and 105 deletions.
4 changes: 2 additions & 2 deletions cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,8 @@ var _ = Describe("ClusterClient", func() {
Describe("ClusterClient failover", func() {
BeforeEach(func() {
opt = redisClusterOptions()
opt.MinRetryBackoff = 100 * time.Millisecond
opt.MaxRetryBackoff = 3 * time.Second
opt.MinRetryBackoff = 250 * time.Millisecond
opt.MaxRetryBackoff = time.Second
client = cluster.clusterClient(opt)

_ = client.ForEachSlave(func(slave *redis.Client) error {
Expand Down
3 changes: 2 additions & 1 deletion command.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/go-redis/redis/internal"
"github.com/go-redis/redis/internal/pool"
"github.com/go-redis/redis/internal/proto"
"github.com/go-redis/redis/internal/util"
)

type Cmder interface {
Expand Down Expand Up @@ -436,7 +437,7 @@ func NewStringCmd(args ...interface{}) *StringCmd {
}

func (cmd *StringCmd) Val() string {
return internal.BytesToString(cmd.val)
return util.BytesToString(cmd.val)
}

func (cmd *StringCmd) Result() (string, error) {
Expand Down
3 changes: 2 additions & 1 deletion commands.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package redis

import (
"errors"
"io"
"time"

Expand Down Expand Up @@ -1802,7 +1803,7 @@ func (c *cmdable) shutdown(modifier string) *StatusCmd {
}
} else {
// Server did not quit. String reply contains the reason.
cmd.err = internal.RedisError(cmd.val)
cmd.err = errors.New(cmd.val)
cmd.val = ""
}
return cmd
Expand Down
4 changes: 2 additions & 2 deletions commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
. "github.com/onsi/gomega"

"github.com/go-redis/redis"
"github.com/go-redis/redis/internal"
"github.com/go-redis/redis/internal/proto"
)

var _ = Describe("Commands", func() {
Expand Down Expand Up @@ -3000,7 +3000,7 @@ var _ = Describe("Commands", func() {
nil,
).Result()
Expect(err).NotTo(HaveOccurred())
Expect(vals).To(Equal([]interface{}{int64(12), internal.RedisError("error"), "abc"}))
Expect(vals).To(Equal([]interface{}{int64(12), proto.RedisError("error"), "abc"}))
})

})
Expand Down
10 changes: 3 additions & 7 deletions internal/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@ import (
"io"
"net"
"strings"
)

const Nil = RedisError("redis: nil")

type RedisError string

func (e RedisError) Error() string { return string(e) }
"github.com/go-redis/redis/internal/proto"
)

func IsRetryableError(err error, retryNetError bool) bool {
if IsNetworkError(err) {
Expand All @@ -30,7 +26,7 @@ func IsRetryableError(err error, retryNetError bool) bool {
}

func IsRedisError(err error) bool {
_, ok := err.(RedisError)
_, ok := err.(proto.RedisError)
return ok
}

Expand Down
48 changes: 21 additions & 27 deletions internal/proto/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"io"
"strconv"

"github.com/go-redis/redis/internal"
"github.com/go-redis/redis/internal/util"
)

const bytesAllocLimit = 1024 * 1024 // 1mb
Expand All @@ -19,6 +19,16 @@ const (
ArrayReply = '*'
)

//------------------------------------------------------------------------------

const Nil = RedisError("redis: nil")

type RedisError string

func (e RedisError) Error() string { return string(e) }

//------------------------------------------------------------------------------

type MultiBulkParse func(*Reader, int64) (interface{}, error)

type Reader struct {
Expand Down Expand Up @@ -66,7 +76,7 @@ func (r *Reader) ReadLine() ([]byte, error) {
return nil, fmt.Errorf("redis: reply is empty")
}
if isNilReply(line) {
return nil, internal.Nil
return nil, Nil
}
return line, nil
}
Expand All @@ -83,7 +93,7 @@ func (r *Reader) ReadReply(m MultiBulkParse) (interface{}, error) {
case StatusReply:
return parseStatusValue(line), nil
case IntReply:
return parseInt(line[1:], 10, 64)
return util.ParseInt(line[1:], 10, 64)
case StringReply:
return r.readTmpBytesValue(line)
case ArrayReply:
Expand All @@ -105,7 +115,7 @@ func (r *Reader) ReadIntReply() (int64, error) {
case ErrorReply:
return 0, ParseErrorReply(line)
case IntReply:
return parseInt(line[1:], 10, 64)
return util.ParseInt(line[1:], 10, 64)
default:
return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line)
}
Expand Down Expand Up @@ -151,7 +161,7 @@ func (r *Reader) ReadFloatReply() (float64, error) {
if err != nil {
return 0, err
}
return parseFloat(b, 64)
return util.ParseFloat(b, 64)
}

func (r *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) {
Expand Down Expand Up @@ -221,7 +231,7 @@ func (r *Reader) ReadScanReply() ([]string, uint64, error) {

func (r *Reader) readTmpBytesValue(line []byte) ([]byte, error) {
if isNilReply(line) {
return nil, internal.Nil
return nil, Nil
}

replyLen, err := strconv.Atoi(string(line[1:]))
Expand All @@ -241,15 +251,15 @@ func (r *Reader) ReadInt() (int64, error) {
if err != nil {
return 0, err
}
return parseInt(b, 10, 64)
return util.ParseInt(b, 10, 64)
}

func (r *Reader) ReadUint() (uint64, error) {
b, err := r.ReadTmpBytesReply()
if err != nil {
return 0, err
}
return parseUint(b, 10, 64)
return util.ParseUint(b, 10, 64)
}

// --------------------------------------------------------------------
Expand Down Expand Up @@ -303,7 +313,7 @@ func isNilReply(b []byte) bool {
}

func ParseErrorReply(line []byte) error {
return internal.RedisError(string(line[1:]))
return RedisError(string(line[1:]))
}

func parseStatusValue(line []byte) []byte {
Expand All @@ -312,23 +322,7 @@ func parseStatusValue(line []byte) []byte {

func parseArrayLen(line []byte) (int64, error) {
if isNilReply(line) {
return 0, internal.Nil
return 0, Nil
}
return parseInt(line[1:], 10, 64)
}

func atoi(b []byte) (int, error) {
return strconv.Atoi(internal.BytesToString(b))
}

func parseInt(b []byte, base int, bitSize int) (int64, error) {
return strconv.ParseInt(internal.BytesToString(b), base, bitSize)
}

func parseUint(b []byte, base int, bitSize int) (uint64, error) {
return strconv.ParseUint(internal.BytesToString(b), base, bitSize)
}

func parseFloat(b []byte, bitSize int) (float64, error) {
return strconv.ParseFloat(internal.BytesToString(b), bitSize)
return util.ParseInt(line[1:], 10, 64)
}
63 changes: 48 additions & 15 deletions internal/proto/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,96 +5,96 @@ import (
"fmt"
"reflect"

"github.com/go-redis/redis/internal"
"github.com/go-redis/redis/internal/util"
)

func Scan(b []byte, v interface{}) error {
switch v := v.(type) {
case nil:
return fmt.Errorf("redis: Scan(nil)")
case *string:
*v = internal.BytesToString(b)
*v = util.BytesToString(b)
return nil
case *[]byte:
*v = b
return nil
case *int:
var err error
*v, err = atoi(b)
*v, err = util.Atoi(b)
return err
case *int8:
n, err := parseInt(b, 10, 8)
n, err := util.ParseInt(b, 10, 8)
if err != nil {
return err
}
*v = int8(n)
return nil
case *int16:
n, err := parseInt(b, 10, 16)
n, err := util.ParseInt(b, 10, 16)
if err != nil {
return err
}
*v = int16(n)
return nil
case *int32:
n, err := parseInt(b, 10, 32)
n, err := util.ParseInt(b, 10, 32)
if err != nil {
return err
}
*v = int32(n)
return nil
case *int64:
n, err := parseInt(b, 10, 64)
n, err := util.ParseInt(b, 10, 64)
if err != nil {
return err
}
*v = n
return nil
case *uint:
n, err := parseUint(b, 10, 64)
n, err := util.ParseUint(b, 10, 64)
if err != nil {
return err
}
*v = uint(n)
return nil
case *uint8:
n, err := parseUint(b, 10, 8)
n, err := util.ParseUint(b, 10, 8)
if err != nil {
return err
}
*v = uint8(n)
return nil
case *uint16:
n, err := parseUint(b, 10, 16)
n, err := util.ParseUint(b, 10, 16)
if err != nil {
return err
}
*v = uint16(n)
return nil
case *uint32:
n, err := parseUint(b, 10, 32)
n, err := util.ParseUint(b, 10, 32)
if err != nil {
return err
}
*v = uint32(n)
return nil
case *uint64:
n, err := parseUint(b, 10, 64)
n, err := util.ParseUint(b, 10, 64)
if err != nil {
return err
}
*v = n
return nil
case *float32:
n, err := parseFloat(b, 32)
n, err := util.ParseFloat(b, 32)
if err != nil {
return err
}
*v = float32(n)
return err
case *float64:
var err error
*v, err = parseFloat(b, 64)
*v, err = util.ParseFloat(b, 64)
return err
case *bool:
*v = len(b) == 1 && b[0] == '1'
Expand All @@ -120,7 +120,7 @@ func ScanSlice(data []string, slice interface{}) error {
return fmt.Errorf("redis: ScanSlice(non-slice %T)", slice)
}

next := internal.MakeSliceNextElemFunc(v)
next := makeSliceNextElemFunc(v)
for i, s := range data {
elem := next()
if err := Scan([]byte(s), elem.Addr().Interface()); err != nil {
Expand All @@ -131,3 +131,36 @@ func ScanSlice(data []string, slice interface{}) error {

return nil
}

func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value {
elemType := v.Type().Elem()

if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
elem := v.Index(v.Len() - 1)
if elem.IsNil() {
elem.Set(reflect.New(elemType))
}
return elem.Elem()
}

elem := reflect.New(elemType)
v.Set(reflect.Append(v, elem))
return elem.Elem()
}
}

zero := reflect.Zero(elemType)
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
return v.Index(v.Len() - 1)
}

v.Set(reflect.Append(v, zero))
return v.Index(v.Len() - 1)
}
}
Loading

0 comments on commit 56dea1f

Please sign in to comment.