Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parse: fix the type of date/time parameters #48237

Merged
merged 3 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pkg/param/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "param",
srcs = ["binary_params.go"],
importpath = "github.com/pingcap/tidb/pkg/param",
visibility = ["//visibility:public"],
deps = [
"//pkg/errno",
"//pkg/expression",
"//pkg/parser/mysql",
"//pkg/types",
"//pkg/util/dbterror",
"//pkg/util/hack",
],
)
275 changes: 275 additions & 0 deletions pkg/param/binary_params.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
// Copyright 2023 PingCAP, Inc.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of this file is copied from /pkg/internal/parse. Because it's not expected to be "internal" inside the server.

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package param

import (
"encoding/binary"
"fmt"
"math"

"github.com/pingcap/tidb/pkg/errno"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/dbterror"
"github.com/pingcap/tidb/pkg/util/hack"
)

var errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType)

// BinaryParam stores the information decoded from the binary protocol
// It can be further parsed into `expression.Expression` through the `ExecArgs` function in this package
type BinaryParam struct {
Tp byte
IsUnsigned bool
IsNull bool
Val []byte
}

// ExecArgs parse execute arguments to datum slice.
func ExecArgs(typectx types.Context, binaryParams []BinaryParam) (params []expression.Expression, err error) {
var (
tmp interface{}
)

params = make([]expression.Expression, len(binaryParams))
args := make([]types.Datum, len(binaryParams))
for i := 0; i < len(args); i++ {
tp := binaryParams[i].Tp
isUnsigned := binaryParams[i].IsUnsigned

switch tp {
case mysql.TypeNull:
var nilDatum types.Datum
nilDatum.SetNull()
args[i] = nilDatum
continue

case mysql.TypeTiny:
if isUnsigned {
args[i] = types.NewUintDatum(uint64(binaryParams[i].Val[0]))
} else {
args[i] = types.NewIntDatum(int64(int8(binaryParams[i].Val[0])))
}
continue

case mysql.TypeShort, mysql.TypeYear:
valU16 := binary.LittleEndian.Uint16(binaryParams[i].Val)
if isUnsigned {
args[i] = types.NewUintDatum(uint64(valU16))
} else {
args[i] = types.NewIntDatum(int64(int16(valU16)))
}
continue

case mysql.TypeInt24, mysql.TypeLong:
valU32 := binary.LittleEndian.Uint32(binaryParams[i].Val)
if isUnsigned {
args[i] = types.NewUintDatum(uint64(valU32))
} else {
args[i] = types.NewIntDatum(int64(int32(valU32)))
}
continue

case mysql.TypeLonglong:
valU64 := binary.LittleEndian.Uint64(binaryParams[i].Val)
if isUnsigned {
args[i] = types.NewUintDatum(valU64)
} else {
args[i] = types.NewIntDatum(int64(valU64))
}
continue

case mysql.TypeFloat:
args[i] = types.NewFloat32Datum(math.Float32frombits(binary.LittleEndian.Uint32(binaryParams[i].Val)))
continue

case mysql.TypeDouble:
args[i] = types.NewFloat64Datum(math.Float64frombits(binary.LittleEndian.Uint64(binaryParams[i].Val)))
continue

case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime:
switch len(binaryParams[i].Val) {
case 0:
tmp = types.ZeroDatetimeStr
case 4:
_, tmp = binaryDate(0, binaryParams[i].Val)
case 7:
_, tmp = binaryDateTime(0, binaryParams[i].Val)
case 11:
_, tmp = binaryTimestamp(0, binaryParams[i].Val)
case 13:
_, tmp = binaryTimestampWithTZ(0, binaryParams[i].Val)
default:
err = mysql.ErrMalformPacket
return
}
// TODO: generate the time datum directly
var parseTime func(types.Context, string) (types.Time, error)
switch tp {
case mysql.TypeDate:
parseTime = types.ParseDate
case mysql.TypeDatetime:
parseTime = types.ParseDatetime
case mysql.TypeTimestamp:
// To be compatible with MySQL, even the type of parameter is
// TypeTimestamp, the return type should also be `Datetime`.
parseTime = types.ParseDatetime
}
var time types.Time
time, err = parseTime(typectx, tmp.(string))
err = typectx.HandleTruncate(err)
if err != nil {
return
}
args[i] = types.NewDatum(time)
continue

case mysql.TypeDuration:
switch len(binaryParams[i].Val) {
case 0:
tmp = "0"
case 8:
isNegative := binaryParams[i].Val[0]
if isNegative > 1 {
err = mysql.ErrMalformPacket
return
}
_, tmp = binaryDuration(1, binaryParams[i].Val, isNegative)
case 12:
isNegative := binaryParams[i].Val[0]
if isNegative > 1 {
err = mysql.ErrMalformPacket
return
}
_, tmp = binaryDurationWithMS(1, binaryParams[i].Val, isNegative)
default:
err = mysql.ErrMalformPacket
return
}
// TODO: generate the duration datum directly
var dur types.Duration
dur, _, err = types.ParseDuration(typectx, tmp.(string), types.MaxFsp)
err = typectx.HandleTruncate(err)
if err != nil {
return
}
args[i] = types.NewDatum(dur)
continue
case mysql.TypeNewDecimal:
if binaryParams[i].IsNull {
args[i] = types.NewDecimalDatum(nil)
} else {
var dec types.MyDecimal
err = typectx.HandleTruncate(dec.FromString(binaryParams[i].Val))
if err != nil {
return nil, err
}
args[i] = types.NewDecimalDatum(&dec)
}
continue
case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
if binaryParams[i].IsNull {
args[i] = types.NewBytesDatum(nil)
} else {
args[i] = types.NewBytesDatum(binaryParams[i].Val)
}
continue
case mysql.TypeUnspecified, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString,
mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeBit:
if !binaryParams[i].IsNull {
tmp = string(hack.String(binaryParams[i].Val))
} else {
tmp = nil
}
args[i] = types.NewDatum(tmp)
continue
default:
err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp)
return
}
}

