Skip to content

Commit

Permalink
expression: remove "self" field in "baseBuiltinFunc" completely (#4766)
Browse files Browse the repository at this point in the history
* expression: remove "self" field in "baseBuiltinFunc" completely

* address comment
  • Loading branch information
zz-jason authored and coocood committed Oct 13, 2017
1 parent 2efa9ba commit 64bc8cb
Show file tree
Hide file tree
Showing 31 changed files with 511 additions and 526 deletions.
56 changes: 3 additions & 53 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package expression

import (
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
Expand All @@ -31,12 +30,9 @@ import (

// baseBuiltinFunc will be contained in every struct that implement builtinFunc interface.
type baseBuiltinFunc struct {
args []Expression
ctx context.Context
tp *types.FieldType
// self points to the built-in function signature which contains this baseBuiltinFunc.
// TODO: self will be removed after all built-in function signatures implement EvalXXX().
self builtinFunc
args []Expression
ctx context.Context
tp *types.FieldType
pbCode tipb.ScalarFuncSig
}

Expand Down Expand Up @@ -155,52 +151,10 @@ func newBaseBuiltinFuncWithTp(ctx context.Context, args []Expression, retType ty
}
}

func (b *baseBuiltinFunc) setSelf(f builtinFunc) builtinFunc {
b.self = f
return f
}

func (b *baseBuiltinFunc) getArgs() []Expression {
return b.args
}

// eval should only be called in test files, and it should be removed after all tests being rewritten.
func (b *baseBuiltinFunc) eval(row []types.Datum) (d types.Datum, err error) {
var (
res interface{}
isNull bool
)
switch b.tp.EvalType() {
case types.ETInt:
var intRes int64
intRes, isNull, err = b.self.evalInt(row)
if mysql.HasUnsignedFlag(b.tp.Flag) {
res = uint64(intRes)
} else {
res = intRes
}
case types.ETReal:
res, isNull, err = b.self.evalReal(row)
case types.ETDecimal:
res, isNull, err = b.self.evalDecimal(row)
case types.ETDatetime, types.ETTimestamp:
res, isNull, err = b.self.evalTime(row)
case types.ETDuration:
res, isNull, err = b.self.evalDuration(row)
case types.ETJson:
res, isNull, err = b.self.evalJSON(row)
case types.ETString:
res, isNull, err = b.self.evalString(row)
}

if isNull || err != nil {
d.SetValue(nil)
return d, errors.Trace(err)
}
d.SetValue(res)
return
}

func (b *baseBuiltinFunc) evalInt(row []types.Datum) (int64, bool, error) {
panic("baseBuiltinFunc.evalInt() should never be called.")
}
Expand Down Expand Up @@ -263,8 +217,6 @@ func (b *baseBuiltinFunc) getCtx() context.Context {

// builtinFunc stands for a particular function signature.
type builtinFunc interface {
// eval evaluates result of builtinFunc by given row.
eval(row []types.Datum) (d types.Datum, err error)
// evalInt evaluates int result of builtinFunc by given row.
evalInt(row []types.Datum) (val int64, isNull bool, err error)
// evalReal evaluates real representation of builtinFunc by given row.
Expand All @@ -287,8 +239,6 @@ type builtinFunc interface {
getCtx() context.Context
// getRetTp returns the return type of the built-in function.
getRetTp() *types.FieldType
// setSelf sets a pointer to itself.
setSelf(builtinFunc) builtinFunc
// setPbCode sets pbCode for signature.
setPbCode(tipb.ScalarFuncSig)
// PbCode returns PbCode of this signature.
Expand Down
34 changes: 17 additions & 17 deletions expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx context.Context, args []Ex
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true)
sig := &builtinArithmeticPlusRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_PlusReal)
return sig.setSelf(sig), nil
return sig, nil
} else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false)
sig := &builtinArithmeticPlusDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_PlusDecimal)
return sig.setSelf(sig), nil
return sig, nil
} else {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
if mysql.HasUnsignedFlag(args[0].GetType().Flag) || mysql.HasUnsignedFlag(args[1].GetType().Flag) {
Expand All @@ -160,7 +160,7 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx context.Context, args []Ex
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
sig := &builtinArithmeticPlusIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_PlusInt)
return sig.setSelf(sig), nil
return sig, nil
}
}

Expand Down Expand Up @@ -269,13 +269,13 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx context.Context, args []E
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true)
sig := &builtinArithmeticMinusRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MinusReal)
return sig.setSelf(sig), nil
return sig, nil
} else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false)
sig := &builtinArithmeticMinusDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MinusDecimal)
return sig.setSelf(sig), nil
return sig, nil
} else {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
Expand All @@ -284,7 +284,7 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx context.Context, args []E
}
sig := &builtinArithmeticMinusIntSig{baseBuiltinFunc: bf}
sig.setPbCode(tipb.ScalarFuncSig_MinusInt)
return sig.setSelf(sig), nil
return sig, nil
}
}

