Skip to content

Commit f896daa

Browse files
committed
migrate to reliable.Group; add cancel stream on context cancel
1 parent a3b13d6 commit f896daa

File tree

13 files changed

+95
-103
lines changed

13 files changed

+95
-103
lines changed

cmd/connet/client.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ import (
1616
"github.com/connet-dev/connet/model"
1717
"github.com/connet-dev/connet/nat"
1818
"github.com/connet-dev/connet/netc"
19+
"github.com/connet-dev/connet/reliable"
1920
"github.com/connet-dev/connet/statusc"
2021
"github.com/quic-go/quic-go"
2122
"github.com/spf13/cobra"
22-
"golang.org/x/sync/errgroup"
2323
)
2424

2525
type ClientConfig struct {
@@ -240,10 +240,10 @@ func clientRun(ctx context.Context, cfg ClientConfig, logger *slog.Logger) error
240240
return fmt.Errorf("create client: %w", err)
241241
}
242242

243-
g, ctx := errgroup.WithContext(ctx)
243+
g := reliable.NewGroup(ctx)
244244

245245
if statusAddr != nil {
246-
g.Go(func() error {
246+
g.Go(func(ctx context.Context) error {
247247
logger.Debug("running status server", "addr", statusAddr)
248248
return statusc.Run(ctx, statusAddr, cl.Status)
249249
})
@@ -254,12 +254,12 @@ func clientRun(ctx context.Context, cfg ClientConfig, logger *slog.Logger) error
254254
if err != nil {
255255
return err
256256
}
257-
g.Go(func() error {
257+
g.Go(func(ctx context.Context) error {
258258
<-dst.Context().Done()
259259
return fmt.Errorf("[destination %s] unexpected error: %w", name, context.Cause(dst.Context()))
260260
})
261261
if dstrun := destinationHandlers[name]; dstrun != nil {
262-
g.Go(func() error { return dstrun(dst).Run(ctx) })
262+
g.Go(dstrun(dst).Run)
263263
}
264264
}
265265

@@ -268,12 +268,12 @@ func clientRun(ctx context.Context, cfg ClientConfig, logger *slog.Logger) error
268268
if err != nil {
269269
return err
270270
}
271-
g.Go(func() error {
271+
g.Go(func(ctx context.Context) error {
272272
<-src.Context().Done()
273273
return fmt.Errorf("[source %s] unexpected error: %w", name, context.Cause(src.Context()))
274274
})
275275
if srcrun := sourceHandlers[name]; srcrun != nil {
276-
g.Go(func() error { return srcrun(src).Run(ctx) })
276+
g.Go(srcrun(src).Run)
277277
}
278278
}
279279

cmd/connet/main.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ import (
1717

1818
"github.com/connet-dev/connet/model"
1919
"github.com/connet-dev/connet/netc"
20+
"github.com/connet-dev/connet/reliable"
2021
"github.com/connet-dev/connet/slogc"
2122
"github.com/connet-dev/connet/statusc"
2223
"github.com/pelletier/go-toml/v2"
2324
"github.com/spf13/cobra"
24-
"golang.org/x/sync/errgroup"
2525
)
2626

2727
type Config struct {
@@ -269,9 +269,9 @@ func runWithStatus[T any](ctx context.Context, srv withStatus[T], statusAddr *ne
269269
return srv.Run(ctx)
270270
}
271271

272-
g, ctx := errgroup.WithContext(ctx)
273-
g.Go(func() error { return srv.Run(ctx) })
274-
g.Go(func() error {
272+
g := reliable.NewGroup(ctx)
273+
g.Go(srv.Run)
274+
g.Go(func(ctx context.Context) error {
275275
logger.Debug("running status server", "addr", statusAddr)
276276
return statusc.Run(ctx, statusAddr, srv.Status)
277277
})

control/clients.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"github.com/connet-dev/connet/reliable"
2525
"github.com/connet-dev/connet/slogc"
2626
"github.com/quic-go/quic-go"
27-
"golang.org/x/sync/errgroup"
2827
)
2928

3029
type ClientAuthenticateRequest struct {
@@ -623,9 +622,9 @@ func (s *clientStream) announce(ctx context.Context, req *pbclient.Request_Annou
623622
}
624623
}()
625624

626-
g, ctx := errgroup.WithContext(ctx)
625+
g := reliable.NewGroup(ctx)
627626

628-
g.Go(func() error {
627+
g.Go(func(ctx context.Context) error {
629628
for {
630629
req, err := pbclient.ReadRequest(s.stream)
631630
if err != nil {
@@ -652,7 +651,7 @@ func (s *clientStream) announce(ctx context.Context, req *pbclient.Request_Annou
652651
}
653652
})
654653

655-
g.Go(func() error {
654+
g.Go(func(ctx context.Context) error {
656655
defer s.conn.logger.Debug("completed sources notify")
657656
return s.conn.server.listen(ctx, endpoint, role.Invert(), func(peers []*pbclient.RemotePeer) error {
658657
s.conn.logger.Debug("updated sources list", "peers", len(peers))
@@ -697,15 +696,15 @@ func (s *clientStream) relay(ctx context.Context, req *pbclient.Request_Relay) e
697696
return err
698697
}
699698

700-
g, ctx := errgroup.WithContext(ctx)
699+
g := reliable.NewGroup(ctx)
701700

702-
g.Go(func() error {
701+
g.Go(func(ctx context.Context) error {
703702
connCtx := s.conn.conn.Context()
704703
<-connCtx.Done()
705704
return context.Cause(connCtx)
706705
})
707706

708-
g.Go(func() error {
707+
g.Go(func(ctx context.Context) error {
709708
defer s.conn.logger.Debug("completed relay notify")
710709
return s.conn.server.relays.Client(ctx, endpoint, role, clientCert, s.conn.auth, func(relays map[RelayID]relayCacheValue) error {
711710
s.conn.logger.Debug("updated relay list", "relays", len(relays))

e2e_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ import (
2525
"github.com/connet-dev/connet/model"
2626
"github.com/connet-dev/connet/netc"
2727
relaysrv "github.com/connet-dev/connet/relay"
28+
"github.com/connet-dev/connet/reliable"
2829
"github.com/connet-dev/connet/restr"
2930
"github.com/connet-dev/connet/selfhosted"
3031
"github.com/connet-dev/connet/server"
3132
"github.com/connet-dev/connet/statusc"
3233
"github.com/gorilla/websocket"
3334
"github.com/pires/go-proxyproto"
3435
"github.com/stretchr/testify/require"
35-
"golang.org/x/sync/errgroup"
3636
)
3737

3838
type connectedTestCase struct {
@@ -212,10 +212,10 @@ func TestE2E(t *testing.T) {
212212
)
213213
require.NoError(t, err)
214214

215-
g, ctx := errgroup.WithContext(ctx)
216-
g.Go(func() error { return proxyProtoServer(ctx, ppListen) })
217-
g.Go(func() error { return echoServer(ctx, echoListen) })
218-
g.Go(func() error { return srv.Run(ctx) })
215+
g := reliable.NewGroup(ctx)
216+
g.Go(reliable.Bind(ppListen, proxyProtoServer))
217+
g.Go(reliable.Bind(echoListen, echoServer))
218+
g.Go(srv.Run)
219219

220220
time.Sleep(time.Millisecond) // time for server to come online
221221

@@ -427,17 +427,17 @@ func TestE2E(t *testing.T) {
427427
switch {
428428
case tc.isSuccessProxyProto():
429429
dstSrv := NewTCPDestination(dst, ppAddr, 0, logger)
430-
g.Go(func() error { return dstSrv.Run(ctx) })
430+
g.Go(dstSrv.Run)
431431
case tc.isSuccessTLS():
432432
clientTransport := htsServer.Client().Transport.(*http.Transport)
433433
dstSrv := NewTLSDestination(dst, htsAddr, clientTransport.TLSClientConfig, 0, logger)
434-
g.Go(func() error { return dstSrv.Run(ctx) })
434+
g.Go(dstSrv.Run)
435435
case tc.isSuccessWS():
436436
dstSrv := NewTCPDestination(dst, echoAddr, 0, logger)
437-
g.Go(func() error { return dstSrv.Run(ctx) })
437+
g.Go(dstSrv.Run)
438438
default:
439439
dstSrv := NewTCPDestination(dst, htAddr, 0, logger)
440-
g.Go(func() error { return dstSrv.Run(ctx) })
440+
g.Go(dstSrv.Run)
441441
}
442442

443443
src, err := clSrc.Source(ctx, tc.s)
@@ -446,20 +446,20 @@ func TestE2E(t *testing.T) {
446446
switch {
447447
case tc.isSuccessTLS():
448448
srcSrv := NewTLSSource(src, fmt.Sprintf(":%d", tc.sport), htsServer.TLS, logger)
449-
g.Go(func() error { return srcSrv.Run(ctx) })
449+
g.Go(srcSrv.Run)
450450
case tc.isSuccessWS():
451451
srcURL, err := url.Parse(fmt.Sprintf("ws://:%d", tc.sport))
452452
require.NoError(t, err)
453453
srcSrv := NewWSSource(src, srcURL, nil, logger)
454-
g.Go(func() error { return srcSrv.Run(ctx) })
454+
g.Go(srcSrv.Run)
455455
case tc.isFail():
456456
srcURL, err := url.Parse(fmt.Sprintf("http://:%d", tc.sport))
457457
require.NoError(t, err)
458458
srcSrv := NewHTTPSource(src, srcURL, nil)
459-
g.Go(func() error { return srcSrv.Run(ctx) })
459+
g.Go(srcSrv.Run)
460460
default:
461461
srcSrv := NewTCPSource(src, fmt.Sprintf(":%d", tc.sport), logger)
462-
g.Go(func() error { return srcSrv.Run(ctx) })
462+
g.Go(srcSrv.Run)
463463
}
464464
}
465465

endpoint.go

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ import (
1212
"github.com/connet-dev/connet/model"
1313
"github.com/connet-dev/connet/proto"
1414
"github.com/connet-dev/connet/proto/pbclient"
15+
"github.com/connet-dev/connet/quicc"
1516
"github.com/connet-dev/connet/reliable"
1617
"github.com/connet-dev/connet/slogc"
1718
"github.com/connet-dev/connet/statusc"
1819
"github.com/quic-go/quic-go"
19-
"golang.org/x/sync/errgroup"
2020
)
2121

2222
type endpointStatus struct {
@@ -177,15 +177,10 @@ func (ep *endpoint) runAnnounce(ctx context.Context, conn *quic.Conn) error {
177177
}
178178
}()
179179

180-
g, ctx := errgroup.WithContext(ctx)
180+
g := reliable.NewGroup(ctx)
181+
g.Go(quicc.CancelStream(stream))
181182

182-
g.Go(func() error {
183-
<-ctx.Done()
184-
stream.CancelRead(0)
185-
return nil
186-
})
187-
188-
g.Go(func() error {
183+
g.Go(func(ctx context.Context) error {
189184
defer ep.logger.Debug("completed announce notify")
190185
return ep.peer.selfListen(ctx, func(peer *pbclient.Peer) error {
191186
ep.logger.Debug("updated announce", "direct", len(peer.Directs), "relays", len(peer.RelayIds))
@@ -199,7 +194,7 @@ func (ep *endpoint) runAnnounce(ctx context.Context, conn *quic.Conn) error {
199194
})
200195
})
201196

202-
g.Go(func() error {
197+
g.Go(func(ctx context.Context) error {
203198
for {
204199
resp, err := pbclient.ReadResponse(stream)
205200
ep.onlineReport(err)
@@ -240,15 +235,10 @@ func (ep *endpoint) runRelay(ctx context.Context, conn *quic.Conn) error {
240235
return err
241236
}
242237

243-
g, ctx := errgroup.WithContext(ctx)
244-
245-
g.Go(func() error {
246-
<-ctx.Done()
247-
stream.CancelRead(0)
248-
return nil
249-
})
238+
g := reliable.NewGroup(ctx)
239+
g.Go(quicc.CancelStream(stream))
250240

251-
g.Go(func() error {
241+
g.Go(func(ctx context.Context) error {
252242
for {
253243
resp, err := pbclient.ReadResponse(stream)
254244
if err != nil {

netc/join.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
func Join(l io.ReadWriteCloser, r io.ReadWriteCloser) error {
1313
var g errgroup.Group
14+
1415
g.Go(func() error {
1516
defer func() {
1617
if err := l.Close(); err != nil {
@@ -20,6 +21,7 @@ func Join(l io.ReadWriteCloser, r io.ReadWriteCloser) error {
2021
_, err := io.Copy(l, r)
2122
return err
2223
})
24+
2325
g.Go(func() error {
2426
defer func() {
2527
if err := r.Close(); err != nil {
@@ -29,6 +31,7 @@ func Join(l io.ReadWriteCloser, r io.ReadWriteCloser) error {
2931
_, err := io.Copy(r, l)
3032
return err
3133
})
34+
3235
return g.Wait()
3336
}
3437

peer_remote.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"github.com/connet-dev/connet/reliable"
2222
"github.com/connet-dev/connet/slogc"
2323
"github.com/quic-go/quic-go"
24-
"golang.org/x/sync/errgroup"
2524
)
2625

2726
type remotePeer struct {
@@ -58,10 +57,10 @@ func (p *remotePeer) run(ctx context.Context) {
5857
p.local.removeActiveConns(p.remoteID)
5958
}()
6059

61-
g, ctx := errgroup.WithContext(ctx)
60+
g := reliable.NewGroup(ctx)
6261

63-
g.Go(func() error { return p.runRemote(ctx) })
64-
g.Go(func() error {
62+
g.Go(p.runRemote)
63+
g.Go(func(ctx context.Context) error {
6564
<-p.closer
6665
return errPeeringStop
6766
})

quicc/stream.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package quicc
2+
3+
import (
4+
"context"
5+
6+
"github.com/quic-go/quic-go"
7+
)
8+
9+
func CancelStream(stream *quic.Stream) func(ctx context.Context) error {
10+
return func(ctx context.Context) error {
11+
<-ctx.Done()
12+
stream.CancelRead(0)
13+
return nil
14+
}
15+
}

relay.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"github.com/connet-dev/connet/reliable"
1919
"github.com/connet-dev/connet/slogc"
2020
"github.com/quic-go/quic-go"
21-
"golang.org/x/sync/errgroup"
2221
)
2322

2423
type relayID string
@@ -48,10 +47,10 @@ func newRelay(local *peer, id relayID, hps []model.HostPort, serverConf *serverT
4847
}
4948

5049
func (r *relay) run(ctx context.Context) {
51-
g, ctx := errgroup.WithContext(ctx)
50+
g := reliable.NewGroup(ctx)
5251

53-
g.Go(func() error { return r.runConn(ctx) })
54-
g.Go(func() error {
52+
g.Go(r.runConn)
53+
g.Go(func(ctx context.Context) error {
5554
<-r.closer
5655
return errPeeringStop
5756
})

relay/clients.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ import (
2020
"github.com/connet-dev/connet/proto/pbconnect"
2121
"github.com/connet-dev/connet/proto/pberror"
2222
"github.com/connet-dev/connet/quicc"
23+
"github.com/connet-dev/connet/reliable"
2324
"github.com/connet-dev/connet/slogc"
2425
"github.com/quic-go/quic-go"
25-
"golang.org/x/sync/errgroup"
2626
)
2727

2828
type clientAuth struct {
@@ -331,9 +331,9 @@ func (c *clientConn) runSource(ctx context.Context) error {
331331
fcs := c.server.addSource(c)
332332
defer c.server.removeSource(fcs, c)
333333

334-
g, ctx := errgroup.WithContext(ctx)
334+
g := reliable.NewGroup(ctx)
335335

336-
g.Go(func() error {
336+
g.Go(func(ctx context.Context) error {
337337
quicc.LogRTTStats(c.conn, c.logger)
338338
for {
339339
select {
@@ -347,7 +347,7 @@ func (c *clientConn) runSource(ctx context.Context) error {
347347
}
348348
})
349349

350-
g.Go(func() error {
350+
g.Go(func(ctx context.Context) error {
351351
for {
352352
stream, err := c.conn.AcceptStream(ctx)
353353
if err != nil {

0 commit comments

Comments
 (0)