Skip to content

Commit 1c3c597

Browse files
committed
sql/expression/function: add concat and split functions
Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
1 parent f3a82f0 commit 1c3c597

File tree

4 files changed

+323
-0
lines changed

4 files changed

+323
-0
lines changed

sql/expression/function/concat.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package function
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
errors "gopkg.in/src-d/go-errors.v0"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql"
9+
)
10+
11+
// Concat joins several strings together.
12+
type Concat struct {
13+
args []sql.Expression
14+
}
15+
16+
// ErrConcatArrayWithOthers is returned when there are more than 1 argument in
17+
// concat and any of them is an array.
18+
var ErrConcatArrayWithOthers = errors.NewKind("can't concat a string array with any other elements")
19+
20+
// NewConcat creates a new Concat UDF.
21+
func NewConcat(args ...sql.Expression) (sql.Expression, error) {
22+
if len(args) == 0 {
23+
return nil, sql.ErrInvalidArgumentNumber.New("1 or more", 0)
24+
}
25+
26+
for _, arg := range args {
27+
if len(args) > 1 && sql.IsArray(arg.Type()) {
28+
return nil, ErrConcatArrayWithOthers.New()
29+
}
30+
31+
if sql.IsTuple(arg.Type()) {
32+
return nil, sql.ErrInvalidType.New("tuple")
33+
}
34+
}
35+
36+
return &Concat{args}, nil
37+
}
38+
39+
// Type implements the Expression interface.
40+
func (f *Concat) Type() sql.Type { return sql.Text }
41+
42+
// IsNullable implements the Expression interface.
43+
func (f *Concat) IsNullable() bool {
44+
for _, arg := range f.args {
45+
if arg.IsNullable() {
46+
return true
47+
}
48+
}
49+
return false
50+
}
51+
52+
func (f *Concat) String() string {
53+
var args = make([]string, len(f.args))
54+
for i, arg := range f.args {
55+
args[i] = arg.String()
56+
}
57+
return fmt.Sprintf("concat(%s)", strings.Join(args, ", "))
58+
}
59+
60+
// TransformUp implements the Expression interface.
61+
func (f *Concat) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
62+
var args = make([]sql.Expression, len(f.args))
63+
for i, arg := range f.args {
64+
arg, err := arg.TransformUp(fn)
65+
if err != nil {
66+
return nil, err
67+
}
68+
args[i] = arg
69+
}
70+
return fn(&Concat{args})
71+
}
72+
73+
// Resolved implements the Expression interface.
74+
func (f *Concat) Resolved() bool {
75+
for _, arg := range f.args {
76+
if !arg.Resolved() {
77+
return false
78+
}
79+
}
80+
return true
81+
}
82+
83+
// Children implements the Expression interface.
84+
func (f *Concat) Children() []sql.Expression { return f.args }
85+
86+
// Eval implements the Expression interface.
87+
func (f *Concat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
88+
var parts []string
89+
90+
for _, arg := range f.args {
91+
val, err := arg.Eval(ctx, row)
92+
if err != nil {
93+
return nil, err
94+
}
95+
96+
if val == nil {
97+
return nil, nil
98+
}
99+
100+
if sql.IsArray(arg.Type()) {
101+
val, err = sql.Array(sql.Text).Convert(val)
102+
if err != nil {
103+
return nil, err
104+
}
105+
106+
for _, v := range val.([]interface{}) {
107+
parts = append(parts, v.(string))
108+
}
109+
} else {
110+
val, err = sql.Text.Convert(val)
111+
if err != nil {
112+
return nil, err
113+
}
114+
115+
parts = append(parts, val.(string))
116+
}
117+
}
118+
119+
return strings.Join(parts, ""), nil
120+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
9+
)
10+
11+
func TestConcat(t *testing.T) {
12+
t.Run("concat multiple arguments", func(t *testing.T) {
13+
require := require.New(t)
14+
f, err := NewConcat(
15+
expression.NewLiteral("foo", sql.Text),
16+
expression.NewLiteral(5, sql.Text),
17+
expression.NewLiteral(true, sql.Boolean),
18+
)
19+
require.NoError(err)
20+
21+
v, err := f.Eval(sql.NewEmptyContext(), nil)
22+
require.NoError(err)
23+
require.Equal("foo5true", v)
24+
})
25+
26+
t.Run("some argument is nil", func(t *testing.T) {
27+
require := require.New(t)
28+
f, err := NewConcat(
29+
expression.NewLiteral("foo", sql.Text),
30+
expression.NewLiteral(nil, sql.Text),
31+
expression.NewLiteral(true, sql.Boolean),
32+
)
33+
require.NoError(err)
34+
35+
v, err := f.Eval(sql.NewEmptyContext(), nil)
36+
require.NoError(err)
37+
require.Equal(nil, v)
38+
})
39+
40+
t.Run("concat array", func(t *testing.T) {
41+
require := require.New(t)
42+
f, err := NewConcat(
43+
expression.NewLiteral([]interface{}{5, "bar", true}, sql.Array(sql.Text)),
44+
)
45+
require.NoError(err)
46+
47+
v, err := f.Eval(sql.NewEmptyContext(), nil)
48+
require.NoError(err)
49+
require.Equal("5bartrue", v)
50+
})
51+
}
52+
53+
func TestNewConcat(t *testing.T) {
54+
require := require.New(t)
55+
56+
_, err := NewConcat(expression.NewLiteral(nil, sql.Array(sql.Text)))
57+
require.NoError(err)
58+
59+
_, err = NewConcat(expression.NewLiteral(nil, sql.Array(sql.Text)), expression.NewLiteral(nil, sql.Int64))
60+
require.Error(err)
61+
require.True(ErrConcatArrayWithOthers.Is(err))
62+
63+
_, err = NewConcat(expression.NewLiteral(nil, sql.Tuple(sql.Text, sql.Text)))
64+
require.Error(err)
65+
require.True(sql.ErrInvalidType.Is(err))
66+
67+
_, err = NewConcat(
68+
expression.NewLiteral(nil, sql.Text),
69+
expression.NewLiteral(nil, sql.Boolean),
70+
expression.NewLiteral(nil, sql.Int64),
71+
expression.NewLiteral(nil, sql.Text),
72+
)
73+
require.NoError(err)
74+
}