Expand Down Expand Up @@ -390,26 +390,26 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx context.Context, args
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true)
sig := &builtinArithmeticMultiplyRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyReal)
return sig.setSelf(sig), nil
return sig, nil
} else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false)
sig := &builtinArithmeticMultiplyDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyDecimal)
return sig.setSelf(sig), nil
return sig, nil
} else {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
sig := &builtinArithmeticMultiplyIntUnsignedSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyInt)
return sig.setSelf(sig), nil
return sig, nil
}
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
sig := &builtinArithmeticMultiplyIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyInt)
return sig.setSelf(sig), nil
return sig, nil
}
}

Expand Down Expand Up @@ -504,12 +504,12 @@ func (c *arithmeticDivideFunctionClass) getFunction(ctx context.Context, args []
c.setType4DivReal(bf.tp)
sig := &builtinArithmeticDivideRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_DivideReal)
return sig.setSelf(sig), nil
return sig, nil
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
c.setType4DivDecimal(bf.tp, lhsTp, rhsTp)
sig := &builtinArithmeticDivideDecimalSig{bf}
return sig.setSelf(sig), nil
return sig, nil
}

type builtinArithmeticDivideRealSig struct{ baseBuiltinFunc }
Expand Down Expand Up @@ -572,14 +572,14 @@ func (c *arithmeticIntDivideFunctionClass) getFunction(ctx context.Context, args
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticIntDivideIntSig{bf}
return sig.setSelf(sig), nil
return sig, nil
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETDecimal, types.ETDecimal)
if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticIntDivideDecimalSig{bf}
return sig.setSelf(sig), nil
return sig, nil
}

type builtinArithmeticIntDivideIntSig struct{ baseBuiltinFunc }
Expand Down Expand Up @@ -692,22 +692,22 @@ func (c *arithmeticModFunctionClass) getFunction(ctx context.Context, args []Exp
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModRealSig{bf}
return sig.setSelf(sig), nil
return sig, nil
} else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
c.setType4ModRealOrDecimal(bf.tp, lhsTp, rhsTp, true)
if mysql.HasUnsignedFlag(lhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModDecimalSig{bf}
return sig.setSelf(sig), nil
return sig, nil
} else {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
if mysql.HasUnsignedFlag(lhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModIntSig{bf}
return sig.setSelf(sig), nil
return sig, nil
}
}

Expand Down
8 changes: 4 additions & 4 deletions expression/builtin_arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func (s *testEvaluatorSuite) TestArithmeticMultiply(c *C) {
sig, err := funcs[ast.Mul].getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(tc.args...)))
c.Assert(err, IsNil)
c.Assert(sig, NotNil)
val, err := sig.eval(nil)
val, err := evalBuiltinFunc(sig, nil)
c.Assert(err, IsNil)
c.Assert(val, testutil.DatumEquals, types.NewDatum(tc.expect))
}
Expand Down Expand Up @@ -328,7 +328,7 @@ func (s *testEvaluatorSuite) TestArithmeticDivide(c *C) {
sig, err := funcs[ast.Div].getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(tc.args...)))
c.Assert(err, IsNil)
c.Assert(sig, NotNil)
val, err := sig.eval(nil)
val, err := evalBuiltinFunc(sig, nil)
c.Assert(err, IsNil)
c.Assert(val, testutil.DatumEquals, types.NewDatum(tc.expect))
}
Expand Down Expand Up @@ -422,7 +422,7 @@ func (s *testEvaluatorSuite) TestArithmeticIntDivide(c *C) {
sig, err := funcs[ast.IntDiv].getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(tc.args...)))
c.Assert(err, IsNil)
c.Assert(sig, NotNil)
val, err := sig.eval(nil)
val, err := evalBuiltinFunc(sig, nil)
c.Assert(err, IsNil)
c.Assert(val, testutil.DatumEquals, types.NewDatum(tc.expect))
}
Expand Down Expand Up @@ -532,7 +532,7 @@ func (s *testEvaluatorSuite) TestArithmeticMod(c *C) {
sig, err := funcs[ast.Mod].getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(tc.args...)))
c.Assert(err, IsNil)
c.Assert(sig, NotNil)
val, err := sig.eval(nil)
val, err := evalBuiltinFunc(sig, nil)
c.Assert(err, IsNil)
c.Assert(val, testutil.DatumEquals, types.NewDatum(tc.expect))
}
Expand Down
28 changes: 14 additions & 14 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (c *castAsIntFunctionClass) getFunction(ctx context.Context, args []Express
if IsHybridType(args[0]) {
sig = &builtinCastIntAsIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastIntAsInt)
return sig.setSelf(sig), nil
return sig, nil
}
argTp := args[0].GetType().EvalType()
switch argTp {
Expand Down Expand Up @@ -149,7 +149,7 @@ func (c *castAsIntFunctionClass) getFunction(ctx context.Context, args []Express
default:
panic("unsupported types.EvalType in castAsIntFunctionClass")
}
return sig.setSelf(sig), nil
return sig, nil
}

