Skip to content

Commit fd04bd2

Browse files
authored
Merge pull request #4232 from norio-nomura/refactor-forwarder-on-event
pkg/portfwd: Refactor `Forwarder.OnEvent()`
2 parents e426bf0 + 9046d8d commit fd04bd2

File tree

4 files changed

+38
-51
lines changed

4 files changed

+38
-51
lines changed

pkg/hostagent/hostagent.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,8 @@ func (a *HostAgent) processGuestAgentEvents(ctx context.Context, client *guestag
813813
if useSSHFwd {
814814
a.portForwarder.OnEvent(ctx, ev)
815815
} else {
816-
a.grpcPortForwarder.OnEvent(ctx, client, ev)
816+
dialContext := portfwd.DialContextToGRPCTunnel(client)
817+
a.grpcPortForwarder.OnEvent(ctx, dialContext, ev)
817818
}
818819
}
819820

pkg/portfwd/client.go

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,61 +18,50 @@ import (
1818
guestagentclient "github.com/lima-vm/lima/v2/pkg/guestagent/api/client"
1919
)
2020

21-
func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.Conn, guestAddr string) {
22-
id := fmt.Sprintf("tcp-%s-%s", conn.LocalAddr().String(), conn.RemoteAddr().String())
21+
func HandleTCPConnection(_ context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), conn net.Conn, guestAddr string) {
22+
proxy := tcpproxy.DialProxy{Addr: guestAddr, DialContext: dialContext}
23+
proxy.HandleConn(conn)
24+
}
2325

24-
stream, err := client.Tunnel(ctx)
26+
func HandleUDPConnection(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), conn net.PacketConn, guestAddr string) {
27+
proxy, err := forwarder.NewUDPProxy(conn, func() (net.Conn, error) {
28+
return dialContext(ctx, "udp", guestAddr)
29+
})
2530
if err != nil {
26-
logrus.Errorf("could not open tcp tunnel for id: %s error:%v", id, err)
27-
return
28-
}
29-
30-
// Handshake message to start tunnel
31-
if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "tcp", GuestAddr: guestAddr}); err != nil {
32-
logrus.Errorf("could not start tcp tunnel for id: %s error:%v", id, err)
31+
logrus.WithError(err).Error("error in udp tunnel proxy")
3332
return
3433
}
3534

36-
rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "tcp"}
37-
proxy := tcpproxy.DialProxy{DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
38-
return conn, nil
39-
}}
40-
proxy.HandleConn(rw)
35+
defer func() {
36+
err := proxy.Close()
37+
if err != nil {
38+
logrus.WithError(err).Error("error in closing udp tunnel proxy")
39+
}
40+
}()
41+
proxy.Run()
4142
}
4243

43-
func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.PacketConn, guestAddr string) {
44-
var udpConnectionCounter atomic.Uint32
45-
initialID := fmt.Sprintf("udp-%s", conn.LocalAddr().String())
46-
44+
func DialContextToGRPCTunnel(client *guestagentclient.GuestAgentClient) func(ctx context.Context, network, addr string) (net.Conn, error) {
4745
// gvisor-tap-vsock's UDPProxy demultiplexes client connections internally based on their source address.
4846
// It calls this dialer function only when it receives a datagram from a new, unrecognized client.
4947
// For each new client, we must return a new net.Conn, which in our case is a new gRPC stream.
5048
// The atomic counter ensures that each stream has a unique ID to distinguish them on the server side.
51-
proxy, err := forwarder.NewUDPProxy(conn, func() (net.Conn, error) {
52-
id := fmt.Sprintf("%s-%d", initialID, udpConnectionCounter.Add(1))
53-
stream, err := client.Tunnel(ctx)
49+
var connectionCounter atomic.Uint32
50+
return func(_ context.Context, network, addr string) (net.Conn, error) {
51+
// Passed context.Context is used for timeout on initiate connection, not for the lifetime of the connection.
52+
// We use context.Background() here to avoid unexpected cancellation.
53+
stream, err := client.Tunnel(context.Background())
5454
if err != nil {
55-
return nil, fmt.Errorf("could not open udp tunnel for id: %s error:%w", id, err)
55+
return nil, fmt.Errorf("could not open tunnel for addr: %s error:%w", addr, err)
5656
}
5757
// Handshake message to start tunnel
58-
if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "udp", GuestAddr: guestAddr}); err != nil {
59-
return nil, fmt.Errorf("could not start udp tunnel for id: %s error:%w", id, err)
58+
id := fmt.Sprintf("%s-%s-%d", network, addr, connectionCounter.Add(1))
59+
if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: network, GuestAddr: addr}); err != nil {
60+
return nil, fmt.Errorf("could not start tunnel for id: %s addr: %s error:%w", id, addr, err)
6061
}
61-
rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "udp"}
62+
rw := &GrpcClientRW{stream: stream, id: id, addr: addr, protocol: network}
6263
return rw, nil
63-
})
64-
if err != nil {
65-
logrus.Errorf("error in udp tunnel proxy for id: %s error:%v", initialID, err)
66-
return
6764
}
68-
69-
defer func() {
70-
err := proxy.Close()
71-
if err != nil {
72-
logrus.Errorf("error in closing udp tunnel proxy for id: %s error:%v", initialID, err)
73-
}
74-
}()
75-
proxy.Run()
7665
}
7766

