Skip to content

Commit

Permalink
[tailscale] net, net/http: add enforcement hooks
Browse files Browse the repository at this point in the history
Updates #55
Updates tailscale/corp#8944
Updates tailscale/corp#12702

Signed-off-by: Jenny Zhang <jz@tailscale.com>
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>

(Cherry-picked from 13373ca)
(cherry picked from commit 043e09a)
(cherry picked from commit 8df9488)
  • Loading branch information
phirework authored and bradfitz committed Aug 21, 2024
1 parent 1e42045 commit ab59d02
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 0 deletions.
3 changes: 3 additions & 0 deletions api/go1.99999.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55
pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55
pkg net/http, func SetRoundTripEnforcer(func(*Request) error) #55
48 changes: 48 additions & 0 deletions src/net/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,24 @@ func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet s
return "", 0, UnknownNetworkError(network)
}

// SetResolveEnforcer set a program-global resolver enforcer that can cause resolvers to
// fail based on the context and/or other arguments.
//
// f must be non-nil, it can only be called once, and must not be called
// concurrent with any dial/resolve.
func SetResolveEnforcer(f func(ctx context.Context, op, network, addr string, hint Addr) error) {
if f == nil {
panic("nil func")
}
if resolveEnforcer != nil {
panic("already called")
}
resolveEnforcer = f
}

// resolveEnforcer, if non-nil, is the installed hook from SetResolveEnforcer.
var resolveEnforcer func(ctx context.Context, op, network, addr string, hint Addr) error

// resolveAddrList resolves addr using hint and returns a list of
// addresses. The result contains at least one address when error is
// nil.
Expand All @@ -299,6 +317,13 @@ func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string
}
return addrList{addr}, nil
}

if resolveEnforcer != nil {
if err := resolveEnforcer(ctx, op, network, addr, hint); err != nil {
return nil, err
}
}

addrs, err := r.internetAddrList(ctx, afnet, addr)
if err != nil || op != "dial" || hint == nil {
return addrs, err
Expand Down Expand Up @@ -603,9 +628,32 @@ func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addr
}
}

// SetDialEnforcer set a program-global dial enforcer that can cause dials to
// fail based on the context and/or Addr(s).
//
// f must be non-nil, it can only be called once, and must not be called
// concurrent with any dial.
func SetDialEnforcer(f func(context.Context, []Addr) error) {
if f == nil {
panic("nil func")
}
if dialEnforcer != nil {
panic("already called")
}
dialEnforcer = f
}

// dialEnforce, if non-nil, is any installed hook from SetDialEnforcer.
var dialEnforcer func(context.Context, []Addr) error

// dialSerial connects to a list of addresses in sequence, returning
// either the first successful connection, or the first error.
func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
if dialEnforcer != nil {
if err := dialEnforcer(ctx, ras); err != nil {
return nil, err
}
}
var firstErr error // The error from the first address is most relevant.

for i, ra := range ras {
Expand Down
21 changes: 21 additions & 0 deletions src/net/http/tailscale.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package http

var roundTripEnforcer func(*Request) error

// SetRoundTripEnforcer set a program-global resolver enforcer that can cause
// RoundTrip calls to fail based on the request and its context.
//
// f must be non-nil.
//
// SetRoundTripEnforcer can only be called once, and must not be called
// concurrent with any RoundTrip call; it's expected to be registered during
// init.
func SetRoundTripEnforcer(f func(*Request) error) {
if f == nil {
panic("nil func")
}
if roundTripEnforcer != nil {
panic("already called")
}
roundTripEnforcer = f
}
6 changes: 6 additions & 0 deletions src/net/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,12 @@ func validateHeaders(hdrs Header) string {

// roundTrip implements a RoundTripper over HTTP.
func (t *Transport) roundTrip(req *Request) (_ *Response, err error) {
if roundTripEnforcer != nil {
if err := roundTripEnforcer(req); err != nil {
return nil, err
}
}

t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
ctx := req.Context()
trace := httptrace.ContextClientTrace(ctx)
Expand Down

0 comments on commit ab59d02

Please sign in to comment.