Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ GOTEST=$(GOCMD) test
GOGET=$(GOCMD) get
GOMOD=$(GOCMD) mod
GOFMT=$(GOCMD) fmt
GODOC=godoc

.PHONY: all test coverage
all: test coverage examples
Expand Down Expand Up @@ -48,3 +49,11 @@ test: get
coverage: get test
$(GOTEST) -race -coverprofile=coverage.txt -covermode=atomic ./redisai

godoc:
$(GOGET) -u golang.org/x/tools/...
echo "Open browser tab on localhost:6060"
$(GODOC)


fmt:
$(GOFMT) ./...
14 changes: 10 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@ module github.com/RedisAI/redisai-go
go 1.12

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-lintpack/lintpack v0.5.2 // indirect
github.com/golangci/errcheck v0.0.0-20181223084120-ef45e06d44b6 // indirect
github.com/golangci/go-tools v0.0.0-20190318055746-e32c54105b7c // indirect
github.com/golangci/goconst v0.0.0-20180610141641-041c5f2b40f3 // indirect
github.com/golangci/gocyclo v0.0.0-20180528144436-0a533e8fa43d // indirect
github.com/golangci/golangci-lint v1.40.0 // indirect
github.com/golangci/golangci-lint v1.42.0 // indirect
github.com/golangci/ineffassign v0.0.0-20190609212857-42439a7714cc // indirect
github.com/golangci/prealloc v0.0.0-20180630174525-215b22d4de21 // indirect
github.com/gomodule/redigo v1.8.2
github.com/google/go-cmp v0.5.4
github.com/google/go-cmp v0.5.5
github.com/jmoiron/sqlx v1.2.1-0.20190826204134-d7d95172beb5 // indirect
github.com/sanposhiho/wastedassign v1.0.0 // indirect
github.com/shirou/gopsutil v0.0.0-20190901111213-e4ec7b275ada // indirect
github.com/stretchr/testify v1.7.0
github.com/tommy-muehle/go-mnd v1.3.1-0.20200224220436-e6f9a994e8fa // indirect
github.com/ugorji/go v1.1.4 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gopkg.in/airbrake/gobrake.v2 v2.0.9 // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4 // indirect
)
320 changes: 236 additions & 84 deletions go.sum

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions redisai/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,19 @@ func (c *Client) ModelSetFromModel(keyName string, model ModelInterface) (err er
// - position 1 the device used to execute the model as a String
// - position 2 the model's tag as a String
// - position 3 a blob containing the serialized model (when called with the BLOB argument) as a String
// - position 4 the maximum size of any batch of incoming requests.
// - position 5 the minimum size of any batch of incoming requests.
// - position 6 array reply with one or more names of the model's input nodes (applicable only for TensorFlow models).
// - position 7 array reply with one or more names of the model's output nodes (applicable only for TensorFlow models).
func (c *Client) ModelGet(keyName string) (data []interface{}, err error) {
var reply interface{}
data = make([]interface{}, 4)
data = make([]interface{}, 8)
args := modelGetFlatArgs(keyName)
reply, err = c.DoOrSend("AI.MODELGET", args, nil)
if err != nil || reply == nil {
return
}
data[0], data[1], data[2], data[3], err = modelGetParseReply(reply)
data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], err = modelGetParseReply(reply)
return
}