7867
type GrpcClientRW struct {

pkg/portfwd/forward.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"github.com/sirupsen/logrus"
1212

1313
"github.com/lima-vm/lima/v2/pkg/guestagent/api"
14-
guestagentclient "github.com/lima-vm/lima/v2/pkg/guestagent/api/client"
1514
"github.com/lima-vm/lima/v2/pkg/limatype"
1615
"github.com/lima-vm/lima/v2/pkg/limayaml"
1716
)
@@ -38,7 +37,7 @@ func (fw *Forwarder) Close() error {
3837
return fw.closableListeners.Close()
3938
}
4039

41-
func (fw *Forwarder) OnEvent(ctx context.Context, client *guestagentclient.GuestAgentClient, ev *api.Event) {
40+
func (fw *Forwarder) OnEvent(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), ev *api.Event) {
4241
for _, f := range ev.AddedLocalPorts {
4342
// Before forwarding, check if any static rule matches this port otherwise it will be forwarded twice and cause a port conflict
4443
if fw.isPortStaticallyForwarded(f) {
@@ -55,7 +54,7 @@ func (fw *Forwarder) OnEvent(ctx context.Context, client *guestagentclient.Guest
5554
continue
5655
}
5756
logrus.Infof("Forwarding %s from %s to %s", strings.ToUpper(f.Protocol), remote, local)
58-
fw.closableListeners.Forward(ctx, client, f.Protocol, local, remote)
57+
fw.closableListeners.Forward(ctx, dialContext, f.Protocol, local, remote)
5958
}
6059
for _, f := range ev.RemovedLocalPorts {
6160
local, remote := fw.forwardingAddresses(f)

pkg/portfwd/listener.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ import (
1414
"sync"
1515

1616
"github.com/sirupsen/logrus"
17-
18-
guestagentclient "github.com/lima-vm/lima/v2/pkg/guestagent/api/client"
1917
)
2018

2119
type ClosableListeners struct {
@@ -59,14 +57,14 @@ func (p *ClosableListeners) Close() error {
5957
return errors.Join(errs...)
6058
}
6159

62-
func (p *ClosableListeners) Forward(ctx context.Context, client *guestagentclient.GuestAgentClient,
60+
func (p *ClosableListeners) Forward(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error),
6361
protocol string, hostAddress string, guestAddress string,
6462
) {
6563
switch protocol {
6664
case "tcp", "tcp6":
67-
go p.forwardTCP(ctx, client, hostAddress, guestAddress)
65+
go p.forwardTCP(ctx, dialContext, hostAddress, guestAddress)
6866
case "udp", "udp6":
69-
go p.forwardUDP(ctx, client, hostAddress, guestAddress)
67+
go p.forwardUDP(ctx, dialContext, hostAddress, guestAddress)
7068
}
7169
}
7270

@@ -93,7 +91,7 @@ func (p *ClosableListeners) Remove(_ context.Context, protocol, hostAddress, gue
9391
}
9492
}
9593

96-
func (p *ClosableListeners) forwardTCP(ctx context.Context, client *guestagentclient.GuestAgentClient, hostAddress, guestAddress string) {
94+
func (p *ClosableListeners) forwardTCP(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), hostAddress, guestAddress string) {
9795
key := key("tcp", hostAddress, guestAddress)
9896

9997
p.listenersRW.Lock()
@@ -124,11 +122,11 @@ func (p *ClosableListeners) forwardTCP(ctx context.Context, client *guestagentcl
124122
}
125123
return
126124
}
127-
go HandleTCPConnection(ctx, client, conn, guestAddress)
125+
go HandleTCPConnection(ctx, dialContext, conn, guestAddress)
128126
}
129127
}
130128

131-
func (p *ClosableListeners) forwardUDP(ctx context.Context, client *guestagentclient.GuestAgentClient, hostAddress, guestAddress string) {
129+
func (p *ClosableListeners) forwardUDP(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), hostAddress, guestAddress string) {
132130
key := key("udp", hostAddress, guestAddress)
133131
defer p.Remove(ctx, "udp", hostAddress, guestAddress)
134132

@@ -148,7 +146,7 @@ func (p *ClosableListeners) forwardUDP(ctx context.Context, client *guestagentcl
148146
p.udpListeners[key] = udpConn
149147
p.udpListenersRW.Unlock()
150148

151-
HandleUDPConnection(ctx, client, udpConn, guestAddress)
149+
HandleUDPConnection(ctx, dialContext, udpConn, guestAddress)
152150
}
153151

154152
func key(protocol, hostAddress, guestAddress string) string {

0 commit comments

Comments
 (0)