Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: support reload tls used by mysql protocol in place (#14749) #15080

Merged
merged 4 commits into from
Mar 5, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
cherry pick #14749 to release-3.0
Signed-off-by: sre-bot <sre-bot@pingcap.com>
  • Loading branch information
lysu authored and sre-bot committed Mar 3, 2020
commit 393c274e14d1e92b48367d90349e5f648f411131
173 changes: 173 additions & 0 deletions domain/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,179 @@ func sysMockFactory(dom *Domain) (pools.Resource, error) {
return nil, nil
}

<<<<<<< HEAD
=======
type mockEtcdBackend struct {
kv.Storage
pdAddrs []string
}

func (mebd *mockEtcdBackend) EtcdAddrs() []string {
return mebd.pdAddrs
}
func (mebd *mockEtcdBackend) TLSConfig() *tls.Config { return nil }
func (mebd *mockEtcdBackend) StartGCWorker() error {
panic("not implemented")
}

// ETCD use ip:port as unix socket address, however this address is invalid on windows.
// We have to skip some of the test in such case.
// https://github.com/etcd-io/etcd/blob/f0faa5501d936cd8c9f561bb9d1baca70eb67ab1/pkg/types/urls.go#L42
func unixSocketAvailable() bool {
c, err := net.Listen("unix", "127.0.0.1:0")
if err == nil {
c.Close()
return true
}
return false
}

func TestInfo(t *testing.T) {
if !unixSocketAvailable() {
return
}
defer testleak.AfterTestT(t)()
ddlLease := 80 * time.Millisecond
s, err := mockstore.NewMockTikvStore()
if err != nil {
t.Fatal(err)
}
clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1})
defer clus.Terminate(t)
mockStore := &mockEtcdBackend{
Storage: s,
pdAddrs: []string{clus.Members[0].GRPCAddr()}}
dom := NewDomain(mockStore, ddlLease, 0, mockFactory)
defer func() {
dom.Close()
s.Close()
}()

cli := clus.RandClient()
dom.etcdClient = cli
// Mock new DDL and init the schema syncer with etcd client.
goCtx := context.Background()
dom.ddl = ddl.NewDDL(
goCtx,
ddl.WithEtcdClient(dom.GetEtcdClient()),
ddl.WithStore(s),
ddl.WithInfoHandle(dom.infoHandle),
ddl.WithLease(ddlLease),
)
err = failpoint.Enable("github.com/pingcap/tidb/domain/MockReplaceDDL", `return(true)`)
if err != nil {
t.Fatal(err)
}
err = dom.Init(ddlLease, sysMockFactory)
if err != nil {
t.Fatal(err)
}
err = failpoint.Disable("github.com/pingcap/tidb/domain/MockReplaceDDL")
if err != nil {
t.Fatal(err)
}

// Test for GetServerInfo and GetServerInfoByID.
ddlID := dom.ddl.GetID()
serverInfo, err := infosync.GetServerInfo()
if err != nil {
t.Fatal(err)
}
info, err := infosync.GetServerInfoByID(goCtx, ddlID)
if err != nil {
t.Fatal(err)
}
if serverInfo.ID != info.ID {
t.Fatalf("server self info %v, info %v", serverInfo, info)
}
_, err = infosync.GetServerInfoByID(goCtx, "not_exist_id")
if err == nil || (err != nil && err.Error() != "[info-syncer] get /tidb/server/info/not_exist_id failed") {
t.Fatal(err)
}

// Test for GetAllServerInfo.
infos, err := infosync.GetAllServerInfo(goCtx)
if err != nil {
t.Fatal(err)
}
if len(infos) != 1 || infos[ddlID].ID != info.ID {
t.Fatalf("server one info %v, info %v", infos[ddlID], info)
}

// Test the scene where syncer.Done() gets the information.
err = failpoint.Enable("github.com/pingcap/tidb/ddl/util/ErrorMockSessionDone", `return(true)`)
if err != nil {
t.Fatal(err)
}
<-dom.ddl.SchemaSyncer().Done()
err = failpoint.Disable("github.com/pingcap/tidb/ddl/util/ErrorMockSessionDone")
if err != nil {
t.Fatal(err)
}
time.Sleep(15 * time.Millisecond)
syncerStarted := false
for i := 0; i < 200; i++ {
if dom.SchemaValidator.IsStarted() {
syncerStarted = true
break
}
time.Sleep(5 * time.Millisecond)
}
if !syncerStarted {
t.Fatal("start syncer failed")
}
// Make sure loading schema is normal.
cs := &ast.CharsetOpt{
Chs: "utf8",
Col: "utf8_bin",
}
ctx := mock.NewContext()
err = dom.ddl.CreateSchema(ctx, model.NewCIStr("aaa"), cs)
if err != nil {
t.Fatal(err)
}
err = dom.Reload()
if err != nil {
t.Fatal(err)
}
if dom.InfoSchema().SchemaMetaVersion() != 1 {
t.Fatalf("update schema version failed, ver %d", dom.InfoSchema().SchemaMetaVersion())
}

// Test for RemoveServerInfo.
dom.info.RemoveServerInfo()
infos, err = infosync.GetAllServerInfo(goCtx)
if err != nil || len(infos) != 0 {
t.Fatalf("err %v, infos %v", err, infos)
}
}

type mockSessionManager struct {
PS []*util.ProcessInfo
}