sql/expression/function/split.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package function
2+
3+
import (
4+
"fmt"
5+
"regexp"
6+
7+
"gopkg.in/src-d/go-mysql-server.v0/sql"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
9+
)
10+
11+
// Split receives a string and returns the parts of it splitted by a
12+
// delimiter.
13+
type Split struct {
14+
expression.BinaryExpression
15+
}
16+
17+
// NewSplit creates a new Split UDF.
18+
func NewSplit(str, delimiter sql.Expression) sql.Expression {
19+
return &Split{expression.BinaryExpression{
20+
Left: str,
21+
Right: delimiter,
22+
}}
23+
}
24+
25+
// Eval implements the Expression interface.
26+
func (f *Split) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
27+
left, err := f.Left.Eval(ctx, row)
28+
if err != nil {
29+
return nil, err
30+
}
31+
32+
if left == nil {
33+
return nil, nil
34+
}
35+
36+
left, err = sql.Text.Convert(left)
37+
if err != nil {
38+
return nil, err
39+
}
40+
41+
right, err := f.Right.Eval(ctx, row)
42+
if err != nil {
43+
return nil, err
44+
}
45+
46+
if right == nil {
47+
return nil, nil
48+
}
49+
50+
right, err = sql.Text.Convert(right)
51+
if err != nil {
52+
return nil, err
53+
}
54+
55+
re, err := regexp.Compile(right.(string))
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
parts := re.Split(left.(string), -1)
61+
var result = make([]interface{}, len(parts))
62+
for i, part := range parts {
63+
result[i] = part
64+
}
65+
66+
return result, nil
67+
}
68+
69+
// Type implements the Expression interface.
70+
func (*Split) Type() sql.Type { return sql.Array(sql.Text) }
71+
72+
// IsNullable implements the Expression interface.
73+
func (f *Split) IsNullable() bool { return f.Left.IsNullable() || f.Right.IsNullable() }
74+
75+
func (f *Split) String() string {
76+
return fmt.Sprintf("split(%s, %s)", f.Left, f.Right)
77+
}
78+
79+
// TransformUp implements the Expression interface.
80+
func (f *Split) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
81+
left, err := f.Left.TransformUp(fn)
82+
if err != nil {
83+
return nil, err
84+
}
85+
86+
right, err := f.Right.TransformUp(fn)
87+
if err != nil {
88+
return nil, err
89+
}
90+
91+
return fn(NewSplit(left, right))
92+
}

sql/expression/function/split_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
9+
)
10+
11+
func TestSplit(t *testing.T) {
12+
testCases := []struct {
13+
name string
14+
input interface{}
15+
delimiter interface{}
16+
expected interface{}
17+
}{
18+
{"has delimiter", "a-b-c", "-", []interface{}{"a", "b", "c"}},
19+
{"regexp delimiter", "a--b----c-d", "-+", []interface{}{"a", "b", "c", "d"}},
20+
{"does not have delimiter", "a.b.c", "-", []interface{}{"a.b.c"}},
21+
{"input is nil", nil, "-", nil},
22+
{"delimiter is nil", "a-b-c", nil, nil},
23+
}
24+
25+
f := NewSplit(
26+
expression.NewGetField(0, sql.Text, "input", true),
27+
expression.NewGetField(1, sql.Text, "delimiter", true),
28+
)
29+
30+
for _, tt := range testCases {
31+
t.Run(tt.name, func(t *testing.T) {
32+
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tt.input, tt.delimiter))
33+
require.NoError(t, err)
34+
require.Equal(t, tt.expected, v)
35+
})
36+
}
37+
}

0 commit comments

Comments
 (0)