for i := range params {
ft := new(types.FieldType)
types.InferParamTypeFromUnderlyingValue(args[i].GetValue(), ft)
params[i] = &expression.Constant{Value: args[i], RetType: ft}
}
return
}

func binaryDate(pos int, paramValues []byte) (int, string) {
year := binary.LittleEndian.Uint16(paramValues[pos : pos+2])
pos += 2
month := paramValues[pos]
pos++
day := paramValues[pos]
pos++
return pos, fmt.Sprintf("%04d-%02d-%02d", year, month, day)
}

func binaryDateTime(pos int, paramValues []byte) (int, string) {
pos, date := binaryDate(pos, paramValues)
hour := paramValues[pos]
pos++
minute := paramValues[pos]
pos++
second := paramValues[pos]
pos++
return pos, fmt.Sprintf("%s %02d:%02d:%02d", date, hour, minute, second)
}

func binaryTimestamp(pos int, paramValues []byte) (int, string) {
pos, dateTime := binaryDateTime(pos, paramValues)
microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
pos += 4
return pos, fmt.Sprintf("%s.%06d", dateTime, microSecond)
}

func binaryTimestampWithTZ(pos int, paramValues []byte) (int, string) {
pos, timestamp := binaryTimestamp(pos, paramValues)
tzShiftInMin := int16(binary.LittleEndian.Uint16(paramValues[pos : pos+2]))
tzShiftHour := tzShiftInMin / 60
tzShiftAbsMin := tzShiftInMin % 60
if tzShiftAbsMin < 0 {
tzShiftAbsMin = -tzShiftAbsMin
}
pos += 2
return pos, fmt.Sprintf("%s%+02d:%02d", timestamp, tzShiftHour, tzShiftAbsMin)
}