type castAsRealFunctionClass struct {
Expand All @@ -167,7 +167,7 @@ func (c *castAsRealFunctionClass) getFunction(ctx context.Context, args []Expres
if IsHybridType(args[0]) {
sig = &builtinCastRealAsRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastRealAsReal)
return sig.setSelf(sig), nil
return sig, nil
}
argTp := args[0].GetType().EvalType()
switch argTp {
Expand Down Expand Up @@ -195,7 +195,7 @@ func (c *castAsRealFunctionClass) getFunction(ctx context.Context, args []Expres
default:
panic("unsupported types.EvalType in castAsRealFunctionClass")
}
return sig.setSelf(sig), nil
return sig, nil
}

type castAsDecimalFunctionClass struct {
Expand All @@ -213,7 +213,7 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx context.Context, args []Exp
if IsHybridType(args[0]) {
sig = &builtinCastDecimalAsDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastDecimalAsDecimal)
return sig.setSelf(sig), nil
return sig, nil
}
argTp := args[0].GetType().EvalType()
switch argTp {
Expand Down Expand Up @@ -241,7 +241,7 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx context.Context, args []Exp
default:
panic("unsupported types.EvalType in castAsDecimalFunctionClass")
}
return sig.setSelf(sig), nil
return sig, nil
}

type castAsStringFunctionClass struct {
Expand All @@ -259,7 +259,7 @@ func (c *castAsStringFunctionClass) getFunction(ctx context.Context, args []Expr
if IsHybridType(args[0]) {
sig = &builtinCastStringAsStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastStringAsString)
return sig.setSelf(sig), nil
return sig, nil
}
argTp := args[0].GetType().EvalType()
switch argTp {
Expand Down Expand Up @@ -287,7 +287,7 @@ func (c *castAsStringFunctionClass) getFunction(ctx context.Context, args []Expr
default:
panic("unsupported types.EvalType in castAsStringFunctionClass")
}
return sig.setSelf(sig), nil
return sig, nil
}

type castAsTimeFunctionClass struct {
Expand All @@ -305,7 +305,7 @@ func (c *castAsTimeFunctionClass) getFunction(ctx context.Context, args []Expres
if IsHybridType(args[0]) {
sig = &builtinCastTimeAsTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastTimeAsTime)
return sig.setSelf(sig), nil
return sig, nil
}
argTp := args[0].GetType().EvalType()
switch argTp {
Expand Down Expand Up @@ -333,7 +333,7 @@ func (c *castAsTimeFunctionClass) getFunction(ctx context.Context, args []Expres
default:
panic("unsupported types.EvalType in castAsTimeFunctionClass")
}
return sig.setSelf(sig), nil
return sig, nil
}

type castAsDurationFunctionClass struct {
Expand All @@ -351,7 +351,7 @@ func (c *castAsDurationFunctionClass) getFunction(ctx context.Context, args []Ex
if IsHybridType(args[0]) {
sig = &builtinCastDurationAsDurationSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastDurationAsDuration)
return sig.setSelf(sig), nil
return sig, nil
}
argTp := args[0].GetType().EvalType()
switch argTp {
Expand Down Expand Up @@ -379,7 +379,7 @@ func (c *castAsDurationFunctionClass) getFunction(ctx context.Context, args []Ex
default:
panic("unsupported types.EvalType in castAsDurationFunctionClass")
}
return sig.setSelf(sig), nil
return sig, nil
}

type castAsJSONFunctionClass struct {
Expand All @@ -397,7 +397,7 @@ func (c *castAsJSONFunctionClass) getFunction(ctx context.Context, args []Expres
if IsHybridType(args[0]) {
sig = &builtinCastJSONAsJSONSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastJsonAsJson)
return sig.setSelf(sig), nil
return sig, nil
}
argTp := args[0].GetType().EvalType()
switch argTp {
Expand Down Expand Up @@ -425,7 +425,7 @@ func (c *castAsJSONFunctionClass) getFunction(ctx context.Context, args []Expres
default:
panic("unsupported types.EvalType in castAsJSONFunctionClass")
}
return sig.setSelf(sig), nil
return sig, nil
}

type builtinCastIntAsIntSig struct {
Expand Down
Loading

0 comments on commit 64bc8cb

Please sign in to comment.