Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/hostagent/hostagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,8 @@ func (a *HostAgent) processGuestAgentEvents(ctx context.Context, client *guestag
if useSSHFwd {
a.portForwarder.OnEvent(ctx, ev)
} else {
a.grpcPortForwarder.OnEvent(ctx, client, ev)
dialContext := portfwd.DialContextToGRPCTunnel(client)
a.grpcPortForwarder.OnEvent(ctx, dialContext, ev)
}
}

Expand Down
65 changes: 27 additions & 38 deletions pkg/portfwd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,61 +18,50 @@ import (
guestagentclient "github.com/lima-vm/lima/v2/pkg/guestagent/api/client"
)

func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.Conn, guestAddr string) {
id := fmt.Sprintf("tcp-%s-%s", conn.LocalAddr().String(), conn.RemoteAddr().String())
func HandleTCPConnection(_ context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), conn net.Conn, guestAddr string) {
proxy := tcpproxy.DialProxy{Addr: guestAddr, DialContext: dialContext}
proxy.HandleConn(conn)
}

stream, err := client.Tunnel(ctx)
func HandleUDPConnection(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), conn net.PacketConn, guestAddr string) {
proxy, err := forwarder.NewUDPProxy(conn, func() (net.Conn, error) {
return dialContext(ctx, "udp", guestAddr)
})
if err != nil {
logrus.Errorf("could not open tcp tunnel for id: %s error:%v", id, err)
return
}

// Handshake message to start tunnel
if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "tcp", GuestAddr: guestAddr}); err != nil {
logrus.Errorf("could not start tcp tunnel for id: %s error:%v", id, err)
logrus.WithError(err).Error("error in udp tunnel proxy")
return
}

rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "tcp"}
proxy := tcpproxy.DialProxy{DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return conn, nil
}}
proxy.HandleConn(rw)
defer func() {
err := proxy.Close()
if err != nil {
logrus.WithError(err).Error("error in closing udp tunnel proxy")
}
}()
proxy.Run()
}

func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.PacketConn, guestAddr string) {
var udpConnectionCounter atomic.Uint32
initialID := fmt.Sprintf("udp-%s", conn.LocalAddr().String())

func DialContextToGRPCTunnel(client *guestagentclient.GuestAgentClient) func(ctx context.Context, network, addr string) (net.Conn, error) {
// gvisor-tap-vsock's UDPProxy demultiplexes client connections internally based on their source address.
// It calls this dialer function only when it receives a datagram from a new, unrecognized client.
// For each new client, we must return a new net.Conn, which in our case is a new gRPC stream.
// The atomic counter ensures that each stream has a unique ID to distinguish them on the server side.
proxy, err := forwarder.NewUDPProxy(conn, func() (net.Conn, error) {
id := fmt.Sprintf("%s-%d", initialID, udpConnectionCounter.Add(1))
stream, err := client.Tunnel(ctx)
var connectionCounter atomic.Uint32
return func(_ context.Context, network, addr string) (net.Conn, error) {
// Passed context.Context is used for timeout on initiate connection, not for the lifetime of the connection.
// We use context.Background() here to avoid unexpected cancellation.
stream, err := client.Tunnel(context.Background())
if err != nil {
return nil, fmt.Errorf("could not open udp tunnel for id: %s error:%w", id, err)
return nil, fmt.Errorf("could not open tunnel for addr: %s error:%w", addr, err)
}
// Handshake message to start tunnel
if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "udp", GuestAddr: guestAddr}); err != nil {
return nil, fmt.Errorf("could not start udp tunnel for id: %s error:%w", id, err)
id := fmt.Sprintf("%s-%s-%d", network, addr, connectionCounter.Add(1))
if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: network, GuestAddr: addr}); err != nil {
return nil, fmt.Errorf("could not start tunnel for id: %s addr: %s error:%w", id, addr, err)
}
rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "udp"}
rw := &GrpcClientRW{stream: stream, id: id, addr: addr, protocol: network}
return rw, nil
})
if err != nil {
logrus.Errorf("error in udp tunnel proxy for id: %s error:%v", initialID, err)
return
}

defer func() {
err := proxy.Close()
if err != nil {
logrus.Errorf("error in closing udp tunnel proxy for id: %s error:%v", initialID, err)
}
}()
proxy.Run()
}

