diff --git a/parser/parser.y b/parser/parser.y index 47e62547cf8d6..15480fc1cab1b 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -2881,9 +2881,9 @@ FunctionCallKeyword: } FunctionCallNonKeyword: - "ABS" '(' Expression ')' + "ABS" '(' ExpressionListOpt ')' { - $$ = &ast.FuncCallExpr{FnName: model.NewCIStr($1), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: model.NewCIStr($1), Args: $3.([]ast.ExprNode)} } | "ADDTIME" '(' Expression ',' Expression ')' { diff --git a/plan/typeinferer.go b/plan/typeinferer.go index db035e2fb3bfa..1b4c20a7ced8d 100644 --- a/plan/typeinferer.go +++ b/plan/typeinferer.go @@ -292,6 +292,9 @@ func (v *typeInferrer) getFsp(x *ast.FuncCallExpr) int { return 0 } +// handleFuncCallExpr ... +// TODO: (zhexuany) this function contains too much redundant things. Maybe replace with a map like +// we did for error in mysql package. func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) { var ( tp *types.FieldType @@ -299,12 +302,20 @@ func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) { ) switch x.FnName.L { case ast.Abs, ast.Ifnull, ast.Nullif: + if len(x.Args) == 0 { + tp = types.NewFieldType(mysql.TypeNull) + break + } tp = x.Args[0].GetType() // TODO: We should cover all types. if x.FnName.L == ast.Abs && tp.Tp == mysql.TypeDatetime { tp = types.NewFieldType(mysql.TypeDouble) } case ast.Round: + if len(x.Args) == 0 { + tp = types.NewFieldType(mysql.TypeNull) + break + } t := x.Args[0].GetType().Tp switch t { case mysql.TypeBit, mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLonglong: @@ -323,15 +334,21 @@ func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) { for i := 1; i < len(x.Args); i++ { tp = mergeCmpType(tp, x.Args[i].GetType()) } + } else { + tp = types.NewFieldType(mysql.TypeNull) } case ast.Ceil, ast.Ceiling, ast.Floor: - t := x.Args[0].GetType().Tp - if t == mysql.TypeNull || t == mysql.TypeFloat || t == mysql.TypeDouble || t == mysql.TypeVarchar || - t == mysql.TypeTinyBlob || t == mysql.TypeMediumBlob || t == mysql.TypeLongBlob || - t == mysql.TypeBlob || t == mysql.TypeVarString || t == mysql.TypeString { - tp = types.NewFieldType(mysql.TypeDouble) + if len(x.Args) > 0 { + t := x.Args[0].GetType().Tp + if t == mysql.TypeNull || t == mysql.TypeFloat || t == mysql.TypeDouble || t == mysql.TypeVarchar || + t == mysql.TypeTinyBlob || t == mysql.TypeMediumBlob || t == mysql.TypeLongBlob || + t == mysql.TypeBlob || t == mysql.TypeVarString || t == mysql.TypeString { + tp = types.NewFieldType(mysql.TypeDouble) + } else { + tp = types.NewFieldType(mysql.TypeLonglong) + } } else { - tp = types.NewFieldType(mysql.TypeLonglong) + tp = types.NewFieldType(mysql.TypeNull) } case ast.FromUnixTime: if len(x.Args) == 1 { diff --git a/plan/typeinferer_test.go b/plan/typeinferer_test.go index 0f3fb7a05a3ea..1b88a1e8500d6 100644 --- a/plan/typeinferer_test.go +++ b/plan/typeinferer_test.go @@ -122,6 +122,7 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) { // Functions {"version()", mysql.TypeVarString, charset.CharsetUTF8, 0}, {"count(c_int)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, + {"abs()", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag}, {"abs(1)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, {"abs(1.1)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag}, {"abs(cast(\"20150817015609\" as DATETIME))", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag}, @@ -225,6 +226,7 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) { {"interval(1, 2, 3)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, {"interval(1.0, 2.0, 3.0)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, {"interval('1', '2', '3')", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, + // {"round()", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag}, {"round(null, 2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag}, {"round('1.2', 2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag}, {"round(1e2, 2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag}, @@ -252,6 +254,11 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) { {"unix_timestamp('2015-11-13 10:20:19')", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, {"to_days('2015-11-13')", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, {"to_days(950501)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, + // {"ceiling()", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag}, + {"ceiling(1.23)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, + // {"ceil()", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag}, + {"ceil(1.23)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, + // {"floor()", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag}, {"floor(1.23)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, {"field('foo', null)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, {"find_in_set('foo', 'foo,bar')", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},