Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit e2905cd

Browse files
committed
function: make array_length not fail with literal null
Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
1 parent bc23348 commit e2905cd

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

engine_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,14 @@ var queries = []struct {
12301230
`SELECT (NULL+1)`,
12311231
[]sql.Row{{nil}},
12321232
},
1233+
{
1234+
`SELECT ARRAY_LENGTH(null)`,
1235+
[]sql.Row{{nil}},
1236+
},
1237+
{
1238+
`SELECT ARRAY_LENGTH("foo")`,
1239+
[]sql.Row{{nil}},
1240+
},
12331241
}
12341242

12351243
func TestQueries(t *testing.T) {

sql/expression/function/arraylength.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package function // import "github.com/src-d/go-mysql-server/sql/expression/func
22

33
import (
44
"fmt"
5-
"reflect"
65

76
"github.com/src-d/go-mysql-server/sql"
87
"github.com/src-d/go-mysql-server/sql/expression"
@@ -39,7 +38,7 @@ func (f *ArrayLength) TransformUp(fn sql.TransformExprFunc) (sql.Expression, err
3938
// Eval implements the Expression interface.
4039
func (f *ArrayLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
4140
if t := f.Child.Type(); !sql.IsArray(t) && t != sql.JSON {
42-
return nil, sql.ErrInvalidType.New(f.Child.Type().Type().String())
41+
return nil, nil
4342
}
4443

4544
child, err := f.Child.Eval(ctx, row)
@@ -53,7 +52,7 @@ func (f *ArrayLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
5352

5453
array, ok := child.([]interface{})
5554
if !ok {
56-
return nil, sql.ErrInvalidType.New(reflect.TypeOf(child))
55+
return nil, nil
5756
}
5857

5958
return int32(len(array)), nil

sql/expression/function/arraylength_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ package function
33
import (
44
"testing"
55

6-
"github.com/stretchr/testify/require"
7-
errors "gopkg.in/src-d/go-errors.v1"
86
"github.com/src-d/go-mysql-server/sql"
97
"github.com/src-d/go-mysql-server/sql/expression"
8+
"github.com/stretchr/testify/require"
9+
errors "gopkg.in/src-d/go-errors.v1"
1010
)
1111

1212
func TestArrayLength(t *testing.T) {
@@ -19,7 +19,7 @@ func TestArrayLength(t *testing.T) {
1919
err *errors.Kind
2020
}{
2121
{"array is nil", sql.NewRow(nil), nil, nil},
22-
{"array is not of right type", sql.NewRow(5), nil, sql.ErrInvalidType},
22+
{"array is not of right type", sql.NewRow(5), nil, nil},
2323
{"array is ok", sql.NewRow([]interface{}{1, 2, 3}), int32(3), nil},
2424
}
2525

@@ -40,7 +40,7 @@ func TestArrayLength(t *testing.T) {
4040

4141
f = NewArrayLength(expression.NewGetField(0, sql.Tuple(sql.Int64, sql.Int64), "", false))
4242
require := require.New(t)
43-
_, err := f.Eval(sql.NewEmptyContext(), []interface{}{int64(1), int64(2)})
44-
require.Error(err)
45-
require.True(sql.ErrInvalidType.Is(err))
43+
v, err := f.Eval(sql.NewEmptyContext(), []interface{}{int64(1), int64(2)})
44+
require.NoError(err)
45+
require.Nil(v)
4646
}

0 commit comments

Comments
 (0)