diff --git a/pkg/apiserver/profiling/fetcher.go b/pkg/apiserver/profiling/fetcher.go index c532a21e2a..da54516ae9 100644 --- a/pkg/apiserver/profiling/fetcher.go +++ b/pkg/apiserver/profiling/fetcher.go @@ -15,7 +15,6 @@ package profiling import ( "fmt" - "net/http" "time" "go.uber.org/fx" @@ -103,8 +102,9 @@ type pdFetcher struct { func (f *pdFetcher) fetch(op *fetchOptions) ([]byte, error) { baseURL := fmt.Sprintf("%s://%s:%d", f.statusAPIHTTPScheme, op.ip, op.port) - f.client.WithBeforeRequest(func(req *http.Request) { - req.Header.Add("PD-Allow-follower-handle", "true") - }) - return f.client.WithTimeout(maxProfilingTimeout).WithBaseURL(baseURL).SendGetRequest(op.path) + return f.client. + WithTimeout(maxProfilingTimeout). + WithBaseURL(baseURL). + AddRequestHeader("PD-Allow-follower-handle", "true"). + SendGetRequest(op.path) } diff --git a/pkg/httpc/client.go b/pkg/httpc/client.go index 7424f08e51..2e2f9b597c 100644 --- a/pkg/httpc/client.go +++ b/pkg/httpc/client.go @@ -36,7 +36,8 @@ const ( type Client struct { http.Client - BeforeRequest func(req *http.Request) + + header http.Header } func NewHTTPClient(lc fx.Lifecycle, config *config.Config) *Client { @@ -63,14 +64,27 @@ func NewHTTPClient(lc fx.Lifecycle, config *config.Config) *Client { } } +// Clone is a temporary solution to the unexpected shared pointer field and race problem +// TODO: use latest `/util/client` for better api experience. +func (c *Client) Clone() *Client { + return &Client{ + Client: c.Client, + header: c.header.Clone(), + } +} + func (c Client) WithTimeout(timeout time.Duration) *Client { c.Timeout = timeout return &c } -func (c Client) WithBeforeRequest(callback func(req *http.Request)) *Client { - c.BeforeRequest = callback - return &c +func (c *Client) CloneAndAddRequestHeader(key, value string) *Client { + cc := c.Clone() + if cc.header == nil { + cc.header = http.Header{} + } + cc.header.Add(key, value) + return cc } // TODO: Replace using go-resty @@ -101,10 +115,7 @@ func (c *Client) Send( log.Warn("SendRequest failed", zap.String("uri", uri), zap.Error(err)) return nil, e } - - if c.BeforeRequest != nil { - c.BeforeRequest(req) - } + req.Header = c.header resp, err := c.Do(req) if err != nil { diff --git a/pkg/httpc/client_test.go b/pkg/httpc/client_test.go new file mode 100644 index 0000000000..204d7e9bd8 --- /dev/null +++ b/pkg/httpc/client_test.go @@ -0,0 +1,67 @@ +// Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. + +package httpc + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/fx/fxtest" + + "github.com/pingcap/tidb-dashboard/pkg/config" +) + +func newTestClient(t *testing.T) *Client { + lc := fxtest.NewLifecycle(t) + config := &config.Config{} + return NewHTTPClient(lc, config) +} + +func Test_Clone(t *testing.T) { + c := newTestClient(t) + cc := c.Clone() + + require.NotSame(t, c, cc) + + require.Nil(t, c.header) + require.Nil(t, cc.header) + require.NotSame(t, c.header, cc.header) +} + +func Test_CloneAndAddRequestHeader(t *testing.T) { + c := newTestClient(t) + cc := c.CloneAndAddRequestHeader("1", "11") + + require.Nil(t, c.header) + require.Equal(t, "11", cc.header.Get("1")) + + cc2 := cc.CloneAndAddRequestHeader("2", "22") + require.Equal(t, "11", cc.header.Get("1")) + require.Equal(t, "", cc.header.Get("2")) + require.Equal(t, "11", cc2.header.Get("1")) + require.Equal(t, "22", cc2.header.Get("2")) +} + +func Test_Send_withHeader(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(r.Header.Get("1"))) + })) + defer ts.Close() + + c := newTestClient(t) + resp1, _ := c.Send(context.Background(), ts.URL, http.MethodGet, nil, nil, "") + d1, _ := resp1.Body() + require.Equal(t, "", string(d1)) + + cc := c.CloneAndAddRequestHeader("1", "11") + resp2, _ := cc.Send(context.Background(), ts.URL, http.MethodGet, nil, nil, "") + d2, _ := resp2.Body() + require.Equal(t, "11", string(d2)) + + resp3, _ := c.Send(context.Background(), ts.URL, http.MethodGet, nil, nil, "") + d3, _ := resp3.Body() + require.Equal(t, "", string(d3)) +} diff --git a/pkg/keyvisual/storage/model.go b/pkg/keyvisual/storage/model.go index 9111a58748..1d0723cbdd 100644 --- a/pkg/keyvisual/storage/model.go +++ b/pkg/keyvisual/storage/model.go @@ -18,6 +18,8 @@ import ( "encoding/gob" "time" + "gorm.io/gorm" + "github.com/pingcap/tidb-dashboard/pkg/dbstore" "github.com/pingcap/tidb-dashboard/pkg/keyvisual/matrix" ) @@ -77,7 +79,9 @@ func CreateTableAxisModelIfNotExists(db *dbstore.DB) (bool, error) { } func ClearTableAxisModel(db *dbstore.DB) error { - return db.Delete(&AxisModel{}).Error + return db.Session(&gorm.Session{AllowGlobalUpdate: true}). + Delete(&AxisModel{}). + Error } func FindAxisModelsOrderByTime(db *dbstore.DB, layerNum uint8) ([]*AxisModel, error) { diff --git a/pkg/keyvisual/storage/model_test.go b/pkg/keyvisual/storage/model_test.go index 14d3a1a3a2..fa15ac7030 100644 --- a/pkg/keyvisual/storage/model_test.go +++ b/pkg/keyvisual/storage/model_test.go @@ -74,7 +74,7 @@ func (t *testDbstoreSuite) TestClearTableAxisModel(c *C) { if err != nil { c.Fatalf("Count table AxisModel error: %v", err) } - c.Assert(count, Equals, 1) + c.Assert(count, Equals, int64(1)) err = ClearTableAxisModel(t.db) c.Assert(err, IsNil) @@ -83,7 +83,7 @@ func (t *testDbstoreSuite) TestClearTableAxisModel(c *C) { if err != nil { c.Fatalf("Count table AxisModel error: %v", err) } - c.Assert(count, Equals, 0) + c.Assert(count, Equals, int64(0)) } func (t *testDbstoreSuite) TestAxisModelFunc(c *C) { @@ -123,7 +123,7 @@ func (t *testDbstoreSuite) TestAxisModelFunc(c *C) { if err != nil { c.Fatalf("Count table AxisModel error: %v", err) } - c.Assert(count, Equals, 0) + c.Assert(count, Equals, int64(0)) err = axisModel.Delete(t.db) c.Assert(err, IsNil) @@ -157,7 +157,7 @@ func (t *testDbstoreSuite) TestAxisModelsFindAndDelete(c *C) { if err != nil { c.Fatalf("Count table AxisModel error: %v", err) } - c.Assert(count, Equals, int(maxLayerNum)*axisModelNumEachLayer) + c.Assert(count, Equals, int64(int(maxLayerNum)*axisModelNumEachLayer)) findLayerNum := maxLayerNum - 1 axisModels, err := FindAxisModelsOrderByTime(t.db, findLayerNum) @@ -175,7 +175,7 @@ func (t *testDbstoreSuite) TestAxisModelsFindAndDelete(c *C) { if err != nil { c.Fatalf("Count table AxisModel error: %v", err) } - c.Assert(count, Equals, int(maxLayerNum-1)*axisModelNumEachLayer) + c.Assert(count, Equals, int64(int(maxLayerNum-1)*axisModelNumEachLayer)) } func axisModelsDeepEqual(obtainedAxisModels []*AxisModel, expectedAxisModels []*AxisModel, c *C) { diff --git a/pkg/pd/client.go b/pkg/pd/client.go index be724c20f9..c505caa616 100644 --- a/pkg/pd/client.go +++ b/pkg/pd/client.go @@ -77,8 +77,8 @@ func (c Client) WithTimeout(timeout time.Duration) *Client { return &c } -func (c Client) WithBeforeRequest(callback func(req *http.Request)) *Client { - c.httpClient.BeforeRequest = callback +func (c Client) AddRequestHeader(key, value string) *Client { + c.httpClient = c.httpClient.CloneAndAddRequestHeader(key, value) return &c } diff --git a/pkg/pd/client_test.go b/pkg/pd/client_test.go new file mode 100644 index 0000000000..6df946a7c4 --- /dev/null +++ b/pkg/pd/client_test.go @@ -0,0 +1,52 @@ +// Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. + +package pd + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/fx/fxtest" + + "github.com/pingcap/tidb-dashboard/pkg/config" + "github.com/pingcap/tidb-dashboard/pkg/httpc" +) + +func newTestClient(t *testing.T) *Client { + lc := fxtest.NewLifecycle(t) + config := &config.Config{} + c := NewPDClient(lc, httpc.NewHTTPClient(lc, config), config) + c.lifecycleCtx = context.Background() + return c +} + +func Test_AddRequestHeader_returnDifferentHTTPClient(t *testing.T) { + c := newTestClient(t) + cc := c.AddRequestHeader("1", "11") + + require.NotSame(t, c.httpClient, cc.httpClient) +} + +func Test_Get_withHeader(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(r.Header.Get("1"))) + })) + defer ts.Close() + + c := newTestClient(t).WithBaseURL(ts.URL) + resp1, _ := c.Get("") + d1, _ := resp1.Body() + require.Equal(t, "", string(d1)) + + cc := c.AddRequestHeader("1", "11") + resp2, _ := cc.Get("") + d2, _ := resp2.Body() + require.Equal(t, "11", string(d2)) + + resp3, _ := c.Get("") + d3, _ := resp3.Body() + require.Equal(t, "", string(d3)) +} diff --git a/release-version b/release-version index 55e3ea6c17..86127f31a6 100644 --- a/release-version +++ b/release-version @@ -1,3 +1,3 @@ # This file specifies the TiDB Dashboard internal version, which will be printed in `--version` # and UI. In release branch, changing this file will result in publishing a new version and tag. -2021.11.08.1 +2021.12.06.1