Skip to content

Commit 7d3a501

Browse files
authored
feat: implement shared listener for QUIC connection pooling and streamline shard management
1 parent 6d4c991 commit 7d3a501

File tree

1 file changed

+113
-102
lines changed

1 file changed

+113
-102
lines changed

quic.go

Lines changed: 113 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -38,37 +38,37 @@ const (
3838

3939
// Shard 连接分片,封装单个QUIC连接及其流管理
4040
type Shard struct {
41-
streams sync.Map // 存储流的映射表
42-
idChan chan string // 可用流ID通道
43-
first atomic.Bool // 首次标志
44-
quicConn atomic.Pointer[quic.Conn] // QUIC连接
45-
quicListener atomic.Pointer[quic.Listener] // QUIC监听器
46-
listenAddr atomic.Pointer[net.Addr] // 监听器地址
47-
index int // 分片索引
48-
maxStreams int // 此分片的最大流数
41+
streams sync.Map // 存储流的映射表
42+
idChan chan string // 可用流ID通道
43+
first atomic.Bool // 首次标志
44+
quicConn atomic.Pointer[quic.Conn] // QUIC连接
45+
listenAddr atomic.Pointer[net.Addr] // 监听器地址(服务端)
46+
index int // 分片索引
47+
maxStreams int // 此分片的最大流数
4948
}
5049

5150
// Pool QUIC连接池结构体,用于管理QUIC流
5251
type Pool struct {
53-
shards []*Shard // 连接分片切片
54-
numShards int // 分片数量
55-
idChan chan string // 全局可用流ID通道
56-
tlsCode string // TLS安全模式代码
57-
hostname string // 主机名
58-
clientIP string // 客户端IP
59-
tlsConfig *tls.Config // TLS配置
60-
addrResolver func() (string, error) // 地址解析器
61-
listenAddr string // 监听地址
62-
errCount atomic.Int32 // 错误计数
63-
capacity atomic.Int32 // 当前容量
64-
minCap int // 最小容量
65-
maxCap int // 最大容量
66-
interval atomic.Int64 // 流创建间隔
67-
minIvl time.Duration // 最小间隔
68-
maxIvl time.Duration // 最大间隔
69-
keepAlive time.Duration // 保活间隔
70-
ctx context.Context // 上下文
71-
cancel context.CancelFunc // 取消函数
52+
shards []*Shard // 连接分片切片
53+
numShards int // 分片数量
54+
idChan chan string // 全局可用流ID通道
55+
tlsCode string // TLS安全模式代码
56+
hostname string // 主机名
57+
clientIP string // 客户端IP
58+
tlsConfig *tls.Config // TLS配置
59+
addrResolver func() (string, error) // 地址解析器
60+
listenAddr string // 监听地址
61+
sharedListener atomic.Pointer[quic.Listener] // 共享监听器(服务端)
62+
errCount atomic.Int32 // 错误计数
63+
capacity atomic.Int32 // 当前容量
64+
minCap int // 最小容量
65+
maxCap int // 最大容量
66+
interval atomic.Int64 // 流创建间隔
67+
minIvl time.Duration // 最小间隔
68+
maxIvl time.Duration // 最大间隔
69+
keepAlive time.Duration // 保活间隔
70+
ctx context.Context // 上下文
71+
cancel context.CancelFunc // 取消函数
7272
}
7373

7474
// StreamConn 将QUIC流包装为接口
@@ -229,14 +229,11 @@ func (s *Shard) flushShard() {
229229
s.idChan = make(chan string, s.maxStreams)
230230
}
231231

232-
// closeShard 关闭分片的连接和监听器
232+
// closeShard 关闭分片的连接
233233
func (s *Shard) closeShard() {
234234
if conn := s.quicConn.Swap(nil); conn != nil {
235235
conn.CloseWithError(0, "pool closed")
236236
}
237-
if listener := s.quicListener.Swap(nil); listener != nil {
238-
listener.Close()
239-
}
240237
}
241238

242239
// NewClientPool 创建新的客户端QUIC池
@@ -440,31 +437,89 @@ func (s *Shard) establishConnection(ctx context.Context, addrResolver func() (st
440437
return nil
441438
}
442439

443-
// startListener 为分片启动QUIC监听器
444-
func (s *Shard) startListener(listenAddr string, tlsConfig *tls.Config, keepAlive time.Duration) error {
445-
if s.quicListener.Load() != nil {
440+
// startSharedListener 启动共享监听器(服务端)
441+
func (p *Pool) startSharedListener() error {
442+
if p.sharedListener.Load() != nil {
446443
return nil
447444
}
448-
if tlsConfig == nil {
449-
return fmt.Errorf("startListener: server mode requires TLS config")
445+
if p.tlsConfig == nil {
446+
return fmt.Errorf("startSharedListener: server mode requires TLS config")
450447
}
451448

452-
clonedTLS := tlsConfig.Clone()
449+
clonedTLS := p.tlsConfig.Clone()
453450
clonedTLS.NextProtos = []string{defaultALPN}
454451
clonedTLS.MinVersion = tls.VersionTLS13
455452

456-
quicConfig := buildQUICConfig(keepAlive)
457-
newListener, err := quic.ListenAddr(listenAddr, clonedTLS, quicConfig)
453+
quicConfig := buildQUICConfig(p.keepAlive)
454+
listener, err := quic.ListenAddr(p.listenAddr, clonedTLS, quicConfig)
458455
if err != nil {
459-
return fmt.Errorf("startListener[shard %d]: %w", s.index, err)
456+
return fmt.Errorf("startSharedListener: %w", err)
457+
}
458+
459+
p.sharedListener.Store(listener)
460+
461+
// 保存监听器地址到所有分片
462+
addr := listener.Addr()
463+
for _, shard := range p.shards {
464+
shard.listenAddr.Store(&addr)
460465
}
461466

462-
s.quicListener.Store(newListener)
463-
addr := newListener.Addr()
464-
s.listenAddr.Store(&addr)
465467
return nil
466468
}
467469

470+
// acceptAndDistribute 接受连接并轮询分配到分片
471+
func (p *Pool) acceptAndDistribute() {
472+
var shardIndex int32 = -1
473+
474+
for p.ctx.Err() == nil {
475+
listener := p.sharedListener.Load()
476+
if listener == nil {
477+
return
478+
}
479+
480+
conn, err := listener.Accept(p.ctx)
481+
if err != nil {
482+
if p.ctx.Err() != nil {
483+
return
484+
}
485+
select {
486+
case <-p.ctx.Done():
487+
return
488+
case <-time.After(acceptRetryInterval):
489+
}
490+
continue
491+
}
492+
493+
// 验证客户端IP
494+
if p.clientIP != "" {
495+
remoteAddr := conn.RemoteAddr().(*net.UDPAddr)
496+
if remoteAddr.IP.String() != p.clientIP {
497+
conn.CloseWithError(0, "unauthorized IP")
498+
continue
499+
}
500+
}
501+
502+
// 轮询分配到分片
503+
idx := int(atomic.AddInt32(&shardIndex, 1)) % p.numShards
504+
shard := p.shards[idx]
505+
shard.quicConn.Store(conn)
506+
507+
// 为该连接启动流接受协程
508+
go p.acceptStreams(shard, conn)
509+
}
510+
}
511+
512+
// acceptStreams 为单个连接接受流
513+
func (p *Pool) acceptStreams(shard *Shard, conn *quic.Conn) {
514+
for p.ctx.Err() == nil {
515+
stream, err := (*conn).AcceptStream(p.ctx)
516+
if err != nil {
517+
return
518+
}
519+
go shard.handleStream(stream, p.idChan, p.maxCap, p.Active)
520+
}
521+
}
522+
468523
// ClientManager 客户端QUIC池管理器
469524
func (p *Pool) ClientManager() {
470525
if p.cancel != nil {
@@ -537,67 +592,16 @@ func (p *Pool) ServerManager() {
537592
}
538593
p.ctx, p.cancel = context.WithCancel(context.Background())
539594

540-
// 为每个分片启动独立的监听器
541-
var wg sync.WaitGroup
542-
for i := range p.shards {
543-
wg.Add(1)
544-
go func(shard *Shard) {
545-
defer wg.Done()
546-
p.manageShardServer(shard)
547-
}(p.shards[i])
548-
}
549-
wg.Wait()
550-
}
551-
552-
// manageShardServer 管理单个分片的服务端监听和流接收
553-
func (p *Pool) manageShardServer(shard *Shard) {
554-
// 启动分片的QUIC监听器
555-
if err := shard.startListener(p.listenAddr, p.tlsConfig, p.keepAlive); err != nil {
595+
// 启动共享监听器
596+
if err := p.startSharedListener(); err != nil {
556597
return
557598
}
558599

559-
// 接受QUIC连接
560-
for p.ctx.Err() == nil {
561-
listener := shard.quicListener.Load()
562-
if listener == nil {
563-
return
564-
}
565-
566-
conn, err := listener.Accept(p.ctx)
567-
if err != nil {
568-
if p.ctx.Err() != nil {
569-
return
570-
}
571-
select {
572-
case <-p.ctx.Done():
573-
return
574-
case <-time.After(acceptRetryInterval):
575-
}
576-
continue
577-
}
578-
579-
// 验证客户端IP
580-
if p.clientIP != "" {
581-
remoteAddr := conn.RemoteAddr().(*net.UDPAddr)
582-
if remoteAddr.IP.String() != p.clientIP {
583-
conn.CloseWithError(0, "unauthorized IP")
584-
continue
585-
}
586-
}
587-
588-
// 存储连接并接受流
589-
shard.quicConn.Store(conn)
600+
// 启动连接分配器
601+
go p.acceptAndDistribute()
590602

591-
go func(c *quic.Conn) {
592-
for p.ctx.Err() == nil {
593-
stream, err := (*c).AcceptStream(p.ctx)
594-
if err != nil {
595-
return
596-
}
597-
go shard.handleStream(stream, p.idChan, p.maxCap, p.Active)
598-
}
599-
}(conn)
600-
}
603+
// 等待上下文取消
604+
<-p.ctx.Done()
601605
}
602606

603607
// OutgoingGet 根据ID获取可用流
@@ -660,7 +664,12 @@ func (p *Pool) Close() {
660664
}
661665
p.Flush()
662666

663-
// 并行关闭所有分片的连接和监听器
667+
// 关闭共享监听器
668+
if listener := p.sharedListener.Swap(nil); listener != nil {
669+
listener.Close()
670+
}
671+
672+
// 并行关闭所有分片的连接
664673
var wg sync.WaitGroup
665674
for _, shard := range p.shards {
666675
wg.Add(1)
@@ -670,7 +679,9 @@ func (p *Pool) Close() {
670679
}(shard)
671680
}
672681
wg.Wait()
673-
} // Ready 检查连接池是否已初始化
682+
}
683+
684+
// Ready 检查连接池是否已初始化
674685
func (p *Pool) Ready() bool {
675686
return p.ctx != nil
676687
}

0 commit comments

Comments
 (0)