Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions proxy/wireguard/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ var nullDestination = net.TCPDestination(net.AnyIP, 0)
type Server struct {
bindServer *netBindServer

info routingInfo
policyManager policy.Manager
info routingInfo
policyManager policy.Manager
tag string
sniffingRequest session.SniffingRequest
}

type routingInfo struct {
ctx context.Context
dispatcher routing.Dispatcher
inboundTag *session.Inbound
contentTag *session.Content
ctx context.Context
dispatcher routing.Dispatcher
}

func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
Expand All @@ -58,6 +58,14 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
}

// Retrieve tag and sniffing config from context (set by AlwaysOnInboundHandler)
if inbound := session.InboundFromContext(ctx); inbound != nil {
server.tag = inbound.Tag
}
if content := session.ContentFromContext(ctx); content != nil {
server.sniffingRequest = content.SniffingRequest
}

tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
if err != nil {
return nil, err
Expand All @@ -81,8 +89,6 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
s.info = routingInfo{
ctx: ctx,
dispatcher: dispatcher,
inboundTag: session.InboundFromContext(ctx),
contentTag: session.ContentFromContext(ctx),
}

ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
Expand Down Expand Up @@ -129,21 +135,21 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
sid := session.NewID()
ctx = c.ContextWithID(ctx, sid)
inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
if s.info.inboundTag != nil {
inbound = *s.info.inboundTag

inbound := session.Inbound{
Name: "wireguard",
Tag: s.tag,
CanSpliceCopy: 3,
// overwrite the source to use the tun address for each sub context.
// Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context
// Currently we have no way to link to the original source address
Source: net.DestinationFromAddr(conn.RemoteAddr()),
}
inbound.Name = "wireguard"
inbound.CanSpliceCopy = 3

// overwrite the source to use the tun address for each sub context.
// Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context
// Currently we have no way to link to the original source address
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
ctx = session.ContextWithInbound(ctx, &inbound)
if s.info.contentTag != nil {
ctx = session.ContextWithContent(ctx, s.info.contentTag)
}
ctx = session.ContextWithContent(ctx, &session.Content{
SniffingRequest: s.sniffingRequest,
})
ctx = session.SubContextFromMuxInbound(ctx)

plcy := s.policyManager.ForLevel(0)
Expand Down