Skip to content

Commit

Permalink
allowed zero argument in typeinferer (pingcap#3137)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhexuany authored and coocood committed Apr 25, 2017
1 parent 5c707ec commit 8aab91f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
4 changes: 2 additions & 2 deletions parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -2881,9 +2881,9 @@ FunctionCallKeyword:
}

FunctionCallNonKeyword:
"ABS" '(' Expression ')'
"ABS" '(' ExpressionListOpt ')'
{
$$ = &ast.FuncCallExpr{FnName: model.NewCIStr($1), Args: []ast.ExprNode{$3.(ast.ExprNode)}}
$$ = &ast.FuncCallExpr{FnName: model.NewCIStr($1), Args: $3.([]ast.ExprNode)}
}
| "ADDTIME" '(' Expression ',' Expression ')'
{
Expand Down
29 changes: 23 additions & 6 deletions plan/typeinferer.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,19 +292,30 @@ func (v *typeInferrer) getFsp(x *ast.FuncCallExpr) int {
return 0
}

// handleFuncCallExpr ...
// TODO: (zhexuany) this function contains too much redundant things. Maybe replace with a map like
// we did for error in mysql package.
func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) {
var (
tp *types.FieldType
chs = charset.CharsetBin
)
switch x.FnName.L {
case ast.Abs, ast.Ifnull, ast.Nullif:
if len(x.Args) == 0 {
tp = types.NewFieldType(mysql.TypeNull)
break
}
tp = x.Args[0].GetType()
// TODO: We should cover all types.
if x.FnName.L == ast.Abs && tp.Tp == mysql.TypeDatetime {
tp = types.NewFieldType(mysql.TypeDouble)
}
case ast.Round:
if len(x.Args) == 0 {
tp = types.NewFieldType(mysql.TypeNull)
break
}
t := x.Args[0].GetType().Tp
switch t {
case mysql.TypeBit, mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLonglong:
Expand All @@ -323,15 +334,21 @@ func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) {
for i := 1; i < len(x.Args); i++ {
tp = mergeCmpType(tp, x.Args[i].GetType())
}
} else {
tp = types.NewFieldType(mysql.TypeNull)
}
case ast.Ceil, ast.Ceiling, ast.Floor:
t := x.Args[0].GetType().Tp
if t == mysql.TypeNull || t == mysql.TypeFloat || t == mysql.TypeDouble || t == mysql.TypeVarchar ||
t == mysql.TypeTinyBlob || t == mysql.TypeMediumBlob || t == mysql.TypeLongBlob ||
t == mysql.TypeBlob || t == mysql.TypeVarString || t == mysql.TypeString {
tp = types.NewFieldType(mysql.TypeDouble)
if len(x.Args) > 0 {
t := x.Args[0].GetType().Tp
if t == mysql.TypeNull || t == mysql.TypeFloat || t == mysql.TypeDouble || t == mysql.TypeVarchar ||
t == mysql.TypeTinyBlob || t == mysql.TypeMediumBlob || t == mysql.TypeLongBlob ||
t == mysql.TypeBlob || t == mysql.TypeVarString || t == mysql.TypeString {
tp = types.NewFieldType(mysql.TypeDouble)
} else {
tp = types.NewFieldType(mysql.TypeLonglong)
}
} else {
tp = types.NewFieldType(mysql.TypeLonglong)
tp = types.NewFieldType(mysql.TypeNull)
}
case ast.FromUnixTime:
if len(x.Args) == 1 {
Expand Down
7 changes: 7 additions & 0 deletions plan/typeinferer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) {
// Functions
{"version()", mysql.TypeVarString, charset.CharsetUTF8, 0},
{"count(c_int)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
{"abs()", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag},
{"abs(1)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
{"abs(1.1)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag},
{"abs(cast(\"20150817015609\" as DATETIME))", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag},
Expand Down Expand Up @@ -225,6 +226,7 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) {
{"interval(1, 2, 3)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
{"interval(1.0, 2.0, 3.0)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
{"interval('1', '2', '3')", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
// {"round()", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag},
{"round(null, 2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag},
{"round('1.2', 2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag},
{"round(1e2, 2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag},
Expand Down Expand Up @@ -252,6 +254,11 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) {
{"unix_timestamp('2015-11-13 10:20:19')", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
{"to_days('2015-11-13')", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
{"to_days(950501)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
// {"ceiling()", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag},
{"ceiling(1.23)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
// {"ceil()", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag},
{"ceil(1.23)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
// {"floor()", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag},
{"floor(1.23)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
{"field('foo', null)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
{"find_in_set('foo', 'foo,bar')", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
Expand Down

0 comments on commit 8aab91f

Please sign in to comment.