From e7efd4fc425273ec781bb5e3b21f6ce21bdb7893 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 8 Mar 2024 13:47:06 +0000 Subject: [PATCH 01/15] build(deps): bump golang.org/x/net from 0.21.0 to 0.22.0 (#51609) --- DEPS.bzl | 36 ++++++++++++++++++------------------ go.mod | 6 +++--- go.sum | 12 ++++++------ 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/DEPS.bzl b/DEPS.bzl index 65f46ebd28914..c205fe157b3a4 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -10152,13 +10152,13 @@ def go_deps(): name = "org_golang_x_crypto", build_file_proto_mode = "disable_global", importpath = "golang.org/x/crypto", - sha256 = "760c835d533e083f3455b6e95d490cf1aba53da2d6cfadb145f02db6c353418d", - strip_prefix = "golang.org/x/crypto@v0.20.0", + sha256 = "689d6b9313d406e061863b9b84eb43b02b7fbe081a49bb25097bfb192f1b90e0", + strip_prefix = "golang.org/x/crypto@v0.21.0", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.20.0.zip", - "http://ats.apps.svc/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.20.0.zip", - "https://cache.hawkingrei.com/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.20.0.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.20.0.zip", + "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.21.0.zip", + "http://ats.apps.svc/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.21.0.zip", + "https://cache.hawkingrei.com/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.21.0.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.21.0.zip", ], ) go_repository( @@ -10243,13 +10243,13 @@ def go_deps(): name = "org_golang_x_net", build_file_proto_mode = "disable_global", importpath = "golang.org/x/net", - sha256 = "4e9cb4bded1957e73fe709741c29879eab05047617c9b14b7237314ff9024913", - strip_prefix = "golang.org/x/net@v0.21.0", + sha256 = "2f624e504f4cd569e907a9449d349f1c4e3652623fb9e352e81d2155ecc2c133", + strip_prefix = "golang.org/x/net@v0.22.0", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/net/org_golang_x_net-v0.21.0.zip", - "http://ats.apps.svc/gomod/golang.org/x/net/org_golang_x_net-v0.21.0.zip", - "https://cache.hawkingrei.com/gomod/golang.org/x/net/org_golang_x_net-v0.21.0.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/net/org_golang_x_net-v0.21.0.zip", + "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/net/org_golang_x_net-v0.22.0.zip", + "http://ats.apps.svc/gomod/golang.org/x/net/org_golang_x_net-v0.22.0.zip", + "https://cache.hawkingrei.com/gomod/golang.org/x/net/org_golang_x_net-v0.22.0.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/net/org_golang_x_net-v0.22.0.zip", ], ) go_repository( @@ -10308,13 +10308,13 @@ def go_deps(): name = "org_golang_x_term", build_file_proto_mode = "disable_global", importpath = "golang.org/x/term", - sha256 = "a38f40301a9ca1154edc70dcbfc6dd2a2ce55abbd49dad8031fb15c1a5e62459", - strip_prefix = "golang.org/x/term@v0.17.0", + sha256 = "60652f7dd2fa4185c62867bcaa3fa56e59b07f5b71083d8f72ab882d251355a6", + strip_prefix = "golang.org/x/term@v0.18.0", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/term/org_golang_x_term-v0.17.0.zip", - "http://ats.apps.svc/gomod/golang.org/x/term/org_golang_x_term-v0.17.0.zip", - "https://cache.hawkingrei.com/gomod/golang.org/x/term/org_golang_x_term-v0.17.0.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/term/org_golang_x_term-v0.17.0.zip", + "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/term/org_golang_x_term-v0.18.0.zip", + "http://ats.apps.svc/gomod/golang.org/x/term/org_golang_x_term-v0.18.0.zip", + "https://cache.hawkingrei.com/gomod/golang.org/x/term/org_golang_x_term-v0.18.0.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/term/org_golang_x_term-v0.18.0.zip", ], ) go_repository( diff --git a/go.mod b/go.mod index 46f7c4378be56..f5e349ac4315a 100644 --- a/go.mod +++ b/go.mod @@ -129,11 +129,11 @@ require ( go.uber.org/multierr v1.11.0 go.uber.org/zap v1.26.0 golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 - golang.org/x/net v0.21.0 + golang.org/x/net v0.22.0 golang.org/x/oauth2 v0.17.0 golang.org/x/sync v0.6.0 golang.org/x/sys v0.18.0 - golang.org/x/term v0.17.0 + golang.org/x/term v0.18.0 golang.org/x/text v0.14.0 golang.org/x/time v0.5.0 golang.org/x/tools v0.18.0 @@ -294,7 +294,7 @@ require ( go.opentelemetry.io/otel/sdk v1.22.0 // indirect go.opentelemetry.io/otel/trace v1.22.0 // indirect go.opentelemetry.io/proto/otlp v1.1.0 // indirect - golang.org/x/crypto v0.20.0 // indirect + golang.org/x/crypto v0.21.0 // indirect golang.org/x/exp/typeparams v0.0.0-20231219180239-dc181d75b848 // indirect golang.org/x/mod v0.15.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect diff --git a/go.sum b/go.sum index 1760519f56cf8..0216469470779 100644 --- a/go.sum +++ b/go.sum @@ -1019,8 +1019,8 @@ golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/crypto v0.20.0 h1:jmAMJJZXr5KiCw05dfYK9QnqaqKLYXijU23lsEdcQqg= -golang.org/x/crypto v0.20.0/go.mod h1:Xwo95rrVNIoSMx9wa1JroENMToLWn3RNVrTBpLHgZPQ= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1123,8 +1123,8 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1233,8 +1233,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= From 7e51c7732c569e187ba6d0e4c3bd1c46320026fc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:26:07 +0000 Subject: [PATCH 02/15] build(deps): bump golang.org/x/crypto from 0.20.0 to 0.21.0 (#51610) From f2ae6987cafdc46c45ea8f1a03b4e4618b5943ef Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 10 Mar 2024 10:33:38 +0000 Subject: [PATCH 03/15] build(deps): bump go.uber.org/zap from 1.26.0 to 1.27.0 (#51608) --- DEPS.bzl | 12 ++++++------ go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/DEPS.bzl b/DEPS.bzl index c205fe157b3a4..3bd29d0970c3b 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -10581,13 +10581,13 @@ def go_deps(): name = "org_uber_go_zap", build_file_proto_mode = "disable_global", importpath = "go.uber.org/zap", - sha256 = "70582d5e7a6da19b70bb9f42fdb1fff86215002796e06d833d525b908a346b42", - strip_prefix = "go.uber.org/zap@v1.26.0", + sha256 = "b994b96ff0bb504a3d58288ab88b9f3c6604689ea1afb69d25b509769705a6c2", + strip_prefix = "go.uber.org/zap@v1.27.0", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/go.uber.org/zap/org_uber_go_zap-v1.26.0.zip", - "http://ats.apps.svc/gomod/go.uber.org/zap/org_uber_go_zap-v1.26.0.zip", - "https://cache.hawkingrei.com/gomod/go.uber.org/zap/org_uber_go_zap-v1.26.0.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/go.uber.org/zap/org_uber_go_zap-v1.26.0.zip", + "http://bazel-cache.pingcap.net:8080/gomod/go.uber.org/zap/org_uber_go_zap-v1.27.0.zip", + "http://ats.apps.svc/gomod/go.uber.org/zap/org_uber_go_zap-v1.27.0.zip", + "https://cache.hawkingrei.com/gomod/go.uber.org/zap/org_uber_go_zap-v1.27.0.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/go.uber.org/zap/org_uber_go_zap-v1.27.0.zip", ], ) go_repository( diff --git a/go.mod b/go.mod index f5e349ac4315a..5a3d97b31d0a3 100644 --- a/go.mod +++ b/go.mod @@ -127,7 +127,7 @@ require ( go.uber.org/goleak v1.3.0 go.uber.org/mock v0.4.0 go.uber.org/multierr v1.11.0 - go.uber.org/zap v1.26.0 + go.uber.org/zap v1.27.0 golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 golang.org/x/net v0.22.0 golang.org/x/oauth2 v0.17.0 diff --git a/go.sum b/go.sum index 0216469470779..8baa969875812 100644 --- a/go.sum +++ b/go.sum @@ -1001,8 +1001,8 @@ go.uber.org/zap v1.12.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= -go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= -go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= From 5f3fc33bf831d896a05118f70b2104aebd06f13f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Mon, 11 Mar 2024 10:50:08 +0800 Subject: [PATCH 04/15] expression: introduce optional properties for `EvalContext` (#51487) close pingcap/tidb#51477 --- pkg/expression/BUILD.bazel | 1 + pkg/expression/builtin.go | 6 + pkg/expression/builtin_info.go | 60 +++++--- pkg/expression/builtin_info_vec.go | 30 ++-- pkg/expression/context.go | 12 ++ pkg/expression/context/BUILD.bazel | 17 ++- pkg/expression/context/context.go | 2 + pkg/expression/context/optional.go | 134 ++++++++++++++++++ pkg/expression/context/optional_test.go | 76 ++++++++++ pkg/expression/contextimpl/BUILD.bazel | 6 + pkg/expression/contextimpl/sessionctx.go | 32 ++++- pkg/expression/contextimpl/sessionctx_test.go | 38 +++++ pkg/expression/contextopt/BUILD.bazel | 30 ++++ pkg/expression/contextopt/current_user.go | 62 ++++++++ pkg/expression/contextopt/optional.go | 87 ++++++++++++ pkg/expression/contextopt/optional_test.go | 46 ++++++ pkg/expression/distsql_builtin.go | 4 +- 17 files changed, 611 insertions(+), 32 deletions(-) create mode 100644 pkg/expression/context/optional.go create mode 100644 pkg/expression/context/optional_test.go create mode 100644 pkg/expression/contextopt/BUILD.bazel create mode 100644 pkg/expression/contextopt/current_user.go create mode 100644 pkg/expression/contextopt/optional.go create mode 100644 pkg/expression/contextopt/optional_test.go diff --git a/pkg/expression/BUILD.bazel b/pkg/expression/BUILD.bazel index 0b124f71dbb4d..bf5c92a855792 100644 --- a/pkg/expression/BUILD.bazel +++ b/pkg/expression/BUILD.bazel @@ -75,6 +75,7 @@ go_library( "//pkg/errctx", "//pkg/errno", "//pkg/expression/context", + "//pkg/expression/contextopt", "//pkg/extension", "//pkg/kv", "//pkg/parser", diff --git a/pkg/expression/builtin.go b/pkg/expression/builtin.go index dea63bcf96132..d4f1e2c71f169 100644 --- a/pkg/expression/builtin.go +++ b/pkg/expression/builtin.go @@ -33,6 +33,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/expression/contextopt" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" @@ -63,6 +64,10 @@ func (b *baseBuiltinFunc) PbCode() tipb.ScalarFuncSig { return b.pbCode } +func (*baseBuiltinFunc) RequiredOptionalEvalProps() (set OptionalEvalPropKeySet) { + return +} + // metadata returns the metadata of a function. // metadata means some functions contain extra inner fields which will not // contain in `tipb.Expr.children` but must be pushed down to coprocessor @@ -475,6 +480,7 @@ type vecBuiltinFunc interface { // builtinFunc stands for a particular function signature. type builtinFunc interface { + contextopt.RequireOptionalEvalProps vecBuiltinFunc // evalInt evaluates int result of builtinFunc by given row. diff --git a/pkg/expression/builtin_info.go b/pkg/expression/builtin_info.go index 7b27ab5083dc4..a52c249d47ba1 100644 --- a/pkg/expression/builtin_info.go +++ b/pkg/expression/builtin_info.go @@ -26,6 +26,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/expression/contextopt" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" @@ -173,12 +174,13 @@ func (c *currentUserFunctionClass) getFunction(ctx BuildContext, args []Expressi return nil, err } bf.tp.SetFlen(64) - sig := &builtinCurrentUserSig{bf} + sig := &builtinCurrentUserSig{baseBuiltinFunc: bf} return sig, nil } type builtinCurrentUserSig struct { baseBuiltinFunc + contextopt.CurrentUserPropReader } func (b *builtinCurrentUserSig) Clone() builtinFunc { @@ -187,14 +189,22 @@ func (b *builtinCurrentUserSig) Clone() builtinFunc { return newSig } +// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. +func (b *builtinCurrentUserSig) RequiredOptionalEvalProps() (set OptionalEvalPropKeySet) { + return b.CurrentUserPropReader.RequiredOptionalEvalProps() +} + // evalString evals a builtinCurrentUserSig. // See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user -func (b *builtinCurrentUserSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - data := ctx.GetSessionVars() - if data == nil || data.User == nil { +func (b *builtinCurrentUserSig) evalString(ctx EvalContext, _ chunk.Row) (string, bool, error) { + user, err := b.CurrentUser(ctx) + if err != nil { + return "", true, err + } + if user == nil { return "", true, errors.Errorf("Missing session variable when eval builtin") } - return data.User.String(), false, nil + return user.String(), false, nil } type currentRoleFunctionClass struct { @@ -210,12 +220,13 @@ func (c *currentRoleFunctionClass) getFunction(ctx BuildContext, args []Expressi return nil, err } bf.tp.SetFlen(64) - sig := &builtinCurrentRoleSig{bf} + sig := &builtinCurrentRoleSig{baseBuiltinFunc: bf} return sig, nil } type builtinCurrentRoleSig struct { baseBuiltinFunc + contextopt.CurrentUserPropReader } func (b *builtinCurrentRoleSig) Clone() builtinFunc { @@ -224,24 +235,32 @@ func (b *builtinCurrentRoleSig) Clone() builtinFunc { return newSig } +// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. +func (b *builtinCurrentRoleSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { + return b.CurrentUserPropReader.RequiredOptionalEvalProps() +} + // evalString evals a builtinCurrentUserSig. // See https://dev.mysql.com/doc/refman/8.0/en/information-functions.html#function_current-role func (b *builtinCurrentRoleSig) evalString(ctx EvalContext, row chunk.Row) (res string, isNull bool, err error) { - data := ctx.GetSessionVars() - if data == nil || data.ActiveRoles == nil { + roles, err := b.ActiveRoles(ctx) + if err != nil { + return "", true, err + } + if roles == nil { return "", true, errors.Errorf("Missing session variable when eval builtin") } - if len(data.ActiveRoles) == 0 { + if len(roles) == 0 { return "NONE", false, nil } sortedRes := make([]string, 0, 10) - for _, r := range data.ActiveRoles { + for _, r := range roles { sortedRes = append(sortedRes, r.String()) } slices.Sort(sortedRes) for i, r := range sortedRes { res += r - if i != len(data.ActiveRoles)-1 { + if i != len(roles)-1 { res += "," } } @@ -309,12 +328,13 @@ func (c *userFunctionClass) getFunction(ctx BuildContext, args []Expression) (bu return nil, err } bf.tp.SetFlen(64) - sig := &builtinUserSig{bf} + sig := &builtinUserSig{baseBuiltinFunc: bf} return sig, nil } type builtinUserSig struct { baseBuiltinFunc + contextopt.CurrentUserPropReader } func (b *builtinUserSig) Clone() builtinFunc { @@ -323,14 +343,22 @@ func (b *builtinUserSig) Clone() builtinFunc { return newSig } +// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. +func (b *builtinUserSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { + return b.CurrentUserPropReader.RequiredOptionalEvalProps() +} + // evalString evals a builtinUserSig. // See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_user -func (b *builtinUserSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - data := ctx.GetSessionVars() - if data == nil || data.User == nil { +func (b *builtinUserSig) evalString(ctx EvalContext, _ chunk.Row) (string, bool, error) { + user, err := b.CurrentUser(ctx) + if err != nil { + return "", true, err + } + if user == nil { return "", true, errors.Errorf("Missing session variable when eval builtin") } - return data.User.LoginString(), false, nil + return user.LoginString(), false, nil } type connectionIDFunctionClass struct { diff --git a/pkg/expression/builtin_info_vec.go b/pkg/expression/builtin_info_vec.go index 86b4a866ddfc1..9c91b2c1048bc 100644 --- a/pkg/expression/builtin_info_vec.go +++ b/pkg/expression/builtin_info_vec.go @@ -106,14 +106,17 @@ func (b *builtinCurrentUserSig) vectorized() bool { // See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user func (b *builtinCurrentUserSig) vecEvalString(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + user, err := b.CurrentUser(ctx) + if err != nil { + return err + } - data := ctx.GetSessionVars() result.ReserveString(n) - if data == nil || data.User == nil { + if user == nil { return errors.Errorf("Missing session variable when eval builtin") } for i := 0; i < n; i++ { - result.AppendString(data.User.String()) + result.AppendString(user.String()) } return nil } @@ -145,13 +148,17 @@ func (b *builtinCurrentRoleSig) vectorized() bool { func (b *builtinCurrentRoleSig) vecEvalString(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() - data := ctx.GetSessionVars() - if data == nil || data.ActiveRoles == nil { + roles, err := b.ActiveRoles(ctx) + if err != nil { + return err + } + + if roles == nil { return errors.Errorf("Missing session variable when eval builtin") } result.ReserveString(n) - if len(data.ActiveRoles) == 0 { + if len(roles) == 0 { for i := 0; i < n; i++ { result.AppendString("NONE") } @@ -159,7 +166,7 @@ func (b *builtinCurrentRoleSig) vecEvalString(ctx EvalContext, input *chunk.Chun } sortedRes := make([]string, 0, 10) - for _, r := range data.ActiveRoles { + for _, r := range roles { sortedRes = append(sortedRes, r.String()) } slices.Sort(sortedRes) @@ -178,14 +185,17 @@ func (b *builtinUserSig) vectorized() bool { // See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_user func (b *builtinUserSig) vecEvalString(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() - data := ctx.GetSessionVars() - if data == nil || data.User == nil { + user, err := b.CurrentUser(ctx) + if err != nil { + return err + } + if user == nil { return errors.Errorf("Missing session variable when eval builtin") } result.ReserveString(n) for i := 0; i < n; i++ { - result.AppendString(data.User.LoginString()) + result.AppendString(user.LoginString()) } return nil } diff --git a/pkg/expression/context.go b/pkg/expression/context.go index 03dd6ec60a0b0..b843e9a2511bf 100644 --- a/pkg/expression/context.go +++ b/pkg/expression/context.go @@ -31,6 +31,18 @@ type EvalContext = context.EvalContext // BuildContext is used to build an expression type BuildContext = context.BuildContext +// OptionalEvalPropKey is an alias of context.OptionalEvalPropKey +type OptionalEvalPropKey = context.OptionalEvalPropKey + +// OptionalEvalPropProvider is an alias of context.OptionalEvalPropProvider +type OptionalEvalPropProvider = context.OptionalEvalPropProvider + +// OptionalEvalPropKeySet is an alias of context.OptionalEvalPropKeySet +type OptionalEvalPropKeySet = context.OptionalEvalPropKeySet + +// OptionalEvalPropDesc is an alias of context.OptionalEvalPropDesc +type OptionalEvalPropDesc = context.OptionalEvalPropDesc + func sqlMode(ctx EvalContext) mysql.SQLMode { return ctx.SQLMode() } diff --git a/pkg/expression/context/BUILD.bazel b/pkg/expression/context/BUILD.bazel index 71c5b75c2e1bf..4b0f7f27716ff 100644 --- a/pkg/expression/context/BUILD.bazel +++ b/pkg/expression/context/BUILD.bazel @@ -1,8 +1,11 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "context", - srcs = ["context.go"], + srcs = [ + "context.go", + "optional.go", + ], importpath = "github.com/pingcap/tidb/pkg/expression/context", visibility = ["//visibility:public"], deps = [ @@ -15,3 +18,13 @@ go_library( "//pkg/types", ], ) + +go_test( + name = "context_test", + timeout = "short", + srcs = ["optional_test.go"], + embed = [":context"], + flaky = True, + shard_count = 4, + deps = ["@com_github_stretchr_testify//require"], +) diff --git a/pkg/expression/context/context.go b/pkg/expression/context/context.go index 546cf7fffcd42..103feb048912e 100644 --- a/pkg/expression/context/context.go +++ b/pkg/expression/context/context.go @@ -71,6 +71,8 @@ type EvalContext interface { GetInfoSchema() infoschema.InfoSchemaMetaVersion // GetDomainInfoSchema returns the latest information schema in domain GetDomainInfoSchema() infoschema.InfoSchemaMetaVersion + // GetOptionalPropProvider gets the optional property provider by key + GetOptionalPropProvider(OptionalEvalPropKey) (OptionalEvalPropProvider, bool) } // BuildContext is used to build an expression diff --git a/pkg/expression/context/optional.go b/pkg/expression/context/optional.go new file mode 100644 index 0000000000000..eb5a11f2cdb27 --- /dev/null +++ b/pkg/expression/context/optional.go @@ -0,0 +1,134 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "fmt" + "log" + "unsafe" +) + +func init() { + // Ensure the count of optional properties is less than the bits of OptionalEvalProps. + if maxCnt := int64(unsafe.Sizeof(*new(OptionalEvalPropKeySet))) * 8; int64(OptPropsCnt) > maxCnt { + log.Fatalf( + "The count optional properties should less than the bits of OptionalEvalProps, but %d > %d", + OptPropsCnt, maxCnt, + ) + } + + // check optionalPropertyDescList are valid + if len(optionalPropertyDescList) != OptPropsCnt { + log.Fatalf( + "The lenghth of OptionalPropertieDescs should be %d, but got %d", + OptPropsCnt, len(optionalPropertyDescList), + ) + } + + for i := 0; i < OptPropsCnt; i++ { + if key := optionalPropertyDescList[i].Key(); key != OptionalEvalPropKey(i) { + log.Fatalf( + "Invalid optionalPropertyDescList[%d].Key, unexpected index: %d", + i, key, + ) + } + } +} + +// OptionalEvalPropKey is the key for optional evaluation properties in EvalContext. +type OptionalEvalPropKey int + +// AsPropKeySet returns the set only contains the property key. +func (k OptionalEvalPropKey) AsPropKeySet() OptionalEvalPropKeySet { + return 1 << k +} + +// Desc returns the description for the property key. +func (k OptionalEvalPropKey) Desc() *OptionalEvalPropDesc { + return &optionalPropertyDescList[k] +} + +// String implements fmt.Stringer interface. +func (k OptionalEvalPropKey) String() string { + if k < optPropsCnt { + return k.Desc().str + } + return fmt.Sprintf("UnknownOptionalEvalPropKey(%d)", k) +} + +const ( + // OptPropCurrentUser indicates the current user property + OptPropCurrentUser OptionalEvalPropKey = iota + optPropsCnt +) + +const allOptPropsMask = (1 << optPropsCnt) - 1 + +// OptPropsCnt is the count of optional properties. +const OptPropsCnt = int(optPropsCnt) + +// OptionalEvalPropDesc is the description for optional evaluation properties in EvalContext. +type OptionalEvalPropDesc struct { + key OptionalEvalPropKey + str string + // TODO: add more fields if needed +} + +// Key returns the property key. +func (desc *OptionalEvalPropDesc) Key() OptionalEvalPropKey { + return desc.key +} + +// OptionalEvalPropProvider is the interface to provide optional properties in EvalContext. +type OptionalEvalPropProvider interface { + Desc() *OptionalEvalPropDesc +} + +// optionalPropertyDescList contains all optional property descriptions in EvalContext. +var optionalPropertyDescList = []OptionalEvalPropDesc{ + { + key: OptPropCurrentUser, + str: "OptPropCurrentUser", + }, +} + +// OptionalEvalPropKeySet is a bit map for optional evaluation properties in EvalContext +// to indicate whether some properties are set. +type OptionalEvalPropKeySet uint64 + +// Add adds the property key to the set +func (b OptionalEvalPropKeySet) Add(key OptionalEvalPropKey) OptionalEvalPropKeySet { + return b | key.AsPropKeySet() +} + +// Remove removes the property key from the set +func (b OptionalEvalPropKeySet) Remove(key OptionalEvalPropKey) OptionalEvalPropKeySet { + return b &^ key.AsPropKeySet() +} + +// Contains checks whether the set contains the property +func (b OptionalEvalPropKeySet) Contains(key OptionalEvalPropKey) bool { + return b&key.AsPropKeySet() != 0 +} + +// IsEmpty checks whether the bit map is empty. +func (b OptionalEvalPropKeySet) IsEmpty() bool { + return b&allOptPropsMask == 0 +} + +// IsFull checks whether all optional properties are contained in the bit map. +func (b OptionalEvalPropKeySet) IsFull() bool { + return b&allOptPropsMask == allOptPropsMask +} diff --git a/pkg/expression/context/optional_test.go b/pkg/expression/context/optional_test.go new file mode 100644 index 0000000000000..a7bb089ec097b --- /dev/null +++ b/pkg/expression/context/optional_test.go @@ -0,0 +1,76 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOptionalPropKeySet(t *testing.T) { + var keySet OptionalEvalPropKeySet + require.True(t, keySet.IsEmpty()) + require.False(t, keySet.IsFull()) + require.False(t, keySet.Contains(OptPropCurrentUser)) + + // Add one key + keySet2 := keySet.Add(OptPropCurrentUser) + require.True(t, keySet2.Contains(OptPropCurrentUser)) + require.False(t, keySet2.IsEmpty()) + require.True(t, keySet2.IsFull()) + + // old key is not affected + require.True(t, keySet.IsEmpty()) + + // remove one key + keySet3 := keySet2.Remove(OptPropCurrentUser) + require.True(t, keySet3.IsEmpty()) + require.False(t, keySet2.IsEmpty()) +} + +func TestOptionalPropKeySetWithUnusedBits(t *testing.T) { + require.Less(t, OptPropsCnt, 64) + full := OptionalEvalPropKeySet(math.MaxUint64) + + bits := full << OptionalEvalPropKeySet(OptPropsCnt) + require.True(t, bits.IsEmpty()) + require.False(t, bits.Contains(OptPropCurrentUser)) + bits = bits.Add(OptPropCurrentUser) + require.True(t, bits.Contains(OptPropCurrentUser)) + + bits = full >> (64 - OptPropsCnt) + require.True(t, bits.IsFull()) + require.True(t, bits.Contains(OptPropCurrentUser)) + bits = bits.Remove(OptPropCurrentUser) + require.False(t, bits.Contains(OptPropCurrentUser)) +} + +func TestOptionalPropKey(t *testing.T) { + keySet := OptPropCurrentUser.AsPropKeySet() + require.True(t, keySet.Contains(OptPropCurrentUser)) + keySet = keySet.Remove(OptPropCurrentUser) + require.True(t, keySet.IsEmpty()) +} + +func TestOptionalPropDescList(t *testing.T) { + require.Equal(t, OptPropsCnt, len(optionalPropertyDescList)) + for i := 0; i < OptPropsCnt; i++ { + key := OptionalEvalPropKey(i) + require.Equal(t, key, optionalPropertyDescList[i].Key()) + require.Same(t, &optionalPropertyDescList[i], key.Desc()) + } +} diff --git a/pkg/expression/contextimpl/BUILD.bazel b/pkg/expression/contextimpl/BUILD.bazel index 58d04d53109d8..784436b4e6735 100644 --- a/pkg/expression/contextimpl/BUILD.bazel +++ b/pkg/expression/contextimpl/BUILD.bazel @@ -8,11 +8,14 @@ go_library( deps = [ "//pkg/errctx", "//pkg/expression/context", + "//pkg/expression/contextopt", + "//pkg/parser/auth", "//pkg/parser/mysql", "//pkg/sessionctx", "//pkg/sessionctx/stmtctx", "//pkg/sessionctx/variable", "//pkg/types", + "//pkg/util/intest", ], ) @@ -24,6 +27,9 @@ go_test( deps = [ ":contextimpl", "//pkg/errctx", + "//pkg/expression/context", + "//pkg/expression/contextopt", + "//pkg/parser/auth", "//pkg/parser/mysql", "//pkg/sessionctx/stmtctx", "//pkg/types", diff --git a/pkg/expression/contextimpl/sessionctx.go b/pkg/expression/contextimpl/sessionctx.go index f5111346c2e13..c23529c9fff9a 100644 --- a/pkg/expression/contextimpl/sessionctx.go +++ b/pkg/expression/contextimpl/sessionctx.go @@ -19,11 +19,14 @@ import ( "github.com/pingcap/tidb/pkg/errctx" exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/expression/contextopt" + "github.com/pingcap/tidb/pkg/parser/auth" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/intest" ) // sessionctx.Context + *ExprCtxExtendedImpl should implement `expression.BuildContext` @@ -35,12 +38,25 @@ var _ exprctx.BuildContext = struct { // ExprCtxExtendedImpl extends the sessionctx.Context to implement `expression.BuildContext` type ExprCtxExtendedImpl struct { - sctx sessionctx.Context + sctx sessionctx.Context + props contextopt.OptionalEvalPropProviders } // NewExprExtendedImpl creates a new ExprCtxExtendedImpl. func NewExprExtendedImpl(sctx sessionctx.Context) *ExprCtxExtendedImpl { - return &ExprCtxExtendedImpl{sctx: sctx} + impl := &ExprCtxExtendedImpl{sctx: sctx} + // set all optional properties + impl.setOptionalProp(currentUserProp(sctx)) + // When EvalContext is created from a session, it should contain all the optional properties. + intest.Assert(impl.props.PropKeySet().IsFull()) + return impl +} + +func (ctx *ExprCtxExtendedImpl) setOptionalProp(prop exprctx.OptionalEvalPropProvider) { + intest.AssertFunc(func() bool { + return !ctx.props.Contains(prop.Desc().Key()) + }) + ctx.props.Add(prop) } // CtxID returns the context id. @@ -102,3 +118,15 @@ func (ctx *ExprCtxExtendedImpl) GetDefaultWeekFormatMode() string { } return mode } + +// GetOptionalPropProvider gets the optional property provider by key +func (ctx *ExprCtxExtendedImpl) GetOptionalPropProvider(key exprctx.OptionalEvalPropKey) (exprctx.OptionalEvalPropProvider, bool) { + return ctx.props.Get(key) +} + +func currentUserProp(sctx sessionctx.Context) exprctx.OptionalEvalPropProvider { + return contextopt.CurrentUserPropProvider(func() (*auth.UserIdentity, []*auth.RoleIdentity) { + vars := sctx.GetSessionVars() + return vars.User, vars.ActiveRoles + }) +} diff --git a/pkg/expression/contextimpl/sessionctx_test.go b/pkg/expression/contextimpl/sessionctx_test.go index 3da4d18d6db28..9db97da45f412 100644 --- a/pkg/expression/contextimpl/sessionctx_test.go +++ b/pkg/expression/contextimpl/sessionctx_test.go @@ -20,7 +20,10 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/expression/contextimpl" + "github.com/pingcap/tidb/pkg/expression/contextopt" + "github.com/pingcap/tidb/pkg/parser/auth" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/types" @@ -34,6 +37,14 @@ func TestEvalContextImplWithSessionCtx(t *testing.T) { sc := vars.StmtCtx impl := contextimpl.NewExprExtendedImpl(ctx) + // should contain all the optional properties + for i := 0; i < context.OptPropsCnt; i++ { + provider, ok := impl.GetOptionalPropProvider(context.OptionalEvalPropKey(i)) + require.True(t, ok) + require.NotNil(t, provider) + require.Same(t, context.OptionalEvalPropKey(i).Desc(), provider.Desc()) + } + ctx.ResetSessionAndStmtTimeZone(time.FixedZone("UTC+11", 11*3600)) vars.SQLMode = mysql.ModeStrictTransTables | mysql.ModeNoZeroDate sc.SetTypeFlags(types.FlagIgnoreInvalidDateErr | types.FlagSkipUTF8Check) @@ -74,3 +85,30 @@ func TestEvalContextImplWithSessionCtx(t *testing.T) { require.Equal(t, stmtctx.WarnLevelWarning, warnings[0].Level) require.Equal(t, "err1", warnings[0].Err.Error()) } + +func getProvider[T context.OptionalEvalPropProvider]( + t *testing.T, + impl *contextimpl.ExprCtxExtendedImpl, + key context.OptionalEvalPropKey, +) T { + val, ok := impl.GetOptionalPropProvider(key) + require.True(t, ok) + p, ok := val.(T) + require.True(t, ok) + return p +} + +func TestEvalContextImplWithSessionCtxForOptProps(t *testing.T) { + ctx := mock.NewContext() + impl := contextimpl.NewExprExtendedImpl(ctx) + + // test for OptPropCurrentUser + ctx.GetSessionVars().User = &auth.UserIdentity{Username: "user1", Hostname: "host1"} + ctx.GetSessionVars().ActiveRoles = []*auth.RoleIdentity{ + {Username: "role1", Hostname: "host1"}, + {Username: "role2", Hostname: "host2"}, + } + user, roles := getProvider[contextopt.CurrentUserPropProvider](t, impl, context.OptPropCurrentUser)() + require.Equal(t, ctx.GetSessionVars().User, user) + require.Equal(t, ctx.GetSessionVars().ActiveRoles, roles) +} diff --git a/pkg/expression/contextopt/BUILD.bazel b/pkg/expression/contextopt/BUILD.bazel new file mode 100644 index 0000000000000..a1ebf6c462b6e --- /dev/null +++ b/pkg/expression/contextopt/BUILD.bazel @@ -0,0 +1,30 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "contextopt", + srcs = [ + "current_user.go", + "optional.go", + ], + importpath = "github.com/pingcap/tidb/pkg/expression/contextopt", + visibility = ["//visibility:public"], + deps = [ + "//pkg/expression/context", + "//pkg/parser/auth", + "//pkg/util/intest", + "@com_github_pingcap_errors//:errors", + ], +) + +go_test( + name = "contextopt_test", + timeout = "short", + srcs = ["optional_test.go"], + embed = [":contextopt"], + flaky = True, + deps = [ + "//pkg/expression/context", + "//pkg/parser/auth", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/expression/contextopt/current_user.go b/pkg/expression/contextopt/current_user.go new file mode 100644 index 0000000000000..997c9c7c03469 --- /dev/null +++ b/pkg/expression/contextopt/current_user.go @@ -0,0 +1,62 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package contextopt + +import ( + "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/parser/auth" +) + +var _ RequireOptionalEvalProps = CurrentUserPropReader{} + +// CurrentUserPropProvider is a provider to get the current user +type CurrentUserPropProvider func() (*auth.UserIdentity, []*auth.RoleIdentity) + +// Desc returns the description for the property key. +func (p CurrentUserPropProvider) Desc() *context.OptionalEvalPropDesc { + return context.OptPropCurrentUser.Desc() +} + +// CurrentUserPropReader is used by expression to read property context.OptPropCurrentUser +type CurrentUserPropReader struct{} + +// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. +func (r CurrentUserPropReader) RequiredOptionalEvalProps() context.OptionalEvalPropKeySet { + return context.OptPropCurrentUser.AsPropKeySet() +} + +// CurrentUser returns the current user +func (r CurrentUserPropReader) CurrentUser(ctx context.EvalContext) (*auth.UserIdentity, error) { + p, err := r.getProvider(ctx) + if err != nil { + return nil, err + } + user, _ := p() + return user, nil +} + +// ActiveRoles returns the active roles +func (r CurrentUserPropReader) ActiveRoles(ctx context.EvalContext) ([]*auth.RoleIdentity, error) { + p, err := r.getProvider(ctx) + if err != nil { + return nil, err + } + _, roles := p() + return roles, nil +} + +func (r CurrentUserPropReader) getProvider(ctx context.EvalContext) (CurrentUserPropProvider, error) { + return getPropProvider[CurrentUserPropProvider](ctx, context.OptPropCurrentUser) +} diff --git a/pkg/expression/contextopt/optional.go b/pkg/expression/contextopt/optional.go new file mode 100644 index 0000000000000..6d75fbaf23cf2 --- /dev/null +++ b/pkg/expression/contextopt/optional.go @@ -0,0 +1,87 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package contextopt + +import ( + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/util/intest" +) + +// RequireOptionalEvalProps is the interface for the function that requires optional evaluation properties or not. +type RequireOptionalEvalProps interface { + // RequiredOptionalEvalProps returns the optional properties that this function requires. + // If the returned `OptionalEvalPropKeySet` is empty, + // it means this function does not require any optional properties. + RequiredOptionalEvalProps() context.OptionalEvalPropKeySet +} + +// OptionalEvalPropProviders contains some evaluation property providers in EvalContext. +type OptionalEvalPropProviders [context.OptPropsCnt]context.OptionalEvalPropProvider + +// Contains checks whether the provider by key exists. +func (o *OptionalEvalPropProviders) Contains(key context.OptionalEvalPropKey) bool { + return o[key] != nil +} + +// Get gets the provider by key. +func (o *OptionalEvalPropProviders) Get(key context.OptionalEvalPropKey) (context.OptionalEvalPropProvider, bool) { + if val := o[key]; val != nil { + intest.Assert(key == val.Desc().Key()) + return val, true + } + return nil, false +} + +// Add adds an optional property +func (o *OptionalEvalPropProviders) Add(val context.OptionalEvalPropProvider) { + intest.AssertFunc(func() bool { + intest.AssertNotNil(val) + switch val.Desc().Key() { + case context.OptPropCurrentUser: + _, ok := val.(CurrentUserPropProvider) + intest.Assert(ok) + default: + intest.Assert(false) + } + return true + }) + o[val.Desc().Key()] = val +} + +// PropKeySet returns the set for optional evaluation properties in EvalContext. +func (o *OptionalEvalPropProviders) PropKeySet() (set context.OptionalEvalPropKeySet) { + for _, p := range o { + if p != nil { + set = set.Add(p.Desc().Key()) + } + } + return +} + +func getPropProvider[T context.OptionalEvalPropProvider](ctx context.EvalContext, key context.OptionalEvalPropKey) (p T, _ error) { + val, ok := ctx.GetOptionalPropProvider(key) + if !ok { + return p, errors.Errorf("optional property: '%s' not exists in EvalContext", key) + } + + p, ok = val.(T) + if !ok { + intest.Assert(false) + return p, errors.Errorf("cannot cast OptionalEvalPropProvider to %T for key '%s'", p, key) + } + + return p, nil +} diff --git a/pkg/expression/contextopt/optional_test.go b/pkg/expression/contextopt/optional_test.go new file mode 100644 index 0000000000000..0e91cc8fe7ff3 --- /dev/null +++ b/pkg/expression/contextopt/optional_test.go @@ -0,0 +1,46 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package contextopt + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/stretchr/testify/require" +) + +func TestOptionalEvalPropProviders(t *testing.T) { + var providers OptionalEvalPropProviders + require.True(t, providers.PropKeySet().IsEmpty()) + require.False(t, providers.Contains(context.OptPropCurrentUser)) + val, ok := providers.Get(context.OptPropCurrentUser) + require.False(t, ok) + require.Nil(t, val) + + var p context.OptionalEvalPropProvider + + user := &auth.UserIdentity{Username: "u1", Hostname: "h1"} + roles := []*auth.RoleIdentity{{Username: "u2", Hostname: "h2"}, {Username: "u3", Hostname: "h3"}} + p = CurrentUserPropProvider(func() (*auth.UserIdentity, []*auth.RoleIdentity) { return user, roles }) + providers.Add(p) + require.True(t, providers.PropKeySet().Contains(context.OptPropCurrentUser)) + require.True(t, providers.Contains(context.OptPropCurrentUser)) + val, ok = providers.Get(context.OptPropCurrentUser) + require.True(t, ok) + user2, roles2 := val.(CurrentUserPropProvider)() + require.Equal(t, user, user2) + require.Equal(t, roles, roles2) +} diff --git a/pkg/expression/distsql_builtin.go b/pkg/expression/distsql_builtin.go index 6969834cb8a9d..381a9037c6dd1 100644 --- a/pkg/expression/distsql_builtin.go +++ b/pkg/expression/distsql_builtin.go @@ -597,9 +597,9 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie case tipb.ScalarFuncSig_FoundRows: f = &builtinFoundRowsSig{base} case tipb.ScalarFuncSig_CurrentUser: - f = &builtinCurrentUserSig{base} + f = &builtinCurrentUserSig{baseBuiltinFunc: base} case tipb.ScalarFuncSig_User: - f = &builtinUserSig{base} + f = &builtinUserSig{baseBuiltinFunc: base} case tipb.ScalarFuncSig_ConnectionID: f = &builtinConnectionIDSig{base} case tipb.ScalarFuncSig_LastInsertID: From 1e5c179bef054b9bdc2a145f602d75864a9ccecc Mon Sep 17 00:00:00 2001 From: Lynn Date: Mon, 11 Mar 2024 11:24:38 +0800 Subject: [PATCH 05/15] ddl, tests: add expression default values feature relevant tests for some DDLs and fix a related bug (#51571) close pingcap/tidb#51554, close pingcap/tidb#51570 --- pkg/ddl/db_integration_test.go | 17 +- pkg/ddl/ddl_api.go | 2 +- .../r/ddl/default_as_expression.result | 308 ++++++++++++++++-- .../t/ddl/default_as_expression.test | 124 ++++++- 4 files changed, 412 insertions(+), 39 deletions(-) diff --git a/pkg/ddl/db_integration_test.go b/pkg/ddl/db_integration_test.go index c4b6b25c20259..f41de8b89b207 100644 --- a/pkg/ddl/db_integration_test.go +++ b/pkg/ddl/db_integration_test.go @@ -1614,7 +1614,7 @@ func TestDefaultValueAsExpressions(t *testing.T) { store := testkit.CreateMockStoreWithSchemaLease(t, testLease) tk := testkit.NewTestKit(t, store) tk.MustExec("use test") - tk.MustExec("drop table if exists t, t1") + tk.MustExec("drop table if exists t, t1, t2") // date_format tk.MustExec("create table t6 (c int(10), c1 int default (date_format(now(),'%Y-%m-%d %H:%i:%s')))") @@ -1630,6 +1630,16 @@ func TestDefaultValueAsExpressions(t *testing.T) { tk.Session().GetSessionVars().User = &auth.UserIdentity{Username: "xyz", Hostname: "localhost"} tk.MustExec("insert into t(c) values (4),(5),(6)") tk.MustExec("insert into t values (7, default)") + rows := tk.MustQuery("SELECT c1 from t order by c").Rows() + for i, row := range rows { + d, ok := row[0].(string) + require.True(t, ok) + if i < 3 { + require.Equal(t, "ROOT", d) + } else { + require.Equal(t, "XYZ", d) + } + } // replace tk.MustExec("create table t1 (c int(10), c1 int default (REPLACE(UPPER(UUID()), '-', '')))") @@ -1642,6 +1652,11 @@ func TestDefaultValueAsExpressions(t *testing.T) { if int(sqlErr.Code) != errno.ErrTruncatedWrongValue { require.Equal(t, errno.ErrDataOutOfRange, int(sqlErr.Code)) } + // test modify column + // The error message has UUID, so put this test here. + tk.MustExec("create table t2(c int(10), c1 varchar(256) default (REPLACE(UPPER(UUID()), '-', '')), index idx(c1));") + tk.MustExec("insert into t2(c) values (1),(2),(3);") + tk.MustGetErrCode("alter table t2 modify column c1 varchar(30) default 'xx';", errno.WarnDataTruncated) } func TestChangingDBCharset(t *testing.T) { diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index 2d01c4670b1ce..118dcd9170796 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -6280,8 +6280,8 @@ func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Alt // Clean the NoDefaultValueFlag value. col.DelFlag(mysql.NoDefaultValueFlag) + col.DefaultIsExpr = false if len(specNewColumn.Options) == 0 { - col.DefaultIsExpr = false err = col.SetDefaultValue(nil) if err != nil { return errors.Trace(err) diff --git a/tests/integrationtest/r/ddl/default_as_expression.result b/tests/integrationtest/r/ddl/default_as_expression.result index 4ae2932c99154..e5779aa89a3f6 100644 --- a/tests/integrationtest/r/ddl/default_as_expression.result +++ b/tests/integrationtest/r/ddl/default_as_expression.result @@ -19,7 +19,7 @@ SELECT * FROM t0 WHERE c = date_format(@x,'%Y-%m') OR c = date_format(DATE_ADD(@ c c1 insert into t1(c) values (1); insert into t1 values (2, default); -SELECT * FROM t1 WHERE c = date_format(@x,'%Y-%m-%d') OR c = date_format(DATE_ADD(@x, INTERVAL 1 SECOND), '%Y-%m-%d'); +SELECT * FROM t1 WHERE c = date_format(@x,'%Y-%m-%d'); c c1 insert into t2(c) values (1); insert into t2 values (2, default); @@ -56,29 +56,71 @@ t2 CREATE TABLE `t2` ( `c` int(10) DEFAULT NULL, `c1` varchar(256) DEFAULT date_format(now(), _utf8mb4'%Y-%m-%d %H.%i.%s') ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +alter table t0 add index idx(c1); +alter table t1 add index idx(c1); +insert into t0 values (3, default); +insert into t1 values (3, default); +show create table t0; +Table Create Table +t0 CREATE TABLE `t0` ( + `c` int(10) DEFAULT NULL, + `c1` varchar(256) DEFAULT date_format(now(), _utf8mb4'%Y-%m'), + KEY `idx` (`c1`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t1; +Table Create Table +t1 CREATE TABLE `t1` ( + `c` int(10) DEFAULT NULL, + `c1` datetime DEFAULT date_format(now(), _utf8mb4'%Y-%m-%d'), + KEY `idx` (`c1`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin alter table t0 modify column c1 varchar(30) default 'xx'; alter table t1 modify column c1 varchar(30) default 'xx'; +insert into t0 values (4, default); +insert into t1 values (4, default); show create table t0; Table Create Table t0 CREATE TABLE `t0` ( `c` int(10) DEFAULT NULL, - `c1` varchar(30) DEFAULT 'xx' + `c1` varchar(30) DEFAULT 'xx', + KEY `idx` (`c1`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin show create table t1; Table Create Table t1 CREATE TABLE `t1` ( `c` int(10) DEFAULT NULL, - `c1` varchar(30) DEFAULT 'xx' + `c1` varchar(30) DEFAULT 'xx', + KEY `idx` (`c1`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin alter table t0 modify column c1 datetime DEFAULT (date_format(now(), '%Y-%m-%d')); Error 1292 (22007): Incorrect datetime value: '2024-03' +alter table t0 alter column c1 SET DEFAULT (date_format(now(), '%Y-%m-%d')); +insert into t0 values (5, default); alter table t1 modify column c1 datetime DEFAULT (date_format(now(), '%Y-%m-%d')); +Error 1292 (22007): Incorrect datetime value: 'xx' +delete from t1 where c = 4; +alter table t1 modify column c1 datetime DEFAULT (date_format(now(), '%Y-%m-%d')); +insert into t1 values (5, default); +alter table t0 drop index idx; +alter table t1 drop index idx; +show create table t0; +Table Create Table +t0 CREATE TABLE `t0` ( + `c` int(10) DEFAULT NULL, + `c1` varchar(30) DEFAULT date_format(now(), _utf8mb4'%Y-%m-%d') +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin show create table t1; Table Create Table t1 CREATE TABLE `t1` ( `c` int(10) DEFAULT NULL, `c1` datetime DEFAULT date_format(now(), _utf8mb4'%Y-%m-%d') ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +SELECT count(1) FROM t0 WHERE c1 = date_format(@x,'%Y-%m') OR c1 = date_format(@x,'%Y-%m-%d') OR c1 = "xx"; +count(1) +5 +SELECT count(1) FROM t1 WHERE c1 = date_format(@x,'%Y-%m-%d'); +count(1) +4 SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t1' AND COLUMN_NAME='c1'; column_default extra date_format(now(), _utf8mb4'%Y-%m-%d') DEFAULT_GENERATED @@ -86,9 +128,9 @@ show columns from test.t1 where field='c1'; Field Type Null Key Default Extra c1 datetime YES date_format(now(), _utf8mb4'%Y-%m-%d') DEFAULT_GENERATED drop table if exists t, t1, t2; -create table t (c int(10), c1 varchar(256) default (REPLACE(UPPER(UUID()), '-', ''))); -create table t1 (c int(10), c1 int default (REPLACE(UPPER(UUID()), '-', ''))); -create table t2 (c int(10), c1 varchar(256) default (REPLACE(CONVERT(UPPER(UUID()) USING UTF8MB4), '-', ''))); +create table t (c int(10), c1 varchar(256) default (REPLACE(UPPER(UUID()), '-', '')), index idx(c1)); +create table t1 (c int(10), c1 int default (REPLACE(UPPER(UUID()), '-', '')), index idx(c1)); +create table t2 (c int(10), c1 varchar(256) default (REPLACE(CONVERT(UPPER(UUID()) USING UTF8MB4), '-', '')), index idx(c1)); create table t1 (c int(10), c1 varchar(256) default (REPLACE('xdfj-jfj', '-', ''))); Error 3770 (HY000): Default value expression of column 'c1' contains a disallowed function: `REPLACE`. create table t1 (c int(10), c1 varchar(256) default (UPPER(UUID()))); @@ -102,7 +144,7 @@ Error 1674 (HY000): Statement is unsafe because it uses a system function that m alter table t add column c4 int default (REPLACE(UPPER('dfdkj-kjkl-d'), '-', '')); Error 1674 (HY000): Statement is unsafe because it uses a system function that may return a different value on the slave insert into t(c) values (1),(2),(3); -insert into t values (4, default) +insert into t values (4, default); SELECT count(1) FROM t WHERE c1 REGEXP '^[A-Z0-9]+$'; count(1) 4 @@ -110,40 +152,65 @@ show create table t; Table Create Table t CREATE TABLE `t` ( `c` int(10) DEFAULT NULL, - `c1` varchar(256) DEFAULT replace(upper(uuid()), _utf8mb4'-', _utf8mb4'') + `c1` varchar(256) DEFAULT replace(upper(uuid()), _utf8mb4'-', _utf8mb4''), + KEY `idx` (`c1`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin show create table t1; Table Create Table t1 CREATE TABLE `t1` ( `c` int(10) DEFAULT NULL, - `c1` int(11) DEFAULT replace(upper(uuid()), _utf8mb4'-', _utf8mb4'') + `c1` int(11) DEFAULT replace(upper(uuid()), _utf8mb4'-', _utf8mb4''), + KEY `idx` (`c1`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin show create table t2; Table Create Table t2 CREATE TABLE `t2` ( `c` int(10) DEFAULT NULL, - `c1` varchar(256) DEFAULT replace(convert(upper(uuid()) using 'utf8mb4'), _utf8mb4'-', _utf8mb4'') + `c1` varchar(256) DEFAULT replace(convert(upper(uuid()) using 'utf8mb4'), _utf8mb4'-', _utf8mb4''), + KEY `idx` (`c1`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin -alter table t1 modify column c1 varchar(30) default 'xx'; -show create table t1; +alter table t alter column c1 set default 'xx'; +alter table t drop index idx; +show create table t; Table Create Table -t1 CREATE TABLE `t1` ( +t CREATE TABLE `t` ( `c` int(10) DEFAULT NULL, - `c1` varchar(30) DEFAULT 'xx' + `c1` varchar(256) DEFAULT 'xx' ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin -alter table t1 modify column c1 varchar(32) default (REPLACE(UPPER(UUID()), '-', '')); -show create table t1; +insert into t values (5, default); +show create table t; Table Create Table -t1 CREATE TABLE `t1` ( +t CREATE TABLE `t` ( `c` int(10) DEFAULT NULL, - `c1` varchar(32) DEFAULT replace(upper(uuid()), _utf8mb4'-', _utf8mb4'') + `c1` varchar(256) DEFAULT 'xx' ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin -SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t1' AND COLUMN_NAME='c1'; +alter table t add unique index idx(c, c1); +alter table t modify column c1 varchar(32) default (REPLACE(UPPER(UUID()), '-', '')); +insert into t values (6, default); +SELECT count(1) FROM t WHERE c1 REGEXP '^[A-Z0-9]+$'; +count(1) +5 +show create table t; +Table Create Table +t CREATE TABLE `t` ( + `c` int(10) DEFAULT NULL, + `c1` varchar(32) DEFAULT replace(upper(uuid()), _utf8mb4'-', _utf8mb4''), + UNIQUE KEY `idx` (`c`,`c1`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t' AND COLUMN_NAME='c1'; column_default extra replace(upper(uuid()), _utf8mb4'-', _utf8mb4'') DEFAULT_GENERATED +alter table t alter column c1 set default null; +insert into t(c) values (7); +alter table t alter column c1 drop default; +insert into t(c) values (8); +Error 1364 (HY000): Field 'c1' doesn't have a default value +SELECT count(1) FROM t WHERE c1 REGEXP '^[A-Z0-9]+$'; +count(1) +5 drop table if exists t0, t1, t2, t3, t4, t5; -create table t0 (c int(10), c1 varchar(32) default (str_to_date('1980-01-01','%Y-%m-%d')), c2 date default (str_to_date('9999-01-01','%Y-%m-%d'))); -create table t1 (c int(10), c1 int default (str_to_date('1980-01-01','%Y-%m-%d')), c2 int default (str_to_date('9999-01-01','%Y-%m-%d'))); +create table t0 (c int(10), c1 varchar(32) default (str_to_date('1980-01-01','%Y-%m-%d')), c2 date default (str_to_date('9999-01-01','%Y-%m-%d')), index idx(c, c1)); +create table t1 (c int(10), c1 int default (str_to_date('1980-01-01','%Y-%m-%d')), c2 int default (str_to_date('9999-01-01','%Y-%m-%d')), unique key idx(c, c1)); create table t3 (c int(10), c1 varchar(32) default (str_to_date('1980-01-01','%m-%d'))); create table t4 (c int(10), c1 varchar(32) default (str_to_date('01-01','%Y-%m-%d'))); set @sqlMode := @@session.sql_mode; @@ -174,11 +241,126 @@ insert into t2 values (4, default, default); set session sql_mode=@sqlMode; insert into t2(c) values (5); Error 1292 (22007): Incorrect datetime value: '0000-00-00 00:00:00' +select * from t0; +c c1 c2 +1 1980-01-01 9999-01-01 +2 1980-01-01 9999-01-01 +3 1980-01-01 9999-01-01 +4 1980-01-01 9999-01-01 +select * from t1; +c c1 c2 +1 19800101 99990101 +2 19800101 99990101 +3 19800101 99990101 +4 19800101 99990101 +select * from t2; +c c1 c2 +1 1980-01-01 NULL +2 1980-01-01 NULL +3 1980-01-01 NULL +4 1980-01-01 NULL +show create table t0; +Table Create Table +t0 CREATE TABLE `t0` ( + `c` int(10) DEFAULT NULL, + `c1` varchar(32) DEFAULT str_to_date(_utf8mb4'1980-01-01', _utf8mb4'%Y-%m-%d'), + `c2` date DEFAULT str_to_date(_utf8mb4'9999-01-01', _utf8mb4'%Y-%m-%d'), + KEY `idx` (`c`,`c1`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t1; +Table Create Table +t1 CREATE TABLE `t1` ( + `c` int(10) DEFAULT NULL, + `c1` int(11) DEFAULT str_to_date(_utf8mb4'1980-01-01', _utf8mb4'%Y-%m-%d'), + `c2` int(11) DEFAULT str_to_date(_utf8mb4'9999-01-01', _utf8mb4'%Y-%m-%d'), + UNIQUE KEY `idx` (`c`,`c1`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t2; +Table Create Table +t2 CREATE TABLE `t2` ( + `c` int(10) DEFAULT NULL, + `c1` blob DEFAULT str_to_date(_utf8mb4'1980-01-01', _utf8mb4'%Y-%m-%d'), + `c2` blob DEFAULT str_to_date(_utf8mb4'9999-01-01', _utf8mb4'%m-%d') +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +alter table t0 add index idx1(c1); +alter table t1 add unique index idx1(c, c1); +insert into t0 values (5, default, default); +insert into t1 values (5, default, default); +show create table t0; +Table Create Table +t0 CREATE TABLE `t0` ( + `c` int(10) DEFAULT NULL, + `c1` varchar(32) DEFAULT str_to_date(_utf8mb4'1980-01-01', _utf8mb4'%Y-%m-%d'), + `c2` date DEFAULT str_to_date(_utf8mb4'9999-01-01', _utf8mb4'%Y-%m-%d'), + KEY `idx` (`c`,`c1`), + KEY `idx1` (`c1`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t1; +Table Create Table +t1 CREATE TABLE `t1` ( + `c` int(10) DEFAULT NULL, + `c1` int(11) DEFAULT str_to_date(_utf8mb4'1980-01-01', _utf8mb4'%Y-%m-%d'), + `c2` int(11) DEFAULT str_to_date(_utf8mb4'9999-01-01', _utf8mb4'%Y-%m-%d'), + UNIQUE KEY `idx` (`c`,`c1`), + UNIQUE KEY `idx1` (`c`,`c1`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +alter table t0 alter column c2 set default (current_date()); +alter table t1 modify column c1 varchar(30) default 'xx'; +insert into t0 values (6, default, default); +insert into t1 values (6, default, default); +show create table t0; +Table Create Table +t0 CREATE TABLE `t0` ( + `c` int(10) DEFAULT NULL, + `c1` varchar(32) DEFAULT str_to_date(_utf8mb4'1980-01-01', _utf8mb4'%Y-%m-%d'), + `c2` date DEFAULT CURRENT_DATE, + KEY `idx` (`c`,`c1`), + KEY `idx1` (`c1`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t1; +Table Create Table +t1 CREATE TABLE `t1` ( + `c` int(10) DEFAULT NULL, + `c1` varchar(30) DEFAULT 'xx', + `c2` int(11) DEFAULT str_to_date(_utf8mb4'9999-01-01', _utf8mb4'%Y-%m-%d'), + UNIQUE KEY `idx` (`c`,`c1`), + UNIQUE KEY `idx1` (`c`,`c1`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +alter table t0 alter column c1 drop default; +alter table t1 modify column c1 varchar(32) default (str_to_date('1980-01-01','%Y-%m-%d')); +insert into t0 values (7, default, default); +Error 1364 (HY000): Field 'c1' doesn't have a default value +insert into t1 values (7, default, default); +select * from t0 where c < 6; +c c1 c2 +1 1980-01-01 9999-01-01 +2 1980-01-01 9999-01-01 +3 1980-01-01 9999-01-01 +4 1980-01-01 9999-01-01 +5 1980-01-01 9999-01-01 +select c, c1 from t0 where c = 6 and c2 = date_format(now(),'%Y-%m-%d');; +c c1 +6 1980-01-01 +select * from t1; +c c1 c2 +1 19800101 99990101 +2 19800101 99990101 +3 19800101 99990101 +4 19800101 99990101 +5 19800101 99990101 +6 xx 99990101 +7 1980-01-01 99990101 +select * from t2; +c c1 c2 +1 1980-01-01 NULL +2 1980-01-01 NULL +3 1980-01-01 NULL +4 1980-01-01 NULL SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t1' AND COLUMN_NAME='c1'; column_default extra str_to_date(_utf8mb4'1980-01-01', _utf8mb4'%Y-%m-%d') DEFAULT_GENERATED drop table if exists t, t1, t2; -create table t (c int(10), c1 varchar(256) default (upper(substring_index(user(),'@',1)))); +create table t (c int(10), c1 varchar(256) default (upper(substring_index(user(),'@',1))), unique index idx(c, c1)); create table t1 (c int(10), c1 int default (upper(substring_index(user(),_utf8mb4'@',1)))); create table t2 (c int(10), c1 varchar(256) default (substring_index(user(),'@',1))); Error 3770 (HY000): Default value expression of column 'c1' contains a disallowed function: `substring_index`. @@ -196,7 +378,8 @@ show create table t; Table Create Table t CREATE TABLE `t` ( `c` int(10) DEFAULT NULL, - `c1` varchar(256) DEFAULT upper(substring_index(user(), _utf8mb4'@', 1)) + `c1` varchar(256) DEFAULT upper(substring_index(user(), _utf8mb4'@', 1)), + UNIQUE KEY `idx` (`c`,`c1`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin show create table t1; Table Create Table @@ -212,11 +395,13 @@ t1 CREATE TABLE `t1` ( `c1` varchar(30) DEFAULT 'xx' ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin alter table t1 modify column c1 varchar(32) default (upper(substring_index(user(),'@',1))); +alter table t1 add index idx1(c1); show create table t1; Table Create Table t1 CREATE TABLE `t1` ( `c` int(10) DEFAULT NULL, - `c1` varchar(32) DEFAULT upper(substring_index(user(), _utf8mb4'@', 1)) + `c1` varchar(32) DEFAULT upper(substring_index(user(), _utf8mb4'@', 1)), + KEY `idx1` (`c1`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t1' AND COLUMN_NAME='c1'; column_default extra @@ -285,6 +470,81 @@ date_format(now(), _utf8mb4'%Y-%m-%d') DEFAULT_GENERATED SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t3' AND COLUMN_NAME='c1'; column_default extra date_format(now(), _utf8mb4'%Y-%m-%d') DEFAULT_GENERATED +alter table t0 alter column c1 set default "xx"; +Error 1101 (42000): BLOB/TEXT/JSON column 'c1' can't have a default value +alter table t1 alter column c1 set default "xx"; +Error 1101 (42000): BLOB/TEXT/JSON column 'c1' can't have a default value +alter table t2 alter column c1 set default 'y'; +alter table t3 alter column c1 set default 'n'; +INSERT INTO t0 values (2, DEFAULT); +INSERT INTO t2 values (2, DEFAULT); +INSERT INTO t3 values (2, DEFAULT); +alter table t0 modify column c1 BLOB default (date_format(now(),'%Y-%m-%d')); +alter table t1 modify column c1 JSON default (date_format(now(),'%Y-%m-%d')); +alter table t2 modify column c1 ENUM('y','n') default (date_format(now(),'%Y-%m-%d')); +alter table t3 modify column c1 SET('y','n') default (date_format(now(),'%Y-%m-%d')); +INSERT INTO t0 values (3, DEFAULT); +show create table t0; +Table Create Table +t0 CREATE TABLE `t0` ( + `c` int(10) DEFAULT NULL, + `c1` blob DEFAULT date_format(now(), _utf8mb4'%Y-%m-%d') +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t1; +Table Create Table +t1 CREATE TABLE `t1` ( + `c` int(10) DEFAULT NULL, + `c1` json DEFAULT date_format(now(), _utf8mb4'%Y-%m-%d') +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t2; +Table Create Table +t2 CREATE TABLE `t2` ( + `c` int(10) DEFAULT NULL, + `c1` enum('y','n') DEFAULT date_format(now(), _utf8mb4'%Y-%m-%d') +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t3; +Table Create Table +t3 CREATE TABLE `t3` ( + `c` int(10) DEFAULT NULL, + `c1` set('y','n') DEFAULT date_format(now(), _utf8mb4'%Y-%m-%d') +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +alter table t0 alter column c1 drop default; +alter table t1 alter column c1 drop default; +alter table t2 alter column c1 drop default; +alter table t3 alter column c1 drop default; +show create table t0; +Table Create Table +t0 CREATE TABLE `t0` ( + `c` int(10) DEFAULT NULL, + `c1` blob +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t1; +Table Create Table +t1 CREATE TABLE `t1` ( + `c` int(10) DEFAULT NULL, + `c1` json +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t2; +Table Create Table +t2 CREATE TABLE `t2` ( + `c` int(10) DEFAULT NULL, + `c1` enum('y','n') +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +show create table t3; +Table Create Table +t3 CREATE TABLE `t3` ( + `c` int(10) DEFAULT NULL, + `c1` set('y','n') +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +select count(1) from t0 where c1 = date_format(now(), '%Y-%m-%d'); +count(1) +4 +select * from t2; +c c1 +2 y +select * from t3; +c c1 +2 n drop table t0, t1, t2, t3; create table t0 (c int(10), c1 BLOB default (REPLACE(UPPER(UUID()), '-', ''))); create table t1 (c int(10), c1 JSON default (REPLACE(UPPER(UUID()), '-', ''))); diff --git a/tests/integrationtest/t/ddl/default_as_expression.test b/tests/integrationtest/t/ddl/default_as_expression.test index db2a0e3293f33..09d81e1cf0e55 100644 --- a/tests/integrationtest/t/ddl/default_as_expression.test +++ b/tests/integrationtest/t/ddl/default_as_expression.test @@ -21,7 +21,7 @@ insert into t0 values (2, default); SELECT * FROM t0 WHERE c = date_format(@x,'%Y-%m') OR c = date_format(DATE_ADD(@x, INTERVAL 1 SECOND), '%Y-%m'); insert into t1(c) values (1); insert into t1 values (2, default); -SELECT * FROM t1 WHERE c = date_format(@x,'%Y-%m-%d') OR c = date_format(DATE_ADD(@x, INTERVAL 1 SECOND), '%Y-%m-%d'); +SELECT * FROM t1 WHERE c = date_format(@x,'%Y-%m-%d'); insert into t2(c) values (1); insert into t2 values (2, default); SELECT * FROM t2 WHERE c = date_format(@x,'%Y-%m-%d %H.%i.%s') OR c = date_format(DATE_ADD(@x, INTERVAL 1 SECOND), '%Y-%m-%d %H.%i.%s'); @@ -35,26 +35,48 @@ SELECT * FROM t4 WHERE c = date_format(@x,'%Y-%m-%d %H:%i:%s') OR c = date_forma insert into t5(c) values (1); insert into t5 values (2, default); SELECT * FROM t5 WHERE c = date_format(@x,'%Y-%m-%d %H:%i:%s') OR c = date_format(DATE_ADD(@x, INTERVAL 1 SECOND), '%Y-%m-%d %H:%i:%s'); + show create table t0; show create table t1; show create table t2; + +# test modify column, set default value, add index +alter table t0 add index idx(c1); +alter table t1 add index idx(c1); +insert into t0 values (3, default); +insert into t1 values (3, default); +show create table t0; +show create table t1; alter table t0 modify column c1 varchar(30) default 'xx'; alter table t1 modify column c1 varchar(30) default 'xx'; +insert into t0 values (4, default); +insert into t1 values (4, default); show create table t0; show create table t1; -- error 1292 alter table t0 modify column c1 datetime DEFAULT (date_format(now(), '%Y-%m-%d')); +alter table t0 alter column c1 SET DEFAULT (date_format(now(), '%Y-%m-%d')); +insert into t0 values (5, default); +-- error 1292 alter table t1 modify column c1 datetime DEFAULT (date_format(now(), '%Y-%m-%d')); +delete from t1 where c = 4; +alter table t1 modify column c1 datetime DEFAULT (date_format(now(), '%Y-%m-%d')); +insert into t1 values (5, default); +alter table t0 drop index idx; +alter table t1 drop index idx; +show create table t0; show create table t1; +SELECT count(1) FROM t0 WHERE c1 = date_format(@x,'%Y-%m') OR c1 = date_format(@x,'%Y-%m-%d') OR c1 = "xx"; +SELECT count(1) FROM t1 WHERE c1 = date_format(@x,'%Y-%m-%d'); SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t1' AND COLUMN_NAME='c1'; show columns from test.t1 where field='c1'; # TestDefaultColumnWithReplace # replace drop table if exists t, t1, t2; -create table t (c int(10), c1 varchar(256) default (REPLACE(UPPER(UUID()), '-', ''))); -create table t1 (c int(10), c1 int default (REPLACE(UPPER(UUID()), '-', ''))); -create table t2 (c int(10), c1 varchar(256) default (REPLACE(CONVERT(UPPER(UUID()) USING UTF8MB4), '-', ''))); +create table t (c int(10), c1 varchar(256) default (REPLACE(UPPER(UUID()), '-', '')), index idx(c1)); +create table t1 (c int(10), c1 int default (REPLACE(UPPER(UUID()), '-', '')), index idx(c1)); +create table t2 (c int(10), c1 varchar(256) default (REPLACE(CONVERT(UPPER(UUID()) USING UTF8MB4), '-', '')), index idx(c1)); -- error 3770 create table t1 (c int(10), c1 varchar(256) default (REPLACE('xdfj-jfj', '-', ''))); -- error 3770 @@ -73,7 +95,7 @@ alter table t add column c4 int default (REPLACE(UPPER('dfdkj-kjkl-d'), '-', '') # insert records insert into t(c) values (1),(2),(3); -insert into t values (4, default) +insert into t values (4, default); # It consists of uppercase letters or numbers. SELECT count(1) FROM t WHERE c1 REGEXP '^[A-Z0-9]+$'; @@ -82,18 +104,32 @@ SELECT count(1) FROM t WHERE c1 REGEXP '^[A-Z0-9]+$'; show create table t; show create table t1; show create table t2; -alter table t1 modify column c1 varchar(30) default 'xx'; -show create table t1; -alter table t1 modify column c1 varchar(32) default (REPLACE(UPPER(UUID()), '-', '')); -show create table t1; -SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t1' AND COLUMN_NAME='c1'; + +# test modify column, set default value, add index +alter table t alter column c1 set default 'xx'; +alter table t drop index idx; +show create table t; +insert into t values (5, default); +show create table t; +alter table t add unique index idx(c, c1); +alter table t modify column c1 varchar(32) default (REPLACE(UPPER(UUID()), '-', '')); +insert into t values (6, default); +SELECT count(1) FROM t WHERE c1 REGEXP '^[A-Z0-9]+$'; +show create table t; +SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t' AND COLUMN_NAME='c1'; +alter table t alter column c1 set default null; +insert into t(c) values (7); +alter table t alter column c1 drop default; +-- error 1364 +insert into t(c) values (8); +SELECT count(1) FROM t WHERE c1 REGEXP '^[A-Z0-9]+$'; # TestDefaultColumnWithStrToDate # str_to_date drop table if exists t0, t1, t2, t3, t4, t5; # create table -create table t0 (c int(10), c1 varchar(32) default (str_to_date('1980-01-01','%Y-%m-%d')), c2 date default (str_to_date('9999-01-01','%Y-%m-%d'))); -create table t1 (c int(10), c1 int default (str_to_date('1980-01-01','%Y-%m-%d')), c2 int default (str_to_date('9999-01-01','%Y-%m-%d'))); +create table t0 (c int(10), c1 varchar(32) default (str_to_date('1980-01-01','%Y-%m-%d')), c2 date default (str_to_date('9999-01-01','%Y-%m-%d')), index idx(c, c1)); +create table t1 (c int(10), c1 int default (str_to_date('1980-01-01','%Y-%m-%d')), c2 int default (str_to_date('9999-01-01','%Y-%m-%d')), unique key idx(c, c1)); create table t3 (c int(10), c1 varchar(32) default (str_to_date('1980-01-01','%m-%d'))); create table t4 (c int(10), c1 varchar(32) default (str_to_date('01-01','%Y-%m-%d'))); set @sqlMode := @@session.sql_mode; @@ -128,13 +164,43 @@ insert into t2 values (4, default, default); set session sql_mode=@sqlMode; -- error 1292 insert into t2(c) values (5); +select * from t0; +select * from t1; +select * from t2; + +show create table t0; +show create table t1; +show create table t2; + +# test modify column, set default value, add index +alter table t0 add index idx1(c1); +alter table t1 add unique index idx1(c, c1); +insert into t0 values (5, default, default); +insert into t1 values (5, default, default); +show create table t0; +show create table t1; +alter table t0 alter column c2 set default (current_date()); +alter table t1 modify column c1 varchar(30) default 'xx'; +insert into t0 values (6, default, default); +insert into t1 values (6, default, default); +show create table t0; +show create table t1; +alter table t0 alter column c1 drop default; +alter table t1 modify column c1 varchar(32) default (str_to_date('1980-01-01','%Y-%m-%d')); +-- error 1364 +insert into t0 values (7, default, default); +insert into t1 values (7, default, default); +select * from t0 where c < 6; +select c, c1 from t0 where c = 6 and c2 = date_format(now(),'%Y-%m-%d');; +select * from t1; +select * from t2; SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t1' AND COLUMN_NAME='c1'; # TestDefaultColumnWithUpper # upper drop table if exists t, t1, t2; # create table -create table t (c int(10), c1 varchar(256) default (upper(substring_index(user(),'@',1)))); +create table t (c int(10), c1 varchar(256) default (upper(substring_index(user(),'@',1))), unique index idx(c, c1)); create table t1 (c int(10), c1 int default (upper(substring_index(user(),_utf8mb4'@',1)))); -- error 3770 create table t2 (c int(10), c1 varchar(256) default (substring_index(user(),'@',1))); @@ -150,9 +216,12 @@ alter table t add column c3 int default (upper(substring_index('fjks@jkkl','@',1 insert into t1(c) values (1); show create table t; show create table t1; + +# test modify column, set default value, add index alter table t1 modify column c1 varchar(30) default 'xx'; show create table t1; alter table t1 modify column c1 varchar(32) default (upper(substring_index(user(),'@',1))); +alter table t1 add index idx1(c1); show create table t1; SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t1' AND COLUMN_NAME='c1'; @@ -189,6 +258,35 @@ SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema= SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t1' AND COLUMN_NAME='c1'; SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t2' AND COLUMN_NAME='c1'; SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t3' AND COLUMN_NAME='c1'; +-- error 1101 +alter table t0 alter column c1 set default "xx"; +-- error 1101 +alter table t1 alter column c1 set default "xx"; +alter table t2 alter column c1 set default 'y'; +alter table t3 alter column c1 set default 'n'; +INSERT INTO t0 values (2, DEFAULT); +INSERT INTO t2 values (2, DEFAULT); +INSERT INTO t3 values (2, DEFAULT); +alter table t0 modify column c1 BLOB default (date_format(now(),'%Y-%m-%d')); +alter table t1 modify column c1 JSON default (date_format(now(),'%Y-%m-%d')); +alter table t2 modify column c1 ENUM('y','n') default (date_format(now(),'%Y-%m-%d')); +alter table t3 modify column c1 SET('y','n') default (date_format(now(),'%Y-%m-%d')); +INSERT INTO t0 values (3, DEFAULT); +show create table t0; +show create table t1; +show create table t2; +show create table t3; +alter table t0 alter column c1 drop default; +alter table t1 alter column c1 drop default; +alter table t2 alter column c1 drop default; +alter table t3 alter column c1 drop default; +show create table t0; +show create table t1; +show create table t2; +show create table t3; +select count(1) from t0 where c1 = date_format(now(), '%Y-%m-%d'); +select * from t2; +select * from t3; drop table t0, t1, t2, t3; # Different data types for replace. create table t0 (c int(10), c1 BLOB default (REPLACE(UPPER(UUID()), '-', ''))); From 7be9a1e89a47071a3abc87bd786f08fbdcd442f4 Mon Sep 17 00:00:00 2001 From: Lynn Date: Mon, 11 Mar 2024 12:27:38 +0800 Subject: [PATCH 06/15] ddl, test: fix the problem that table creation with auto_increment and default expressions (#51640) close pingcap/tidb#51587 --- pkg/ddl/ddl_api.go | 14 +++++--- .../r/ddl/default_as_expression.result | 35 +++++++++++++++++++ .../t/ddl/default_as_expression.test | 24 +++++++++++++ 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index 118dcd9170796..3c12e55dad49a 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -1661,6 +1661,9 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue if c.GetDefaultValue() != nil { if c.DefaultIsExpr { + if mysql.HasAutoIncrementFlag(c.GetFlag()) { + return types.ErrInvalidDefault.GenWithStackByArgs(c.Name) + } return nil } if _, err := table.GetColDefaultValue(ctx.GetExprCtx(), c.ToInfo()); err != nil { @@ -5490,6 +5493,8 @@ func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu if err != nil { return hasDefaultValue, errors.Trace(err) } + } else { + hasDefaultValue = true } err = setDefaultValueWithBinaryPadding(col, value) if err != nil { @@ -5526,8 +5531,8 @@ func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.Col return errors.Trace(err) } -// ProcessColumnOptions process column options. -func ProcessColumnOptions(ctx sessionctx.Context, col *table.Column, options []*ast.ColumnOption) error { +// ProcessModifyColumnOptions process column options. +func ProcessModifyColumnOptions(ctx sessionctx.Context, col *table.Column, options []*ast.ColumnOption) error { var sb strings.Builder restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutSchemaName @@ -5605,7 +5610,8 @@ func ProcessColumnOptions(ctx sessionctx.Context, col *table.Column, options []* return nil } -func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Column, outPriKeyConstraint *ast.Constraint, hasDefaultValue, setOnUpdateNow, hasNullFlag bool) error { +func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Column, + outPriKeyConstraint *ast.Constraint, hasDefaultValue, setOnUpdateNow, hasNullFlag bool) error { processDefaultValue(col, hasDefaultValue, setOnUpdateNow) processColumnFlags(col) @@ -5774,7 +5780,7 @@ func GetModifiableColumnJob( // TODO: If user explicitly set NULL, we should throw error ErrPrimaryCantHaveNull. } - if err = ProcessColumnOptions(sctx, newCol, specNewColumn.Options); err != nil { + if err = ProcessModifyColumnOptions(sctx, newCol, specNewColumn.Options); err != nil { return nil, errors.Trace(err) } diff --git a/tests/integrationtest/r/ddl/default_as_expression.result b/tests/integrationtest/r/ddl/default_as_expression.result index e5779aa89a3f6..848011b4dfb42 100644 --- a/tests/integrationtest/r/ddl/default_as_expression.result +++ b/tests/integrationtest/r/ddl/default_as_expression.result @@ -736,3 +736,38 @@ SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema= column_default extra SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t3' AND COLUMN_NAME='c1'; column_default extra +create table t0 (c int(10), c1 int auto_increment default (str_to_date('1980-01-01','%Y-%m-%d'))); +Error 1067 (42000): Invalid default value for 'c1' +CREATE TABLE t0 (id int, c int); +insert into t0(id) values (1); +alter table t0 modify column c int auto_increment default (str_to_date('1980-01-01','%Y-%m-%d')); +Error 1067 (42000): Invalid default value for 'c' +ALTER TABLE t0 MODIFY COLUMN c INT PRIMARY KEY DEFAULT(str_to_date('1980-01-01','%Y-%m-%d')); +Error 8200 (HY000): can't change column constraint (PRIMARY KEY) +ALTER TABLE t0 ALTER COLUMN c SET DEFAULT(str_to_date('1980-01-01','%Y-%m-%d')); +insert into t0(id) values (2); +drop table t0; +CREATE TABLE t1 (i INT, b int DEFAULT (str_to_date('1980-01-01','%Y-%m-%d')), c INT GENERATED ALWAYS AS (b+2)); +SHOW COLUMNS FROM t1; +Field Type Null Key Default Extra +i int(11) YES NULL +b int(11) YES str_to_date(_utf8mb4'1980-01-01', _utf8mb4'%Y-%m-%d') DEFAULT_GENERATED +c int(11) YES NULL VIRTUAL GENERATED +show create table t1; +Table Create Table +t1 CREATE TABLE `t1` ( + `i` int(11) DEFAULT NULL, + `b` int(11) DEFAULT str_to_date(_utf8mb4'1980-01-01', _utf8mb4'%Y-%m-%d'), + `c` int(11) GENERATED ALWAYS AS (`b` + 2) VIRTUAL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +INSERT INTO t1(i) VALUES (1); +INSERT INTO t1(i, b) VALUES (2, DEFAULT); +INSERT INTO t1(i, b) VALUES (3, 123); +INSERT INTO t1(i, b) VALUES (NULL, NULL); +SELECT * FROM t1; +i b c +1 19800101 19800103 +2 19800101 19800103 +3 123 125 +NULL NULL NULL +drop table t1; diff --git a/tests/integrationtest/t/ddl/default_as_expression.test b/tests/integrationtest/t/ddl/default_as_expression.test index 09d81e1cf0e55..e6c67dba4cf5a 100644 --- a/tests/integrationtest/t/ddl/default_as_expression.test +++ b/tests/integrationtest/t/ddl/default_as_expression.test @@ -383,3 +383,27 @@ SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema= SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t1' AND COLUMN_NAME='c1'; SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t2' AND COLUMN_NAME='c1'; SELECT column_default, extra FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='test' AND TABLE_NAME='t3' AND COLUMN_NAME='c1'; + +# test auto_increment +-- error 1067 +create table t0 (c int(10), c1 int auto_increment default (str_to_date('1980-01-01','%Y-%m-%d'))); +CREATE TABLE t0 (id int, c int); +insert into t0(id) values (1); +-- error 1067 +alter table t0 modify column c int auto_increment default (str_to_date('1980-01-01','%Y-%m-%d')); +-- error 8200 +ALTER TABLE t0 MODIFY COLUMN c INT PRIMARY KEY DEFAULT(str_to_date('1980-01-01','%Y-%m-%d')); +ALTER TABLE t0 ALTER COLUMN c SET DEFAULT(str_to_date('1980-01-01','%Y-%m-%d')); +insert into t0(id) values (2); +drop table t0; + +# test generated column +CREATE TABLE t1 (i INT, b int DEFAULT (str_to_date('1980-01-01','%Y-%m-%d')), c INT GENERATED ALWAYS AS (b+2)); +SHOW COLUMNS FROM t1; +show create table t1; +INSERT INTO t1(i) VALUES (1); +INSERT INTO t1(i, b) VALUES (2, DEFAULT); +INSERT INTO t1(i, b) VALUES (3, 123); +INSERT INTO t1(i, b) VALUES (NULL, NULL); +SELECT * FROM t1; +drop table t1; From 5cb6c0e2af0e809407902625171d065895aa638b Mon Sep 17 00:00:00 2001 From: lance6716 Date: Mon, 11 Mar 2024 13:47:08 +0800 Subject: [PATCH 07/15] br: move ScatterRegions into split package (#51614) ref pingcap/tidb#51533 --- br/pkg/restore/import_retry_test.go | 115 ++++++++++--------- br/pkg/restore/range_test.go | 11 +- br/pkg/restore/split.go | 105 ++--------------- br/pkg/restore/split/BUILD.bazel | 9 +- br/pkg/restore/split/client.go | 89 ++++++++++++++- br/pkg/restore/split/split_test.go | 119 +++++++++++++++++++- br/pkg/restore/split_test.go | 167 +++++++--------------------- br/pkg/restore/util_test.go | 79 +++++++------ 8 files changed, 357 insertions(+), 337 deletions(-) diff --git a/br/pkg/restore/import_retry_test.go b/br/pkg/restore/import_retry_test.go index 97d1d10aacae0..6dbc05e2d402e 100644 --- a/br/pkg/restore/import_retry_test.go +++ b/br/pkg/restore/import_retry_test.go @@ -1,6 +1,6 @@ // Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. -package restore_test +package restore import ( "context" @@ -18,7 +18,6 @@ import ( "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/metapb" berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/restore" "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/pkg/kv" @@ -58,35 +57,35 @@ func TestScanSuccess(t *testing.T) { ctx := context.Background() // make exclusive to inclusive. - ctl := restore.OverRegionsInRange([]byte("aa"), []byte("aay"), cli, &rs) + ctl := OverRegionsInRange([]byte("aa"), []byte("aay"), cli, &rs) collectedRegions := []*split.RegionInfo{} - ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { collectedRegions = append(collectedRegions, r) - return restore.RPCResultOK() + return RPCResultOK() }) assertRegions(t, collectedRegions, "", "aay", "bba") - ctl = restore.OverRegionsInRange([]byte("aaz"), []byte("bb"), cli, &rs) + ctl = OverRegionsInRange([]byte("aaz"), []byte("bb"), cli, &rs) collectedRegions = []*split.RegionInfo{} - ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { collectedRegions = append(collectedRegions, r) - return restore.RPCResultOK() + return RPCResultOK() }) assertRegions(t, collectedRegions, "aay", "bba", "bbh", "cca") - ctl = restore.OverRegionsInRange([]byte("aa"), []byte("cc"), cli, &rs) + ctl = OverRegionsInRange([]byte("aa"), []byte("cc"), cli, &rs) collectedRegions = []*split.RegionInfo{} - ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { collectedRegions = append(collectedRegions, r) - return restore.RPCResultOK() + return RPCResultOK() }) assertRegions(t, collectedRegions, "", "aay", "bba", "bbh", "cca", "") - ctl = restore.OverRegionsInRange([]byte("aa"), []byte(""), cli, &rs) + ctl = OverRegionsInRange([]byte("aa"), []byte(""), cli, &rs) collectedRegions = []*split.RegionInfo{} - ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { collectedRegions = append(collectedRegions, r) - return restore.RPCResultOK() + return RPCResultOK() }) assertRegions(t, collectedRegions, "", "aay", "bba", "bbh", "cca", "") } @@ -95,7 +94,7 @@ func TestNotLeader(t *testing.T) { // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) cli := initTestClient(false) rs := utils.InitialRetryState(1, 0, 0) - ctl := restore.OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) ctx := context.Background() notLeader := errorpb.Error{ @@ -109,17 +108,17 @@ func TestNotLeader(t *testing.T) { meetRegions := []*split.RegionInfo{} // record all regions we meet with id == 2. idEqualsTo2Regions := []*split.RegionInfo{} - err := ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + err := ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { if r.Region.Id == 2 { idEqualsTo2Regions = append(idEqualsTo2Regions, r) } if r.Region.Id == 2 && (r.Leader == nil || r.Leader.Id != 42) { - return restore.RPCResult{ + return RPCResult{ StoreError: ¬Leader, } } meetRegions = append(meetRegions, r) - return restore.RPCResultOK() + return RPCResultOK() }) require.NoError(t, err) @@ -135,7 +134,7 @@ func TestServerIsBusy(t *testing.T) { // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) cli := initTestClient(false) rs := utils.InitialRetryState(2, 0, 0) - ctl := restore.OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) ctx := context.Background() serverIsBusy := errorpb.Error{ @@ -149,16 +148,16 @@ func TestServerIsBusy(t *testing.T) { // record all regions we meet with id == 2. idEqualsTo2Regions := []*split.RegionInfo{} theFirstRun := true - err := ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + err := ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { if theFirstRun && r.Region.Id == 2 { idEqualsTo2Regions = append(idEqualsTo2Regions, r) theFirstRun = false - return restore.RPCResult{ + return RPCResult{ StoreError: &serverIsBusy, } } meetRegions = append(meetRegions, r) - return restore.RPCResultOK() + return RPCResultOK() }) require.NoError(t, err) @@ -176,7 +175,7 @@ func TestServerIsBusyWithMemoryIsLimited(t *testing.T) { // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) cli := initTestClient(false) rs := utils.InitialRetryState(2, 0, 0) - ctl := restore.OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) ctx := context.Background() serverIsBusy := errorpb.Error{ @@ -190,16 +189,16 @@ func TestServerIsBusyWithMemoryIsLimited(t *testing.T) { // record all regions we meet with id == 2. idEqualsTo2Regions := []*split.RegionInfo{} theFirstRun := true - err := ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + err := ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { if theFirstRun && r.Region.Id == 2 { idEqualsTo2Regions = append(idEqualsTo2Regions, r) theFirstRun = false - return restore.RPCResult{ + return RPCResult{ StoreError: &serverIsBusy, } } meetRegions = append(meetRegions, r) - return restore.RPCResultOK() + return RPCResultOK() }) require.NoError(t, err) @@ -228,7 +227,7 @@ func TestEpochNotMatch(t *testing.T) { // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) cli := initTestClient(false) rs := utils.InitialRetryState(2, 0, 0) - ctl := restore.OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) ctx := context.Background() printPDRegion("cli", cli.regionsInfo.Regions) @@ -262,18 +261,18 @@ func TestEpochNotMatch(t *testing.T) { firstRunRegions := []*split.RegionInfo{} secondRunRegions := []*split.RegionInfo{} isSecondRun := false - err = ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + err = ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { if !isSecondRun && r.Region.Id == left.Region.Id { mergeRegion() isSecondRun = true - return restore.RPCResultFromPBError(epochNotMatch) + return RPCResultFromPBError(epochNotMatch) } if isSecondRun { secondRunRegions = append(secondRunRegions, r) } else { firstRunRegions = append(firstRunRegions, r) } - return restore.RPCResultOK() + return RPCResultOK() }) printRegion("first", firstRunRegions) printRegion("second", secondRunRegions) @@ -287,7 +286,7 @@ func TestRegionSplit(t *testing.T) { // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) cli := initTestClient(false) rs := utils.InitialRetryState(2, 0, 0) - ctl := restore.OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) ctx := context.Background() printPDRegion("cli", cli.regionsInfo.Regions) @@ -338,18 +337,18 @@ func TestRegionSplit(t *testing.T) { firstRunRegions := []*split.RegionInfo{} secondRunRegions := []*split.RegionInfo{} isSecondRun := false - err = ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + err = ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { if !isSecondRun && r.Region.Id == target.Region.Id { splitRegion() isSecondRun = true - return restore.RPCResultFromPBError(epochNotMatch) + return RPCResultFromPBError(epochNotMatch) } if isSecondRun { secondRunRegions = append(secondRunRegions, r) } else { firstRunRegions = append(firstRunRegions, r) } - return restore.RPCResultOK() + return RPCResultOK() }) printRegion("first", firstRunRegions) printRegion("second", secondRunRegions) @@ -363,7 +362,7 @@ func TestRetryBackoff(t *testing.T) { // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) cli := initTestClient(false) rs := utils.InitialRetryState(2, time.Millisecond, 10*time.Millisecond) - ctl := restore.OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) ctx := context.Background() printPDRegion("cli", cli.regionsInfo.Regions) @@ -380,12 +379,12 @@ func TestRetryBackoff(t *testing.T) { }, }} isSecondRun := false - err = ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + err = ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { if !isSecondRun && r.Region.Id == left.Region.Id { isSecondRun = true - return restore.RPCResultFromPBError(epochNotLeader) + return RPCResultFromPBError(epochNotLeader) } - return restore.RPCResultOK() + return RPCResultOK() }) printPDRegion("cli", cli.regionsInfo.Regions) require.Equal(t, 1, rs.Attempt()) @@ -395,10 +394,10 @@ func TestRetryBackoff(t *testing.T) { } func TestWrappedError(t *testing.T) { - result := restore.RPCResultFromError(errors.Trace(status.Error(codes.Unavailable, "the server is slacking. ><=·>"))) - require.Equal(t, result.StrategyForRetry(), restore.StrategyFromThisRegion) - result = restore.RPCResultFromError(errors.Trace(status.Error(codes.Unknown, "the server said something hard to understand"))) - require.Equal(t, result.StrategyForRetry(), restore.StrategyGiveUp) + result := RPCResultFromError(errors.Trace(status.Error(codes.Unavailable, "the server is slacking. ><=·>"))) + require.Equal(t, result.StrategyForRetry(), StrategyFromThisRegion) + result = RPCResultFromError(errors.Trace(status.Error(codes.Unknown, "the server said something hard to understand"))) + require.Equal(t, result.StrategyForRetry(), StrategyGiveUp) } func envInt(name string, def int) int { @@ -414,22 +413,22 @@ func TestPaginateScanLeader(t *testing.T) { // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) cli := initTestClient(false) rs := utils.InitialRetryState(2, time.Millisecond, 10*time.Millisecond) - ctl := restore.OverRegionsInRange([]byte("aa"), []byte("aaz"), cli, &rs) + ctl := OverRegionsInRange([]byte("aa"), []byte("aaz"), cli, &rs) ctx := context.Background() cli.InjectErr = true cli.InjectTimes = int32(envInt("PAGINATE_SCAN_LEADER_FAILURE_COUNT", 2)) collectedRegions := []*split.RegionInfo{} - ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) restore.RPCResult { + ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { collectedRegions = append(collectedRegions, r) - return restore.RPCResultOK() + return RPCResultOK() }) assertRegions(t, collectedRegions, "", "aay", "bba") } func TestImportKVFiles(t *testing.T) { var ( - importer = restore.FileImporter{} + importer = FileImporter{} ctx = context.Background() shiftStartTS uint64 = 100 startTS uint64 = 200 @@ -438,7 +437,7 @@ func TestImportKVFiles(t *testing.T) { err := importer.ImportKVFiles( ctx, - []*restore.LogDataFileInfo{ + []*LogDataFileInfo{ { DataFileInfo: &backuppb.DataFileInfo{ Path: "log3", @@ -460,7 +459,7 @@ func TestImportKVFiles(t *testing.T) { } func TestFilterFilesByRegion(t *testing.T) { - files := []*restore.LogDataFileInfo{ + files := []*LogDataFileInfo{ { DataFileInfo: &backuppb.DataFileInfo{ Path: "log3", @@ -484,7 +483,7 @@ func TestFilterFilesByRegion(t *testing.T) { testCases := []struct { r split.RegionInfo - subfiles []*restore.LogDataFileInfo + subfiles []*LogDataFileInfo err error }{ { @@ -494,7 +493,7 @@ func TestFilterFilesByRegion(t *testing.T) { EndKey: []byte("1110"), }, }, - subfiles: []*restore.LogDataFileInfo{}, + subfiles: []*LogDataFileInfo{}, err: nil, }, { @@ -504,7 +503,7 @@ func TestFilterFilesByRegion(t *testing.T) { EndKey: []byte("1111"), }, }, - subfiles: []*restore.LogDataFileInfo{ + subfiles: []*LogDataFileInfo{ files[0], }, err: nil, @@ -516,7 +515,7 @@ func TestFilterFilesByRegion(t *testing.T) { EndKey: []byte("2222"), }, }, - subfiles: []*restore.LogDataFileInfo{ + subfiles: []*LogDataFileInfo{ files[0], }, err: nil, @@ -528,7 +527,7 @@ func TestFilterFilesByRegion(t *testing.T) { EndKey: []byte("3332"), }, }, - subfiles: []*restore.LogDataFileInfo{ + subfiles: []*LogDataFileInfo{ files[0], }, err: nil, @@ -540,7 +539,7 @@ func TestFilterFilesByRegion(t *testing.T) { EndKey: []byte("3332"), }, }, - subfiles: []*restore.LogDataFileInfo{}, + subfiles: []*LogDataFileInfo{}, err: nil, }, { @@ -550,7 +549,7 @@ func TestFilterFilesByRegion(t *testing.T) { EndKey: []byte("3333"), }, }, - subfiles: []*restore.LogDataFileInfo{ + subfiles: []*LogDataFileInfo{ files[1], }, err: nil, @@ -562,7 +561,7 @@ func TestFilterFilesByRegion(t *testing.T) { EndKey: []byte("5555"), }, }, - subfiles: []*restore.LogDataFileInfo{ + subfiles: []*LogDataFileInfo{ files[1], }, err: nil, @@ -574,7 +573,7 @@ func TestFilterFilesByRegion(t *testing.T) { EndKey: nil, }, }, - subfiles: []*restore.LogDataFileInfo{ + subfiles: []*LogDataFileInfo{ files[1], }, err: nil, @@ -592,7 +591,7 @@ func TestFilterFilesByRegion(t *testing.T) { } for _, c := range testCases { - subfile, err := restore.FilterFilesByRegion(files, ranges, &c.r) + subfile, err := FilterFilesByRegion(files, ranges, &c.r) require.Equal(t, err, c.err) require.Equal(t, subfile, c.subfiles) } diff --git a/br/pkg/restore/range_test.go b/br/pkg/restore/range_test.go index a03271de3da03..322789ec023c1 100644 --- a/br/pkg/restore/range_test.go +++ b/br/pkg/restore/range_test.go @@ -1,12 +1,11 @@ // Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. -package restore_test +package restore import ( "testing" "github.com/pingcap/kvproto/pkg/import_sstpb" - "github.com/pingcap/tidb/br/pkg/restore" "github.com/pingcap/tidb/br/pkg/rtree" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/stretchr/testify/require" @@ -25,7 +24,7 @@ func TestSortRange(t *testing.T) { {OldKeyPrefix: tablecodec.GenTableRecordPrefix(1), NewKeyPrefix: tablecodec.GenTableRecordPrefix(4)}, {OldKeyPrefix: tablecodec.GenTableRecordPrefix(2), NewKeyPrefix: tablecodec.GenTableRecordPrefix(5)}, } - rewriteRules := &restore.RewriteRules{ + rewriteRules := &RewriteRules{ Data: dataRules, } ranges1 := []rtree.Range{ @@ -34,7 +33,7 @@ func TestSortRange(t *testing.T) { EndKey: append(tablecodec.GenTableRecordPrefix(1), []byte("bbb")...), Files: nil, }, } - rs1, err := restore.SortRanges(ranges1, rewriteRules) + rs1, err := SortRanges(ranges1, rewriteRules) require.NoErrorf(t, err, "sort range1 failed: %v", err) rangeEquals(t, rs1, []rtree.Range{ { @@ -49,13 +48,13 @@ func TestSortRange(t *testing.T) { EndKey: append(tablecodec.GenTableRecordPrefix(2), []byte("bbb")...), Files: nil, }, } - _, err = restore.SortRanges(ranges2, rewriteRules) + _, err = SortRanges(ranges2, rewriteRules) require.Error(t, err) require.Regexp(t, "table id mismatch.*", err.Error()) ranges3 := initRanges() rewriteRules1 := initRewriteRules() - rs3, err := restore.SortRanges(ranges3, rewriteRules1) + rs3, err := SortRanges(ranges3, rewriteRules1) require.NoErrorf(t, err, "sort range1 failed: %v", err) rangeEquals(t, rs3, []rtree.Range{ {StartKey: []byte("bbd"), EndKey: []byte("bbf"), Files: nil}, diff --git a/br/pkg/restore/split.go b/br/pkg/restore/split.go index 0790999e8be24..43052e8ef6588 100644 --- a/br/pkg/restore/split.go +++ b/br/pkg/restore/split.go @@ -6,7 +6,6 @@ import ( "bytes" "context" "sort" - "strconv" "strings" "sync" "time" @@ -24,11 +23,8 @@ import ( "github.com/pingcap/tidb/br/pkg/utils/iter" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/util/codec" - "go.uber.org/multierr" "go.uber.org/zap" "golang.org/x/sync/errgroup" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) type Granularity string @@ -254,7 +250,10 @@ func (rs *RegionSplitter) splitAndScatterRegions( } return nil, errors.Trace(err) } - rs.ScatterRegions(ctx, append(newRegions, regionInfo)) + err2 := rs.client.ScatterRegions(ctx, append(newRegions, regionInfo)) + if err2 != nil { + log.Warn("failed to scatter regions", zap.Error(err2)) + } return newRegions, nil } @@ -274,32 +273,6 @@ func (rs *RegionSplitter) splitRegions( return newRegions, nil } -// scatterRegions scatter the regions. -// for same reason just log and ignore error. -// See the comments of function waitRegionScattered. -func (rs *RegionSplitter) ScatterRegions(ctx context.Context, newRegions []*split.RegionInfo) { - log.Info("start to scatter regions", zap.Int("regions", len(newRegions))) - // the retry is for the temporary network errors during sending request. - err := utils.WithRetry(ctx, func() error { - err := rs.client.ScatterRegions(ctx, newRegions) - if isUnsupportedError(err) { - log.Warn("batch scatter isn't supported, rollback to old method", logutil.ShortError(err)) - rs.ScatterRegionsSequentially( - ctx, newRegions, - // backoff about 6s, or we give up scattering this region. - &split.ExponentialBackoffer{ - Attempts: 7, - BaseBackoff: 100 * time.Millisecond, - }) - return nil - } - return err - }, &split.ExponentialBackoffer{Attempts: 3, BaseBackoff: 500 * time.Millisecond}) - if err != nil { - log.Warn("failed to scatter regions", logutil.ShortError(err)) - } -} - // waitRegionsSplitted check multiple regions have finished the split. func (rs *RegionSplitter) waitRegionsSplitted(ctx context.Context, splitRegions []*split.RegionInfo) { // Wait for a while until the regions successfully split. @@ -350,50 +323,6 @@ func (rs *RegionSplitter) waitRegionsScattered(ctx context.Context, scatterRegio } } -// ScatterRegionsSequentially scatter the region with some backoffer. -// This function is for testing the retry mechanism. -// For a real cluster, directly use ScatterRegions would be fine. -func (rs *RegionSplitter) ScatterRegionsSequentially(ctx context.Context, newRegions []*split.RegionInfo, backoffer utils.Backoffer) { - newRegionSet := make(map[uint64]*split.RegionInfo, len(newRegions)) - for _, newRegion := range newRegions { - newRegionSet[newRegion.Region.Id] = newRegion - } - - if err := utils.WithRetry(ctx, func() error { - log.Info("trying to scatter regions...", zap.Int("remain", len(newRegionSet))) - var errs error - for _, region := range newRegionSet { - err := rs.client.ScatterRegion(ctx, region) - if err == nil { - // it is safe according to the Go language spec. - delete(newRegionSet, region.Region.Id) - } else if !split.PdErrorCanRetry(err) { - log.Warn("scatter meet error cannot be retried, skipping", - logutil.ShortError(err), - logutil.Region(region.Region), - ) - delete(newRegionSet, region.Region.Id) - } - errs = multierr.Append(errs, err) - } - return errs - }, backoffer); err != nil { - log.Warn("Some regions haven't been scattered because errors.", - zap.Int("count", len(newRegionSet)), - // if all region are failed to scatter, the short error might also be verbose... - logutil.ShortError(err), - logutil.AbbreviatedArray("failed-regions", newRegionSet, func(i any) []string { - m := i.(map[uint64]*split.RegionInfo) - result := make([]string, 0, len(m)) - for id := range m { - result = append(result, strconv.Itoa(int(id))) - } - return result - }), - ) - } -} - // hasHealthyRegion is used to check whether region splitted success func (rs *RegionSplitter) hasHealthyRegion(ctx context.Context, regionID uint64) (bool, error) { regionInfo, err := rs.client.GetRegionByID(ctx, regionID) @@ -462,7 +391,10 @@ func (rs *RegionSplitter) WaitForScatterRegionsTimeout(ctx context.Context, regi } if len(reScatterRegions) > 0 { - rs.ScatterRegions(ctx, reScatterRegions) + err2 := rs.client.ScatterRegions(ctx, reScatterRegions) + if err2 != nil { + log.Warn("failed to scatter regions", zap.Error(err2)) + } } if time.Since(startTime) > timeout { @@ -950,27 +882,6 @@ func (splitIter *LogFilesIterWithSplitHelper) TryNext(ctx context.Context) iter. return res } -// isUnsupportedError checks whether we should fallback to ScatterRegion API when meeting the error. -func isUnsupportedError(err error) bool { - s, ok := status.FromError(errors.Cause(err)) - if !ok { - // Not a gRPC error. Something other went wrong. - return false - } - // In two conditions, we fallback to ScatterRegion: - // (1) If the RPC endpoint returns UNIMPLEMENTED. (This is just for making test cases not be so magic.) - // (2) If the Message is "region 0 not found": - // In fact, PD reuses the gRPC endpoint `ScatterRegion` for the batch version of scattering. - // When the request contains the field `regionIDs`, it would use the batch version, - // Otherwise, it uses the old version and scatter the region with `regionID` in the request. - // When facing 4.x, BR(which uses v5.x PD clients and call `ScatterRegions`!) would set `regionIDs` - // which would be ignored by protocol buffers, and leave the `regionID` be zero. - // Then the older version of PD would try to search the region with ID 0. - // (Then it consistently fails, and returns "region 0 not found".) - return s.Code() == codes.Unimplemented || - strings.Contains(s.Message(), "region 0 not found") -} - type splitBackoffer struct { state utils.RetryState } diff --git a/br/pkg/restore/split/BUILD.bazel b/br/pkg/restore/split/BUILD.bazel index 45b2d8907f718..2663c89bcd56c 100644 --- a/br/pkg/restore/split/BUILD.bazel +++ b/br/pkg/restore/split/BUILD.bazel @@ -47,13 +47,18 @@ go_test( "split_test.go", "sum_sorted_test.go", ], + embed = [":split"], flaky = True, - shard_count = 4, + shard_count = 6, deps = [ - ":split", "//br/pkg/errors", "//br/pkg/utils", "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_kvproto//pkg/metapb", + "@com_github_pingcap_kvproto//pkg/pdpb", "@com_github_stretchr_testify//require", + "@com_github_tikv_pd_client//:client", + "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//status", ], ) diff --git a/br/pkg/restore/split/client.go b/br/pkg/restore/split/client.go index cad86e6602cca..edfd18fd484c4 100644 --- a/br/pkg/restore/split/client.go +++ b/br/pkg/restore/split/client.go @@ -6,6 +6,7 @@ import ( "bytes" "context" "crypto/tls" + "strconv" "strings" "sync" "time" @@ -23,6 +24,7 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/config" "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/utils" pd "github.com/tikv/pd/client" pdhttp "github.com/tikv/pd/client/http" "go.uber.org/multierr" @@ -138,10 +140,27 @@ func (c *pdClient) needScatter(ctx context.Context) bool { return c.needScatterVal } -// ScatterRegions scatters regions in a batch. -func (c *pdClient) ScatterRegions(ctx context.Context, regionInfo []*RegionInfo) error { - c.mu.Lock() - defer c.mu.Unlock() +func (c *pdClient) ScatterRegions(ctx context.Context, newRegions []*RegionInfo) error { + log.Info("scatter regions", zap.Int("regions", len(newRegions))) + // the retry is for the temporary network errors during sending request. + return utils.WithRetry(ctx, func() error { + err := c.scatterRegions(ctx, newRegions) + if isUnsupportedError(err) { + log.Warn("batch scatter isn't supported, rollback to old method", logutil.ShortError(err)) + c.scatterRegionsSequentially( + ctx, newRegions, + // backoff about 6s, or we give up scattering this region. + &ExponentialBackoffer{ + Attempts: 7, + BaseBackoff: 100 * time.Millisecond, + }) + return nil + } + return err + }, &ExponentialBackoffer{Attempts: 3, BaseBackoff: 500 * time.Millisecond}) +} + +func (c *pdClient) scatterRegions(ctx context.Context, regionInfo []*RegionInfo) error { regionsID := make([]uint64, 0, len(regionInfo)) for _, v := range regionInfo { regionsID = append(regionsID, v.Region.Id) @@ -564,6 +583,47 @@ func (c *pdClient) SetStoresLabel( return nil } +func (c *pdClient) scatterRegionsSequentially(ctx context.Context, newRegions []*RegionInfo, backoffer utils.Backoffer) { + newRegionSet := make(map[uint64]*RegionInfo, len(newRegions)) + for _, newRegion := range newRegions { + newRegionSet[newRegion.Region.Id] = newRegion + } + + if err := utils.WithRetry(ctx, func() error { + log.Info("trying to scatter regions...", zap.Int("remain", len(newRegionSet))) + var errs error + for _, region := range newRegionSet { + err := c.ScatterRegion(ctx, region) + if err == nil { + // it is safe according to the Go language spec. + delete(newRegionSet, region.Region.Id) + } else if !PdErrorCanRetry(err) { + log.Warn("scatter meet error cannot be retried, skipping", + logutil.ShortError(err), + logutil.Region(region.Region), + ) + delete(newRegionSet, region.Region.Id) + } + errs = multierr.Append(errs, err) + } + return errs + }, backoffer); err != nil { + log.Warn("Some regions haven't been scattered because errors.", + zap.Int("count", len(newRegionSet)), + // if all region are failed to scatter, the short error might also be verbose... + logutil.ShortError(err), + logutil.AbbreviatedArray("failed-regions", newRegionSet, func(i any) []string { + m := i.(map[uint64]*RegionInfo) + result := make([]string, 0, len(m)) + for id := range m { + result = append(result, strconv.Itoa(int(id))) + } + return result + }), + ) + } +} + func (c *pdClient) IsScatterRegionFinished( ctx context.Context, regionID uint64, @@ -665,3 +725,24 @@ func (b *ExponentialBackoffer) NextBackoff(error) time.Duration { func (b *ExponentialBackoffer) Attempt() int { return b.Attempts } + +// isUnsupportedError checks whether we should fallback to ScatterRegion API when meeting the error. +func isUnsupportedError(err error) bool { + s, ok := status.FromError(errors.Cause(err)) + if !ok { + // Not a gRPC error. Something other went wrong. + return false + } + // In two conditions, we fallback to ScatterRegion: + // (1) If the RPC endpoint returns UNIMPLEMENTED. (This is just for making test cases not be so magic.) + // (2) If the Message is "region 0 not found": + // In fact, PD reuses the gRPC endpoint `ScatterRegion` for the batch version of scattering. + // When the request contains the field `regionIDs`, it would use the batch version, + // Otherwise, it uses the old version and scatter the region with `regionID` in the request. + // When facing 4.x, BR(which uses v5.x PD clients and call `ScatterRegions`!) would set `regionIDs` + // which would be ignored by protocol buffers, and leave the `regionID` be zero. + // Then the older version of PD would try to search the region with ID 0. + // (Then it consistently fails, and returns "region 0 not found".) + return s.Code() == codes.Unimplemented || + strings.Contains(s.Message(), "region 0 not found") +} diff --git a/br/pkg/restore/split/split_test.go b/br/pkg/restore/split/split_test.go index 1b76a9fafc693..060f09b688632 100644 --- a/br/pkg/restore/split/split_test.go +++ b/br/pkg/restore/split/split_test.go @@ -1,20 +1,25 @@ // Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. -package split_test +package split import ( "context" "testing" + "time" "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/br/pkg/utils" "github.com/stretchr/testify/require" + pd "github.com/tikv/pd/client" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func TestScanRegionBackOfferWithSuccess(t *testing.T) { var counter int - bo := split.NewWaitRegionOnlineBackoffer() + bo := NewWaitRegionOnlineBackoffer() err := utils.WithRetry(context.Background(), func() error { defer func() { @@ -37,7 +42,7 @@ func TestScanRegionBackOfferWithFail(t *testing.T) { }() var counter int - bo := split.NewWaitRegionOnlineBackoffer() + bo := NewWaitRegionOnlineBackoffer() err := utils.WithRetry(context.Background(), func() error { defer func() { @@ -46,7 +51,7 @@ func TestScanRegionBackOfferWithFail(t *testing.T) { return berrors.ErrPDBatchScanRegion }, bo) require.Error(t, err) - require.Equal(t, counter, split.WaitRegionOnlineAttemptTimes) + require.Equal(t, counter, WaitRegionOnlineAttemptTimes) } func TestScanRegionBackOfferWithStopRetry(t *testing.T) { @@ -56,7 +61,7 @@ func TestScanRegionBackOfferWithStopRetry(t *testing.T) { }() var counter int - bo := split.NewWaitRegionOnlineBackoffer() + bo := NewWaitRegionOnlineBackoffer() err := utils.WithRetry(context.Background(), func() error { defer func() { @@ -71,3 +76,105 @@ func TestScanRegionBackOfferWithStopRetry(t *testing.T) { require.Error(t, err) require.Equal(t, counter, 6) } + +type mockScatterFailedPDClient struct { + pd.Client + failed map[uint64]int + failedBefore int +} + +func (c *mockScatterFailedPDClient) ScatterRegion(ctx context.Context, regionID uint64) error { + if c.failed == nil { + c.failed = make(map[uint64]int) + } + c.failed[regionID]++ + if c.failed[regionID] > c.failedBefore { + return nil + } + return status.Errorf(codes.Unknown, "region %d is not fully replicated", regionID) +} + +type recordCntBackoffer struct { + already int +} + +func (b *recordCntBackoffer) NextBackoff(error) time.Duration { + b.already++ + return 0 +} + +func (b *recordCntBackoffer) Attempt() int { + return 100 +} + +func TestScatterSequentiallyRetryCnt(t *testing.T) { + client := pdClient{ + needScatterVal: true, + client: &mockScatterFailedPDClient{failedBefore: 7}, + } + client.needScatterInit.Do(func() {}) + + ctx := context.Background() + regions := []*RegionInfo{ + { + Region: &metapb.Region{ + Id: 1, + }, + }, + { + Region: &metapb.Region{ + Id: 2, + }, + }, + } + backoffer := &recordCntBackoffer{} + client.scatterRegionsSequentially( + ctx, + regions, + backoffer, + ) + require.Equal(t, 7, backoffer.already) +} + +type mockOldPDClient struct { + pd.Client + + scattered map[uint64]struct{} +} + +func (c *mockOldPDClient) ScatterRegion(_ context.Context, regionID uint64) error { + if c.scattered == nil { + c.scattered = make(map[uint64]struct{}) + } + c.scattered[regionID] = struct{}{} + return nil +} + +func (c *mockOldPDClient) ScatterRegions(context.Context, []uint64, ...pd.RegionsOption) (*pdpb.ScatterRegionResponse, error) { + return nil, status.Error(codes.Unimplemented, "Ah, yep") +} + +func TestScatterBackwardCompatibility(t *testing.T) { + client := pdClient{ + needScatterVal: true, + client: &mockOldPDClient{}, + } + client.needScatterInit.Do(func() {}) + + ctx := context.Background() + regions := []*RegionInfo{ + { + Region: &metapb.Region{ + Id: 1, + }, + }, + { + Region: &metapb.Region{ + Id: 2, + }, + }, + } + err := client.ScatterRegions(ctx, regions) + require.NoError(t, err) + require.Equal(t, map[uint64]struct{}{1: {}, 2: {}}, client.client.(*mockOldPDClient).scattered) +} diff --git a/br/pkg/restore/split_test.go b/br/pkg/restore/split_test.go index 15a5f0a12663d..a2a266bf51bdc 100644 --- a/br/pkg/restore/split_test.go +++ b/br/pkg/restore/split_test.go @@ -1,6 +1,6 @@ // Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. -package restore_test +package restore import ( "bytes" @@ -19,10 +19,8 @@ import ( berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/restore" "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/br/pkg/rtree" - "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/br/pkg/utils/iter" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/store/pdtypes" @@ -37,14 +35,13 @@ import ( type TestClient struct { split.SplitClient - mu sync.RWMutex - stores map[uint64]*metapb.Store - regions map[uint64]*split.RegionInfo - regionsInfo *pdtypes.RegionTree // For now it's only used in ScanRegions - nextRegionID uint64 - injectInScatter func(*split.RegionInfo) error - injectInOperator func(uint64) (*pdpb.GetOperatorResponse, error) - supportBatchScatter bool + mu sync.RWMutex + stores map[uint64]*metapb.Store + regions map[uint64]*split.RegionInfo + regionsInfo *pdtypes.RegionTree // For now it's only used in ScanRegions + nextRegionID uint64 + injectInScatter func(*split.RegionInfo) error + injectInOperator func(uint64) (*pdpb.GetOperatorResponse, error) scattered map[uint64]bool InjectErr bool @@ -70,15 +67,8 @@ func NewTestClient( } } -func (c *TestClient) InstallBatchScatterSupport() { - c.supportBatchScatter = true -} - // ScatterRegions scatters regions in a batch. func (c *TestClient) ScatterRegions(ctx context.Context, regionInfo []*split.RegionInfo) error { - if !c.supportBatchScatter { - return status.Error(codes.Unimplemented, "Ah, yep") - } regions := map[uint64]*split.RegionInfo{} for _, region := range regionInfo { regions[region.Region.Id] = region @@ -254,41 +244,13 @@ func (c *TestClient) IsScatterRegionFinished( return split.IsScatterRegionFinished(resp) } -type assertRetryLessThanBackoffer struct { - max int - already int - t *testing.T -} - -func assertRetryLessThan(t *testing.T, times int) utils.Backoffer { - return &assertRetryLessThanBackoffer{ - max: times, - already: 0, - t: t, - } -} - -// NextBackoff returns a duration to wait before retrying again -func (b *assertRetryLessThanBackoffer) NextBackoff(err error) time.Duration { - b.already++ - if b.already >= b.max { - b.t.Logf("retry more than %d time: test failed", b.max) - b.t.FailNow() - } - return 0 -} - -// Attempt returns the remain attempt times -func (b *assertRetryLessThanBackoffer) Attempt() int { - return b.max - b.already -} func TestScanEmptyRegion(t *testing.T) { client := initTestClient(false) ranges := initRanges() // make ranges has only one ranges = ranges[0:1] rewriteRules := initRewriteRules() - regionSplitter := restore.NewRegionSplitter(client) + regionSplitter := NewRegionSplitter(client) ctx := context.Background() err := regionSplitter.ExecuteSplit(ctx, ranges, rewriteRules, 1, false, func(key [][]byte) {}) @@ -296,44 +258,6 @@ func TestScanEmptyRegion(t *testing.T) { require.NoError(t, err) } -func TestScatterFinishInTime(t *testing.T) { - client := initTestClient(false) - ranges := initRanges() - rewriteRules := initRewriteRules() - regionSplitter := restore.NewRegionSplitter(client) - - ctx := context.Background() - err := regionSplitter.ExecuteSplit(ctx, ranges, rewriteRules, 1, false, func(key [][]byte) {}) - require.NoError(t, err) - regions := client.GetAllRegions() - if !validateRegions(regions) { - for _, region := range regions { - t.Logf("region: %v\n", region.Region) - } - t.Log("get wrong result") - t.Fail() - } - - regionInfos := make([]*split.RegionInfo, 0, len(regions)) - for _, info := range regions { - regionInfos = append(regionInfos, info) - } - failed := map[uint64]int{} - client.injectInScatter = func(r *split.RegionInfo) error { - failed[r.Region.Id]++ - if failed[r.Region.Id] > 7 { - return nil - } - return status.Errorf(codes.Unknown, "region %d is not fully replicated", r.Region.Id) - } - - // When using a exponential backoffer, if we try to backoff more than 40 times in 10 regions, - // it would cost time unacceptable. - regionSplitter.ScatterRegionsSequentially(ctx, - regionInfos, - assertRetryLessThan(t, 40)) -} - // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) // range: [aaa, aae), [aae, aaz), [ccd, ccf), [ccf, ccj) // rewrite rules: aa -> xx, cc -> bb @@ -343,17 +267,11 @@ func TestScatterFinishInTime(t *testing.T) { // [bbj, cca), [cca, xxe), [xxe, xxz), [xxz, ) func TestSplitAndScatter(t *testing.T) { t.Run("BatchScatter", func(t *testing.T) { - client := initTestClient(false) - client.InstallBatchScatterSupport() - runTestSplitAndScatterWith(t, client) - }) - t.Run("BackwardCompatibility", func(t *testing.T) { client := initTestClient(false) runTestSplitAndScatterWith(t, client) }) t.Run("WaitScatter", func(t *testing.T) { client := initTestClient(false) - client.InstallBatchScatterSupport() runWaitScatter(t, client) }) } @@ -450,7 +368,7 @@ func runWaitScatter(t *testing.T, client *TestClient) { for _, info := range regionsMap { regions = append(regions, info) } - regionSplitter := restore.NewRegionSplitter(client) + regionSplitter := NewRegionSplitter(client) leftCnt := regionSplitter.WaitForScatterRegionsTimeout(ctx, regions, 2000*time.Second) require.Equal(t, leftCnt, 0) } @@ -458,7 +376,7 @@ func runWaitScatter(t *testing.T, client *TestClient) { func runTestSplitAndScatterWith(t *testing.T, client *TestClient) { ranges := initRanges() rewriteRules := initRewriteRules() - regionSplitter := restore.NewRegionSplitter(client) + regionSplitter := NewRegionSplitter(client) ctx := context.Background() err := regionSplitter.ExecuteSplit(ctx, ranges, rewriteRules, 1, false, func(key [][]byte) {}) @@ -485,7 +403,8 @@ func runTestSplitAndScatterWith(t *testing.T, client *TestClient) { scattered[regionInfo.Region.Id] = true return nil } - regionSplitter.ScatterRegions(ctx, regionInfos) + err = regionSplitter.client.ScatterRegions(ctx, regionInfos) + require.NoError(t, err) for key := range regions { if key == alwaysFailedRegionID { require.Falsef(t, scattered[key], "always failed region %d was scattered successfully", key) @@ -506,7 +425,7 @@ func TestRawSplit(t *testing.T) { client := initTestClient(true) ctx := context.Background() - regionSplitter := restore.NewRegionSplitter(client) + regionSplitter := NewRegionSplitter(client) err := regionSplitter.ExecuteSplit(ctx, ranges, nil, 1, true, func(key [][]byte) {}) require.NoError(t, err) regions := client.GetAllRegions() @@ -579,7 +498,7 @@ func initRanges() []rtree.Range { return ranges[:] } -func initRewriteRules() *restore.RewriteRules { +func initRewriteRules() *RewriteRules { var rules [2]*import_sstpb.RewriteRule rules[0] = &import_sstpb.RewriteRule{ OldKeyPrefix: []byte("aa"), @@ -589,7 +508,7 @@ func initRewriteRules() *restore.RewriteRules { OldKeyPrefix: []byte("cc"), NewKeyPrefix: []byte("bb"), } - return &restore.RewriteRules{ + return &RewriteRules{ Data: rules[:], } } @@ -708,7 +627,7 @@ type fakeRestorer struct { tableIDIsInsequence bool } -func (f *fakeRestorer) SplitRanges(ctx context.Context, ranges []rtree.Range, rewriteRules *restore.RewriteRules, updateCh glue.Progress, isRawKv bool) error { +func (f *fakeRestorer) SplitRanges(ctx context.Context, ranges []rtree.Range, rewriteRules *RewriteRules, updateCh glue.Progress, isRawKv bool) error { f.mu.Lock() defer f.mu.Unlock() @@ -725,7 +644,7 @@ func (f *fakeRestorer) SplitRanges(ctx context.Context, ranges []rtree.Range, re return nil } -func (f *fakeRestorer) RestoreSSTFiles(ctx context.Context, tableIDWithFiles []restore.TableIDWithFiles, rewriteRules *restore.RewriteRules, updateCh glue.Progress) error { +func (f *fakeRestorer) RestoreSSTFiles(ctx context.Context, tableIDWithFiles []TableIDWithFiles, rewriteRules *RewriteRules, updateCh glue.Progress) error { f.mu.Lock() defer f.mu.Unlock() @@ -743,7 +662,7 @@ func (f *fakeRestorer) RestoreSSTFiles(ctx context.Context, tableIDWithFiles []r return err } -func fakeRanges(keys ...string) (r restore.DrainResult) { +func fakeRanges(keys ...string) (r DrainResult) { for i := range keys { if i+1 == len(keys) { return @@ -754,7 +673,7 @@ func fakeRanges(keys ...string) (r restore.DrainResult) { Files: []*backuppb.File{{Name: "fake.sst"}}, }) r.TableEndOffsetInRanges = append(r.TableEndOffsetInRanges, len(r.Ranges)) - r.TablesToSend = append(r.TablesToSend, restore.CreatedTable{ + r.TablesToSend = append(r.TablesToSend, CreatedTable{ Table: &model.TableInfo{ ID: int64(i), }, @@ -769,7 +688,7 @@ type errorInTimeSink struct { t *testing.T } -func (e errorInTimeSink) EmitTables(tables ...restore.CreatedTable) {} +func (e errorInTimeSink) EmitTables(tables ...CreatedTable) {} func (e errorInTimeSink) EmitError(err error) { e.errCh <- err @@ -796,7 +715,7 @@ func assertErrorEmitInTime(ctx context.Context, t *testing.T) errorInTimeSink { } func TestRestoreFailed(t *testing.T) { - ranges := []restore.DrainResult{ + ranges := []DrainResult{ fakeRanges("aax", "abx", "abz"), fakeRanges("abz", "bbz", "bcy"), fakeRanges("bcy", "cad", "xxy"), @@ -804,7 +723,7 @@ func TestRestoreFailed(t *testing.T) { r := &fakeRestorer{ tableIDIsInsequence: true, } - sender, err := restore.NewTiKVSender(context.TODO(), r, nil, 1, string(restore.FineGrained)) + sender, err := NewTiKVSender(context.TODO(), r, nil, 1, string(FineGrained)) require.NoError(t, err) dctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -821,13 +740,13 @@ func TestRestoreFailed(t *testing.T) { } func TestSplitFailed(t *testing.T) { - ranges := []restore.DrainResult{ + ranges := []DrainResult{ fakeRanges("aax", "abx", "abz"), fakeRanges("abz", "bbz", "bcy"), fakeRanges("bcy", "cad", "xxy"), } r := &fakeRestorer{errorInSplit: true, tableIDIsInsequence: true} - sender, err := restore.NewTiKVSender(context.TODO(), r, nil, 1, string(restore.FineGrained)) + sender, err := NewTiKVSender(context.TODO(), r, nil, 1, string(FineGrained)) require.NoError(t, err) dctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -852,7 +771,7 @@ func TestSplitPoint(t *testing.T) { ctx := context.Background() var oldTableID int64 = 50 var tableID int64 = 100 - rewriteRules := &restore.RewriteRules{ + rewriteRules := &RewriteRules{ Data: []*import_sstpb.RewriteRule{ { OldKeyPrefix: tablecodec.EncodeTablePrefix(oldTableID), @@ -875,8 +794,8 @@ func TestSplitPoint(t *testing.T) { client.AppendRegion(keyWithTablePrefix(tableID, "h"), keyWithTablePrefix(tableID, "j")) client.AppendRegion(keyWithTablePrefix(tableID, "j"), keyWithTablePrefix(tableID+1, "a")) - iter := restore.NewSplitHelperIteratorForTest(splitHelper, tableID, rewriteRules) - err := restore.SplitPoint(ctx, iter, client, func(ctx context.Context, rs *restore.RegionSplitter, u uint64, o int64, ri *split.RegionInfo, v []split.Valued) error { + iter := NewSplitHelperIteratorForTest(splitHelper, tableID, rewriteRules) + err := SplitPoint(ctx, iter, client, func(ctx context.Context, rs *RegionSplitter, u uint64, o int64, ri *split.RegionInfo, v []split.Valued) error { require.Equal(t, u, uint64(0)) require.Equal(t, o, int64(0)) require.Equal(t, ri.Region.StartKey, keyWithTablePrefix(tableID, "a")) @@ -902,7 +821,7 @@ func TestSplitPoint2(t *testing.T) { ctx := context.Background() var oldTableID int64 = 50 var tableID int64 = 100 - rewriteRules := &restore.RewriteRules{ + rewriteRules := &RewriteRules{ Data: []*import_sstpb.RewriteRule{ { OldKeyPrefix: tablecodec.EncodeTablePrefix(oldTableID), @@ -933,8 +852,8 @@ func TestSplitPoint2(t *testing.T) { client.AppendRegion(keyWithTablePrefix(tableID, "o"), keyWithTablePrefix(tableID+1, "a")) firstSplit := true - iter := restore.NewSplitHelperIteratorForTest(splitHelper, tableID, rewriteRules) - err := restore.SplitPoint(ctx, iter, client, func(ctx context.Context, rs *restore.RegionSplitter, u uint64, o int64, ri *split.RegionInfo, v []split.Valued) error { + iter := NewSplitHelperIteratorForTest(splitHelper, tableID, rewriteRules) + err := SplitPoint(ctx, iter, client, func(ctx context.Context, rs *RegionSplitter, u uint64, o int64, ri *split.RegionInfo, v []split.Valued) error { if firstSplit { require.Equal(t, u, uint64(0)) require.Equal(t, o, int64(0)) @@ -1007,7 +926,7 @@ func TestGetRewriteTableID(t *testing.T) { var tableID int64 = 76 var oldTableID int64 = 80 { - rewriteRules := &restore.RewriteRules{ + rewriteRules := &RewriteRules{ Data: []*import_sstpb.RewriteRule{ { OldKeyPrefix: tablecodec.EncodeTablePrefix(oldTableID), @@ -1016,12 +935,12 @@ func TestGetRewriteTableID(t *testing.T) { }, } - newTableID := restore.GetRewriteTableID(oldTableID, rewriteRules) + newTableID := GetRewriteTableID(oldTableID, rewriteRules) require.Equal(t, tableID, newTableID) } { - rewriteRules := &restore.RewriteRules{ + rewriteRules := &RewriteRules{ Data: []*import_sstpb.RewriteRule{ { OldKeyPrefix: tablecodec.GenTableRecordPrefix(oldTableID), @@ -1030,7 +949,7 @@ func TestGetRewriteTableID(t *testing.T) { }, } - newTableID := restore.GetRewriteTableID(oldTableID, rewriteRules) + newTableID := GetRewriteTableID(oldTableID, rewriteRules) require.Equal(t, tableID, newTableID) } } @@ -1039,12 +958,12 @@ type mockLogIter struct { next int } -func (m *mockLogIter) TryNext(ctx context.Context) iter.IterResult[*restore.LogDataFileInfo] { +func (m *mockLogIter) TryNext(ctx context.Context) iter.IterResult[*LogDataFileInfo] { if m.next > 10000 { - return iter.Done[*restore.LogDataFileInfo]() + return iter.Done[*LogDataFileInfo]() } m.next += 1 - return iter.Emit(&restore.LogDataFileInfo{ + return iter.Emit(&LogDataFileInfo{ DataFileInfo: &backuppb.DataFileInfo{ StartKey: []byte(fmt.Sprintf("a%d", m.next)), EndKey: []byte("b"), @@ -1056,7 +975,7 @@ func (m *mockLogIter) TryNext(ctx context.Context) iter.IterResult[*restore.LogD func TestLogFilesIterWithSplitHelper(t *testing.T) { var tableID int64 = 76 var oldTableID int64 = 80 - rewriteRules := &restore.RewriteRules{ + rewriteRules := &RewriteRules{ Data: []*import_sstpb.RewriteRule{ { OldKeyPrefix: tablecodec.EncodeTablePrefix(oldTableID), @@ -1064,12 +983,12 @@ func TestLogFilesIterWithSplitHelper(t *testing.T) { }, }, } - rewriteRulesMap := map[int64]*restore.RewriteRules{ + rewriteRulesMap := map[int64]*RewriteRules{ oldTableID: rewriteRules, } mockIter := &mockLogIter{} ctx := context.Background() - logIter := restore.NewLogFilesIterWithSplitHelper(mockIter, rewriteRulesMap, newFakeSplitClient(), 144*1024*1024, 1440000) + logIter := NewLogFilesIterWithSplitHelper(mockIter, rewriteRulesMap, newFakeSplitClient(), 144*1024*1024, 1440000) next := 0 for r := logIter.TryNext(ctx); !r.Finished; r = logIter.TryNext(ctx) { require.NoError(t, r.Err) @@ -1123,7 +1042,7 @@ func TestSplitCheckPartRegionConsistency(t *testing.T) { } func TestGetSplitSortedKeysFromSortedRegions(t *testing.T) { - splitContext := restore.SplitContext{} + splitContext := SplitContext{} sortedKeys := [][]byte{ []byte("b"), []byte("d"), @@ -1154,7 +1073,7 @@ func TestGetSplitSortedKeysFromSortedRegions(t *testing.T) { }, }, } - result := restore.TestGetSplitSortedKeysFromSortedRegionsTest(splitContext, sortedKeys, sortedRegions) + result := TestGetSplitSortedKeysFromSortedRegionsTest(splitContext, sortedKeys, sortedRegions) require.Equal(t, 3, len(result)) require.Equal(t, [][]byte{[]byte("b"), []byte("d")}, result[1]) require.Equal(t, [][]byte{[]byte("g"), []byte("j")}, result[2]) diff --git a/br/pkg/restore/util_test.go b/br/pkg/restore/util_test.go index 2740594a79cf1..e6edc8f81334d 100644 --- a/br/pkg/restore/util_test.go +++ b/br/pkg/restore/util_test.go @@ -1,6 +1,6 @@ // Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. -package restore_test +package restore import ( "context" @@ -14,7 +14,6 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" recover_data "github.com/pingcap/kvproto/pkg/recoverdatapb" berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/restore" "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/pkg/store/pdtypes" "github.com/pingcap/tidb/pkg/tablecodec" @@ -23,19 +22,19 @@ import ( ) func TestParseQuoteName(t *testing.T) { - schema, table := restore.ParseQuoteName("`a`.`b`") + schema, table := ParseQuoteName("`a`.`b`") require.Equal(t, "a", schema) require.Equal(t, "b", table) - schema, table = restore.ParseQuoteName("`a``b`.``````") + schema, table = ParseQuoteName("`a``b`.``````") require.Equal(t, "a`b", schema) require.Equal(t, "``", table) - schema, table = restore.ParseQuoteName("`.`.`.`") + schema, table = ParseQuoteName("`.`.`.`") require.Equal(t, ".", schema) require.Equal(t, ".", table) - schema, table = restore.ParseQuoteName("`.``.`.`.`") + schema, table = ParseQuoteName("`.``.`.`.`") require.Equal(t, ".`.", schema) require.Equal(t, ".", table) } @@ -54,7 +53,7 @@ func TestGetSSTMetaFromFile(t *testing.T) { StartKey: []byte("t2abc"), EndKey: []byte("t3a"), } - sstMeta, err := restore.GetSSTMetaFromFile([]byte{}, file, region, rule, restore.RewriteModeLegacy) + sstMeta, err := GetSSTMetaFromFile([]byte{}, file, region, rule, RewriteModeLegacy) require.Nil(t, err) require.Equal(t, "t2abc", string(sstMeta.GetRange().GetStart())) require.Equal(t, "t2\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", string(sstMeta.GetRange().GetEnd())) @@ -91,14 +90,14 @@ func TestMapTableToFiles(t *testing.T) { }, } - result := restore.MapTableToFiles(append(filesOfTable2, filesOfTable1...)) + result := MapTableToFiles(append(filesOfTable2, filesOfTable1...)) require.Equal(t, filesOfTable1, result[1]) require.Equal(t, filesOfTable2, result[2]) } func TestValidateFileRewriteRule(t *testing.T) { - rules := &restore.RewriteRules{ + rules := &RewriteRules{ Data: []*import_sstpb.RewriteRule{{ OldKeyPrefix: []byte(tablecodec.EncodeTablePrefix(1)), NewKeyPrefix: []byte(tablecodec.EncodeTablePrefix(2)), @@ -106,7 +105,7 @@ func TestValidateFileRewriteRule(t *testing.T) { } // Empty start/end key is not allowed. - err := restore.ValidateFileRewriteRule( + err := ValidateFileRewriteRule( &backuppb.File{ Name: "file_write.sst", StartKey: []byte(""), @@ -118,7 +117,7 @@ func TestValidateFileRewriteRule(t *testing.T) { require.Regexp(t, ".*cannot find rewrite rule.*", err.Error()) // Range is not overlap, no rule found. - err = restore.ValidateFileRewriteRule( + err = ValidateFileRewriteRule( &backuppb.File{ Name: "file_write.sst", StartKey: tablecodec.EncodeTablePrefix(0), @@ -130,7 +129,7 @@ func TestValidateFileRewriteRule(t *testing.T) { require.Regexp(t, ".*cannot find rewrite rule.*", err.Error()) // No rule for end key. - err = restore.ValidateFileRewriteRule( + err = ValidateFileRewriteRule( &backuppb.File{ Name: "file_write.sst", StartKey: tablecodec.EncodeTablePrefix(1), @@ -146,7 +145,7 @@ func TestValidateFileRewriteRule(t *testing.T) { OldKeyPrefix: tablecodec.EncodeTablePrefix(2), NewKeyPrefix: tablecodec.EncodeTablePrefix(3), }) - err = restore.ValidateFileRewriteRule( + err = ValidateFileRewriteRule( &backuppb.File{ Name: "file_write.sst", StartKey: tablecodec.EncodeTablePrefix(1), @@ -162,7 +161,7 @@ func TestValidateFileRewriteRule(t *testing.T) { OldKeyPrefix: tablecodec.EncodeTablePrefix(2), NewKeyPrefix: tablecodec.EncodeTablePrefix(1), }) - err = restore.ValidateFileRewriteRule( + err = ValidateFileRewriteRule( &backuppb.File{ Name: "file_write.sst", StartKey: tablecodec.EncodeTablePrefix(1), @@ -345,7 +344,7 @@ func (c *regionOnlineSlowClient) ScanRegions(ctx context.Context, key, endKey [] } func TestRewriteFileKeys(t *testing.T) { - rewriteRules := restore.RewriteRules{ + rewriteRules := RewriteRules{ Data: []*import_sstpb.RewriteRule{ { NewKeyPrefix: tablecodec.GenTablePrefix(2), @@ -362,7 +361,7 @@ func TestRewriteFileKeys(t *testing.T) { StartKey: tablecodec.GenTableRecordPrefix(1), EndKey: tablecodec.GenTableRecordPrefix(1).PrefixNext(), } - start, end, err := restore.GetRewriteRawKeys(&rawKeyFile, &rewriteRules) + start, end, err := GetRewriteRawKeys(&rawKeyFile, &rewriteRules) require.NoError(t, err) _, end, err = codec.DecodeBytes(end, nil) require.NoError(t, err) @@ -376,7 +375,7 @@ func TestRewriteFileKeys(t *testing.T) { StartKey: codec.EncodeBytes(nil, tablecodec.GenTableRecordPrefix(1)), EndKey: codec.EncodeBytes(nil, tablecodec.GenTableRecordPrefix(1).PrefixNext()), } - start, end, err = restore.GetRewriteEncodedKeys(&encodeKeyFile, &rewriteRules) + start, end, err = GetRewriteEncodedKeys(&encodeKeyFile, &rewriteRules) require.NoError(t, err) require.Equal(t, codec.EncodeBytes(nil, tablecodec.GenTableRecordPrefix(2)), start) require.Equal(t, codec.EncodeBytes(nil, tablecodec.GenTableRecordPrefix(2).PrefixNext()), end) @@ -388,12 +387,12 @@ func TestRewriteFileKeys(t *testing.T) { EndKey: codec.EncodeBytes(nil, tablecodec.GenTableRecordPrefix(767).PrefixNext()), } // use raw rewrite should no error but not equal - start, end, err = restore.GetRewriteRawKeys(&encodeKeyFile767, &rewriteRules) + start, end, err = GetRewriteRawKeys(&encodeKeyFile767, &rewriteRules) require.NoError(t, err) require.NotEqual(t, codec.EncodeBytes(nil, tablecodec.GenTableRecordPrefix(511)), start) require.NotEqual(t, codec.EncodeBytes(nil, tablecodec.GenTableRecordPrefix(511).PrefixNext()), end) // use encode rewrite should no error and equal - start, end, err = restore.GetRewriteEncodedKeys(&encodeKeyFile767, &rewriteRules) + start, end, err = GetRewriteEncodedKeys(&encodeKeyFile767, &rewriteRules) require.NoError(t, err) require.Equal(t, codec.EncodeBytes(nil, tablecodec.GenTableRecordPrefix(511)), start) require.Equal(t, codec.EncodeBytes(nil, tablecodec.GenTableRecordPrefix(511).PrefixNext()), end) @@ -410,8 +409,8 @@ func newPeerMeta( commitIndex uint64, version uint64, tombstone bool, -) *restore.RecoverRegion { - return &restore.RecoverRegion{ +) *RecoverRegion { + return &RecoverRegion{ &recover_data.RegionMeta{ RegionId: regionId, PeerId: peerId, @@ -427,12 +426,12 @@ func newPeerMeta( } } -func newRecoverRegionInfo(r *restore.RecoverRegion) *restore.RecoverRegionInfo { - return &restore.RecoverRegionInfo{ +func newRecoverRegionInfo(r *RecoverRegion) *RecoverRegionInfo { + return &RecoverRegionInfo{ RegionVersion: r.Version, RegionId: r.RegionId, - StartKey: restore.PrefixStartKey(r.StartKey), - EndKey: restore.PrefixEndKey(r.EndKey), + StartKey: PrefixStartKey(r.StartKey), + EndKey: PrefixEndKey(r.EndKey), TombStone: r.Tombstone, } } @@ -441,7 +440,7 @@ func TestSortRecoverRegions(t *testing.T) { selectedPeer1 := newPeerMeta(9, 11, 2, []byte("aa"), nil, 2, 0, 0, 0, false) selectedPeer2 := newPeerMeta(19, 22, 3, []byte("bbb"), nil, 2, 1, 0, 1, false) selectedPeer3 := newPeerMeta(29, 30, 1, []byte("c"), nil, 2, 1, 1, 2, false) - regions := map[uint64][]*restore.RecoverRegion{ + regions := map[uint64][]*RecoverRegion{ 9: { // peer 11 should be selected because of log term newPeerMeta(9, 10, 1, []byte("a"), nil, 1, 1, 1, 1, false), @@ -461,8 +460,8 @@ func TestSortRecoverRegions(t *testing.T) { newPeerMeta(29, 32, 3, []byte("ccc"), nil, 2, 1, 0, 0, false), }, } - regionsInfos := restore.SortRecoverRegions(regions) - expectRegionInfos := []*restore.RecoverRegionInfo{ + regionsInfos := SortRecoverRegions(regions) + expectRegionInfos := []*RecoverRegionInfo{ newRecoverRegionInfo(selectedPeer3), newRecoverRegionInfo(selectedPeer2), newRecoverRegionInfo(selectedPeer1), @@ -476,13 +475,13 @@ func TestCheckConsistencyAndValidPeer(t *testing.T) { validPeer2 := newPeerMeta(19, 22, 3, []byte("bb"), []byte("cc"), 2, 1, 0, 1, false) validPeer3 := newPeerMeta(29, 30, 1, []byte("cc"), []byte(""), 2, 1, 1, 2, false) - validRegionInfos := []*restore.RecoverRegionInfo{ + validRegionInfos := []*RecoverRegionInfo{ newRecoverRegionInfo(validPeer1), newRecoverRegionInfo(validPeer2), newRecoverRegionInfo(validPeer3), } - validPeer, err := restore.CheckConsistencyAndValidPeer(validRegionInfos) + validPeer, err := CheckConsistencyAndValidPeer(validRegionInfos) require.NoError(t, err) require.Equal(t, 3, len(validPeer)) var regions = make(map[uint64]struct{}, 3) @@ -497,13 +496,13 @@ func TestCheckConsistencyAndValidPeer(t *testing.T) { invalidPeer2 := newPeerMeta(19, 22, 3, []byte("dd"), []byte("cc"), 2, 1, 0, 1, false) invalidPeer3 := newPeerMeta(29, 30, 1, []byte("cc"), []byte("dd"), 2, 1, 1, 2, false) - invalidRegionInfos := []*restore.RecoverRegionInfo{ + invalidRegionInfos := []*RecoverRegionInfo{ newRecoverRegionInfo(invalidPeer1), newRecoverRegionInfo(invalidPeer2), newRecoverRegionInfo(invalidPeer3), } - _, err = restore.CheckConsistencyAndValidPeer(invalidRegionInfos) + _, err = CheckConsistencyAndValidPeer(invalidRegionInfos) require.Error(t, err) require.Regexp(t, ".*invalid restore range.*", err.Error()) } @@ -514,13 +513,13 @@ func TestLeaderCandidates(t *testing.T) { validPeer2 := newPeerMeta(19, 22, 3, []byte("bb"), []byte("cc"), 2, 1, 0, 1, false) validPeer3 := newPeerMeta(29, 30, 1, []byte("cc"), []byte(""), 2, 1, 0, 2, false) - peers := []*restore.RecoverRegion{ + peers := []*RecoverRegion{ validPeer1, validPeer2, validPeer3, } - candidates, err := restore.LeaderCandidates(peers) + candidates, err := LeaderCandidates(peers) require.NoError(t, err) require.Equal(t, 3, len(candidates)) } @@ -530,30 +529,30 @@ func TestSelectRegionLeader(t *testing.T) { validPeer2 := newPeerMeta(19, 22, 3, []byte("bb"), []byte("cc"), 2, 1, 0, 1, false) validPeer3 := newPeerMeta(29, 30, 1, []byte("cc"), []byte(""), 2, 1, 0, 2, false) - peers := []*restore.RecoverRegion{ + peers := []*RecoverRegion{ validPeer1, validPeer2, validPeer3, } // init store banlance score all is 0 storeBalanceScore := make(map[uint64]int, len(peers)) - leader := restore.SelectRegionLeader(storeBalanceScore, peers) + leader := SelectRegionLeader(storeBalanceScore, peers) require.Equal(t, validPeer1, leader) // change store banlance store storeBalanceScore[2] = 3 storeBalanceScore[3] = 2 storeBalanceScore[1] = 1 - leader = restore.SelectRegionLeader(storeBalanceScore, peers) + leader = SelectRegionLeader(storeBalanceScore, peers) require.Equal(t, validPeer3, leader) // one peer - peer := []*restore.RecoverRegion{ + peer := []*RecoverRegion{ validPeer3, } // init store banlance score all is 0 storeScore := make(map[uint64]int, len(peer)) - leader = restore.SelectRegionLeader(storeScore, peer) + leader = SelectRegionLeader(storeScore, peer) require.Equal(t, validPeer3, leader) } @@ -567,7 +566,7 @@ func TestLogFilesSkipMap(t *testing.T) { ) for ratio < 1 { - skipmap := restore.NewLogFilesSkipMap() + skipmap := NewLogFilesSkipMap() nativemap := make(map[string]map[int]map[int]struct{}) count := 0 for i := 0; i < int(ratio*float64(metaNum*groupNum*fileNum)); i++ { From c6a4aec01fad746d380d9b0aa73319328b58fa00 Mon Sep 17 00:00:00 2001 From: Daemon Date: Mon, 11 Mar 2024 14:18:37 +0800 Subject: [PATCH 08/15] util: use atomic.Pointer to replace with atomic.Value (#51633) ref pingcap/tidb#44736 --- br/pkg/restore/import.go | 4 ++-- pkg/config/config.go | 4 ++-- pkg/domain/infosync/info.go | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/br/pkg/restore/import.go b/br/pkg/restore/import.go index c1b46b353134e..73d60dd3e25e7 100644 --- a/br/pkg/restore/import.go +++ b/br/pkg/restore/import.go @@ -1151,7 +1151,7 @@ func (importer *FileImporter) downloadRawKVSSTV2( } log.Debug("download SST", logutil.SSTMeta(sstMeta), logutil.Region(regionInfo.Region)) - var atomicResp atomic.Value + var atomicResp atomic.Pointer[import_sstpb.DownloadResponse] eg, ectx := errgroup.WithContext(ctx) for _, p := range regionInfo.Region.GetPeers() { peer := p @@ -1176,7 +1176,7 @@ func (importer *FileImporter) downloadRawKVSSTV2( return nil, err } - downloadResp := atomicResp.Load().(*import_sstpb.DownloadResponse) + downloadResp := atomicResp.Load() sstMeta.Range.Start = downloadResp.Range.GetStart() sstMeta.Range.End = downloadResp.Range.GetEnd() sstMeta.ApiVersion = apiVersion diff --git a/pkg/config/config.go b/pkg/config/config.go index d0bdc189cedca..716204baf623d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1083,7 +1083,7 @@ var defaultConf = Config{ } var ( - globalConf atomic.Value + globalConf atomic.Pointer[Config] ) // NewConfig creates a new config instance with default value. @@ -1096,7 +1096,7 @@ func NewConfig() *Config { // It should store configuration from command line and configuration file. // Other parts of the system can read the global configuration use this function. func GetGlobalConfig() *Config { - return globalConf.Load().(*Config) + return globalConf.Load() } // StoreGlobalConfig stores a new config to the globalConf. It mostly uses in the test to avoid some data races. diff --git a/pkg/domain/infosync/info.go b/pkg/domain/infosync/info.go index b20a0a18f1262..51dfbcac6e16b 100644 --- a/pkg/domain/infosync/info.go +++ b/pkg/domain/infosync/info.go @@ -177,15 +177,15 @@ type ServerVersionInfo struct { // globalInfoSyncer stores the global infoSyncer. // Use a global variable for simply the code, use the domain.infoSyncer will have circle import problem in some pkg. -// Use atomic.Value to avoid data race in the test. -var globalInfoSyncer atomic.Value +// Use atomic.Pointer to avoid data race in the test. +var globalInfoSyncer atomic.Pointer[InfoSyncer] func getGlobalInfoSyncer() (*InfoSyncer, error) { v := globalInfoSyncer.Load() if v == nil { return nil, errors.New("infoSyncer is not initialized") } - return v.(*InfoSyncer), nil + return v, nil } func setGlobalInfoSyncer(is *InfoSyncer) { From f495cc509e687b42287441f20d4724d4718d4193 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Mon, 11 Mar 2024 14:18:44 +0800 Subject: [PATCH 09/15] *: fix security alert about JWX (#51648) close pingcap/tidb#51647 --- DEPS.bzl | 48 ++++++++++++++++++++++++------------------------ go.mod | 8 ++++---- go.sum | 14 ++++++++------ 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/DEPS.bzl b/DEPS.bzl index 3bd29d0970c3b..90b2bebe419d3 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -4740,13 +4740,13 @@ def go_deps(): name = "com_github_lestrrat_go_httprc", build_file_proto_mode = "disable_global", importpath = "github.com/lestrrat-go/httprc", - sha256 = "fd0658206207ff68f0561d9a681a99bee765d9cc453665d202a01ce860c72a90", - strip_prefix = "github.com/lestrrat-go/httprc@v1.0.4", + sha256 = "b5ec122596da8970869d3b41a1bc901a440c66a906bbd2fcbe2a19e8728787d7", + strip_prefix = "github.com/lestrrat-go/httprc@v1.0.5", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/github.com/lestrrat-go/httprc/com_github_lestrrat_go_httprc-v1.0.4.zip", - "http://ats.apps.svc/gomod/github.com/lestrrat-go/httprc/com_github_lestrrat_go_httprc-v1.0.4.zip", - "https://cache.hawkingrei.com/gomod/github.com/lestrrat-go/httprc/com_github_lestrrat_go_httprc-v1.0.4.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/github.com/lestrrat-go/httprc/com_github_lestrrat_go_httprc-v1.0.4.zip", + "http://bazel-cache.pingcap.net:8080/gomod/github.com/lestrrat-go/httprc/com_github_lestrrat_go_httprc-v1.0.5.zip", + "http://ats.apps.svc/gomod/github.com/lestrrat-go/httprc/com_github_lestrrat_go_httprc-v1.0.5.zip", + "https://cache.hawkingrei.com/gomod/github.com/lestrrat-go/httprc/com_github_lestrrat_go_httprc-v1.0.5.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/lestrrat-go/httprc/com_github_lestrrat_go_httprc-v1.0.5.zip", ], ) go_repository( @@ -4766,13 +4766,13 @@ def go_deps(): name = "com_github_lestrrat_go_jwx_v2", build_file_proto_mode = "disable_global", importpath = "github.com/lestrrat-go/jwx/v2", - sha256 = "28e43e9f0b531d806db5c31f47076375a6f0c87207f1a8a69cc6bde242c83c65", - strip_prefix = "github.com/lestrrat-go/jwx/v2@v2.0.19", + sha256 = "f49d9cb1482cbd4ed113d8fa1c3f197df5ba498dd461641123cff0337e030af2", + strip_prefix = "github.com/lestrrat-go/jwx/v2@v2.0.21", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/github.com/lestrrat-go/jwx/v2/com_github_lestrrat_go_jwx_v2-v2.0.19.zip", - "http://ats.apps.svc/gomod/github.com/lestrrat-go/jwx/v2/com_github_lestrrat_go_jwx_v2-v2.0.19.zip", - "https://cache.hawkingrei.com/gomod/github.com/lestrrat-go/jwx/v2/com_github_lestrrat_go_jwx_v2-v2.0.19.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/github.com/lestrrat-go/jwx/v2/com_github_lestrrat_go_jwx_v2-v2.0.19.zip", + "http://bazel-cache.pingcap.net:8080/gomod/github.com/lestrrat-go/jwx/v2/com_github_lestrrat_go_jwx_v2-v2.0.21.zip", + "http://ats.apps.svc/gomod/github.com/lestrrat-go/jwx/v2/com_github_lestrrat_go_jwx_v2-v2.0.21.zip", + "https://cache.hawkingrei.com/gomod/github.com/lestrrat-go/jwx/v2/com_github_lestrrat_go_jwx_v2-v2.0.21.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/lestrrat-go/jwx/v2/com_github_lestrrat_go_jwx_v2-v2.0.21.zip", ], ) go_repository( @@ -6911,26 +6911,26 @@ def go_deps(): name = "com_github_stretchr_objx", build_file_proto_mode = "disable_global", importpath = "github.com/stretchr/objx", - sha256 = "1a00b3bb5ad41cb72634ace06b7eb7df857404d77a7cab4e401a7c729561fe4c", - strip_prefix = "github.com/stretchr/objx@v0.5.0", + sha256 = "3c22c1d1c4c4024eb16a12f0187775640bf35d51b0a06649febc7797119451c0", + strip_prefix = "github.com/stretchr/objx@v0.5.2", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/github.com/stretchr/objx/com_github_stretchr_objx-v0.5.0.zip", - "http://ats.apps.svc/gomod/github.com/stretchr/objx/com_github_stretchr_objx-v0.5.0.zip", - "https://cache.hawkingrei.com/gomod/github.com/stretchr/objx/com_github_stretchr_objx-v0.5.0.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/github.com/stretchr/objx/com_github_stretchr_objx-v0.5.0.zip", + "http://bazel-cache.pingcap.net:8080/gomod/github.com/stretchr/objx/com_github_stretchr_objx-v0.5.2.zip", + "http://ats.apps.svc/gomod/github.com/stretchr/objx/com_github_stretchr_objx-v0.5.2.zip", + "https://cache.hawkingrei.com/gomod/github.com/stretchr/objx/com_github_stretchr_objx-v0.5.2.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/stretchr/objx/com_github_stretchr_objx-v0.5.2.zip", ], ) go_repository( name = "com_github_stretchr_testify", build_file_proto_mode = "disable_global", importpath = "github.com/stretchr/testify", - sha256 = "e206daaede0bd03de060bdfbeb984ac2c49b83058753fffc93fe0c220ea87532", - strip_prefix = "github.com/stretchr/testify@v1.8.4", + sha256 = "ee5d4f73cb689b1b5432c6908a189f9fbdb172507c49c32dbdf79b239ea9b8e0", + strip_prefix = "github.com/stretchr/testify@v1.9.0", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/github.com/stretchr/testify/com_github_stretchr_testify-v1.8.4.zip", - "http://ats.apps.svc/gomod/github.com/stretchr/testify/com_github_stretchr_testify-v1.8.4.zip", - "https://cache.hawkingrei.com/gomod/github.com/stretchr/testify/com_github_stretchr_testify-v1.8.4.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/github.com/stretchr/testify/com_github_stretchr_testify-v1.8.4.zip", + "http://bazel-cache.pingcap.net:8080/gomod/github.com/stretchr/testify/com_github_stretchr_testify-v1.9.0.zip", + "http://ats.apps.svc/gomod/github.com/stretchr/testify/com_github_stretchr_testify-v1.9.0.zip", + "https://cache.hawkingrei.com/gomod/github.com/stretchr/testify/com_github_stretchr_testify-v1.9.0.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/stretchr/testify/com_github_stretchr_testify-v1.9.0.zip", ], ) go_repository( diff --git a/go.mod b/go.mod index 5a3d97b31d0a3..fbe4c5c6d983d 100644 --- a/go.mod +++ b/go.mod @@ -71,7 +71,7 @@ require ( github.com/klauspost/compress v1.17.4 github.com/ks3sdklib/aws-sdk-go v1.2.9 github.com/kyoh86/exportloopref v0.1.11 - github.com/lestrrat-go/jwx/v2 v2.0.19 + github.com/lestrrat-go/jwx/v2 v2.0.21 github.com/mgechev/revive v1.3.7 github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7 github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef @@ -103,7 +103,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/spkg/bom v1.0.0 github.com/stathat/consistent v1.0.0 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 github.com/tdakkota/asciicheck v0.2.0 github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 github.com/tidwall/btree v1.7.0 @@ -245,7 +245,7 @@ require ( github.com/kylelemons/godebug v1.1.0 // indirect github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect - github.com/lestrrat-go/httprc v1.0.4 // indirect + github.com/lestrrat-go/httprc v1.0.5 // indirect github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/option v1.0.1 // indirect github.com/lufia/plan9stats v0.0.0-20230326075908-cb1d2100619a // indirect @@ -277,7 +277,7 @@ require ( github.com/shurcooL/vfsgen v0.0.0-20181202132449-6a9ea43bcacd // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/cast v1.5.0 // indirect - github.com/stretchr/objx v0.5.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/go.sum b/go.sum index 8baa969875812..9fc9ab33941d8 100644 --- a/go.sum +++ b/go.sum @@ -597,12 +597,12 @@ github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= -github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8= -github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= +github.com/lestrrat-go/httprc v1.0.5 h1:bsTfiH8xaKOJPrg1R+E3iE/AWZr/x0Phj9PBTG/OLUk= +github.com/lestrrat-go/httprc v1.0.5/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= -github.com/lestrrat-go/jwx/v2 v2.0.19 h1:ekv1qEZE6BVct89QA+pRF6+4pCpfVrOnEJnTnT4RXoY= -github.com/lestrrat-go/jwx/v2 v2.0.19/go.mod h1:l3im3coce1lL2cDeAjqmaR+Awx+X8Ih+2k8BuHNJ4CU= +github.com/lestrrat-go/jwx/v2 v2.0.21 h1:jAPKupy4uHgrHFEdjVjNkUgoBKtVDgrQPB/h55FHrR0= +github.com/lestrrat-go/jwx/v2 v2.0.21/go.mod h1:09mLW8zto6bWL9GbwnqAli+ArLf+5M33QLQPDggkUWM= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= @@ -840,8 +840,9 @@ github.com/stathat/consistent v1.0.0/go.mod h1:uajTPbgSygZBJ+V+0mY7meZ8i0XAcZs7A github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -852,8 +853,9 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tdakkota/asciicheck v0.2.0 h1:o8jvnUANo0qXtnslk2d3nMKTFNlOnJjRrNcj0j9qkHM= github.com/tdakkota/asciicheck v0.2.0/go.mod h1:Qb7Y9EgjCLJGup51gDHFzbI08/gbGhL/UVhYIPWG2rg= github.com/tenntenn/modver v1.0.1 h1:2klLppGhDgzJrScMpkj9Ujy3rXPUspSjAcev9tSEBgA= From 8b02143e2b8ce4dd07d7d43d69c3784d65c2f198 Mon Sep 17 00:00:00 2001 From: EasonBall <592838129@qq.com> Date: Mon, 11 Mar 2024 15:29:08 +0800 Subject: [PATCH 10/15] disttask: skip scheduler take slots for some states (#51022) ref pingcap/tidb#49008 --- pkg/disttask/framework/scheduler/BUILD.bazel | 6 +- pkg/disttask/framework/scheduler/interface.go | 9 +- pkg/disttask/framework/scheduler/scheduler.go | 29 +++- .../framework/scheduler/scheduler_manager.go | 129 +++++++++----- .../scheduler/scheduler_manager_nokit_test.go | 164 ++++++++++++++++++ .../scheduler/scheduler_nokit_test.go | 163 +++++++++-------- .../framework/scheduler/scheduler_test.go | 60 ++----- pkg/disttask/framework/scheduler/testutil.go | 49 ++++++ pkg/disttask/framework/storage/table_test.go | 2 +- pkg/disttask/framework/storage/task_table.go | 6 +- pkg/disttask/framework/testutil/context.go | 3 +- .../framework/testutil/scheduler_util.go | 6 +- pkg/metrics/disttask.go | 4 +- .../addindextest1/disttask_test.go | 2 +- 14 files changed, 440 insertions(+), 192 deletions(-) create mode 100644 pkg/disttask/framework/scheduler/scheduler_manager_nokit_test.go create mode 100644 pkg/disttask/framework/scheduler/testutil.go diff --git a/pkg/disttask/framework/scheduler/BUILD.bazel b/pkg/disttask/framework/scheduler/BUILD.bazel index bda7c103c930b..83a5167ae896b 100644 --- a/pkg/disttask/framework/scheduler/BUILD.bazel +++ b/pkg/disttask/framework/scheduler/BUILD.bazel @@ -11,6 +11,7 @@ go_library( "scheduler_manager.go", "slots.go", "state_transform.go", + "testutil.go", ], importpath = "github.com/pingcap/tidb/pkg/disttask/framework/scheduler", visibility = ["//visibility:public"], @@ -18,6 +19,7 @@ go_library( "//br/pkg/lightning/log", "//pkg/disttask/framework/handle", "//pkg/disttask/framework/proto", + "//pkg/disttask/framework/scheduler/mock", "//pkg/disttask/framework/storage", "//pkg/domain/infosync", "//pkg/kv", @@ -34,6 +36,7 @@ go_library( "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_log//:log", "@com_github_prometheus_client_golang//prometheus", + "@org_uber_go_mock//gomock", "@org_uber_go_zap//:zap", ], ) @@ -45,6 +48,7 @@ go_test( "balancer_test.go", "main_test.go", "nodes_test.go", + "scheduler_manager_nokit_test.go", "scheduler_manager_test.go", "scheduler_nokit_test.go", "scheduler_test.go", @@ -53,7 +57,7 @@ go_test( embed = [":scheduler"], flaky = True, race = "off", - shard_count = 31, + shard_count = 33, deps = [ "//pkg/config", "//pkg/disttask/framework/mock", diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index 883b8db0fc961..4a61273eb2c97 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -136,10 +136,11 @@ type Extension interface { // Param is used to pass parameters when creating scheduler. type Param struct { - taskMgr TaskManager - nodeMgr *NodeManager - slotMgr *SlotManager - serverID string + taskMgr TaskManager + nodeMgr *NodeManager + slotMgr *SlotManager + serverID string + allocatedSlots bool } // schedulerFactoryFn is used to create a scheduler. diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index baf18f7ba9886..96e1ddcb6eaaf 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -95,7 +95,7 @@ var MockOwnerChange func() // NewBaseScheduler creates a new BaseScheduler. func NewBaseScheduler(ctx context.Context, task *proto.Task, param Param) *BaseScheduler { - logger := log.L().With(zap.Int64("task-id", task.ID), zap.Stringer("task-type", task.Type)) + logger := log.L().With(zap.Int64("task-id", task.ID), zap.Stringer("task-type", task.Type), zap.Bool("allocated-slots", param.allocatedSlots)) if intest.InTest { logger = logger.With(zap.String("server-id", param.serverID)) } @@ -179,6 +179,10 @@ func (s *BaseScheduler) scheduleTask() { continue } task := *s.GetTask() + // TODO: refine failpoints below. + failpoint.Inject("exitScheduler", func() { + failpoint.Return() + }) failpoint.Inject("cancelTaskAfterRefreshTask", func(val failpoint.Value) { if val.(bool) && task.State == proto.TaskStateRunning { err := s.taskMgr.CancelTask(s.ctx, task.ID) @@ -222,12 +226,35 @@ func (s *BaseScheduler) scheduleTask() { return } case proto.TaskStateResuming: + // Case with 2 nodes. + // Here is the timeline + // 1. task in pausing state. + // 2. node1 and node2 start schedulers with task in pausing state without allocatedSlots. + // 3. node1's scheduler transfer the node from pausing to paused state. + // 4. resume the task. + // 5. node2 scheduler call refreshTask and get task with resuming state. + if !s.allocatedSlots { + s.logger.Info("scheduler exit since not allocated slots", zap.Stringer("state", task.State)) + return + } err = s.onResuming() case proto.TaskStateReverting: err = s.onReverting() case proto.TaskStatePending: err = s.onPending() case proto.TaskStateRunning: + // Case with 2 nodes. + // Here is the timeline + // 1. task in pausing state. + // 2. node1 and node2 start schedulers with task in pausing state without allocatedSlots. + // 3. node1's scheduler transfer the node from pausing to paused state. + // 4. resume the task. + // 5. node1 start another scheduler and transfer the node from resuming to running state. + // 6. node2 scheduler call refreshTask and get task with running state. + if !s.allocatedSlots { + s.logger.Info("scheduler exit since not allocated slots", zap.Stringer("state", task.State)) + return + } err = s.onRunning() case proto.TaskStateSucceed, proto.TaskStateReverted, proto.TaskStateFailed: s.onFinished() diff --git a/pkg/disttask/framework/scheduler/scheduler_manager.go b/pkg/disttask/framework/scheduler/scheduler_manager.go index 8b0346d66de95..b556cf64356d1 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager.go @@ -216,52 +216,80 @@ func (sm *Manager) scheduleTaskLoop() { continue } - tasks, err := sm.taskMgr.GetTopUnfinishedTasks(sm.ctx) + schedulableTasks, err := sm.getSchedulableTasks() if err != nil { - sm.logger.Warn("get unfinished tasks failed", zap.Error(err)) continue } - schedulableTasks := make([]*proto.TaskBase, 0, len(tasks)) - for _, task := range tasks { - if sm.hasScheduler(task.ID) { - continue - } - // we check it before start scheduler, so no need to check it again. - // see startScheduler. - // this should not happen normally, unless user modify system table - // directly. - if getSchedulerFactory(task.Type) == nil { - sm.logger.Warn("unknown task type", zap.Int64("task-id", task.ID), - zap.Stringer("task-type", task.Type)) - sm.failTask(task.ID, task.State, errors.New("unknown task type")) - continue - } - schedulableTasks = append(schedulableTasks, task) - } - if len(schedulableTasks) == 0 { + err = sm.startSchedulers(schedulableTasks) + if err != nil { continue } + } +} + +func (sm *Manager) getSchedulableTasks() ([]*proto.TaskBase, error) { + tasks, err := sm.taskMgr.GetTopUnfinishedTasks(sm.ctx) + if err != nil { + sm.logger.Warn("get unfinished tasks failed", zap.Error(err)) + return nil, err + } - if err = sm.slotMgr.update(sm.ctx, sm.nodeMgr, sm.taskMgr); err != nil { - sm.logger.Warn("update used slot failed", zap.Error(err)) + schedulableTasks := make([]*proto.TaskBase, 0, len(tasks)) + for _, task := range tasks { + if sm.hasScheduler(task.ID) { continue } - for _, task := range schedulableTasks { - taskCnt = sm.getSchedulerCount() - if taskCnt >= proto.MaxConcurrentTask { - break - } - reservedExecID, ok := sm.slotMgr.canReserve(task) + // we check it before start scheduler, so no need to check it again. + // see startScheduler. + // this should not happen normally, unless user modify system table + // directly. + if getSchedulerFactory(task.Type) == nil { + sm.logger.Warn("unknown task type", zap.Int64("task-id", task.ID), + zap.Stringer("task-type", task.Type)) + sm.failTask(task.ID, task.State, errors.New("unknown task type")) + continue + } + schedulableTasks = append(schedulableTasks, task) + } + return schedulableTasks, nil +} + +func (sm *Manager) startSchedulers(schedulableTasks []*proto.TaskBase) error { + if len(schedulableTasks) == 0 { + return nil + } + if err := sm.slotMgr.update(sm.ctx, sm.nodeMgr, sm.taskMgr); err != nil { + sm.logger.Warn("update used slot failed", zap.Error(err)) + return err + } + for _, task := range schedulableTasks { + taskCnt := sm.getSchedulerCount() + if taskCnt >= proto.MaxConcurrentTask { + break + } + var reservedExecID string + allocateSlots := true + var ok bool + switch task.State { + case proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateResuming: + reservedExecID, ok = sm.slotMgr.canReserve(task) if !ok { // task of lower rank might be able to be scheduled. continue } - metrics.DistTaskGauge.WithLabelValues(task.Type.String(), metrics.SchedulingStatus).Inc() - metrics.UpdateMetricsForDispatchTask(task.ID, task.Type) - sm.startScheduler(task, reservedExecID) + // reverting/cancelling/pausing + default: + allocateSlots = false + sm.logger.Info("start scheduler without allocating slots", + zap.Int64("task-id", task.ID), zap.Stringer("state", task.State)) } + + metrics.DistTaskGauge.WithLabelValues(task.Type.String(), metrics.SchedulingStatus).Inc() + metrics.UpdateMetricsForScheduleTask(task.ID, task.Type) + sm.startScheduler(task, allocateSlots, reservedExecID) } + return nil } func (sm *Manager) failTask(id int64, currState proto.TaskState, err error) { @@ -300,7 +328,7 @@ func (sm *Manager) gcSubtaskHistoryTableLoop() { } } -func (sm *Manager) startScheduler(basicTask *proto.TaskBase, reservedExecID string) { +func (sm *Manager) startScheduler(basicTask *proto.TaskBase, allocateSlots bool, reservedExecID string) { task, err := sm.taskMgr.GetTaskByID(sm.ctx, basicTask.ID) if err != nil { sm.logger.Error("get task failed", zap.Int64("task-id", basicTask.ID), zap.Error(err)) @@ -309,10 +337,11 @@ func (sm *Manager) startScheduler(basicTask *proto.TaskBase, reservedExecID stri schedulerFactory := getSchedulerFactory(task.Type) scheduler := schedulerFactory(sm.ctx, task, Param{ - taskMgr: sm.taskMgr, - nodeMgr: sm.nodeMgr, - slotMgr: sm.slotMgr, - serverID: sm.serverID, + taskMgr: sm.taskMgr, + nodeMgr: sm.nodeMgr, + slotMgr: sm.slotMgr, + serverID: sm.serverID, + allocatedSlots: allocateSlots, }) if err = scheduler.Init(); err != nil { sm.logger.Error("init scheduler failed", zap.Error(err)) @@ -320,13 +349,17 @@ func (sm *Manager) startScheduler(basicTask *proto.TaskBase, reservedExecID stri return } sm.addScheduler(task.ID, scheduler) - sm.slotMgr.reserve(basicTask, reservedExecID) + if allocateSlots { + sm.slotMgr.reserve(basicTask, reservedExecID) + } sm.logger.Info("task scheduler started", zap.Int64("task-id", task.ID)) sm.schedulerWG.RunWithLog(func() { defer func() { scheduler.Close() sm.delScheduler(task.ID) - sm.slotMgr.unReserve(basicTask, reservedExecID) + if allocateSlots { + sm.slotMgr.unReserve(basicTask, reservedExecID) + } handle.NotifyTaskChange() sm.logger.Info("task scheduler exist", zap.Int64("task-id", task.ID)) }() @@ -416,16 +449,6 @@ func (sm *Manager) cleanupFinishedTasks(tasks []*proto.Task) error { return sm.taskMgr.TransferTasks2History(sm.ctx, cleanedTasks) } -// MockScheduler mock one scheduler for one task, only used for tests. -func (sm *Manager) MockScheduler(task *proto.Task) *BaseScheduler { - return NewBaseScheduler(sm.ctx, task, Param{ - taskMgr: sm.taskMgr, - nodeMgr: sm.nodeMgr, - slotMgr: sm.slotMgr, - serverID: sm.serverID, - }) -} - func (sm *Manager) collectLoop() { sm.logger.Info("collect loop start") ticker := time.NewTicker(defaultCollectMetricsInterval) @@ -450,3 +473,13 @@ func (sm *Manager) collect() { subtaskCollector.subtaskInfo.Store(&subtasks) } + +// MockScheduler mock one scheduler for one task, only used for tests. +func (sm *Manager) MockScheduler(task *proto.Task) *BaseScheduler { + return NewBaseScheduler(sm.ctx, task, Param{ + taskMgr: sm.taskMgr, + nodeMgr: sm.nodeMgr, + slotMgr: sm.slotMgr, + serverID: sm.serverID, + }) +} diff --git a/pkg/disttask/framework/scheduler/scheduler_manager_nokit_test.go b/pkg/disttask/framework/scheduler/scheduler_manager_nokit_test.go new file mode 100644 index 0000000000000..9380d0bd6e927 --- /dev/null +++ b/pkg/disttask/framework/scheduler/scheduler_manager_nokit_test.go @@ -0,0 +1,164 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "context" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/disttask/framework/mock" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestManagerSchedulersOrdered(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mgr := NewManager(context.Background(), nil, "1") + for i := 1; i <= 5; i++ { + task := &proto.Task{TaskBase: proto.TaskBase{ + ID: int64(i * 10), + }} + mockScheduler := mock.NewMockScheduler(ctrl) + mockScheduler.EXPECT().GetTask().Return(task).AnyTimes() + mgr.addScheduler(task.ID, mockScheduler) + } + ordered := func(schedulers []Scheduler) bool { + for i := 1; i < len(schedulers); i++ { + if schedulers[i-1].GetTask().CompareTask(schedulers[i].GetTask()) >= 0 { + return false + } + } + return true + } + require.Len(t, mgr.getSchedulers(), 5) + require.True(t, ordered(mgr.getSchedulers())) + + task35 := &proto.Task{TaskBase: proto.TaskBase{ + ID: int64(35), + }} + mockScheduler35 := mock.NewMockScheduler(ctrl) + mockScheduler35.EXPECT().GetTask().Return(task35).AnyTimes() + + mgr.delScheduler(30) + require.False(t, mgr.hasScheduler(30)) + mgr.addScheduler(task35.ID, mockScheduler35) + require.True(t, mgr.hasScheduler(35)) + require.Len(t, mgr.getSchedulers(), 5) + require.True(t, ordered(mgr.getSchedulers())) +} + +func TestSchedulerCleanupTask(t *testing.T) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask")) + }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + taskMgr := mock.NewMockTaskManager(ctrl) + ctx := context.Background() + mgr := NewManager(ctx, taskMgr, "1") + + // normal + tasks := []*proto.Task{ + {TaskBase: proto.TaskBase{ID: 1}}, + } + taskMgr.EXPECT().GetTasksInStates( + mgr.ctx, + proto.TaskStateFailed, + proto.TaskStateReverted, + proto.TaskStateSucceed).Return(tasks, nil) + + taskMgr.EXPECT().TransferTasks2History(mgr.ctx, tasks).Return(nil) + mgr.doCleanupTask() + require.True(t, ctrl.Satisfied()) + + // fail in transfer + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/WaitCleanUpFinished", "1*return()")) + mockErr := errors.New("transfer err") + taskMgr.EXPECT().GetTasksInStates( + mgr.ctx, + proto.TaskStateFailed, + proto.TaskStateReverted, + proto.TaskStateSucceed).Return(tasks, nil) + taskMgr.EXPECT().TransferTasks2History(mgr.ctx, tasks).Return(mockErr) + mgr.doCleanupTask() + require.True(t, ctrl.Satisfied()) + + taskMgr.EXPECT().GetTasksInStates( + mgr.ctx, + proto.TaskStateFailed, + proto.TaskStateReverted, + proto.TaskStateSucceed).Return(tasks, nil) + taskMgr.EXPECT().TransferTasks2History(mgr.ctx, tasks).Return(nil) + mgr.doCleanupTask() + require.True(t, ctrl.Satisfied()) + + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/WaitCleanUpFinished")) +} + +func TestManagerSchedulerNotAllocateSlots(t *testing.T) { + // the tests make sure allocatedSlots correct. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/exitScheduler", "return()")) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + taskMgr := mock.NewMockTaskManager(ctrl) + mgr := NewManager(context.Background(), taskMgr, "1") + RegisterSchedulerFactory(proto.TaskTypeExample, + func(ctx context.Context, task *proto.Task, param Param) Scheduler { + mockScheduler := NewBaseScheduler(ctx, task, param) + mockScheduler.Extension = GetTestSchedulerExt(ctrl) + return mockScheduler + }) + taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(nil, nil).AnyTimes() + tasks := []*proto.TaskBase{ + { + ID: int64(1), + Concurrency: 1, + Type: proto.TaskTypeExample, + State: proto.TaskStateCancelling, + }, + { + ID: int64(2), + Concurrency: 1, + Type: proto.TaskTypeExample, + State: proto.TaskStateReverting, + }, + { + ID: int64(3), + Concurrency: 1, + Type: proto.TaskTypeExample, + State: proto.TaskStatePausing, + }, + } + for i := 1; i <= 3; i++ { + taskMgr.EXPECT().GetTaskByID(gomock.Any(), int64(i)).Return(&proto.Task{TaskBase: *tasks[i-1]}, nil) + taskMgr.EXPECT().GetTaskBaseByID(gomock.Any(), int64(i)).Return(tasks[i-1], nil) + } + + require.NoError(t, mgr.startSchedulers(tasks)) + schs := mgr.getSchedulers() + require.Equal(t, 3, len(schs)) + for _, sch := range schs { + require.Equal(t, false, sch.(*BaseScheduler).allocatedSlots) + <-mgr.finishCh + } + mgr.schedulerWG.Wait() + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/exitScheduler")) +} diff --git a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go index ba805148a6920..aa0a24d0825ff 100644 --- a/pkg/disttask/framework/scheduler/scheduler_nokit_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_nokit_test.go @@ -21,7 +21,6 @@ import ( "time" "github.com/pingcap/errors" - "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/config" "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/proto" @@ -33,14 +32,24 @@ import ( "go.uber.org/mock/gomock" ) -func TestDispatcherOnNextStage(t *testing.T) { +func createScheduler(task *proto.Task, allocatedSlots bool, taskMgr TaskManager, ctrl *gomock.Controller) *BaseScheduler { + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "scheduler") + nodeMgr := NewNodeManager() + sch := NewBaseScheduler(ctx, task, Param{ + taskMgr: taskMgr, + nodeMgr: nodeMgr, + slotMgr: newSlotManager(), + allocatedSlots: allocatedSlots, + }) + return sch +} + +func TestSchedulerOnNextStage(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskMgr := mock.NewMockTaskManager(ctrl) schExt := schmock.NewMockExtension(ctrl) - - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "dispatcher") task := proto.Task{ TaskBase: proto.TaskBase{ ID: 1, @@ -49,12 +58,7 @@ func TestDispatcherOnNextStage(t *testing.T) { }, } cloneTask := task - nodeMgr := NewNodeManager() - sch := NewBaseScheduler(ctx, &cloneTask, Param{ - taskMgr: taskMgr, - nodeMgr: nodeMgr, - slotMgr: newSlotManager(), - }) + sch := createScheduler(&cloneTask, true, taskMgr, ctrl) sch.Extension = schExt // test next step is done @@ -140,43 +144,6 @@ func TestDispatcherOnNextStage(t *testing.T) { require.True(t, ctrl.Satisfied()) } -func TestManagerSchedulersOrdered(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mgr := NewManager(context.Background(), nil, "1") - for i := 1; i <= 5; i++ { - task := &proto.Task{TaskBase: proto.TaskBase{ - ID: int64(i * 10), - }} - mockScheduler := mock.NewMockScheduler(ctrl) - mockScheduler.EXPECT().GetTask().Return(task).AnyTimes() - mgr.addScheduler(task.ID, mockScheduler) - } - ordered := func(schedulers []Scheduler) bool { - for i := 1; i < len(schedulers); i++ { - if schedulers[i-1].GetTask().CompareTask(schedulers[i].GetTask()) >= 0 { - return false - } - } - return true - } - require.Len(t, mgr.getSchedulers(), 5) - require.True(t, ordered(mgr.getSchedulers())) - - task35 := &proto.Task{TaskBase: proto.TaskBase{ - ID: int64(35), - }} - mockScheduler35 := mock.NewMockScheduler(ctrl) - mockScheduler35.EXPECT().GetTask().Return(task35).AnyTimes() - - mgr.delScheduler(30) - require.False(t, mgr.hasScheduler(30)) - mgr.addScheduler(task35.ID, mockScheduler35) - require.True(t, mgr.hasScheduler(35)) - require.Len(t, mgr.getSchedulers(), 5) - require.True(t, ordered(mgr.getSchedulers())) -} - func TestGetEligibleNodes(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -219,53 +186,81 @@ func TestSchedulerIsStepSucceed(t *testing.T) { } } -func TestSchedulerCleanupTask(t *testing.T) { - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)")) - defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask")) - }() +func TestSchedulerNotAllocateSlots(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskMgr := mock.NewMockTaskManager(ctrl) - ctx := context.Background() - mgr := NewManager(ctx, taskMgr, "1") - // normal - tasks := []*proto.Task{ - {TaskBase: proto.TaskBase{ID: 1}}, + // scheduler not allocated slots, task from paused to resuming. Should exit the scheduler. + task := proto.Task{ + TaskBase: proto.TaskBase{ + ID: int64(1), + Concurrency: 1, + Type: proto.TaskTypeExample, + State: proto.TaskStatePaused, + }, } - taskMgr.EXPECT().GetTasksInStates( - mgr.ctx, - proto.TaskStateFailed, - proto.TaskStateReverted, - proto.TaskStateSucceed).Return(tasks, nil) - - taskMgr.EXPECT().TransferTasks2History(mgr.ctx, tasks).Return(nil) - mgr.doCleanupTask() + cloneTask := task + sch := createScheduler(&cloneTask, false, taskMgr, ctrl) + taskMgr.EXPECT().GetTaskBaseByID(gomock.Any(), cloneTask.ID).DoAndReturn(func(_ context.Context, _ int64) (*proto.TaskBase, error) { + cloneTask.State = proto.TaskStateResuming + return &cloneTask.TaskBase, nil + }) + sch.scheduleTask() require.True(t, ctrl.Satisfied()) - // fail in transfer - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/WaitCleanUpFinished", "1*return()")) - mockErr := errors.New("transfer err") - taskMgr.EXPECT().GetTasksInStates( - mgr.ctx, - proto.TaskStateFailed, - proto.TaskStateReverted, - proto.TaskStateSucceed).Return(tasks, nil) - taskMgr.EXPECT().TransferTasks2History(mgr.ctx, tasks).Return(mockErr) - mgr.doCleanupTask() + // scheduler not allocated slots, task from paused to running. Should exit the scheduler. + task.State = proto.TaskStatePaused + cloneTask = task + sch = createScheduler(&cloneTask, false, taskMgr, ctrl) + taskMgr.EXPECT().GetTaskBaseByID(gomock.Any(), cloneTask.ID).DoAndReturn(func(_ context.Context, _ int64) (*proto.TaskBase, error) { + cloneTask.State = proto.TaskStateRunning + return &cloneTask.TaskBase, nil + }) + sch.scheduleTask() require.True(t, ctrl.Satisfied()) - taskMgr.EXPECT().GetTasksInStates( - mgr.ctx, - proto.TaskStateFailed, - proto.TaskStateReverted, - proto.TaskStateSucceed).Return(tasks, nil) - taskMgr.EXPECT().TransferTasks2History(mgr.ctx, tasks).Return(nil) - mgr.doCleanupTask() + // scheduler not allocated slots, but won't exit the scheduler. + task.State = proto.TaskStateReverting + cloneTask = task + + sch = createScheduler(&cloneTask, false, taskMgr, ctrl) + schExt := schmock.NewMockExtension(ctrl) + sch.Extension = schExt + schExt.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + taskMgr.EXPECT().GetTaskBaseByID(gomock.Any(), cloneTask.ID).DoAndReturn(func(_ context.Context, _ int64) (*proto.TaskBase, error) { + return &cloneTask.TaskBase, nil + }) + + taskMgr.EXPECT().GetSubtaskCntGroupByStates(gomock.Any(), cloneTask.ID, cloneTask.Step).Return(map[proto.SubtaskState]int64{ + proto.SubtaskStatePending: 0, + proto.SubtaskStateRunning: 0}, nil) + taskMgr.EXPECT().RevertedTask(gomock.Any(), cloneTask.ID).Return(nil) + taskMgr.EXPECT().GetTaskBaseByID(gomock.Any(), cloneTask.ID).DoAndReturn(func(_ context.Context, _ int64) (*proto.TaskBase, error) { + cloneTask.State = proto.TaskStateReverted + return &cloneTask.TaskBase, nil + }) + sch.scheduleTask() require.True(t, ctrl.Satisfied()) - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/WaitCleanUpFinished")) + task.State = proto.TaskStatePausing + cloneTask = task + sch = createScheduler(&cloneTask, false, taskMgr, ctrl) + schExt = schmock.NewMockExtension(ctrl) + sch.Extension = schExt + taskMgr.EXPECT().GetTaskBaseByID(gomock.Any(), cloneTask.ID).DoAndReturn(func(_ context.Context, _ int64) (*proto.TaskBase, error) { + return &cloneTask.TaskBase, nil + }) + taskMgr.EXPECT().GetSubtaskCntGroupByStates(gomock.Any(), cloneTask.ID, cloneTask.Step).Return(map[proto.SubtaskState]int64{ + proto.SubtaskStatePending: 0, + proto.SubtaskStateRunning: 0}, nil) + taskMgr.EXPECT().GetTaskBaseByID(gomock.Any(), cloneTask.ID).DoAndReturn(func(_ context.Context, _ int64) (*proto.TaskBase, error) { + cloneTask.State = proto.TaskStatePaused + return &cloneTask.TaskBase, nil + }) + taskMgr.EXPECT().PausedTask(gomock.Any(), cloneTask.ID).Return(nil) + sch.scheduleTask() + require.True(t, ctrl.Satisfied()) } func TestSchedulerRefreshTask(t *testing.T) { diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index 3a202663de17d..a7804d66060f5 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -27,7 +27,7 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - mockDispatch "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mock" + mockscheduler "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mock" "github.com/pingcap/tidb/pkg/disttask/framework/storage" "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/domain/infosync" @@ -46,32 +46,8 @@ const ( subtaskCnt = 3 ) -func getTestSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { - mockScheduler := mockDispatch.NewMockExtension(ctrl) - mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() - mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ *proto.Task) ([]string, error) { - return nil, nil - }, - ).AnyTimes() - mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(true).AnyTimes() - mockScheduler.EXPECT().GetNextStep(gomock.Any()).DoAndReturn( - func(task *proto.TaskBase) proto.Step { - return proto.StepDone - }, - ).AnyTimes() - mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ storage.TaskHandle, _ *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { - return nil, nil - }, - ).AnyTimes() - - mockScheduler.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - return mockScheduler -} - func getNumberExampleSchedulerExt(ctrl *gomock.Controller) scheduler.Extension { - mockScheduler := mockDispatch.NewMockExtension(ctrl) + mockScheduler := mockscheduler.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, _ *proto.Task) ([]string, error) { @@ -118,7 +94,7 @@ func MockSchedulerManager(t *testing.T, ctrl *gomock.Controller, pool *pools.Res sch := scheduler.NewManager(util.WithInternalSourceType(ctx, "scheduler"), mgr, "host:port") scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { - mockScheduler := sch.MockScheduler(task) + mockScheduler := scheduler.NewBaseScheduler(ctx, task, param) mockScheduler.Extension = ext return mockScheduler }) @@ -142,11 +118,11 @@ func TestTaskFailInManager(t *testing.T) { ctx := context.Background() ctx = util.WithInternalSourceType(ctx, "handle_test") - mockScheduler := mock.NewMockScheduler(ctrl) - mockScheduler.EXPECT().Init().Return(errors.New("mock scheduler init error")) - schManager, mgr := MockSchedulerManager(t, ctrl, pool, getTestSchedulerExt(ctrl), nil) + schManager, mgr := MockSchedulerManager(t, ctrl, pool, scheduler.GetTestSchedulerExt(ctrl), nil) scheduler.RegisterSchedulerFactory(proto.TaskTypeExample, func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { + mockScheduler := mock.NewMockScheduler(ctrl) + mockScheduler.EXPECT().Init().Return(errors.New("mock scheduler init error")) return mockScheduler }) schManager.Start() @@ -173,9 +149,9 @@ func TestTaskFailInManager(t *testing.T) { }, time.Second*10, time.Millisecond*300) } -func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, isPauseAndResume bool) { +func checkSchedule(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, isPauseAndResume bool) { testkit.EnableFailPoint(t, "github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)") - // test DispatchTaskLoop + // test scheduleTaskLoop // test parallelism control var originalConcurrency int if taskCnt == 1 { @@ -351,43 +327,43 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, } func TestSimple(t *testing.T) { - checkDispatch(t, 1, true, false, false, false) + checkSchedule(t, 1, true, false, false, false) } func TestSimpleErrStage(t *testing.T) { - checkDispatch(t, 1, false, false, false, false) + checkSchedule(t, 1, false, false, false, false) } func TestSimpleCancel(t *testing.T) { - checkDispatch(t, 1, false, true, false, false) + checkSchedule(t, 1, false, true, false, false) } func TestSimpleSubtaskCancel(t *testing.T) { - checkDispatch(t, 1, false, false, true, false) + checkSchedule(t, 1, false, false, true, false) } func TestParallel(t *testing.T) { - checkDispatch(t, 3, true, false, false, false) + checkSchedule(t, 3, true, false, false, false) } func TestParallelErrStage(t *testing.T) { - checkDispatch(t, 3, false, false, false, false) + checkSchedule(t, 3, false, false, false, false) } func TestParallelCancel(t *testing.T) { - checkDispatch(t, 3, false, true, false, false) + checkSchedule(t, 3, false, true, false, false) } func TestParallelSubtaskCancel(t *testing.T) { - checkDispatch(t, 3, false, false, true, false) + checkSchedule(t, 3, false, false, true, false) } func TestPause(t *testing.T) { - checkDispatch(t, 1, false, false, false, true) + checkSchedule(t, 1, false, false, false, true) } func TestParallelPause(t *testing.T) { - checkDispatch(t, 3, false, false, false, true) + checkSchedule(t, 3, false, false, false, true) } func TestVerifyTaskStateTransform(t *testing.T) { diff --git a/pkg/disttask/framework/scheduler/testutil.go b/pkg/disttask/framework/scheduler/testutil.go new file mode 100644 index 0000000000000..6303dde35fcdb --- /dev/null +++ b/pkg/disttask/framework/scheduler/testutil.go @@ -0,0 +1,49 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "context" + + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + mockScheduler "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mock" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "go.uber.org/mock/gomock" +) + +// GetTestSchedulerExt return scheduler.Extension for testing. +func GetTestSchedulerExt(ctrl *gomock.Controller) Extension { + mockScheduler := mockScheduler.NewMockExtension(ctrl) + mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() + mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, _ *proto.Task) ([]string, error) { + return nil, nil + }, + ).AnyTimes() + mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(true).AnyTimes() + mockScheduler.EXPECT().GetNextStep(gomock.Any()).DoAndReturn( + func(_ *proto.Task) proto.Step { + return proto.StepDone + }, + ).AnyTimes() + mockScheduler.EXPECT().OnNextSubtasksBatch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, _ storage.TaskHandle, _ *proto.Task, _ []string, _ proto.Step) (metas [][]byte, err error) { + return nil, nil + }, + ).AnyTimes() + + mockScheduler.EXPECT().OnDone(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + return mockScheduler +} diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 01d927cc27d36..237bb64c14a06 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -299,7 +299,7 @@ func TestSwitchTaskStepInBatch(t *testing.T) { require.NoError(t, err) checkAfterSwitchStep(t, startTime, task1, subtasks1, proto.StepOne) - // mock another dispatcher inserted some subtasks + // mock another scheduler inserted some subtasks testkit.EnableFailPoint(t, "github.com/pingcap/tidb/pkg/disttask/framework/storage/waitBeforeInsertSubtasks", `1*return(true)`) task2, subtasks2 := prepare("key2") go func() { diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index a6bc40738e5ac..7b858ea660c53 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -581,7 +581,7 @@ func (mgr *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, subtasks []*p return err } -// SwitchTaskStep implements the dispatcher.TaskManager interface. +// SwitchTaskStep implements the scheduler.TaskManager interface. func (mgr *TaskManager) SwitchTaskStep( ctx context.Context, task *proto.Task, @@ -662,7 +662,7 @@ func (*TaskManager) insertSubtasks(ctx context.Context, se sessionctx.Context, s return err } -// SwitchTaskStepInBatch implements the dispatcher.TaskManager interface. +// SwitchTaskStepInBatch implements the scheduler.TaskManager interface. func (mgr *TaskManager) SwitchTaskStepInBatch( ctx context.Context, task *proto.Task, @@ -671,7 +671,7 @@ func (mgr *TaskManager) SwitchTaskStepInBatch( subtasks []*proto.Subtask, ) error { return mgr.WithNewSession(func(se sessionctx.Context) error { - // some subtasks may be inserted by other dispatchers, we can skip them. + // some subtasks may be inserted by other schedulers, we can skip them. rs, err := sqlexec.ExecSQL(ctx, se, ` select count(1) from mysql.tidb_background_subtask where task_key = %? and step = %?`, task.ID, nextStep) diff --git a/pkg/disttask/framework/testutil/context.go b/pkg/disttask/framework/testutil/context.go index 6194f31b8945c..469015d45672a 100644 --- a/pkg/disttask/framework/testutil/context.go +++ b/pkg/disttask/framework/testutil/context.go @@ -387,8 +387,7 @@ type TestContext struct { func InitTestContext(t *testing.T, nodeNum int) (context.Context, *gomock.Controller, *TestContext, *testkit.DistExecutionContext) { ctrl := gomock.NewController(t) defer ctrl.Finish() - ctx := context.Background() - ctx = util.WithInternalSourceType(ctx, "dispatcher") + ctx := util.WithInternalSourceType(context.Background(), "scheduler") testkit.EnableFailPoint(t, "github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(8)") executionContext := testkit.NewDistExecutionContext(t, nodeNum) diff --git a/pkg/disttask/framework/testutil/scheduler_util.go b/pkg/disttask/framework/testutil/scheduler_util.go index 4316b1ea9fd38..d9983ac44e15e 100644 --- a/pkg/disttask/framework/testutil/scheduler_util.go +++ b/pkg/disttask/framework/testutil/scheduler_util.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - mockDispatch "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mock" + mockScheduler "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/mock" "github.com/pingcap/tidb/pkg/disttask/framework/storage" "go.uber.org/mock/gomock" ) @@ -73,7 +73,7 @@ func GetMockSchedulerExt(ctrl *gomock.Controller, schedulerInfo SchedulerInfo) s } stepTransition[currStep] = proto.StepDone - mockScheduler := mockDispatch.NewMockExtension(ctrl) + mockScheduler := mockScheduler.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockScheduler.EXPECT().IsRetryableErr(gomock.Any()).Return(schedulerInfo.AllErrorRetryable).AnyTimes() @@ -138,7 +138,7 @@ func GetStepTwoPlanNotRetryableErrSchedulerExt(ctrl *gomock.Controller) schedule // GetPlanErrSchedulerExt returns mock scheduler.Extension which will generate error when planning. func GetPlanErrSchedulerExt(ctrl *gomock.Controller, testContext *TestContext) scheduler.Extension { - mockScheduler := mockDispatch.NewMockExtension(ctrl) + mockScheduler := mockScheduler.NewMockExtension(ctrl) mockScheduler.EXPECT().OnTick(gomock.Any(), gomock.Any()).Return().AnyTimes() mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).DoAndReturn( func(_ context.Context, _ *proto.Task) ([]string, error) { diff --git a/pkg/metrics/disttask.go b/pkg/metrics/disttask.go index c82f070b0eb28..f5642d2087bee 100644 --- a/pkg/metrics/disttask.go +++ b/pkg/metrics/disttask.go @@ -74,8 +74,8 @@ func UpdateMetricsForAddTask(task *proto.TaskBase) { DistTaskStarttimeGauge.WithLabelValues(task.Type.String(), WaitingStatus, fmt.Sprint(task.ID)).Set(float64(time.Now().UnixMicro())) } -// UpdateMetricsForDispatchTask update metrics when a task is added -func UpdateMetricsForDispatchTask(id int64, taskType proto.TaskType) { +// UpdateMetricsForScheduleTask update metrics when a task is added +func UpdateMetricsForScheduleTask(id int64, taskType proto.TaskType) { DistTaskGauge.WithLabelValues(taskType.String(), WaitingStatus).Dec() DistTaskStarttimeGauge.DeleteLabelValues(taskType.String(), WaitingStatus, fmt.Sprint(id)) DistTaskStarttimeGauge.WithLabelValues(taskType.String(), SchedulingStatus, fmt.Sprint(id)).SetToCurrentTime() diff --git a/tests/realtikvtest/addindextest1/disttask_test.go b/tests/realtikvtest/addindextest1/disttask_test.go index e9c7582c0b667..c2ecf39a1a6f6 100644 --- a/tests/realtikvtest/addindextest1/disttask_test.go +++ b/tests/realtikvtest/addindextest1/disttask_test.go @@ -78,7 +78,7 @@ func TestAddIndexDistBasic(t *testing.T) { tk.MustExec("admin check index t idx;") taskMgr, err := storage.GetTaskManager() require.NoError(t, err) - ctx := util.WithInternalSourceType(context.Background(), "dispatcher") + ctx := util.WithInternalSourceType(context.Background(), "scheduler") task, err := taskMgr.GetTaskByIDWithHistory(ctx, storage.TestLastTaskID.Load()) require.NoError(t, err) require.Equal(t, 1, task.Concurrency) From caf83ed93010bd5bb9c99e29d586ce5a81a478c0 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Mon, 11 Mar 2024 16:06:38 +0800 Subject: [PATCH 11/15] statistics: add `last_analyze_version` for `mysql.stats_meta` (#51489) ref pingcap/tidb#49594 --- pkg/executor/show_stats.go | 29 +++++++--- pkg/executor/show_stats_test.go | 3 ++ pkg/executor/test/analyzetest/analyze_test.go | 10 ++-- pkg/planner/core/planbuilder.go | 4 +- .../handler/optimizor/plan_replayer_test.go | 2 +- .../optimizor/statistics_handler_test.go | 2 +- pkg/session/bootstraptest/main_test.go | 2 + pkg/session/main_test.go | 2 + pkg/statistics/BUILD.bazel | 2 +- .../handle/autoanalyze/autoanalyze.go | 2 +- .../handle/autoanalyze/autoanalyze_test.go | 18 +++---- .../handle/autoanalyze/exec/exec.go | 15 ------ .../handle/autoanalyze/refresher/refresher.go | 19 ++----- .../autoanalyze/refresher/refresher_test.go | 15 ++++-- pkg/statistics/handle/bootstrap.go | 10 ++++ pkg/statistics/handle/storage/read.go | 18 +++++-- pkg/statistics/integration_test.go | 54 +++++++++++++++++++ pkg/statistics/table.go | 28 ++++++---- 18 files changed, 159 insertions(+), 76 deletions(-) diff --git a/pkg/executor/show_stats.go b/pkg/executor/show_stats.go index 88aba55e33816..2da8ff961b612 100644 --- a/pkg/executor/show_stats.go +++ b/pkg/executor/show_stats.go @@ -141,14 +141,27 @@ func (e *ShowExec) appendTableForStatsMeta(dbName, tblName, partitionName string if statsTbl.Pseudo { return } - e.appendRow([]any{ - dbName, - tblName, - partitionName, - e.versionToTime(statsTbl.Version), - statsTbl.ModifyCount, - statsTbl.RealtimeCount, - }) + if !statsTbl.IsAnalyzed() { + e.appendRow([]any{ + dbName, + tblName, + partitionName, + e.versionToTime(statsTbl.Version), + statsTbl.ModifyCount, + statsTbl.RealtimeCount, + nil, + }) + } else { + e.appendRow([]any{ + dbName, + tblName, + partitionName, + e.versionToTime(statsTbl.Version), + statsTbl.ModifyCount, + statsTbl.RealtimeCount, + e.versionToTime(statsTbl.LastAnalyzeVersion), + }) + } } func (e *ShowExec) appendTableForStatsLocked(dbName, tblName, partitionName string) { diff --git a/pkg/executor/show_stats_test.go b/pkg/executor/show_stats_test.go index 539a7f94862b8..725d500cc2d9b 100644 --- a/pkg/executor/show_stats_test.go +++ b/pkg/executor/show_stats_test.go @@ -36,9 +36,12 @@ func TestShowStatsMeta(t *testing.T) { tk.MustExec("create table t1 (a int, b int)") tk.MustExec("analyze table t, t1") result := tk.MustQuery("show stats_meta") + result = result.Sort() require.Len(t, result.Rows(), 2) require.Equal(t, "t", result.Rows()[0][1]) require.Equal(t, "t1", result.Rows()[1][1]) + require.NotEqual(t, "", result.Rows()[0][6]) + require.NotEqual(t, "", result.Rows()[1][6]) result = tk.MustQuery("show stats_meta where table_name = 't'") require.Len(t, result.Rows(), 1) require.Equal(t, "t", result.Rows()[0][1]) diff --git a/pkg/executor/test/analyzetest/analyze_test.go b/pkg/executor/test/analyzetest/analyze_test.go index 1947235540333..d57eff36086fd 100644 --- a/pkg/executor/test/analyzetest/analyze_test.go +++ b/pkg/executor/test/analyzetest/analyze_test.go @@ -1394,9 +1394,9 @@ func TestAnalyzeColumnsWithDynamicPartitionTable(t *testing.T) { rows = tk.MustQuery("show stats_meta where db_name = 'test' and table_name = 't'").Sort().Rows() require.Equal(t, 3, len(rows)) - require.Equal(t, []any{"test", "t", "global", "0", "20"}, append(rows[0][:3], rows[0][4:]...)) - require.Equal(t, []any{"test", "t", "p0", "0", "9"}, append(rows[1][:3], rows[1][4:]...)) - require.Equal(t, []any{"test", "t", "p1", "0", "11"}, append(rows[2][:3], rows[2][4:]...)) + require.Equal(t, []any{"test", "t", "global", "0", "20"}, append(rows[0][:3], rows[0][4:6]...)) + require.Equal(t, []any{"test", "t", "p0", "0", "9"}, append(rows[1][:3], rows[1][4:6]...)) + require.Equal(t, []any{"test", "t", "p1", "0", "11"}, append(rows[2][:3], rows[2][4:6]...)) tk.MustQuery("show stats_topn where db_name = 'test' and table_name = 't' and is_index = 0").Sort().Check( // db, tbl, part, col, is_idx, value, count @@ -1516,8 +1516,8 @@ func TestAnalyzeColumnsWithStaticPartitionTable(t *testing.T) { rows = tk.MustQuery("show stats_meta where db_name = 'test' and table_name = 't'").Sort().Rows() require.Equal(t, 2, len(rows)) - require.Equal(t, []any{"test", "t", "p0", "0", "9"}, append(rows[0][:3], rows[0][4:]...)) - require.Equal(t, []any{"test", "t", "p1", "0", "11"}, append(rows[1][:3], rows[1][4:]...)) + require.Equal(t, []any{"test", "t", "p0", "0", "9"}, append(rows[0][:3], rows[0][4:6]...)) + require.Equal(t, []any{"test", "t", "p1", "0", "11"}, append(rows[1][:3], rows[1][4:6]...)) tk.MustQuery("show stats_topn where db_name = 'test' and table_name = 't' and is_index = 0").Sort().Check( // db, tbl, part, col, is_idx, value, count diff --git a/pkg/planner/core/planbuilder.go b/pkg/planner/core/planbuilder.go index a9ba17583f614..99f718f7f700a 100644 --- a/pkg/planner/core/planbuilder.go +++ b/pkg/planner/core/planbuilder.go @@ -5215,8 +5215,8 @@ func buildShowSchema(s *ast.ShowStmt, isView bool, isSequence bool) (schema *exp names = []string{"NodeID", "Address", "State", "Max_Commit_Ts", "Update_Time"} ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong, mysql.TypeVarchar} case ast.ShowStatsMeta: - names = []string{"Db_name", "Table_name", "Partition_name", "Update_time", "Modify_count", "Row_count"} - ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeDatetime, mysql.TypeLonglong, mysql.TypeLonglong} + names = []string{"Db_name", "Table_name", "Partition_name", "Update_time", "Modify_count", "Row_count", "Last_analyze_time"} + ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeDatetime, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeDatetime} case ast.ShowStatsExtended: names = []string{"Db_name", "Table_name", "Stats_name", "Column_names", "Stats_type", "Stats_val", "Last_update_version"} ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong} diff --git a/pkg/server/handler/optimizor/plan_replayer_test.go b/pkg/server/handler/optimizor/plan_replayer_test.go index f1afd703bfb16..444a695e335e0 100644 --- a/pkg/server/handler/optimizor/plan_replayer_test.go +++ b/pkg/server/handler/optimizor/plan_replayer_test.go @@ -190,7 +190,7 @@ func TestDumpPlanReplayerAPI(t *testing.T) { var dbName, tableName string var modifyCount, count int64 var other any - err = rows.Scan(&dbName, &tableName, &other, &other, &modifyCount, &count) + err = rows.Scan(&dbName, &tableName, &other, &other, &modifyCount, &count, &other) require.NoError(t, err) require.Equal(t, "planReplayer", dbName) require.Equal(t, "t", tableName) diff --git a/pkg/server/handler/optimizor/statistics_handler_test.go b/pkg/server/handler/optimizor/statistics_handler_test.go index bf977a75fd2f8..fbdee303f72fc 100644 --- a/pkg/server/handler/optimizor/statistics_handler_test.go +++ b/pkg/server/handler/optimizor/statistics_handler_test.go @@ -280,7 +280,7 @@ func checkData(t *testing.T, path string, client *testserverclient.TestServerCli var dbName, tableName string var modifyCount, count int64 var other any - err = rows.Scan(&dbName, &tableName, &other, &other, &modifyCount, &count) + err = rows.Scan(&dbName, &tableName, &other, &other, &modifyCount, &count, &other) require.NoError(t, err) require.Equal(t, "tidb", dbName) require.Equal(t, "test", tableName) diff --git a/pkg/session/bootstraptest/main_test.go b/pkg/session/bootstraptest/main_test.go index 4752c373a6e23..dd86e800eb799 100644 --- a/pkg/session/bootstraptest/main_test.go +++ b/pkg/session/bootstraptest/main_test.go @@ -54,6 +54,8 @@ func TestMain(m *testing.M) { goleak.IgnoreTopFunction("net/http.(*persistConn).writeLoop"), goleak.IgnoreTopFunction("github.com/pingcap/tidb/pkg/ttl/ttlworker.(*ttlDeleteWorker).loop"), goleak.IgnoreTopFunction("github.com/pingcap/tidb/pkg/ttl/ttlworker.(*ttlScanWorker).loop"), + goleak.IgnoreTopFunction("github.com/dgraph-io/ristretto.(*defaultPolicy).processItems"), + goleak.IgnoreTopFunction("github.com/dgraph-io/ristretto.(*Cache).processItems"), } callback := func(i int) int { // wait for MVCCLevelDB to close, MVCCLevelDB will be closed in one second diff --git a/pkg/session/main_test.go b/pkg/session/main_test.go index fb40c953cf77a..a2f62bcf5e336 100644 --- a/pkg/session/main_test.go +++ b/pkg/session/main_test.go @@ -57,6 +57,8 @@ func TestMain(m *testing.M) { goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*http2Client).keepalive"), goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"), goleak.IgnoreTopFunction("net/http.(*persistConn).writeLoop"), + goleak.IgnoreTopFunction("github.com/dgraph-io/ristretto.(*defaultPolicy).processItems"), + goleak.IgnoreTopFunction("github.com/dgraph-io/ristretto.(*Cache).processItems"), goleak.IgnoreTopFunction("github.com/tikv/client-go/v2/txnkv/transaction.keepAlive"), } callback := func(i int) int { diff --git a/pkg/statistics/BUILD.bazel b/pkg/statistics/BUILD.bazel index 8371bf0948a77..e6b3b730c3a66 100644 --- a/pkg/statistics/BUILD.bazel +++ b/pkg/statistics/BUILD.bazel @@ -79,7 +79,7 @@ go_test( data = glob(["testdata/**"]), embed = [":statistics"], flaky = True, - shard_count = 35, + shard_count = 36, deps = [ "//pkg/config", "//pkg/parser/ast", diff --git a/pkg/statistics/handle/autoanalyze/autoanalyze.go b/pkg/statistics/handle/autoanalyze/autoanalyze.go index eb50ac8000538..6a526dc5e2923 100644 --- a/pkg/statistics/handle/autoanalyze/autoanalyze.go +++ b/pkg/statistics/handle/autoanalyze/autoanalyze.go @@ -516,7 +516,7 @@ func tryAutoAnalyzeTable( // // Exposed for test. func NeedAnalyzeTable(tbl *statistics.Table, autoAnalyzeRatio float64) (bool, string) { - analyzed := exec.TableAnalyzed(tbl) + analyzed := tbl.IsAnalyzed() if !analyzed { return true, "table unanalyzed" } diff --git a/pkg/statistics/handle/autoanalyze/autoanalyze_test.go b/pkg/statistics/handle/autoanalyze/autoanalyze_test.go index bdc003c3fb3d8..3d95d0638cfff 100644 --- a/pkg/statistics/handle/autoanalyze/autoanalyze_test.go +++ b/pkg/statistics/handle/autoanalyze/autoanalyze_test.go @@ -207,12 +207,12 @@ func TestTableAnalyzed(t *testing.T) { require.NoError(t, h.Update(is)) statsTbl := h.GetTableStats(tableInfo) - require.False(t, exec.TableAnalyzed(statsTbl)) + require.False(t, statsTbl.LastAnalyzeVersion > 0) testKit.MustExec("analyze table t") require.NoError(t, h.Update(is)) statsTbl = h.GetTableStats(tableInfo) - require.True(t, exec.TableAnalyzed(statsTbl)) + require.True(t, statsTbl.LastAnalyzeVersion > 0) h.Clear() oriLease := h.Lease() @@ -223,7 +223,7 @@ func TestTableAnalyzed(t *testing.T) { }() require.NoError(t, h.Update(is)) statsTbl = h.GetTableStats(tableInfo) - require.True(t, exec.TableAnalyzed(statsTbl)) + require.True(t, statsTbl.LastAnalyzeVersion > 0) } func TestNeedAnalyzeTable(t *testing.T) { @@ -251,42 +251,42 @@ func TestNeedAnalyzeTable(t *testing.T) { }, // table was already analyzed but auto analyze is disabled { - tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 1, RealtimeCount: 1}}, + tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 1, RealtimeCount: 1}, LastAnalyzeVersion: 1}, ratio: 0, result: false, reason: "", }, // table was already analyzed but modify count is small { - tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 0, RealtimeCount: 1}}, + tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 0, RealtimeCount: 1}, LastAnalyzeVersion: 1}, ratio: 0.3, result: false, reason: "", }, // table was already analyzed { - tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 1, RealtimeCount: 1}}, + tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 1, RealtimeCount: 1}, LastAnalyzeVersion: 1}, ratio: 0.3, result: true, reason: "too many modifications", }, // table was already analyzed { - tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 1, RealtimeCount: 1}}, + tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 1, RealtimeCount: 1}, LastAnalyzeVersion: 1}, ratio: 0.3, result: true, reason: "too many modifications", }, // table was already analyzed { - tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 1, RealtimeCount: 1}}, + tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 1, RealtimeCount: 1}, LastAnalyzeVersion: 1}, ratio: 0.3, result: true, reason: "too many modifications", }, // table was already analyzed { - tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 1, RealtimeCount: 1}}, + tbl: &statistics.Table{HistColl: statistics.HistColl{Columns: columns, ModifyCount: 1, RealtimeCount: 1}, LastAnalyzeVersion: 1}, ratio: 0.3, result: true, reason: "too many modifications", diff --git a/pkg/statistics/handle/autoanalyze/exec/exec.go b/pkg/statistics/handle/autoanalyze/exec/exec.go index 36547ae3d4700..4548a2b4eee16 100644 --- a/pkg/statistics/handle/autoanalyze/exec/exec.go +++ b/pkg/statistics/handle/autoanalyze/exec/exec.go @@ -94,21 +94,6 @@ func execAnalyzeStmt( return statsutil.ExecWithOpts(sctx, optFuncs, sql, params...) } -// TableAnalyzed checks if any column or index of the table has been analyzed. -func TableAnalyzed(tbl *statistics.Table) bool { - for _, col := range tbl.Columns { - if col.IsAnalyzed() { - return true - } - } - for _, idx := range tbl.Indices { - if idx.IsAnalyzed() { - return true - } - } - return false -} - // GetAutoAnalyzeParameters gets the auto analyze parameters from mysql.global_variables. func GetAutoAnalyzeParameters(sctx sessionctx.Context) map[string]string { sql := "select variable_name, variable_value from mysql.global_variables where variable_name in (%?, %?, %?)" diff --git a/pkg/statistics/handle/autoanalyze/refresher/refresher.go b/pkg/statistics/handle/autoanalyze/refresher/refresher.go index 48a378df4a939..c2b0c920705c3 100644 --- a/pkg/statistics/handle/autoanalyze/refresher/refresher.go +++ b/pkg/statistics/handle/autoanalyze/refresher/refresher.go @@ -370,7 +370,7 @@ func CalculateChangePercentage( tblStats *statistics.Table, autoAnalyzeRatio float64, ) float64 { - if !exec.TableAnalyzed(tblStats) { + if !tblStats.IsAnalyzed() { return unanalyzedTableDefaultChangePercentage } @@ -424,23 +424,12 @@ func findLastAnalyzeTime( tblStats *statistics.Table, currentTs uint64, ) time.Time { - maxVersion := uint64(0) - for _, idx := range tblStats.Indices { - if idx.IsAnalyzed() { - maxVersion = max(maxVersion, idx.LastUpdateVersion) - } - } - for _, col := range tblStats.Columns { - if col.IsAnalyzed() { - maxVersion = max(maxVersion, col.LastUpdateVersion) - } - } // Table is not analyzed, compose a fake version. - if maxVersion == 0 { + if !tblStats.IsAnalyzed() { phy := oracle.GetTimeFromTS(currentTs) return phy.Add(unanalyzedTableDefaultLastUpdateDuration) } - return oracle.GetTimeFromTS(maxVersion) + return oracle.GetTimeFromTS(tblStats.LastAnalyzeVersion) } // CheckIndexesNeedAnalyze checks if the indexes of the table need to be analyzed. @@ -450,7 +439,7 @@ func CheckIndexesNeedAnalyze( ) []string { // If table is not analyzed, we need to analyze whole table. // So we don't need to check indexes. - if !exec.TableAnalyzed(tblStats) { + if !tblStats.IsAnalyzed() { return nil } diff --git a/pkg/statistics/handle/autoanalyze/refresher/refresher_test.go b/pkg/statistics/handle/autoanalyze/refresher/refresher_test.go index 4235a076eebc6..225850a351f90 100644 --- a/pkg/statistics/handle/autoanalyze/refresher/refresher_test.go +++ b/pkg/statistics/handle/autoanalyze/refresher/refresher_test.go @@ -410,6 +410,7 @@ func TestCalculateChangePercentage(t *testing.T) { Indices: analyzedIndices, ModifyCount: (exec.AutoAnalyzeMinCnt + 1) * 2, }, + LastAnalyzeVersion: 1, }, autoAnalyzeRatio: 0.5, want: 2, @@ -440,6 +441,7 @@ func TestGetTableLastAnalyzeDuration(t *testing.T) { }, }, }, + LastAnalyzeVersion: lastUpdateTs, } // 2024-01-01 10:00:00 currentTime := time.Date(2024, 1, 1, 10, 0, 0, 0, time.UTC) @@ -505,6 +507,7 @@ func TestCheckIndexesNeedAnalyze(t *testing.T) { }, }, }, + LastAnalyzeVersion: 1, }, want: []string{"index1"}, }, @@ -636,7 +639,8 @@ func TestCalculateIndicatorsForPartitions(t *testing.T) { }, }, }, - Version: currentTs, + Version: currentTs, + LastAnalyzeVersion: lastUpdateTs, }, { ID: 2, @@ -661,7 +665,8 @@ func TestCalculateIndicatorsForPartitions(t *testing.T) { }, }, }, - Version: currentTs, + Version: currentTs, + LastAnalyzeVersion: lastUpdateTs, }, }, defs: []model.PartitionDefinition{ @@ -724,7 +729,8 @@ func TestCalculateIndicatorsForPartitions(t *testing.T) { }, }, }, - Version: currentTs, + Version: currentTs, + LastAnalyzeVersion: lastUpdateTs, }, { ID: 2, @@ -749,7 +755,8 @@ func TestCalculateIndicatorsForPartitions(t *testing.T) { }, }, }, - Version: currentTs, + Version: currentTs, + LastAnalyzeVersion: lastUpdateTs, }, }, defs: []model.PartitionDefinition{ diff --git a/pkg/statistics/handle/bootstrap.go b/pkg/statistics/handle/bootstrap.go index 0286fae7a3001..31304c4a867f6 100644 --- a/pkg/statistics/handle/bootstrap.go +++ b/pkg/statistics/handle/bootstrap.go @@ -132,6 +132,8 @@ func (h *Handle) initStatsHistograms4ChunkLite(is infoschema.InfoSchema, cache s lastAnalyzePos.Copy(&index.LastAnalyzePos) if index.IsAnalyzed() { index.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() + // The LastAnalyzeVersion is added by ALTER table so its value might be 0. + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, version) } table.Indices[hist.ID] = index } else { @@ -158,6 +160,8 @@ func (h *Handle) initStatsHistograms4ChunkLite(is infoschema.InfoSchema, cache s lastAnalyzePos.Copy(&col.LastAnalyzePos) if col.StatsAvailable() { col.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() + // The LastAnalyzeVersion is added by ALTER table so its value might be 0. + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, version) } table.Columns[hist.ID] = col } @@ -208,6 +212,8 @@ func (h *Handle) initStatsHistograms4Chunk(is infoschema.InfoSchema, cache stats } if statsVer != statistics.Version0 { index.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() + // The LastAnalyzeVersion is added by ALTER table so its value might be 0. + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, version) } lastAnalyzePos.Copy(&index.LastAnalyzePos) table.Indices[hist.ID] = index @@ -234,6 +240,10 @@ func (h *Handle) initStatsHistograms4Chunk(is infoschema.InfoSchema, cache stats } lastAnalyzePos.Copy(&col.LastAnalyzePos) table.Columns[hist.ID] = col + if statsVer != statistics.Version0 { + // The LastAnalyzeVersion is added by ALTER table so its value might be 0. + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, version) + } } cache.Put(tblID, table) // put this table again since it is updated } diff --git a/pkg/statistics/handle/storage/read.go b/pkg/statistics/handle/storage/read.go index b20b81d23756d..5e384202b79f6 100644 --- a/pkg/statistics/handle/storage/read.go +++ b/pkg/statistics/handle/storage/read.go @@ -303,6 +303,9 @@ func indexStatsFromStorage(sctx sessionctx.Context, row chunk.Row, table *statis if tracker != nil { tracker.Consume(idx.MemoryUsage().TotalMemoryUsage()) } + if idx.IsAnalyzed() { + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, idx.LastUpdateVersion) + } table.Indices[histID] = idx } else { logutil.BgLogger().Debug("we cannot find index id in table info. It may be deleted.", zap.Int64("indexID", histID), zap.String("table", tableInfo.Name.O)) @@ -414,6 +417,9 @@ func columnStatsFromStorage(sctx sessionctx.Context, row chunk.Row, table *stati if tracker != nil { tracker.Consume(col.MemoryUsage().TotalMemoryUsage()) } + if col.IsAnalyzed() { + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, col.LastUpdateVersion) + } table.Columns[col.ID] = col } else { // If we didn't find a Column or Index in tableInfo, we won't load the histogram for it. @@ -562,9 +568,6 @@ func loadNeededColumnHistograms(sctx sessionctx.Context, statsCache statstypes.S IsHandle: c.IsHandle, StatsVer: statsVer, } - if colHist.StatsAvailable() { - colHist.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() - } // Reload the latest stats cache, otherwise the `updateStatsCache` may fail with high probability, because functions // like `GetPartitionStats` called in `fmSketchFromStorage` would have modified the stats cache already. tbl, ok = statsCache.Get(col.TableID) @@ -572,8 +575,12 @@ func loadNeededColumnHistograms(sctx sessionctx.Context, statsCache statstypes.S return nil } tbl = tbl.Copy() - if statsVer != statistics.Version0 { - tbl.StatsVer = int(statsVer) + if colHist.StatsAvailable() { + colHist.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() + tbl.LastAnalyzeVersion = max(tbl.LastAnalyzeVersion, colHist.LastUpdateVersion) + if statsVer != statistics.Version0 { + tbl.StatsVer = int(statsVer) + } } tbl.Columns[c.ID] = colHist statsCache.UpdateStatsCache([]*statistics.Table{tbl}, nil) @@ -629,6 +636,7 @@ func loadNeededIndexHistograms(sctx sessionctx.Context, statsCache statstypes.St tbl.StatsVer = int(idxHist.StatsVer) } tbl.Indices[idx.ID] = idxHist + tbl.LastAnalyzeVersion = max(tbl.LastAnalyzeVersion, idxHist.LastUpdateVersion) statsCache.UpdateStatsCache([]*statistics.Table{tbl}, nil) statistics.HistogramNeededItems.Delete(idx) return nil diff --git a/pkg/statistics/integration_test.go b/pkg/statistics/integration_test.go index 700fefa2e24ad..28a910073640f 100644 --- a/pkg/statistics/integration_test.go +++ b/pkg/statistics/integration_test.go @@ -481,3 +481,57 @@ func TestIssue44369(t *testing.T) { tk.MustExec("alter table t rename column b to bb;") tk.MustExec("select * from t where a = 10 and bb > 20;") } + +func TestTableLastAnalyzeVersion(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + h := dom.StatsHandle() + tk := testkit.NewTestKit(t, store) + + // Only create table should not set the last_analyze_version + tk.MustExec("use test") + tk.MustExec("create table t(a int);") + require.NoError(t, h.HandleDDLEvent(<-h.DDLEventCh())) + is := dom.InfoSchema() + require.NoError(t, h.Update(is)) + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + require.NoError(t, err) + statsTbl, found := h.Get(tbl.Meta().ID) + require.True(t, found) + require.Equal(t, uint64(0), statsTbl.LastAnalyzeVersion) + + // Only alter table should not set the last_analyze_version + tk.MustExec("alter table t add column b int default 0") + is = dom.InfoSchema() + tbl, err = is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + require.NoError(t, err) + require.NoError(t, h.HandleDDLEvent(<-h.DDLEventCh())) + require.NoError(t, h.Update(is)) + statsTbl, found = h.Get(tbl.Meta().ID) + require.True(t, found) + require.Equal(t, uint64(0), statsTbl.LastAnalyzeVersion) + tk.MustExec("alter table t add index idx(a)") + is = dom.InfoSchema() + tbl, err = is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + // We don't handle the ADD INDEX event in the HandleDDLEvent. + require.Equal(t, 0, len(h.DDLEventCh())) + require.NoError(t, err) + require.NoError(t, h.Update(is)) + statsTbl, found = h.Get(tbl.Meta().ID) + require.True(t, found) + require.Equal(t, uint64(0), statsTbl.LastAnalyzeVersion) + + // INSERT and updating the modify_count should not set the last_analyze_version + tk.MustExec("insert into t values(1, 1)") + require.NoError(t, h.DumpStatsDeltaToKV(true)) + require.NoError(t, h.Update(is)) + statsTbl, found = h.Get(tbl.Meta().ID) + require.True(t, found) + require.Equal(t, uint64(0), statsTbl.LastAnalyzeVersion) + + // After analyze, last_analyze_version is set. + tk.MustExec("analyze table t") + require.NoError(t, h.Update(is)) + statsTbl, found = h.Get(tbl.Meta().ID) + require.True(t, found) + require.NotEqual(t, uint64(0), statsTbl.LastAnalyzeVersion) +} diff --git a/pkg/statistics/table.go b/pkg/statistics/table.go index 130d7e5ee9e7c..844729826e67d 100644 --- a/pkg/statistics/table.go +++ b/pkg/statistics/table.go @@ -62,6 +62,8 @@ type Table struct { Name string HistColl Version uint64 + // It's the timestamp of the last analyze time. + LastAnalyzeVersion uint64 // TblInfoUpdateTS is the UpdateTS of the TableInfo used when filling this struct. // It is the schema version of the corresponding table. It is used to skip redundant // loading of stats, i.e, if the cached stats is already update-to-date with mysql.stats_xxx tables, @@ -298,10 +300,11 @@ func (t *Table) Copy() *Table { newHistColl.Indices[id] = idx.Copy() } nt := &Table{ - HistColl: newHistColl, - Version: t.Version, - Name: t.Name, - TblInfoUpdateTS: t.TblInfoUpdateTS, + HistColl: newHistColl, + Version: t.Version, + Name: t.Name, + TblInfoUpdateTS: t.TblInfoUpdateTS, + LastAnalyzeVersion: t.LastAnalyzeVersion, } if t.ExtendedStats != nil { newExtStatsColl := &ExtendedStatsColl{ @@ -331,11 +334,12 @@ func (t *Table) ShallowCopy() *Table { StatsVer: t.StatsVer, } nt := &Table{ - HistColl: newHistColl, - Version: t.Version, - Name: t.Name, - TblInfoUpdateTS: t.TblInfoUpdateTS, - ExtendedStats: t.ExtendedStats, + HistColl: newHistColl, + Version: t.Version, + Name: t.Name, + TblInfoUpdateTS: t.TblInfoUpdateTS, + ExtendedStats: t.ExtendedStats, + LastAnalyzeVersion: t.LastAnalyzeVersion, } return nt } @@ -412,6 +416,12 @@ func (t *Table) GetStatsInfo(id int64, isIndex bool, needCopy bool) (*Histogram, return nil, nil, nil, nil, false } +// IsAnalyzed checks whether the table is analyzed or not by checking its last analyze's timestamp value. +// A valid timestamp must be greater than 0. +func (t *Table) IsAnalyzed() bool { + return t.LastAnalyzeVersion > 0 +} + // GetAnalyzeRowCount tries to get the row count of a column or an index if possible. // This method is useful because this row count doesn't consider the modify count. func (coll *HistColl) GetAnalyzeRowCount() float64 { From 2b6318411e829a89394fe3a790bb2e794936b72e Mon Sep 17 00:00:00 2001 From: crazycs Date: Mon, 11 Mar 2024 17:39:38 +0800 Subject: [PATCH 12/15] executor: fix issue of some insert execution stats was omitted (#51630) close pingcap/tidb#51629 --- pkg/executor/insert.go | 13 ++++--- pkg/executor/insert_common.go | 11 +++--- .../pipelineddmltest/pipelineddml_test.go | 35 +++++++++++++++++++ 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/pkg/executor/insert.go b/pkg/executor/insert.go index 99bf773f2a63d..9b24962c8c1b4 100644 --- a/pkg/executor/insert.go +++ b/pkg/executor/insert.go @@ -73,6 +73,12 @@ func (e *InsertExec) exec(ctx context.Context, rows [][]types.Datum) error { return err } setOptionForTopSQL(sessVars.StmtCtx, txn) + if e.collectRuntimeStatsEnabled() { + if snapshot := txn.GetSnapshot(); snapshot != nil { + snapshot.SetOption(kv.CollectRuntimeStats, e.stats.SnapshotRuntimeStats) + defer snapshot.SetOption(kv.CollectRuntimeStats, nil) + } + } sessVars.StmtCtx.AddRecordRows(uint64(len(rows))) // If you use the IGNORE keyword, duplicate-key error that occurs while executing the INSERT statement are ignored. // For example, without IGNORE, a row that duplicates an existing UNIQUE index or PRIMARY KEY value in @@ -92,7 +98,6 @@ func (e *InsertExec) exec(ctx context.Context, rows [][]types.Datum) error { return err } } else { - e.collectRuntimeStatsEnabled() start := time.Now() for i, row := range rows { var err error @@ -221,12 +226,6 @@ func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.D return err } - if e.collectRuntimeStatsEnabled() { - if snapshot := txn.GetSnapshot(); snapshot != nil { - snapshot.SetOption(kv.CollectRuntimeStats, e.stats.SnapshotRuntimeStats) - defer snapshot.SetOption(kv.CollectRuntimeStats, nil) - } - } prefetchStart := time.Now() // Use BatchGet to fill cache. // It's an optimization and could be removed without affecting correctness. diff --git a/pkg/executor/insert_common.go b/pkg/executor/insert_common.go index 77d25e7a26d76..e6b478f374f5e 100644 --- a/pkg/executor/insert_common.go +++ b/pkg/executor/insert_common.go @@ -1192,12 +1192,6 @@ func (e *InsertValues) batchCheckAndInsert( return err } setOptionForTopSQL(e.Ctx().GetSessionVars().StmtCtx, txn) - if e.collectRuntimeStatsEnabled() { - if snapshot := txn.GetSnapshot(); snapshot != nil { - snapshot.SetOption(kv.CollectRuntimeStats, e.stats.SnapshotRuntimeStats) - defer snapshot.SetOption(kv.CollectRuntimeStats, nil) - } - } sc := e.Ctx().GetSessionVars().StmtCtx for _, fkc := range e.fkChecks { err = fkc.checkRows(ctx, sc, txn, toBeCheckedRows) @@ -1493,6 +1487,11 @@ func (e *InsertRuntimeStat) String() string { buf.WriteString("}") } else { fmt.Fprintf(buf, "insert:%v", execdetails.FormatDuration(e.CheckInsertTime)) + if e.SnapshotRuntimeStats != nil { + if rpc := e.SnapshotRuntimeStats.String(); len(rpc) > 0 { + fmt.Fprintf(buf, ", rpc:{%s}", rpc) + } + } } return buf.String() } diff --git a/tests/realtikvtest/pipelineddmltest/pipelineddml_test.go b/tests/realtikvtest/pipelineddmltest/pipelineddml_test.go index 1e2ef0ebd0e21..00241ce8dcc61 100644 --- a/tests/realtikvtest/pipelineddmltest/pipelineddml_test.go +++ b/tests/realtikvtest/pipelineddmltest/pipelineddml_test.go @@ -250,6 +250,41 @@ func TestPipelinedDMLInsertOnDuplicateKeyUpdate(t *testing.T) { compareTables(t, tk, "t", "_t") } +func TestPipelinedDMLInsertRPC(t *testing.T) { + store := realtikvtest.CreateMockStoreAndSetup(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (a int, b int, unique index idx(b))") + res := tk.MustQuery("explain analyze insert ignore into t1 values (1,1), (2,2), (3,3), (4,4), (5,5)") + explain := getExplainResult(res) + require.Regexp(t, "Insert.* check_insert: {total_time: .* rpc:{BatchGet:{num_rpc:1, total_time:.*}}}.*", explain) + // Test with bulk dml. + tk.MustExec("set session tidb_dml_type = bulk") + // Test normal insert. + tk.MustExec("truncate table t1") + res = tk.MustQuery("explain analyze insert into t1 values (1,1), (2,2), (3,3), (4,4), (5,5)") + explain = getExplainResult(res) + // TODO: try to optimize the rpc count, when use bulk dml, insert will send many BufferBatchGet rpc. + require.Regexp(t, "Insert.* insert:.*, rpc:{BufferBatchGet:{num_rpc:10, total_time:.*}}.*", explain) + // Test insert ignore. + tk.MustExec("truncate table t1") + res = tk.MustQuery("explain analyze insert ignore into t1 values (1,1), (2,2), (3,3), (4,4), (5,5)") + explain = getExplainResult(res) + // TODO: try to optimize the rpc count, when use bulk dml, insert ignore will send 5 BufferBatchGet and 1 BatchGet rpc. + // but without bulk dml, it will only use 1 BatchGet rpcs. + require.Regexp(t, "Insert.* check_insert: {total_time: .* rpc:{.*BufferBatchGet:{num_rpc:5, total_time:.*}}}.*", explain) + require.Regexp(t, "Insert.* check_insert: {total_time: .* rpc:{.*BatchGet:{num_rpc:1, total_time:.*}}}.*", explain) +} + +func getExplainResult(res *testkit.Result) string { + resBuff := bytes.NewBufferString("") + for _, row := range res.Rows() { + _, _ = fmt.Fprintf(resBuff, "%s\t", row) + } + return resBuff.String() +} + func TestPipelinedDMLInsertOnDuplicateKeyUpdateInTxn(t *testing.T) { require.Nil(t, failpoint.Enable("tikvclient/pipelinedMemDBMinFlushKeys", `return(10)`)) require.Nil(t, failpoint.Enable("tikvclient/pipelinedMemDBMinFlushSize", `return(100)`)) From a632277c57301bf89328bbc7cf9d051274714b4b Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Mon, 11 Mar 2024 18:38:09 +0800 Subject: [PATCH 13/15] planner: remove unused binding metrics (#51665) ref pingcap/tidb#51347 --- pkg/bindinfo/binding.go | 40 ----------------------------- pkg/bindinfo/global_handle.go | 2 -- pkg/bindinfo/global_handle_test.go | 4 --- pkg/bindinfo/session_handle.go | 7 +---- pkg/bindinfo/session_handle_test.go | 11 -------- pkg/metrics/bindinfo.go | 18 ------------- pkg/metrics/metrics.go | 2 -- 7 files changed, 1 insertion(+), 83 deletions(-) diff --git a/pkg/bindinfo/binding.go b/pkg/bindinfo/binding.go index 5dbbf7b3eaf40..447e80dd8f778 100644 --- a/pkg/bindinfo/binding.go +++ b/pkg/bindinfo/binding.go @@ -18,7 +18,6 @@ import ( "time" "unsafe" - "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/sessionctx" @@ -224,47 +223,8 @@ func (br Bindings) size() float64 { return mem } -var statusIndex = map[string]int{ - Enabled: 0, - deleted: 1, - Invalid: 2, -} - -func bindingMetrics(br Bindings) ([]float64, []int) { - sizes := make([]float64, len(statusIndex)) - count := make([]int, len(statusIndex)) - if br == nil { - return sizes, count - } - commonLength := float64(0) - // We treat it as deleted if there are no bindings. It could only occur in session handles. - if len(br) == 0 { - sizes[statusIndex[deleted]] = commonLength - count[statusIndex[deleted]] = 1 - return sizes, count - } - // Make the common length counted in the first binding. - sizes[statusIndex[br[0].Status]] = commonLength - for _, binding := range br { - sizes[statusIndex[binding.Status]] += binding.size() - count[statusIndex[binding.Status]]++ - } - return sizes, count -} - // size calculates the memory size of a bind info. func (b *Binding) size() float64 { res := len(b.OriginalSQL) + len(b.Db) + len(b.BindSQL) + len(b.Status) + 2*int(unsafe.Sizeof(b.CreateTime)) + len(b.Charset) + len(b.Collation) + len(b.ID) return float64(res) } - -func updateMetrics(scope string, before Bindings, after Bindings, sizeOnly bool) { - beforeSize, beforeCount := bindingMetrics(before) - afterSize, afterCount := bindingMetrics(after) - for status, index := range statusIndex { - metrics.BindMemoryUsage.WithLabelValues(scope, status).Add(afterSize[index] - beforeSize[index]) - if !sizeOnly { - metrics.BindTotalGauge.WithLabelValues(scope, status).Add(float64(afterCount[index] - beforeCount[index])) - } - } -} diff --git a/pkg/bindinfo/global_handle.go b/pkg/bindinfo/global_handle.go index 9166711e6b7ba..9441426af3e1d 100644 --- a/pkg/bindinfo/global_handle.go +++ b/pkg/bindinfo/global_handle.go @@ -23,7 +23,6 @@ import ( "time" "github.com/pingcap/errors" - "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/format" @@ -243,7 +242,6 @@ func (h *globalBindingHandle) LoadFromStorageToCache(fullLoad bool) (err error) } else { newCache.RemoveBinding(sqlDigest) } - updateMetrics(metrics.ScopeGlobal, oldBinding, newCache.GetBinding(sqlDigest), true) } return nil }) diff --git a/pkg/bindinfo/global_handle_test.go b/pkg/bindinfo/global_handle_test.go index a381a72d15ba1..a7fa40acb5285 100644 --- a/pkg/bindinfo/global_handle_test.go +++ b/pkg/bindinfo/global_handle_test.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/tidb/pkg/bindinfo" "github.com/pingcap/tidb/pkg/bindinfo/internal" "github.com/pingcap/tidb/pkg/bindinfo/norm" - "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser" sessiontypes "github.com/pingcap/tidb/pkg/session/types" "github.com/pingcap/tidb/pkg/testkit" @@ -425,9 +424,6 @@ func TestGlobalBinding(t *testing.T) { tk.MustExec("create table t1(i int, s varchar(20))") tk.MustExec("create index index_t on t(i,s)") - metrics.BindTotalGauge.Reset() - metrics.BindMemoryUsage.Reset() - _, err := tk.Exec("create global " + testSQL.createSQL) require.NoError(t, err, "err %v", err) diff --git a/pkg/bindinfo/session_handle.go b/pkg/bindinfo/session_handle.go index 5b9f933a628b1..c235b2d13fd63 100644 --- a/pkg/bindinfo/session_handle.go +++ b/pkg/bindinfo/session_handle.go @@ -21,7 +21,6 @@ import ( "time" "github.com/pingcap/errors" - "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" @@ -68,12 +67,10 @@ func NewSessionBindingHandle() SessionBindingHandle { // appendSessionBinding adds the Bindings to the cache, all the stale bindMetas are // removed from the cache after this operation. func (h *sessionBindingHandle) appendSessionBinding(sqlDigest string, meta Bindings) { - oldBindings := h.ch.GetBinding(sqlDigest) err := h.ch.SetBinding(sqlDigest, meta) if err != nil { logutil.BgLogger().Warn("SessionHandle.appendSessionBinding", zap.String("category", "sql-bind"), zap.Error(err)) } - updateMetrics(metrics.ScopeSession, oldBindings, meta, false) } // CreateSessionBinding creates a Bindings to the cache. @@ -146,9 +143,7 @@ func (h *sessionBindingHandle) DecodeSessionStates(_ context.Context, sctx sessi } // Close closes the session handle. -func (h *sessionBindingHandle) Close() { - updateMetrics(metrics.ScopeSession, h.ch.GetAllBindings(), nil, false) -} +func (*sessionBindingHandle) Close() {} // sessionBindInfoKeyType is a dummy type to avoid naming collision in context. type sessionBindInfoKeyType int diff --git a/pkg/bindinfo/session_handle_test.go b/pkg/bindinfo/session_handle_test.go index 1769164bf9c02..75826afc25a31 100644 --- a/pkg/bindinfo/session_handle_test.go +++ b/pkg/bindinfo/session_handle_test.go @@ -91,9 +91,6 @@ func TestSessionBinding(t *testing.T) { tk.MustExec("create table t1(i int, s varchar(20))") tk.MustExec("create index index_t on t(i,s)") - metrics.BindTotalGauge.Reset() - metrics.BindMemoryUsage.Reset() - _, err := tk.Exec("create session " + testSQL.createSQL) require.NoError(t, err, "err %v", err) @@ -102,14 +99,6 @@ func TestSessionBinding(t *testing.T) { require.NoError(t, err) } - pb := &dto.Metric{} - err = metrics.BindTotalGauge.WithLabelValues(metrics.ScopeSession, bindinfo.Enabled).Write(pb) - require.NoError(t, err) - require.Equal(t, float64(1), pb.GetGauge().GetValue()) - err = metrics.BindMemoryUsage.WithLabelValues(metrics.ScopeSession, bindinfo.Enabled).Write(pb) - require.NoError(t, err) - require.Equal(t, testSQL.memoryUsage, pb.GetGauge().GetValue()) - handle := tk.Session().Value(bindinfo.SessionBindInfoKeyType).(bindinfo.SessionBindingHandle) stmt, err := parser.New().ParseOneStmt(testSQL.originSQL, "", "") require.NoError(t, err) diff --git a/pkg/metrics/bindinfo.go b/pkg/metrics/bindinfo.go index 0c579e4857076..71ba09f686209 100644 --- a/pkg/metrics/bindinfo.go +++ b/pkg/metrics/bindinfo.go @@ -19,8 +19,6 @@ import "github.com/prometheus/client_golang/prometheus" // bindinfo metrics. var ( BindUsageCounter *prometheus.CounterVec - BindTotalGauge *prometheus.GaugeVec - BindMemoryUsage *prometheus.GaugeVec ) // InitBindInfoMetrics initializes bindinfo metrics. @@ -32,20 +30,4 @@ func InitBindInfoMetrics() { Name: "bind_usage_counter", Help: "Counter of query using sql bind", }, []string{LabelScope}) - - BindTotalGauge = NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: "tidb", - Subsystem: "bindinfo", - Name: "bind_total_gauge", - Help: "Total number of sql bind", - }, []string{LabelScope, LblType}) - - BindMemoryUsage = NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: "tidb", - Subsystem: "bindinfo", - Name: "bind_memory_usage", - Help: "Memory usage of sql bind", - }, []string{LabelScope, LblType}) } diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index f8226be2976fd..3a2e0301ec86f 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -124,8 +124,6 @@ func RegisterMetrics() { prometheus.MustRegister(AutoIDHistogram) prometheus.MustRegister(BatchAddIdxHistogram) prometheus.MustRegister(BindUsageCounter) - prometheus.MustRegister(BindTotalGauge) - prometheus.MustRegister(BindMemoryUsage) prometheus.MustRegister(CampaignOwnerCounter) prometheus.MustRegister(ConnGauge) prometheus.MustRegister(DisconnectionCounter) From 715399f321c011f863d71b49cd1acab7575b4fb5 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Mon, 11 Mar 2024 19:26:11 +0800 Subject: [PATCH 14/15] executor: fix flaky test TestAnalyzeClusteredIndexPrimary (#51670) close pingcap/tidb#51649 --- pkg/executor/test/analyzetest/analyze_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/executor/test/analyzetest/analyze_test.go b/pkg/executor/test/analyzetest/analyze_test.go index d57eff36086fd..613c1bc553b93 100644 --- a/pkg/executor/test/analyzetest/analyze_test.go +++ b/pkg/executor/test/analyzetest/analyze_test.go @@ -598,7 +598,7 @@ func TestAnalyzeClusteredIndexPrimary(t *testing.T) { tk.MustExec("set @@session.tidb_analyze_version = 1") tk.MustExec("analyze table t0 index primary") tk.MustExec("analyze table t1 index primary") - tk.MustQuery("show stats_buckets").Check(testkit.Rows( + tk.MustQuery("show stats_buckets").Sort().Check(testkit.Rows( "test t0 PRIMARY 1 0 1 1 1111 1111 0", "test t1 PRIMARY 1 0 1 1 1111 1111 0")) tk.MustExec("set @@session.tidb_analyze_version = 2") From f2cbe00ed91864a9cf2f7bf6822f0bfb13c6aa94 Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Mon, 11 Mar 2024 20:24:38 +0800 Subject: [PATCH 15/15] planner: allow the optimizer to cache query plans accessing generated columns by default (#51654) close pingcap/tidb#45798 --- pkg/executor/explainfor_test.go | 3 +-- pkg/planner/core/plan_cache_test.go | 6 ++---- pkg/planner/core/plan_cacheable_checker.go | 6 +++--- tests/integrationtest/r/planner/core/indexmerge_path.result | 4 +--- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/pkg/executor/explainfor_test.go b/pkg/executor/explainfor_test.go index 0464ba960959c..ba6768e09b4cd 100644 --- a/pkg/executor/explainfor_test.go +++ b/pkg/executor/explainfor_test.go @@ -707,8 +707,7 @@ func TestIndexMerge4PlanCache(t *testing.T) { tk.MustExec("prepare stmt from 'SELECT /*+ USE_INDEX_MERGE(t0, i0, PRIMARY)*/ t0.c0 FROM t0 WHERE t0.c1 OR t0.c0;';") tk.MustQuery("execute stmt;").Check(testkit.Rows("1")) tk.MustQuery("execute stmt;").Check(testkit.Rows("1")) - // The plan contains the generated column, so it can not be cached. - tk.MustQuery("select @@last_plan_from_cache;").Check(testkit.Rows("0")) + tk.MustQuery("select @@last_plan_from_cache;").Check(testkit.Rows("1")) tk.MustExec("drop table if exists t1, t2") tk.MustExec("create table t1(id int primary key, a int, b int, c int, d int)") diff --git a/pkg/planner/core/plan_cache_test.go b/pkg/planner/core/plan_cache_test.go index 8292dac135d76..f0d1c5a798c3b 100644 --- a/pkg/planner/core/plan_cache_test.go +++ b/pkg/planner/core/plan_cache_test.go @@ -1005,6 +1005,8 @@ func TestNonPreparedPlanExplainWarning(t *testing.T) { "select distinct a from t1 where a > 1 and b < 2", // distinct "select count(*) from t1 where a > 1 and b < 2 group by a", // group by "select * from t1 order by a", // order by + "select * from t3 where full_name = 'a b'", // generated column + "select * from t3 where a > 1 and full_name = 'a b'", } unsupported := []string{ @@ -1022,8 +1024,6 @@ func TestNonPreparedPlanExplainWarning(t *testing.T) { "select * from t where bt > 0", // bit "select * from t where a > 1 and bt > 0", "select data_type from INFORMATION_SCHEMA.columns where table_name = 'v'", // memTable - "select * from t3 where full_name = 'a b'", // generated column - "select * from t3 where a > 1 and full_name = 'a b'", "select * from v", // view "select * from t where a = null", // null "select * from t where false", // table dual @@ -1044,8 +1044,6 @@ func TestNonPreparedPlanExplainWarning(t *testing.T) { "skip non-prepared plan-cache: query has some filters with JSON, Enum, Set or Bit columns", "skip non-prepared plan-cache: query has some filters with JSON, Enum, Set or Bit columns", "skip non-prepared plan-cache: access tables in system schema", - "skip non-prepared plan-cache: query accesses generated columns is un-cacheable", - "skip non-prepared plan-cache: query accesses generated columns is un-cacheable", "skip non-prepared plan-cache: queries that access views are not supported", "skip non-prepared plan-cache: query has null constants", "skip non-prepared plan-cache: some parameters may be overwritten when constant propagation", diff --git a/pkg/planner/core/plan_cacheable_checker.go b/pkg/planner/core/plan_cacheable_checker.go index c52ee52f2339f..b24f5ae2f8566 100644 --- a/pkg/planner/core/plan_cacheable_checker.go +++ b/pkg/planner/core/plan_cacheable_checker.go @@ -621,11 +621,11 @@ func getMaxParamLimit(sctx PlanContext) int { func enablePlanCacheForGeneratedCols(sctx PlanContext) bool { // disable this by default since it's not well tested. - // TODO: complete its test and enable it by default. + defaultVal := true if sctx == nil || sctx.GetSessionVars() == nil || sctx.GetSessionVars().GetOptimizerFixControlMap() == nil { - return false + return defaultVal } - return fixcontrol.GetBoolWithDefault(sctx.GetSessionVars().GetOptimizerFixControlMap(), fixcontrol.Fix45798, false) + return fixcontrol.GetBoolWithDefault(sctx.GetSessionVars().GetOptimizerFixControlMap(), fixcontrol.Fix45798, defaultVal) } // checkTableCacheable checks whether a query accessing this table is cacheable. diff --git a/tests/integrationtest/r/planner/core/indexmerge_path.result b/tests/integrationtest/r/planner/core/indexmerge_path.result index c2a9931becafb..4fc1903e018e2 100644 --- a/tests/integrationtest/r/planner/core/indexmerge_path.result +++ b/tests/integrationtest/r/planner/core/indexmerge_path.result @@ -470,15 +470,13 @@ Selection 23.98 root json_overlaps(json_extract(planner__core__indexmerge_path. drop table if exists t; create table t(j json, index kj((cast(j as signed array)))); prepare st from 'select /*+ use_index_merge(t, kj) */ * from t where (1 member of (j))'; -Level Code Message -Warning 1105 skip prepared plan-cache: query accesses generated columns is un-cacheable execute st; j execute st; j select @@last_plan_from_cache; @@last_plan_from_cache -0 +1 drop table if exists t; create table t(j json, unique kj((cast(j as signed array)))); explain select j from t where j=1;