Skip to content

Commit

Permalink
expression: check max_allowed_packet constraint for function insert (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
eurekaka authored and shenli committed Aug 29, 2018
1 parent eef448a commit 9d09f1d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
34 changes: 29 additions & 5 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -3145,21 +3145,30 @@ func (c *insertFunctionClass) getFunction(ctx sessionctx.Context, args []Express
bf.tp.Flen = mysql.MaxBlobWidth
SetBinFlagOrBinStr(args[0].GetType(), bf.tp)
SetBinFlagOrBinStr(args[3].GetType(), bf.tp)

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) {
sig = &builtinInsertBinarySig{bf}
sig = &builtinInsertBinarySig{bf, maxAllowedPacket}
} else {
sig = &builtinInsertSig{bf}
sig = &builtinInsertSig{bf, maxAllowedPacket}
}
return sig, nil
}

type builtinInsertBinarySig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinInsertBinarySig) Clone() builtinFunc {
newSig := &builtinInsertBinarySig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand Down Expand Up @@ -3191,18 +3200,26 @@ func (b *builtinInsertBinarySig) evalString(row types.Row) (string, bool, error)
}

if length > strLength-pos+1 || length < 0 {
return str[0:pos-1] + newstr, false, nil
length = strLength - pos + 1
}

if uint64(strLength-length+int64(len(newstr))) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("insert", b.maxAllowedPacket))
return "", true, nil
}

return str[0:pos-1] + newstr + str[pos+length-1:], false, nil
}

type builtinInsertSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinInsertSig) Clone() builtinFunc {
newSig := &builtinInsertSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand Down Expand Up @@ -3235,9 +3252,16 @@ func (b *builtinInsertSig) evalString(row types.Row) (string, bool, error) {
}

if length > runeLength-pos+1 || length < 0 {
return string(runes[0:pos-1]) + newstr, false, nil
length = runeLength - pos + 1
}

strHead := string(runes[0 : pos-1])
strTail := string(runes[pos+length-1:])
if uint64(len(strHead)+len(newstr)+len(strTail)) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("insert", b.maxAllowedPacket))
return "", true, nil
}
return string(runes[0:pos-1]) + newstr + string(runes[pos+length-1:]), false, nil
return strHead + newstr + strTail, false, nil
}

type instrFunctionClass struct {
Expand Down
45 changes: 45 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,51 @@ func (s *testEvaluatorSuite) TestRpadSig(c *C) {
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}

func (s *testEvaluatorSuite) TestInsertBinarySig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeLonglong},
{Tp: mysql.TypeLonglong},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 3}

args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
&Column{Index: 2, RetType: colTypes[2]},
&Column{Index: 3, RetType: colTypes[3]},
}

base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
insert := &builtinInsertBinarySig{base, 3}

input := chunk.NewChunkWithCapacity(colTypes, 2)
input.AppendString(0, "abc")
input.AppendString(0, "abc")
input.AppendInt64(1, 3)
input.AppendInt64(1, 3)
input.AppendInt64(2, -1)
input.AppendInt64(2, -1)
input.AppendString(3, "d")
input.AppendString(3, "de")

res, isNull, err := insert.evalString(input.GetRow(0))
c.Assert(res, Equals, "abd")
c.Assert(isNull, IsFalse)
c.Assert(err, IsNil)

res, isNull, err = insert.evalString(input.GetRow(1))
c.Assert(res, Equals, "")
c.Assert(isNull, IsTrue)
c.Assert(err, IsNil)

warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, 1)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err))
}

func (s *testEvaluatorSuite) TestInstr(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
Expand Down

0 comments on commit 9d09f1d

Please sign in to comment.