func binaryDuration(pos int, paramValues []byte, isNegative uint8) (int, string) {
sign := ""
if isNegative == 1 {
sign = "-"
}
days := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
pos += 4
hours := paramValues[pos]
pos++
minutes := paramValues[pos]
pos++
seconds := paramValues[pos]
pos++
return pos, fmt.Sprintf("%s%d %02d:%02d:%02d", sign, days, hours, minutes, seconds)
}

func binaryDurationWithMS(pos int, paramValues []byte,
isNegative uint8) (int, string) {
pos, dur := binaryDuration(pos, paramValues, isNegative)
microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
pos += 4
return pos, fmt.Sprintf("%s.%06d", dur, microSecond)
}
9 changes: 8 additions & 1 deletion pkg/server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
srcs = [
"conn.go",
"conn_stmt.go",
"conn_stmt_params.go",
"driver.go",
"driver_tidb.go",
"extension.go",
Expand Down Expand Up @@ -32,6 +33,7 @@ go_library(
"//pkg/infoschema",
"//pkg/kv",
"//pkg/metrics",
"//pkg/param",
"//pkg/parser",
"//pkg/parser/ast",
"//pkg/parser/auth",
Expand Down Expand Up @@ -76,6 +78,7 @@ go_library(
"//pkg/util/arena",
"//pkg/util/chunk",
"//pkg/util/cpuprofile",
"//pkg/util/dbterror",
"//pkg/util/dbterror/exeerrors",
"//pkg/util/execdetails",
"//pkg/util/fastrand",
Expand Down Expand Up @@ -125,6 +128,7 @@ go_test(
name = "server_test",
timeout = "short",
srcs = [
"conn_stmt_params_test.go",
"conn_stmt_test.go",
"conn_test.go",
"driver_tidb_test.go",
Expand All @@ -138,20 +142,23 @@ go_test(
data = glob(["testdata/**"]),
embed = [":server"],
flaky = True,
shard_count = 48,
shard_count = 50,
deps = [
"//pkg/config",
"//pkg/domain",
"//pkg/domain/infosync",
"//pkg/expression",
"//pkg/extension",
"//pkg/keyspace",
"//pkg/kv",
"//pkg/metrics",
"//pkg/param",
"//pkg/parser/ast",
"//pkg/parser/auth",
"//pkg/parser/charset",
"//pkg/parser/model",
"//pkg/parser/mysql",
"//pkg/parser/terror",
"//pkg/server/internal",
"//pkg/server/internal/column",
"//pkg/server/internal/handshake",
Expand Down
10 changes: 5 additions & 5 deletions pkg/server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/param"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/charset"
Expand Down Expand Up @@ -180,7 +180,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
)
cc.initInputEncoder(ctx)
numParams := stmt.NumParams()
args := make([]expression.Expression, numParams)
args := make([]param.BinaryParam, numParams)
if numParams > 0 {
nullBitmapLen := (numParams + 7) >> 3
if len(data) < (pos + nullBitmapLen + 1) {
Expand All @@ -206,7 +206,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
paramValues = data[pos+1:]
}

err = parse.ExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder)
err = parseBinaryParams(args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder)
// This `.Reset` resets the arguments, so it's fine to just ignore the error (and the it'll be reset again in the following routine)
errReset := stmt.Reset()
if errReset != nil {
Expand All @@ -227,7 +227,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
return err
}

func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{}, args []expression.Expression, useCursor bool) (err error) {
func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{}, args []param.BinaryParam, useCursor bool) (err error) {
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{})
retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor)
Expand Down Expand Up @@ -262,7 +262,7 @@ func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{}

// The first return value indicates whether the call of executePreparedStmtAndWriteResult has no side effect and can be retried.
// Currently the first return value is used to fallback to TiKV when TiFlash is down.
func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []expression.Expression, useCursor bool) (bool, error) {
func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []param.BinaryParam, useCursor bool) (bool, error) {
vars := (&cc.ctx).GetSessionVars()
prepStmt, err := vars.GetPreparedStmtByID(uint32(stmt.ID()))
if err != nil {
Expand Down
Loading