type GrpcClientRW struct {
Expand Down
5 changes: 2 additions & 3 deletions pkg/portfwd/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/sirupsen/logrus"

"github.com/lima-vm/lima/v2/pkg/guestagent/api"
guestagentclient "github.com/lima-vm/lima/v2/pkg/guestagent/api/client"
"github.com/lima-vm/lima/v2/pkg/limatype"
"github.com/lima-vm/lima/v2/pkg/limayaml"
)
Expand All @@ -38,7 +37,7 @@ func (fw *Forwarder) Close() error {
return fw.closableListeners.Close()
}

func (fw *Forwarder) OnEvent(ctx context.Context, client *guestagentclient.GuestAgentClient, ev *api.Event) {
func (fw *Forwarder) OnEvent(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), ev *api.Event) {
for _, f := range ev.AddedLocalPorts {
// Before forwarding, check if any static rule matches this port otherwise it will be forwarded twice and cause a port conflict
if fw.isPortStaticallyForwarded(f) {
Expand All @@ -55,7 +54,7 @@ func (fw *Forwarder) OnEvent(ctx context.Context, client *guestagentclient.Guest
continue
}
logrus.Infof("Forwarding %s from %s to %s", strings.ToUpper(f.Protocol), remote, local)
fw.closableListeners.Forward(ctx, client, f.Protocol, local, remote)
fw.closableListeners.Forward(ctx, dialContext, f.Protocol, local, remote)
}
for _, f := range ev.RemovedLocalPorts {
local, remote := fw.forwardingAddresses(f)
Expand Down
16 changes: 7 additions & 9 deletions pkg/portfwd/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import (
"sync"

"github.com/sirupsen/logrus"

guestagentclient "github.com/lima-vm/lima/v2/pkg/guestagent/api/client"
)

type ClosableListeners struct {
Expand Down Expand Up @@ -59,14 +57,14 @@ func (p *ClosableListeners) Close() error {
return errors.Join(errs...)
}

func (p *ClosableListeners) Forward(ctx context.Context, client *guestagentclient.GuestAgentClient,
func (p *ClosableListeners) Forward(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error),
protocol string, hostAddress string, guestAddress string,
) {
switch protocol {
case "tcp", "tcp6":
go p.forwardTCP(ctx, client, hostAddress, guestAddress)
go p.forwardTCP(ctx, dialContext, hostAddress, guestAddress)
case "udp", "udp6":
go p.forwardUDP(ctx, client, hostAddress, guestAddress)
go p.forwardUDP(ctx, dialContext, hostAddress, guestAddress)
}
}

Expand All @@ -93,7 +91,7 @@ func (p *ClosableListeners) Remove(_ context.Context, protocol, hostAddress, gue
}
}

func (p *ClosableListeners) forwardTCP(ctx context.Context, client *guestagentclient.GuestAgentClient, hostAddress, guestAddress string) {
func (p *ClosableListeners) forwardTCP(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), hostAddress, guestAddress string) {
key := key("tcp", hostAddress, guestAddress)

p.listenersRW.Lock()
Expand Down Expand Up @@ -124,11 +122,11 @@ func (p *ClosableListeners) forwardTCP(ctx context.Context, client *guestagentcl
}
return
}
go HandleTCPConnection(ctx, client, conn, guestAddress)
go HandleTCPConnection(ctx, dialContext, conn, guestAddress)
}
}

func (p *ClosableListeners) forwardUDP(ctx context.Context, client *guestagentclient.GuestAgentClient, hostAddress, guestAddress string) {
func (p *ClosableListeners) forwardUDP(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), hostAddress, guestAddress string) {
key := key("udp", hostAddress, guestAddress)
defer p.Remove(ctx, "udp", hostAddress, guestAddress)

Expand All @@ -148,7 +146,7 @@ func (p *ClosableListeners) forwardUDP(ctx context.Context, client *guestagentcl
p.udpListeners[key] = udpConn
p.udpListenersRW.Unlock()

HandleUDPConnection(ctx, client, udpConn, guestAddress)
HandleUDPConnection(ctx, dialContext, udpConn, guestAddress)
}

func key(protocol, hostAddress, guestAddress string) string {
Expand Down
Loading