From 396efd11a57016450abd9196001f181f77e34e3e Mon Sep 17 00:00:00 2001 From: xhe Date: Tue, 25 Apr 2023 15:40:33 +0800 Subject: [PATCH] backend: add onTraffic callback for serverless (#276) Signed-off-by: xhe --- Makefile | 2 +- pkg/proxy/backend/backend_conn_mgr.go | 1 + pkg/proxy/backend/backend_conn_mgr_test.go | 40 ++++++++++++++++++++++ pkg/proxy/backend/handshake_handler.go | 12 +++++++ 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 19359455..e3440f6e 100644 --- a/Makefile +++ b/Makefile @@ -44,8 +44,8 @@ golangci-lint: GOBIN=$(GOBIN) go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest lint: golangci-lint tidy - $(GOBIN)/golangci-lint run cd lib && $(GOBIN)/golangci-lint run + $(GOBIN)/golangci-lint run gocovmerge: GOBIN=$(GOBIN) go install github.com/wadey/gocovmerge@master diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index dbfaadf2..eb26313a 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -266,6 +266,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) (err error) { defer func() { mgr.setQuitSourceByErr(err) + mgr.handshakeHandler.OnTraffic(mgr) }() if len(request) < 1 { err = mysql.ErrMalformPacket diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 38cb5801..04eb209d 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -784,6 +784,46 @@ func TestHandlerReturnError(t *testing.T) { } } +func TestOnTraffic(t *testing.T) { + i := 0 + inbytes, outbytes := []int{ + 0x99, + }, []int{ + 0xce, + } + ts := newBackendMgrTester(t, func(config *testConfig) { + config.proxyConfig.checkBackendInterval = 10 * time.Millisecond + config.proxyConfig.handler.onTraffic = func(cc ConnContext) { + require.Equal(t, uint64(inbytes[i]), cc.ClientInBytes()) + require.Equal(t, uint64(outbytes[i]), cc.ClientOutBytes()) + i++ + } + }) + runners := []runner{ + // 1st handshake + { + client: ts.mc.authenticate, + proxy: ts.firstHandshake4Proxy, + backend: ts.handshake4Backend, + }, + // query + { + client: func(packetIO *pnet.PacketIO) error { + ts.mc.sql = "select 1" + return ts.mc.request(packetIO) + }, + proxy: ts.forwardCmd4Proxy, + backend: func(packetIO *pnet.PacketIO) error { + ts.mb.respondType = responseTypeResultSet + ts.mb.columns = 1 + ts.mb.rows = 1 + return ts.mb.respond(packetIO) + }, + }, + } + ts.runTests(runners) +} + func TestGetBackendIO(t *testing.T) { addrs := make([]string, 0, 3) listeners := make([]net.Listener, 0, cap(addrs)) diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index e567ddca..29a4b6f6 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -65,6 +65,7 @@ func (es ErrorSource) String() string { } var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) +var _ HandshakeHandler = (*CustomHandshakeHandler)(nil) type ConnContext interface { ClientAddr() string @@ -82,6 +83,7 @@ type HandshakeHandler interface { GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) OnHandshake(ctx ConnContext, to string, err error) OnConnClose(ctx ConnContext) error + OnTraffic(ctx ConnContext) GetCapability() pnet.Capability GetServerVersion() string } @@ -117,6 +119,9 @@ func (handler *DefaultHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.Ha func (handler *DefaultHandshakeHandler) OnHandshake(ConnContext, string, error) { } +func (handler *DefaultHandshakeHandler) OnTraffic(ConnContext) { +} + func (handler *DefaultHandshakeHandler) OnConnClose(ConnContext) error { return nil } @@ -135,6 +140,7 @@ func (handler *DefaultHandshakeHandler) GetServerVersion() string { type CustomHandshakeHandler struct { getRouter func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) onHandshake func(ConnContext, string, error) + onTraffic func(ConnContext) onConnClose func(ConnContext) error handleHandshakeResp func(ctx ConnContext, resp *pnet.HandshakeResp) error getCapability func() pnet.Capability @@ -154,6 +160,12 @@ func (h *CustomHandshakeHandler) OnHandshake(ctx ConnContext, addr string, err e } } +func (h *CustomHandshakeHandler) OnTraffic(ctx ConnContext) { + if h.onTraffic != nil { + h.onTraffic(ctx) + } +} + func (h *CustomHandshakeHandler) OnConnClose(ctx ConnContext) error { if h.onConnClose != nil { return h.onConnClose(ctx)