Skip to content

Commit 5bcbb29

Browse files
authored
rpc: add PeerInfo (#24255)
This replaces the sketchy and undocumented string context keys for HTTP requests with a defined interface. Using string keys with context is discouraged because they may clash with keys created by other packages. We added these keys to make connection metadata available in the signer, so this change also updates signer/core to use the new PeerInfo API.
1 parent 514ae7c commit 5bcbb29

11 files changed

+183
-64
lines changed

rpc/client.go

+16-36
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,6 @@ const (
5858
maxClientSubscriptionBuffer = 20000
5959
)
6060

61-
const (
62-
httpScheme = "http"
63-
wsScheme = "ws"
64-
ipcScheme = "ipc"
65-
)
66-
6761
// BatchElem is an element in a batch request.
6862
type BatchElem struct {
6963
Method string
@@ -80,7 +74,7 @@ type BatchElem struct {
8074
// Client represents a connection to an RPC server.
8175
type Client struct {
8276
idgen func() ID // for subscriptions
83-
scheme string // connection type: http, ws or ipc
77+
isHTTP bool // connection type: http, ws or ipc
8478
services *serviceRegistry
8579

8680
idCounter uint32
@@ -115,11 +109,9 @@ type clientConn struct {
115109
}
116110

117111
func (c *Client) newClientConn(conn ServerCodec) *clientConn {
118-
ctx := context.WithValue(context.Background(), clientContextKey{}, c)
119-
// Http connections have already set the scheme
120-
if !c.isHTTP() && c.scheme != "" {
121-
ctx = context.WithValue(ctx, "scheme", c.scheme)
122-
}
112+
ctx := context.Background()
113+
ctx = context.WithValue(ctx, clientContextKey{}, c)
114+
ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo())
123115
handler := newHandler(ctx, conn, c.idgen, c.services)
124116
return &clientConn{conn, handler}
125117
}
@@ -145,7 +137,7 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro
145137
select {
146138
case <-ctx.Done():
147139
// Send the timeout to dispatch so it can remove the request IDs.
148-
if !c.isHTTP() {
140+
if !c.isHTTP {
149141
select {
150142
case c.reqTimeout <- op:
151143
case <-c.closing:
@@ -212,18 +204,10 @@ func newClient(initctx context.Context, connect reconnectFunc) (*Client, error)
212204
}
213205

214206
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client {
215-
scheme := ""
216-
switch conn.(type) {
217-
case *httpConn:
218-
scheme = httpScheme
219-
case *websocketCodec:
220-
scheme = wsScheme
221-
case *jsonCodec:
222-
scheme = ipcScheme
223-
}
207+
_, isHTTP := conn.(*httpConn)
224208
c := &Client{
209+
isHTTP: isHTTP,
225210
idgen: idgen,
226-
scheme: scheme,
227211
services: services,
228212
writeConn: conn,
229213
close: make(chan struct{}),
@@ -236,7 +220,7 @@ func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *C
236220
reqSent: make(chan error, 1),
237221
reqTimeout: make(chan *requestOp),
238222
}
239-
if !c.isHTTP() {
223+
if !isHTTP {
240224
go c.dispatch(conn)
241225
}
242226
return c
@@ -267,7 +251,7 @@ func (c *Client) SupportedModules() (map[string]string, error) {
267251

268252
// Close closes the client, aborting any in-flight requests.
269253
func (c *Client) Close() {
270-
if c.isHTTP() {
254+
if c.isHTTP {
271255
return
272256
}
273257
select {
@@ -281,7 +265,7 @@ func (c *Client) Close() {
281265
// This method only works for clients using HTTP, it doesn't have
282266
// any effect for clients using another transport.
283267
func (c *Client) SetHeader(key, value string) {
284-
if !c.isHTTP() {
268+
if !c.isHTTP {
285269
return
286270
}
287271
conn := c.writeConn.(*httpConn)
@@ -315,7 +299,7 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
315299
}
316300
op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)}
317301

318-
if c.isHTTP() {
302+
if c.isHTTP {
319303
err = c.sendHTTP(ctx, op, msg)
320304
} else {
321305
err = c.send(ctx, op, msg)
@@ -378,7 +362,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
378362
}
379363

380364
var err error
381-
if c.isHTTP() {
365+
if c.isHTTP {
382366
err = c.sendBatchHTTP(ctx, op, msgs)
383367
} else {
384368
err = c.send(ctx, op, msgs)
@@ -417,7 +401,7 @@ func (c *Client) Notify(ctx context.Context, method string, args ...interface{})
417401
}
418402
msg.ID = nil
419403

420-
if c.isHTTP() {
404+
if c.isHTTP {
421405
return c.sendHTTP(ctx, op, msg)
422406
}
423407
return c.send(ctx, op, msg)
@@ -450,12 +434,12 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf
450434
// Check type of channel first.
451435
chanVal := reflect.ValueOf(channel)
452436
if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 {
453-
panic("first argument to Subscribe must be a writable channel")
437+
panic(fmt.Sprintf("channel argument of Subscribe has type %T, need writable channel", channel))
454438
}
455439
if chanVal.IsNil() {
456440
panic("channel given to Subscribe must not be nil")
457441
}
458-
if c.isHTTP() {
442+
if c.isHTTP {
459443
return nil, ErrNotificationsUnsupported
460444
}
461445

@@ -509,8 +493,8 @@ func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error
509493
}
510494

511495
func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error {
512-
// The previous write failed. Try to establish a new connection.
513496
if c.writeConn == nil {
497+
// The previous write failed. Try to establish a new connection.
514498
if err := c.reconnect(ctx); err != nil {
515499
return err
516500
}
@@ -657,7 +641,3 @@ func (c *Client) read(codec ServerCodec) {
657641
c.readOp <- readOp{msgs, batch}
658642
}
659643
}
660-
661-
func (c *Client) isHTTP() bool {
662-
return c.scheme == httpScheme
663-
}

rpc/http.go

+18-12
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,18 @@ type httpConn struct {
4848
headers http.Header
4949
}
5050

51-
// httpConn is treated specially by Client.
51+
// httpConn implements ServerCodec, but it is treated specially by Client
52+
// and some methods don't work. The panic() stubs here exist to ensure
53+
// this special treatment is correct.
54+
5255
func (hc *httpConn) writeJSON(context.Context, interface{}) error {
5356
panic("writeJSON called on httpConn")
5457
}
5558

59+
func (hc *httpConn) peerInfo() PeerInfo {
60+
panic("peerInfo called on httpConn")
61+
}
62+
5663
func (hc *httpConn) remoteAddr() string {
5764
return hc.url
5865
}
@@ -236,20 +243,19 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
236243
http.Error(w, err.Error(), code)
237244
return
238245
}
246+
247+
// Create request-scoped context.
248+
connInfo := PeerInfo{Transport: "http", RemoteAddr: r.RemoteAddr}
249+
connInfo.HTTP.Version = r.Proto
250+
connInfo.HTTP.Host = r.Host
251+
connInfo.HTTP.Origin = r.Header.Get("Origin")
252+
connInfo.HTTP.UserAgent = r.Header.Get("User-Agent")
253+
ctx := r.Context()
254+
ctx = context.WithValue(ctx, peerInfoContextKey{}, connInfo)
255+
239256
// All checks passed, create a codec that reads directly from the request body
240257
// until EOF, writes the response to w, and orders the server to process a
241258
// single request.
242-
ctx := r.Context()
243-
ctx = context.WithValue(ctx, "remote", r.RemoteAddr)
244-
ctx = context.WithValue(ctx, "scheme", r.Proto)
245-
ctx = context.WithValue(ctx, "local", r.Host)
246-
if ua := r.Header.Get("User-Agent"); ua != "" {
247-
ctx = context.WithValue(ctx, "User-Agent", ua)
248-
}
249-
if origin := r.Header.Get("Origin"); origin != "" {
250-
ctx = context.WithValue(ctx, "Origin", origin)
251-
}
252-
253259
w.Header().Set("content-type", contentType)
254260
codec := newHTTPServerConn(r, w)
255261
defer codec.close()

rpc/http_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,39 @@ func TestHTTPErrorResponse(t *testing.T) {
162162
t.Error("unexpected error message", errMsg)
163163
}
164164
}
165+
166+
func TestHTTPPeerInfo(t *testing.T) {
167+
s := newTestServer()
168+
defer s.Stop()
169+
ts := httptest.NewServer(s)
170+
defer ts.Close()
171+
172+
c, err := Dial(ts.URL)
173+
if err != nil {
174+
t.Fatal(err)
175+
}
176+
c.SetHeader("user-agent", "ua-testing")
177+
c.SetHeader("origin", "origin.example.com")
178+
179+
// Request peer information.
180+
var info PeerInfo
181+
if err := c.Call(&info, "test_peerInfo"); err != nil {
182+
t.Fatal(err)
183+
}
184+
185+
if info.RemoteAddr == "" {
186+
t.Error("RemoteAddr not set")
187+
}
188+
if info.Transport != "http" {
189+
t.Errorf("wrong Transport %q", info.Transport)
190+
}
191+
if info.HTTP.Version != "HTTP/1.1" {
192+
t.Errorf("wrong HTTP.Version %q", info.HTTP.Version)
193+
}
194+
if info.HTTP.UserAgent != "ua-testing" {
195+
t.Errorf("wrong HTTP.UserAgent %q", info.HTTP.UserAgent)
196+
}
197+
if info.HTTP.Origin != "origin.example.com" {
198+
t.Errorf("wrong HTTP.Origin %q", info.HTTP.UserAgent)
199+
}
200+
}

rpc/json.go

+5
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ func NewCodec(conn Conn) ServerCodec {
198198
return NewFuncCodec(conn, enc.Encode, dec.Decode)
199199
}
200200

201+
func (c *jsonCodec) peerInfo() PeerInfo {
202+
// This returns "ipc" because all other built-in transports have a separate codec type.
203+
return PeerInfo{Transport: "ipc", RemoteAddr: c.remote}
204+
}
205+
201206
func (c *jsonCodec) remoteAddr() string {
202207
return c.remote
203208
}

rpc/server.go

+35
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,38 @@ func (s *RPCService) Modules() map[string]string {
145145
}
146146
return modules
147147
}
148+
149+
// PeerInfo contains information about the remote end of the network connection.
150+
//
151+
// This is available within RPC method handlers through the context. Call
152+
// PeerInfoFromContext to get information about the client connection related to
153+
// the current method call.
154+
type PeerInfo struct {
155+
// Transport is name of the protocol used by the client.
156+
// This can be "http", "ws" or "ipc".
157+
Transport string
158+
159+
// Address of client. This will usually contain the IP address and port.
160+
RemoteAddr string
161+
162+
// Addditional information for HTTP and WebSocket connections.
163+
HTTP struct {
164+
// Protocol version, i.e. "HTTP/1.1". This is not set for WebSocket.
165+
Version string
166+
// Header values sent by the client.
167+
UserAgent string
168+
Origin string
169+
Host string
170+
}
171+
}
172+
173+
type peerInfoContextKey struct{}
174+
175+
// PeerInfoFromContext returns information about the client's network connection.
176+
// Use this with the context passed to RPC method handler functions.
177+
//
178+
// The zero value is returned if no connection info is present in ctx.
179+
func PeerInfoFromContext(ctx context.Context) PeerInfo {
180+
info, _ := ctx.Value(peerInfoContextKey{}).(PeerInfo)
181+
return info
182+
}

rpc/server_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func TestServerRegisterName(t *testing.T) {
4545
t.Fatalf("Expected service calc to be registered")
4646
}
4747

48-
wantCallbacks := 9
48+
wantCallbacks := 10
4949
if len(svc.callbacks) != wantCallbacks {
5050
t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks))
5151
}

rpc/testservice_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *
8080
return echoResult{str, i, args}
8181
}
8282

83+
func (s *testService) PeerInfo(ctx context.Context) PeerInfo {
84+
return PeerInfoFromContext(ctx)
85+
}
86+
8387
func (s *testService) Sleep(ctx context.Context, duration time.Duration) {
8488
time.Sleep(duration)
8589
}

rpc/types.go

+2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ type API struct {
4040
// a RPC session. Implementations must be go-routine safe since the codec can be called in
4141
// multiple go-routines concurrently.
4242
type ServerCodec interface {
43+
peerInfo() PeerInfo
4344
readBatch() (msgs []*jsonrpcMessage, isBatch bool, err error)
4445
close()
46+
4547
jsonWriter
4648
}
4749

rpc/websocket.go

+17-3
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
6060
log.Debug("WebSocket upgrade failed", "err", err)
6161
return
6262
}
63-
codec := newWebsocketCodec(conn)
63+
codec := newWebsocketCodec(conn, r.Host, r.Header)
6464
s.ServeCodec(codec, 0)
6565
})
6666
}
@@ -197,7 +197,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
197197
}
198198
return nil, hErr
199199
}
200-
return newWebsocketCodec(conn), nil
200+
return newWebsocketCodec(conn, endpoint, header), nil
201201
})
202202
}
203203

@@ -235,12 +235,13 @@ func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
235235
type websocketCodec struct {
236236
*jsonCodec
237237
conn *websocket.Conn
238+
info PeerInfo
238239

239240
wg sync.WaitGroup
240241
pingReset chan struct{}
241242
}
242243

243-
func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
244+
func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec {
244245
conn.SetReadLimit(wsMessageSizeLimit)
245246
conn.SetPongHandler(func(appData string) error {
246247
conn.SetReadDeadline(time.Time{})
@@ -250,7 +251,16 @@ func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
250251
jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
251252
conn: conn,
252253
pingReset: make(chan struct{}, 1),
254+
info: PeerInfo{
255+
Transport: "ws",
256+
RemoteAddr: conn.RemoteAddr().String(),
257+
},
253258
}
259+
// Fill in connection details.
260+
wc.info.HTTP.Host = host
261+
wc.info.HTTP.Origin = req.Get("Origin")
262+
wc.info.HTTP.UserAgent = req.Get("User-Agent")
263+
// Start pinger.
254264
wc.wg.Add(1)
255265
go wc.pingLoop()
256266
return wc
@@ -261,6 +271,10 @@ func (wc *websocketCodec) close() {
261271
wc.wg.Wait()
262272
}
263273

274+
func (wc *websocketCodec) peerInfo() PeerInfo {
275+
return wc.info
276+
}
277+
264278
func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
265279
err := wc.jsonCodec.writeJSON(ctx, v)
266280
if err == nil {

0 commit comments

Comments
 (0)