func (msm *mockSessionManager) ShowProcessList() map[uint64]*util.ProcessInfo {
ret := make(map[uint64]*util.ProcessInfo)
for _, item := range msm.PS {
ret[item.ID] = item
}
return ret
}

func (msm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) {
for _, item := range msm.PS {
if item.ID == id {
return item, true
}
}
return &util.ProcessInfo{}, false
}

func (msm *mockSessionManager) Kill(cid uint64, query bool) {}

func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {}

>>>>>>> 5c68d53... *: support reload tls used by mysql protocol in place (#14749)
func (*testSuite) TestT(c *C) {
defer testleak.AfterTest(c)()
store, err := mockstore.NewMockTikvStore()
Expand Down
4 changes: 4 additions & 0 deletions executor/executor_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package executor

import (
"context"
"crypto/tls"

. "github.com/pingcap/check"
"github.com/pingcap/parser/ast"
Expand Down Expand Up @@ -62,6 +63,9 @@ func (msm *mockSessionManager) Kill(cid uint64, query bool) {

}

func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {
}

func (s *testExecSuite) TestShowProcessList(c *C) {
// Compose schema.
names := []string{"Id", "User", "Host", "db", "Command", "Time", "State", "Info"}
Expand Down
4 changes: 4 additions & 0 deletions executor/explainfor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package executor_test

import (
"crypto/tls"
"fmt"

. "github.com/pingcap/check"
Expand Down Expand Up @@ -51,6 +52,9 @@ func (msm *mockSessionManager1) Kill(cid uint64, query bool) {

}

func (msm *mockSessionManager1) UpdateTLSConfig(cfg *tls.Config) {
}

func (s *testSuite) TestExplainFor(c *C) {
tkRoot := testkit.NewTestKitWithInit(c, s.store)
tkUser := testkit.NewTestKitWithInit(c, s.store)
Expand Down
23 changes: 23 additions & 0 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
Expand Down Expand Up @@ -108,6 +109,8 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
err = e.executeUse(x)
case *ast.FlushStmt:
err = e.executeFlush(x)
case *ast.AlterInstanceStmt:
err = e.executeAlterInstance(x)
case *ast.BeginStmt:
err = e.executeBegin(ctx, x)
case *ast.CommitStmt:
Expand Down Expand Up @@ -1093,6 +1096,26 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error {
return nil
}

func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error {
if s.ReloadTLS {
logutil.BgLogger().Info("execute reload tls", zap.Bool("NoRollbackOnError", s.NoRollbackOnError))
sm := e.ctx.GetSessionManager()
tlsCfg, err := util.LoadTLSCertificates(
variable.SysVars["ssl_ca"].Value,
variable.SysVars["ssl_key"].Value,
variable.SysVars["ssl_cert"].Value,
)
if err != nil {
if !s.NoRollbackOnError {
return err
}
logutil.BgLogger().Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'")
}
sm.UpdateTLSConfig(tlsCfg)
}
return nil
}

func (e *SimpleExec) executeDropStats(s *ast.DropStatsStmt) error {
h := domain.GetDomain(e.ctx).StatsHandle()
err := h.DeleteTableStatsFromKV(s.Table.TableInfo.ID)
Expand Down
2 changes: 2 additions & 0 deletions infoschema/tables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ func (sm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool

func (sm *mockSessionManager) Kill(connectionID uint64, query bool) {}

func (sm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {}

func (s *testTableSuite) TestSomeTables(c *C) {
tk := testkit.NewTestKit(c, s.store)

Expand Down
11 changes: 10 additions & 1 deletion planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (b *PlanBuilder) Build(ctx context.Context, node ast.Node) (Plan, error) {
case *ast.AnalyzeTableStmt:
return b.buildAnalyze(x)
case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.AlterInstanceStmt,
*ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt,
*ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt, *ast.ShutdownStmt:
return b.buildSimple(node.(ast.StmtNode))
Expand Down Expand Up @@ -1385,6 +1385,15 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) {
p := &Simple{Statement: node}

switch raw := node.(type) {
<<<<<<< HEAD
=======
case *ast.FlushStmt:
err := ErrSpecificAccessDenied.GenWithStackByArgs("RELOAD")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.ReloadPriv, "", "", "", err)
case *ast.AlterInstanceStmt:
err := ErrSpecificAccessDenied.GenWithStack("ALTER INSTANCE")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", err)
>>>>>>> 5c68d53... *: support reload tls used by mysql protocol in place (#14749)
case *ast.AlterUserStmt:
err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err)
Expand Down
37 changes: 20 additions & 17 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,23 +495,26 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}

if (resp.Capability&mysql.ClientSSL > 0) && cc.server.tlsConfig != nil {
// The packet is a SSLRequest, let's switch to TLS.
if err = cc.upgradeToTLS(cc.server.tlsConfig); err != nil {
return err
}
// Read the following HandshakeResponse packet.
data, err = cc.readPacket()
if err != nil {
return err
}
if isOldVersion {
pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data)
} else {
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
}
if err != nil {
return err
if resp.Capability&mysql.ClientSSL > 0 {
tlsConfig := (*tls.Config)(atomic.LoadPointer(&cc.server.tlsConfig))
if tlsConfig != nil {
// The packet is a SSLRequest, let's switch to TLS.
if err = cc.upgradeToTLS(tlsConfig); err != nil {
return err
}
// Read the following HandshakeResponse packet.
data, err = cc.readPacket()
if err != nil {
return err
}
if isOldVersion {
pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data)
} else {
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
}
if err != nil {
return err
}
}
}

Expand Down
Loading