Expand Down
53 changes: 33 additions & 20 deletions redisai/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,43 +485,53 @@ func TestCommand_ModelGet(t *testing.T) {
name string
}
tests := []struct {
name string
args args
wantBackend string
wantDevice string
wantTag string
wantData []byte
wantErr bool
name string
args args
wantBackend string
wantDevice string
wantTag string
wantData []byte
wantBatchsize int64
wantMinbatchsize int64
wantInputs []string
wantOutputs []string
wantErr bool
}{
{keyModelUnexistent1, args{keyModelUnexistent1}, BackendTF, DeviceCPU, "", data, true},
{keyModel1, args{keyModel1}, BackendTF, DeviceCPU, "", data, false},
{keyModelUnexistent1, args{keyModelUnexistent1}, BackendTF, DeviceCPU, "", data, 0, 0, nil, nil, true},
{keyModel1, args{keyModel1}, BackendTF, DeviceCPU, "", data, 0, 0, []string{"transaction", "reference"}, []string{"output"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := createTestClient()
gotData, err := client.ModelGet(tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("ModelGetToModel() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("ModelGet() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if !reflect.DeepEqual(gotData[0], tt.wantBackend) {
t.Errorf("ModelGetToModel() gotBackend = %v, want %v. gotBackend Type %v, want Type %v.", gotData[0], tt.wantBackend, reflect.TypeOf(gotData[0]), reflect.TypeOf(tt.wantBackend))
t.Errorf("ModelGet() gotBackend = %v, want %v. gotBackend Type %v, want Type %v.", gotData[0], tt.wantBackend, reflect.TypeOf(gotData[0]), reflect.TypeOf(tt.wantBackend))
}
}
if !tt.wantErr {
if !reflect.DeepEqual(gotData[1], tt.wantDevice) {
t.Errorf("ModelGetToModel() gotDevice = %v, want %v. gotDevice Type %v, want Type %v.", gotData[1], tt.wantDevice, reflect.TypeOf(gotData[1]), reflect.TypeOf(tt.wantDevice))
t.Errorf("ModelGet() gotDevice = %v, want %v. gotDevice Type %v, want Type %v.", gotData[1], tt.wantDevice, reflect.TypeOf(gotData[1]), reflect.TypeOf(tt.wantDevice))
}
}
if !tt.wantErr {
if !reflect.DeepEqual(gotData[2], tt.wantTag) {
t.Errorf("ModelGetToModel() gotTag = %v, want %v. gotTag Type %v, want Type %v.", gotData[2], tt.wantTag, reflect.TypeOf(gotData[2]), reflect.TypeOf(tt.wantTag))
t.Errorf("ModelGet() gotTag = %v, want %v. gotTag Type %v, want Type %v.", gotData[2], tt.wantTag, reflect.TypeOf(gotData[2]), reflect.TypeOf(tt.wantTag))
}
}
if !tt.wantErr {
if !reflect.DeepEqual(gotData[3], tt.wantData) {
t.Errorf("ModelGetToModel() gotData = %v, want %v. gotData Type %v, want Type %v.", gotData[3], tt.wantData, reflect.TypeOf(gotData[3]), reflect.TypeOf(tt.wantData))
t.Errorf("ModelGet() gotData = %v, want %v. gotData Type %v, want Type %v.", gotData[3], tt.wantData, reflect.TypeOf(gotData[3]), reflect.TypeOf(tt.wantData))
}
if !reflect.DeepEqual(gotData[4], tt.wantBatchsize) {
t.Errorf("ModelGet() gotBatchsize = %v, want %v. gotBatchsize Type %v, want Type %v.", gotData[4], tt.wantBatchsize, reflect.TypeOf(gotData[4]), reflect.TypeOf(tt.wantBatchsize))
}
if !reflect.DeepEqual(gotData[5], tt.wantMinbatchsize) {
t.Errorf("ModelGet() gotMinbatchsize = %v, want %v. gotMinbatchsize Type %v, want Type %v.", gotData[5], tt.wantMinbatchsize, reflect.TypeOf(gotData[5]), reflect.TypeOf(tt.wantMinbatchsize))
}
if !reflect.DeepEqual(gotData[6], tt.wantInputs) {
t.Errorf("ModelGet() gotInputs = %v, want %v. gotInputs Type %v, want Type %v.", gotData[6], tt.wantInputs, reflect.TypeOf(gotData[6]), reflect.TypeOf(tt.wantInputs))
}
if !reflect.DeepEqual(gotData[7], tt.wantOutputs) {
t.Errorf("ModelGet() gotOutputs = %v, want %v. gotOutputs Type %v, want Type %v.", gotData[7], tt.wantOutputs, reflect.TypeOf(gotData[7]), reflect.TypeOf(tt.wantOutputs))
}
}

Expand Down Expand Up @@ -654,6 +664,9 @@ func TestCommand_FullFromModelFlow(t *testing.T) {
err = client.ModelGetToModel("financialNet1", model2)
assert.Nil(t, err)
assert.Equal(t, model1.Tag(), model2.Tag())
assert.Equal(t, model1.BatchSize(), model2.BatchSize())
assert.Equal(t, model1.MinBatchSize(), model2.MinBatchSize())
assert.Equal(t, model1.Outputs(), model2.Outputs())
}

func TestCommand_ScriptDel(t *testing.T) {
Expand Down
53 changes: 33 additions & 20 deletions redisai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ func modelSetInterfaceArgs(keyName string, modelInterface ModelInterface) redis.
}

func modelRunFlatArgs(name string, inputTensorNames, outputTensorNames []string) redis.Args {
args := redis.Args{}
args = args.Add(name)
args := redis.Args{name}
if len(inputTensorNames) > 0 {
args = args.Add("INPUTS").AddFlat(inputTensorNames)
}
Expand All @@ -83,53 +82,67 @@ func modelRunFlatArgs(name string, inputTensorNames, outputTensorNames []string)
}

func modelGetParseToInterface(reply interface{}, model ModelInterface) (err error) {
var backend string
var device string
var tag string
var blob []byte
backend, device, tag, blob, err = modelGetParseReply(reply)
backend, device, tag, blob, batchsize, minbatchsize, inputs, outputs, err := modelGetParseReply(reply)
if err != nil {
return err
}
model.SetBackend(backend)
model.SetDevice(device)
model.SetTag(tag)
model.SetBlob(blob)
model.SetBatchSize(batchsize)
model.SetMinBatchSize(minbatchsize)
model.SetInputs(inputs)
model.SetOutputs(outputs)
return
}

func modelGetParseReply(reply interface{}) (backend string, device string, tag string, blob []byte, err error) {
func modelGetParseReply(reply interface{}) (backend string, device string, tag string, blob []byte, batchsize int64, minbatchsize int64, inputs []string, outputs []string, err error) {
var replySlice []interface{}
var key string
inputs = nil
outputs = nil
replySlice, err = redis.Values(reply, err)
if err != nil {
return
}
for pos := 0; pos < len(replySlice); pos += 2 {
// we need this condition for after parsing err check
if err != nil {
break
}
key, err = redis.String(replySlice[pos], err)
if err != nil {
return
break
}
switch key {
case "backend":
backend, err = redis.String(replySlice[pos+1], err)
if err != nil {
return
}
case "device":
device, err = redis.String(replySlice[pos+1], err)
if err != nil {
return
}
case "blob":
blob, err = redis.Bytes(replySlice[pos+1], err)
if err != nil {
return
}
case "tag":
tag, err = redis.String(replySlice[pos+1], err)
if err != nil {
return
case "batchsize":
batchsize, err = redis.Int64(replySlice[pos+1], err)
case "minbatchsize":
minbatchsize, err = redis.Int64(replySlice[pos+1], err)
case "inputs":
// we need to create a temporary slice given redis.Strings creates by default a slice with capacity of the input slice even if it can't be parsed
// so the solution is to only use the replied slice of redis.Strings in case of success. Otherwise you can have inputs filled with []string(nil)
var temporaryInputs []string
temporaryInputs, err = redis.Strings(replySlice[pos+1], err)
if err == nil {
inputs = temporaryInputs
}
case "outputs":
// we need to create a temporary slice given redis.Strings creates by default a slice with capacity of the input slice even if it can't be parsed
// so the solution is to only use the replied slice of redis.Strings in case of success. Otherwise you can have outputs filled with []string(nil)
var temporaryOutputs []string
temporaryOutputs, err = redis.Strings(replySlice[pos+1], err)
if err == nil {
outputs = temporaryOutputs
}
}
}
Expand Down
58 changes: 41 additions & 17 deletions redisai/model_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package redisai

import (
"github.com/stretchr/testify/assert"
"reflect"
"testing"
)
Expand All @@ -10,37 +11,60 @@ func Test_modelGetParseReply(t *testing.T) {
reply interface{}
}
tests := []struct {
name string
args args
wantBackend string
wantDevice string
wantTag string
wantBlob []byte
wantErr bool
name string
args args
wantBackend string
wantDevice string
wantTag string
wantBlob []byte
wantBatchsize int64
wantMinbatchsize int64
wantInputs []string
wantOutputs []string
wantErr bool
}{
{"empty", args{}, "", "", "", nil, true},
{"negative-wrong-reply", args{[]interface{}{[]interface{}{[]byte("serie 1"), []interface{}{}, []interface{}{[]interface{}{[]byte("AA"), []byte("1")}}}}}, "", "", "", nil, true},
{"negative-wrong-reply", args{[]interface{}{[]byte("dtype"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, true},
{"negative-wrong-device", args{[]interface{}{[]byte("device"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, true},
{"negative-wrong-blob", args{[]interface{}{[]byte("blob"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, true},
{"empty", args{}, "", "", "", nil, 0, 0, nil, nil, true},
{"negative-wrong-reply", args{[]interface{}{[]interface{}{[]byte("serie 1"), []interface{}{}, []interface{}{[]interface{}{[]byte("AA"), []byte("1")}}}}}, "", "", "", nil, 0, 0, nil, nil, true},
{"negative-wrong-reply", args{[]interface{}{[]byte("dtype"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, 0, 0, nil, nil, true},
{"positive-backend", args{[]interface{}{[]byte("backend"), []byte("TF")}}, "TF", "", "", nil, 0, 0, nil, nil, false},
{"negative-wrong-device", args{[]interface{}{[]byte("device"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, 0, 0, nil, nil, true},
{"positive-device", args{[]interface{}{[]byte("device"), []byte(DeviceGPU)}}, "", DeviceGPU, "", nil, 0, 0, nil, nil, false},
{"negative-wrong-batchsize", args{[]interface{}{[]byte("batchsize"), []interface{}{[]byte("1")}}}, "", "", "", nil, 0, 0, nil, nil, true},
{"positive-batchsize", args{[]interface{}{[]byte("batchsize"), int64(1)}}, "", "", "", nil, 1, 0, nil, nil, false},
{"negative-wrong-minbatchsize", args{[]interface{}{[]byte("minbatchsize"), []interface{}{[]byte("1")}}}, "", "", "", nil, 0, 0, nil, nil, true},
{"positive-minbatchsize", args{[]interface{}{[]byte("minbatchsize"), int64(1)}}, "", "", "", nil, 0, 1, nil, nil, false},
{"negative-wrong-inputs", args{[]interface{}{[]byte("inputs"), []interface{}{[]interface{}{[]byte("bar"), []byte("foo")}}}}, "", "", "", nil, 0, 0, nil, nil, true},
{"positive-inputs", args{[]interface{}{[]byte("inputs"), []interface{}{[]byte("bar"), []byte("foo")}}}, "", "", "", nil, 0, 0, []string{"bar", "foo"}, nil, false},
{"negative-wrong-output", args{[]interface{}{[]byte("output"), []interface{}{[]interface{}{[]byte("output")}}}}, "", "", "", nil, 0, 0, nil, nil, true},
{"positive-output", args{[]interface{}{[]byte("outputs"), []interface{}{[]byte("output")}}}, "", "", "", nil, 0, 0, nil, []string{"output"}, false},
{"negative-wrong-blob", args{[]interface{}{[]byte("blob"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, 0, 0, nil, nil, true},
{"positive-blob", args{[]interface{}{[]byte("blob"), []byte("blob")}}, "", "", "", []byte("blob"), 0, 0, nil, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotBackend, gotDevice, gotTag, gotBlob, gotErr := modelGetParseReply(tt.args.reply)
gotBackend, gotDevice, gotTag, gotBlob, gotBatchsize, gotMinbatchsize, gotInputs, gotOutputs, gotErr := modelGetParseReply(tt.args.reply)
if gotErr != nil && !tt.wantErr {
t.Errorf("modelGetParseReply() gotErr = %v, want %v", gotErr, tt.wantErr)
}
if gotBackend != tt.wantBackend {
t.Errorf("modelGetParseReply() gotBackend = %v, want %v", gotBackend, tt.wantBackend)
t.Errorf("modelGetParseReply() gotBackend = %v, want %v. gotErr = %v", gotBackend, tt.wantBackend, gotErr)
}
if gotDevice != tt.wantDevice {
t.Errorf("modelGetParseReply() gotDevice = %v, want %v", gotDevice, tt.wantDevice)
t.Errorf("modelGetParseReply() gotDevice = %v, want %v. gotErr = %v", gotDevice, tt.wantDevice, gotErr)
}
if gotTag != tt.wantTag {
t.Errorf("modelGetParseReply() gotTag = %v, want %v", gotTag, tt.wantTag)
t.Errorf("modelGetParseReply() gotTag = %v, want %v. gotErr = %v", gotTag, tt.wantTag, gotErr)
}
if gotBatchsize != tt.wantBatchsize {
t.Errorf("modelGetParseReply() gotBatchsize = %v, want %v. gotErr = %v", gotBatchsize, tt.wantBatchsize, gotErr)
}
if gotMinbatchsize != tt.wantMinbatchsize {
t.Errorf("modelGetParseReply() gotMinbatchsize = %v, want %v. gotErr = %v", gotMinbatchsize, tt.wantMinbatchsize, gotErr)
}
assert.EqualValues(t, gotInputs, tt.wantInputs, "modelGetParseReply() gotInputs = %v, want %v. gotErr = %v", gotInputs, tt.wantInputs, gotErr)
assert.EqualValues(t, gotOutputs, tt.wantOutputs, "modelGetParseReply() gotOutputs = %v, want %v. gotErr = %v", gotOutputs, tt.wantOutputs, gotErr)
if !reflect.DeepEqual(gotBlob, tt.wantBlob) {
t.Errorf("modelGetParseReply() gotBlob = %v, want %v", gotBlob, tt.wantBlob)
t.Errorf("modelGetParseReply() gotBlob = %v, want %v. gotErr = %v", gotBlob, tt.wantBlob, gotErr)
}
})
}
Expand Down