Skip to content

Commit 7ab9e78

Browse files
authored
Merge pull request #2455 from dolthub/fulghum/com_binlog_prototype
Feature: `gtid_subtract()` function
2 parents cf70da4 + cf2a2bb commit 7ab9e78

File tree

3 files changed

+273
-0
lines changed

3 files changed

+273
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/vitess/go/mysql"
21+
22+
"github.com/dolthub/go-mysql-server/sql"
23+
"github.com/dolthub/go-mysql-server/sql/types"
24+
)
25+
26+
// GtidSubtract implements MySQL's built-in gtid_subtract() function.
27+
// https://dev.mysql.com/doc/refman/8.0/en/gtid-functions.html#function_gtid-subtract
28+
type GtidSubtract struct {
29+
gtid1 sql.Expression
30+
gtid2 sql.Expression
31+
}
32+
33+
var _ sql.FunctionExpression = (*GtidSubtract)(nil)
34+
var _ sql.CollationCoercible = (*GtidSubtract)(nil)
35+
36+
func NewGtidSubtract(gtid1, gtid2 sql.Expression) sql.Expression {
37+
return &GtidSubtract{gtid1, gtid2}
38+
}
39+
40+
// FunctionName implements sql.FunctionExpression
41+
func (gs *GtidSubtract) FunctionName() string {
42+
return "gtid_subtract"
43+
}
44+
45+
// Description implements sql.FunctionExpression
46+
func (gs *GtidSubtract) Description() string {
47+
return "Given two sets of global transaction identifiers set1 and set2, " +
48+
"returns only those GTIDs from set1 that are not in set2. Returns NULL if set1 or set2 is NULL."
49+
}
50+
51+
// Type implements the Expression interface.
52+
func (gs *GtidSubtract) Type() sql.Type { return types.LongText }
53+
54+
// CollationCoercibility implements the interface sql.CollationCoercible.
55+
func (gs *GtidSubtract) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
56+
collation, coercibility = sql.GetCoercibility(ctx, gs.gtid1)
57+
nextCollation, nextCoercibility := sql.GetCoercibility(ctx, gs.gtid2)
58+
return sql.ResolveCoercibility(collation, coercibility, nextCollation, nextCoercibility)
59+
}
60+
61+
// IsNullable implements the Expression interface.
62+
func (gs *GtidSubtract) IsNullable() bool {
63+
return gs.gtid1.IsNullable() || gs.gtid2.IsNullable()
64+
}
65+
66+
func (gs *GtidSubtract) String() string {
67+
return fmt.Sprintf("%s(%s, %s)", gs.FunctionName(), gs.gtid1, gs.gtid2)
68+
}
69+
70+
func (gs *GtidSubtract) DebugString() string {
71+
return fmt.Sprintf("%s(%s, %s)", gs.FunctionName(), gs.gtid1, gs.gtid2)
72+
}
73+
74+
// WithChildren implements the Expression interface.
75+
func (gs *GtidSubtract) WithChildren(children ...sql.Expression) (sql.Expression, error) {
76+
if len(children) != 2 {
77+
return nil, sql.ErrInvalidChildrenNumber.New(gs, len(children), 2)
78+
}
79+
return NewGtidSubtract(children[0], children[1]), nil
80+
}
81+
82+
// Resolved implements the Expression interface.
83+
func (gs *GtidSubtract) Resolved() bool {
84+
return gs.gtid1.Resolved() && gs.gtid2.Resolved()
85+
}
86+
87+
// Children implements the Expression interface.
88+
func (gs *GtidSubtract) Children() []sql.Expression {
89+
return []sql.Expression{gs.gtid1, gs.gtid2}
90+
}
91+
92+
// Eval implements the Expression interface.
93+
func (gs *GtidSubtract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
94+
if gs.gtid1 == nil || gs.gtid2 == nil {
95+
return nil, nil
96+
}
97+
98+
left, err := gs.gtid1.Eval(ctx, row)
99+
if err != nil {
100+
return nil, err
101+
}
102+
if left == nil {
103+
return nil, nil
104+
}
105+
106+
right, err := gs.gtid2.Eval(ctx, row)
107+
if err != nil {
108+
return nil, err
109+
}
110+
if right == nil {
111+
return nil, nil
112+
}
113+
114+
if _, ok := left.(string); !ok {
115+
return nil, sql.ErrInvalidType.New(gs.gtid1)
116+
}
117+
if _, ok := right.(string); !ok {
118+
return nil, sql.ErrInvalidType.New(gs.gtid2)
119+
}
120+
121+
gtidSet1, err := mysql.ParseMysql56GTIDSet(left.(string))
122+
if err != nil {
123+
return nil, err
124+
}
125+
126+
gtidSet2, err := mysql.ParseMysql56GTIDSet(right.(string))
127+
if err != nil {
128+
return nil, err
129+
}
130+
131+
newGtidSet := gtidSet1.Subtract(gtidSet2)
132+
return newGtidSet.String(), nil
133+
}
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"testing"
19+
20+
"github.com/dolthub/vitess/go/sqltypes"
21+
"github.com/stretchr/testify/require"
22+
23+
"github.com/dolthub/go-mysql-server/sql"
24+
"github.com/dolthub/go-mysql-server/sql/expression"
25+
"github.com/dolthub/go-mysql-server/sql/types"
26+
)
27+
28+
func TestGtidSubtract(t *testing.T) {
29+
tests := []struct {
30+
left, right sql.Expression
31+
expected any
32+
error string
33+
}{
34+
// NULL cases
35+
{
36+
left: nil,
37+
right: nil,
38+
expected: nil,
39+
},
40+
{
41+
left: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
42+
right: nil,
43+
expected: nil,
44+
},
45+
{
46+
left: nil,
47+
right: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
48+
expected: nil,
49+
},
50+
{
51+
left: expression.NewLiteral(nil, types.Null),
52+
right: expression.NewLiteral(nil, types.Null),
53+
expected: nil,
54+
},
55+
{
56+
left: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
57+
right: expression.NewLiteral(nil, types.Null),
58+
expected: nil,
59+
},
60+
{
61+
left: expression.NewLiteral(nil, types.Null),
62+
right: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
63+
expected: nil,
64+
},
65+
66+
// Error cases
67+
{
68+
left: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
69+
right: newStringLiteral("not a parseable SID:not a valid interval"),
70+
error: "invalid MySQL 5.6 GTID set (\"not a parseable SID:not a valid interval\"): invalid MySQL 5.6 SID \"not a parseable SID\"",
71+
},
72+
{
73+
left: expression.NewLiteral(42, types.Int32),
74+
right: expression.NewLiteral(42, types.Int32),
75+
error: "invalid type: 42",
76+
},
77+
{
78+
left: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
79+
right: expression.NewLiteral(42, types.Int32),
80+
error: "invalid type: 42",
81+
},
82+
{
83+
left: expression.NewLiteral(42, types.Int32),
84+
right: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
85+
error: "invalid type: 42",
86+
},
87+
88+
// MySQL documentation cases
89+
{
90+
left: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
91+
right: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21"),
92+
expected: "3e11fa47-71ca-11e1-9e33-c80aa9429562:22-57",
93+
},
94+
{
95+
left: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
96+
right: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:20-25"),
97+
expected: "3e11fa47-71ca-11e1-9e33-c80aa9429562:26-57",
98+
},
99+
{
100+
left: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
101+
right: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:23-24"),
102+
expected: "3e11fa47-71ca-11e1-9e33-c80aa9429562:21-22:25-57",
103+
},
104+
{
105+
left: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
106+
right: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
107+
expected: "",
108+
},
109+
110+
// Additional cases
111+
{
112+
left: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
113+
right: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:20-21"),
114+
expected: "3e11fa47-71ca-11e1-9e33-c80aa9429562:22-57",
115+
},
116+
{
117+
left: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57"),
118+
right: newStringLiteral("3E11FA47-71CA-11E1-9E33-C80AA9429562:57-58"),
119+
expected: "3e11fa47-71ca-11e1-9e33-c80aa9429562:21-56",
120+
},
121+
}
122+
123+
for _, test := range tests {
124+
f := NewGtidSubtract(test.left, test.right)
125+
t.Run(f.String(), func(t *testing.T) {
126+
res, err := f.Eval(sql.NewEmptyContext(), nil)
127+
if test.error != "" {
128+
require.Equal(t, test.error, err.Error())
129+
} else {
130+
require.NoError(t, err)
131+
require.Equal(t, test.expected, res)
132+
}
133+
})
134+
}
135+
}
136+
137+
func newStringLiteral(s string) sql.Expression {
138+
return expression.NewLiteral(s, types.MustCreateStringWithDefaults(sqltypes.VarChar, 100))
139+
}

sql/expression/function/registry.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ var BuiltIns = []sql.Function{
9999
sql.Function1{Name: "from_unixtime", Fn: NewFromUnixtime},
100100
sql.FunctionN{Name: "greatest", Fn: NewGreatest},
101101
sql.Function0{Name: "group_concat", Fn: aggregation.NewEmptyGroupConcat},
102+
sql.Function2{Name: "gtid_subtract", Fn: NewGtidSubtract},
102103
sql.Function1{Name: "hex", Fn: NewHex},
103104
sql.Function1{Name: "hour", Fn: NewHour},
104105
sql.Function3{Name: "if", Fn: NewIf},

0 commit comments

Comments
 (0)