Skip to content

Commit

Permalink
types: more strict for types.StrictFlags (#47842)
Browse files Browse the repository at this point in the history
close #47829
  • Loading branch information
lcwangchao authored Oct 20, 2023
1 parent f874e3a commit 92749f7
Show file tree
Hide file tree
Showing 19 changed files with 73 additions and 65 deletions.
4 changes: 1 addition & 3 deletions br/pkg/lightning/backend/kv/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,7 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session {
}
}
vars.StmtCtx.SetTimeZone(vars.Location())
vars.StmtCtx.SetTypeFlags(types.StrictFlags.
WithClipNegativeToZero(true),
)
vars.StmtCtx.SetTypeFlags(types.StrictFlags)
if err := vars.SetSystemVar("timestamp", strconv.FormatInt(options.Timestamp, 10)); err != nil {
logger.Warn("new session: failed to set timestamp",
log.ShortError(err))
Expand Down
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/tidb/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ func (enc *tidbEncoder) appendSQL(sb *strings.Builder, datum *types.Datum, _ *ta

case types.KindMysqlBit:
var buffer [20]byte
intValue, err := datum.GetBinaryLiteral().ToInt(types.DefaultNoWarningContext)
intValue, err := datum.GetBinaryLiteral().ToInt(types.DefaultStmtNoWarningContext)
if err != nil {
return err
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/ddl/backfilling_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ func initSessCtx(
sessCtx.GetSessionVars().StmtCtx.IgnoreZeroInDate = !sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode()
sessCtx.GetSessionVars().StmtCtx.NoZeroDate = sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.SetTypeFlags(types.StrictFlags.
WithTruncateAsWarning(!sqlMode.HasStrictMode()).
WithClipNegativeToZero(true),
WithTruncateAsWarning(!sqlMode.HasStrictMode()),
)

// Prevent initializing the mock context in the workers concurrently.
Expand Down
4 changes: 2 additions & 2 deletions pkg/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2163,10 +2163,10 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
WithSkipUTF8Check(vars.SkipUTF8Check).
WithSkipSACIICheck(vars.SkipASCIICheck).
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()).
// WithClipNegativeToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types.
// WithAllowNegativeToUnsigned with false value indicates values less than 0 should be clipped to 0 for unsigned integer types.
// This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode.
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
WithClipNegativeToZero(sc.InInsertStmt || sc.InLoadDataStmt || sc.InUpdateStmt || sc.InCreateOrAlterStmt),
WithAllowNegativeToUnsigned(!sc.InInsertStmt && !sc.InLoadDataStmt && !sc.InUpdateStmt && !sc.InCreateOrAlterStmt),
)

vars.PlanCacheParams.Reset()
Expand Down
4 changes: 1 addition & 3 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,7 @@ var fakeSctx = newFakeSctx()

func newFakeSctx() *stmtctx.StatementContext {
sc := stmtctx.NewStmtCtx()
sc.SetTypeFlags(types.StrictFlags.
WithClipNegativeToZero(true),
)
sc.SetTypeFlags(types.StrictFlags)
return sc
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3028,7 +3028,7 @@ func handleAnalyzeOptionsV2(opts []ast.AnalyzeOpt) (map[ast.AnalyzeOptionType]ui
optMap[opt.Type] = v
case ast.AnalyzeOptSampleRate:
// Only Int/Float/decimal is accepted, so pass nil here is safe.
fVal, err := datumValue.ToFloat64(types.DefaultNoWarningContext)
fVal, err := datumValue.ToFloat64(types.DefaultStmtNoWarningContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -3091,7 +3091,7 @@ func handleAnalyzeOptions(opts []ast.AnalyzeOpt, statsVer int) (map[ast.AnalyzeO
optMap[opt.Type] = v
case ast.AnalyzeOptSampleRate:
// Only Int/Float/decimal is accepted, so pass nil here is safe.
fVal, err := datumValue.ToFloat64(types.DefaultNoWarningContext)
fVal, err := datumValue.ToFloat64(types.DefaultStmtNoWarningContext)
if err != nil {
return nil, err
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,22 +428,22 @@ type StatementContext struct {
// NewStmtCtx creates a new statement context
func NewStmtCtx() *StatementContext {
sc := &StatementContext{}
sc.typeCtx = typectx.NewContext(typectx.StrictFlags, time.UTC, sc.AppendWarning)
sc.typeCtx = typectx.NewContext(typectx.DefaultStmtFlags, time.UTC, sc.AppendWarning)
return sc
}

// NewStmtCtxWithTimeZone creates a new StatementContext with the given timezone
func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext {
intest.Assert(tz)
sc := &StatementContext{}
sc.typeCtx = typectx.NewContext(typectx.StrictFlags, tz, sc.AppendWarning)
sc.typeCtx = typectx.NewContext(typectx.DefaultStmtFlags, tz, sc.AppendWarning)
return sc
}

// Reset resets a statement context
func (sc *StatementContext) Reset() {
*sc = StatementContext{
typeCtx: typectx.NewContext(typectx.StrictFlags, time.UTC, sc.AppendWarning),
typeCtx: typectx.NewContext(typectx.DefaultStmtFlags, time.UTC, sc.AppendWarning),
}
}

Expand Down Expand Up @@ -1213,10 +1213,10 @@ func (sc *StatementContext) InitFromPBFlagAndTz(flags uint64, tz *time.Location)
sc.IgnoreZeroInDate = (flags & model.FlagIgnoreZeroInDate) > 0
sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0
sc.SetTimeZone(tz)
sc.SetTypeFlags(typectx.StrictFlags.
sc.SetTypeFlags(typectx.DefaultStmtFlags.
WithIgnoreTruncateErr((flags & model.FlagIgnoreTruncate) > 0).
WithTruncateAsWarning((flags & model.FlagTruncateAsWarning) > 0).
WithClipNegativeToZero(sc.InInsertStmt),
WithAllowNegativeToUnsigned(!sc.InInsertStmt),
)
}

Expand Down Expand Up @@ -1365,7 +1365,7 @@ func (sc *StatementContext) TypeCtxOrDefault() typectx.Context {
return sc.typeCtx
}

return typectx.DefaultNoWarningContext
return typectx.DefaultStmtNoWarningContext
}

// UsedStatsInfoForTable records stats that are used during query and their information.
Expand Down
20 changes: 10 additions & 10 deletions pkg/sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ func TestStmtHintsClone(t *testing.T) {

func TestNewStmtCtx(t *testing.T) {
sc := stmtctx.NewStmtCtx()
require.Equal(t, types.StrictFlags, sc.TypeFlags())
require.Equal(t, types.DefaultStmtFlags, sc.TypeFlags())
require.Same(t, time.UTC, sc.TimeZone())
require.Same(t, time.UTC, sc.TimeZone())
sc.AppendWarning(errors.New("err1"))
Expand All @@ -327,7 +327,7 @@ func TestNewStmtCtx(t *testing.T) {

tz := time.FixedZone("UTC+1", 2*60*60)
sc = stmtctx.NewStmtCtxWithTimeZone(tz)
require.Equal(t, types.StrictFlags, sc.TypeFlags())
require.Equal(t, types.DefaultStmtFlags, sc.TypeFlags())
require.Same(t, tz, sc.TimeZone())
require.Same(t, tz, sc.TimeZone())
sc.AppendWarning(errors.New("err2"))
Expand All @@ -347,10 +347,10 @@ func TestSetStmtCtxTimeZone(t *testing.T) {

func TestSetStmtCtxTypeFlags(t *testing.T) {
sc := stmtctx.NewStmtCtx()
require.Equal(t, types.StrictFlags, sc.TypeFlags())
require.Equal(t, types.DefaultStmtFlags, sc.TypeFlags())

sc.SetTypeFlags(typectx.FlagClipNegativeToZero | typectx.FlagSkipASCIICheck)
require.Equal(t, typectx.FlagClipNegativeToZero|typectx.FlagSkipASCIICheck, sc.TypeFlags())
sc.SetTypeFlags(typectx.FlagAllowNegativeToUnsigned | typectx.FlagSkipASCIICheck)
require.Equal(t, typectx.FlagAllowNegativeToUnsigned|typectx.FlagSkipASCIICheck, sc.TypeFlags())
require.Equal(t, sc.TypeFlags(), sc.TypeFlags())

sc.SetTypeFlags(typectx.FlagSkipASCIICheck | typectx.FlagSkipUTF8Check | typectx.FlagInvalidDateAsWarning)
Expand All @@ -360,24 +360,24 @@ func TestSetStmtCtxTypeFlags(t *testing.T) {

func TestResetStmtCtx(t *testing.T) {
sc := stmtctx.NewStmtCtx()
require.Equal(t, types.StrictFlags, sc.TypeFlags())
require.Equal(t, types.DefaultStmtFlags, sc.TypeFlags())

tz := time.FixedZone("UTC+1", 2*60*60)
sc.SetTimeZone(tz)
sc.SetTypeFlags(typectx.FlagClipNegativeToZero | typectx.FlagSkipASCIICheck)
sc.SetTypeFlags(typectx.FlagAllowNegativeToUnsigned | typectx.FlagSkipASCIICheck)
sc.AppendWarning(errors.New("err1"))
sc.InRestrictedSQL = true
sc.StmtType = "Insert"

require.Same(t, tz, sc.TimeZone())
require.Equal(t, typectx.FlagClipNegativeToZero|typectx.FlagSkipASCIICheck, sc.TypeFlags())
require.Equal(t, typectx.FlagAllowNegativeToUnsigned|typectx.FlagSkipASCIICheck, sc.TypeFlags())
require.Equal(t, 1, len(sc.GetWarnings()))

sc.Reset()
require.Same(t, time.UTC, sc.TimeZone())
require.Same(t, time.UTC, sc.TimeZone())
require.Equal(t, types.StrictFlags, sc.TypeFlags())
require.Equal(t, types.StrictFlags, sc.TypeFlags())
require.Equal(t, types.DefaultStmtFlags, sc.TypeFlags())
require.Equal(t, types.DefaultStmtFlags, sc.TypeFlags())
require.False(t, sc.InRestrictedSQL)
require.Empty(t, sc.StmtType)
require.Equal(t, 0, len(sc.GetWarnings()))
Expand Down
2 changes: 1 addition & 1 deletion pkg/table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd
}

// Clip to zero if get negative value after cast to unsigned.
if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && !sctx.GetSessionVars().StmtCtx.TypeFlags().ClipNegativeToZero() {
if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && sctx.GetSessionVars().StmtCtx.TypeFlags().AllowNegativeToUnsigned() {
switch datum.Kind() {
case types.KindInt64:
if datum.GetInt64() < 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/binary_literal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func TestBinaryLiteral(t *testing.T) {
{"0x1010ffff8080ff12", 0x1010ffff8080ff12, false},
{"0x1010ffff8080ff12ff", 0xffffffffffffffff, true},
}
ctx := DefaultNoWarningContext
ctx := DefaultStmtNoWarningContext
for _, item := range tbl {
hex, err := ParseHexStr(item.Input)
require.NoError(t, err)
Expand Down
9 changes: 7 additions & 2 deletions pkg/types/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,10 @@ const StrictFlags = context.StrictFlags
// NewContext creates a new `Context`
var NewContext = context.NewContext

// DefaultNoWarningContext is an alias of `DefaultNoWarningContext`
var DefaultNoWarningContext = context.DefaultNoWarningContext
// DefaultStmtFlags is the default flags for statement context with the flag `FlagAllowNegativeToUnsigned` set.
// TODO: make DefaultStmtFlags to be equal with StrictFlags, and setting flag `FlagAllowNegativeToUnsigned`
// is only for make the code to be equivalent with the old implement during refactoring.
const DefaultStmtFlags = context.DefaultStmtFlags

// DefaultStmtNoWarningContext is an alias of `DefaultStmtNoWarningContext`
var DefaultStmtNoWarningContext = context.DefaultStmtNoWarningContext
36 changes: 22 additions & 14 deletions pkg/types/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
// StrictFlags is a flags with a fields unset and has the most strict behavior.
const StrictFlags Flags = 0

// Flags indicates how to handle the conversion of a value.
// Flags indicate how to handle the conversion of a value.
type Flags uint16

const (
Expand All @@ -32,10 +32,13 @@ const (
FlagIgnoreTruncateErr Flags = 1 << iota
// FlagTruncateAsWarning indicates to append the truncate error to warnings instead of returning it to user.
FlagTruncateAsWarning
// FlagClipNegativeToZero indicates to clip the value to zero when casting a negative value to an unsigned integer.
// When this flag is set and the clip happens, an overflow error occurs and how to handle it will be determined by flags
// `FlagIgnoreOverflowError` and `FlagOverflowAsWarning`.
FlagClipNegativeToZero
// FlagAllowNegativeToUnsigned indicates to allow the casting from negative to unsigned int.
// When this flag is not set by default, casting a negative value to unsigned results an overflow error.
// The overflow will also be controlled by `FlagIgnoreOverflowError` and `FlagOverflowAsWarning`. When any of them is set,
// a zero value is returned instead.
// Whe this flag is set, casting a negative value to unsigned will be allowed. And the negative value will be cast to
// a positive value by adding the max value of the unsigned type.
FlagAllowNegativeToUnsigned
// FlagIgnoreOverflowError indicates to ignore the overflow error.
// If this flag is set, `FlagOverflowAsWarning` will be ignored.
FlagIgnoreOverflowError
Expand Down Expand Up @@ -65,17 +68,17 @@ const (
FlagSkipUTF8MB4Check
)

// ClipNegativeToZero indicates whether the flag `FlagClipNegativeToZero` is set
func (f Flags) ClipNegativeToZero() bool {
return f&FlagClipNegativeToZero != 0
// AllowNegativeToUnsigned indicates whether the flag `FlagAllowNegativeToUnsigned` is set
func (f Flags) AllowNegativeToUnsigned() bool {
return f&FlagAllowNegativeToUnsigned != 0
}

// WithClipNegativeToZero returns a new flags with `FlagClipNegativeToZero` set/unset according to the clip parameter
func (f Flags) WithClipNegativeToZero(clip bool) Flags {
// WithAllowNegativeToUnsigned returns a new flags with `FlagAllowNegativeToUnsigned` set/unset according to the clip parameter
func (f Flags) WithAllowNegativeToUnsigned(clip bool) Flags {
if clip {
return f | FlagClipNegativeToZero
return f | FlagAllowNegativeToUnsigned
}
return f &^ FlagClipNegativeToZero
return f &^ FlagAllowNegativeToUnsigned
}

// SkipASCIICheck indicates whether the flag `FlagSkipASCIICheck` is set
Expand Down Expand Up @@ -204,7 +207,12 @@ func (c *Context) AppendWarningFunc() func(err error) {
return c.appendWarningFn
}

// DefaultNoWarningContext is the context without any special configuration
var DefaultNoWarningContext = NewContext(StrictFlags, time.UTC, func(_ error) {
// DefaultStmtFlags is the default flags for statement context with the flag `FlagAllowNegativeToUnsigned` set.
// TODO: make DefaultStmtFlags to be equal with StrictFlags, and setting flag `FlagAllowNegativeToUnsigned`
// is only for make the code to be equivalent with the old implement during refactoring.
const DefaultStmtFlags = StrictFlags | FlagAllowNegativeToUnsigned

// DefaultStmtNoWarningContext is the context with default statement flags without any other special configuration
var DefaultStmtNoWarningContext = NewContext(DefaultStmtFlags, time.UTC, func(_ error) {
// the error is ignored
})
8 changes: 4 additions & 4 deletions pkg/types/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ func TestSimpleOnOffFlags(t *testing.T) {
writeFn func(Flags, bool) Flags
}{
{
name: "FlagClipNegativeToZero",
flag: FlagClipNegativeToZero,
name: "FlagAllowNegativeToUnsigned",
flag: FlagAllowNegativeToUnsigned,
readFn: func(f Flags) bool {
return f.ClipNegativeToZero()
return f.AllowNegativeToUnsigned()
},
writeFn: func(f Flags, clip bool) Flags {
return f.WithClipNegativeToZero(clip)
return f.WithAllowNegativeToUnsigned(clip)
},
},
{
Expand Down
4 changes: 2 additions & 2 deletions pkg/types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) {

// ConvertIntToUint converts an int value to an uint value.
func ConvertIntToUint(flags Flags, val int64, upperBound uint64, tp byte) (uint64, error) {
if val < 0 && flags.ClipNegativeToZero() {
if val < 0 && !flags.AllowNegativeToUnsigned() {
return 0, overflow(val, tp)
}

Expand All @@ -170,7 +170,7 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) {
func ConvertFloatToUint(flags Flags, fval float64, upperBound uint64, tp byte) (uint64, error) {
val := RoundFloat(fval)
if val < 0 {
if flags.ClipNegativeToZero() {
if !flags.AllowNegativeToUnsigned() {
return 0, overflow(val, tp)
}
return uint64(int64(val)), overflow(val, tp)
Expand Down
14 changes: 7 additions & 7 deletions pkg/types/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ func TestConvertToBinaryString(t *testing.T) {
}

func testStrToInt(t *testing.T, str string, expect int64, truncateAsErr bool, expectErr error) {
ctx := DefaultNoWarningContext.WithFlags(StrictFlags.WithIgnoreTruncateErr(!truncateAsErr))
ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(!truncateAsErr))
val, err := StrToInt(ctx, str, false)
if expectErr != nil {
require.Truef(t, terror.ErrorEqual(err, expectErr), "err %v", err)
Expand All @@ -483,7 +483,7 @@ func testStrToInt(t *testing.T, str string, expect int64, truncateAsErr bool, ex
}

func testStrToUint(t *testing.T, str string, expect uint64, truncateAsErr bool, expectErr error) {
ctx := DefaultNoWarningContext.WithFlags(StrictFlags.WithIgnoreTruncateErr(!truncateAsErr))
ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(!truncateAsErr))
val, err := StrToUint(ctx, str, false)
if expectErr != nil {
require.Truef(t, terror.ErrorEqual(err, expectErr), "err %v", err)
Expand All @@ -494,7 +494,7 @@ func testStrToUint(t *testing.T, str string, expect uint64, truncateAsErr bool,
}

func testStrToFloat(t *testing.T, str string, expect float64, truncateAsErr bool, expectErr error) {
ctx := DefaultNoWarningContext.WithFlags(StrictFlags.WithIgnoreTruncateErr(!truncateAsErr))
ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(!truncateAsErr))
val, err := StrToFloat(ctx, str, false)
if expectErr != nil {
require.Truef(t, terror.ErrorEqual(err, expectErr), "err %v", err)
Expand Down Expand Up @@ -927,7 +927,7 @@ func TestGetValidInt(t *testing.T) {
{"123e+", "123", true},
{"123de", "123", true},
}
sc.SetTypeFlags(StrictFlags)
sc.SetTypeFlags(DefaultStmtFlags)
sc.InSelectStmt = false
for _, tt := range tests2 {
prefix, err := getValidIntPrefix(sc.TypeCtxOrDefault(), tt.origin, false)
Expand Down Expand Up @@ -963,7 +963,7 @@ func TestGetValidFloat(t *testing.T) {
{"9-3", "9"},
{"1001001\\u0000\\u0000\\u0000", "1001001"},
}
ctx := DefaultNoWarningContext
ctx := DefaultStmtNoWarningContext
for _, tt := range tests {
prefix, _ := getValidFloatPrefix(ctx, tt.origin, false)
require.Equal(t, tt.valid, prefix)
Expand Down Expand Up @@ -1114,7 +1114,7 @@ func TestConvertJSONToFloat(t *testing.T) {
{in: "123.456hello", out: 123.456, ty: JSONTypeCodeString, err: true},
{in: "1234", out: 1234, ty: JSONTypeCodeString},
}
ctx := DefaultNoWarningContext
ctx := DefaultStmtNoWarningContext
for _, tt := range tests {
j := CreateBinaryJSON(tt.in)
require.Equal(t, tt.ty, j.TypeCode)
Expand Down Expand Up @@ -1143,7 +1143,7 @@ func TestConvertJSONToDecimal(t *testing.T) {
{in: `false`, out: NewDecFromStringForTest("0")},
{in: `null`, out: NewDecFromStringForTest("0"), err: true},
}
ctx := DefaultNoWarningContext
ctx := DefaultStmtNoWarningContext
for _, tt := range tests {
j, err := ParseBinaryJSONFromString(tt.in)
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ func (d *Datum) SetValue(val interface{}, tp *types.FieldType) {
// Compare compares datum to another datum.
// Notes: don't rely on datum.collation to get the collator, it's tend to buggy.
func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collate.Collator) (int, error) {
typeCtx := DefaultNoWarningContext
typeCtx := DefaultStmtNoWarningContext
if sc != nil {
typeCtx = sc.TypeCtx()
}
Expand Down
Loading

0 comments on commit 92749f7

Please sign in to comment.