@@ -38,37 +38,37 @@ const (
3838
3939// Shard 连接分片,封装单个QUIC连接及其流管理
4040type Shard struct {
41- streams sync.Map // 存储流的映射表
42- idChan chan string // 可用流ID通道
43- first atomic.Bool // 首次标志
44- quicConn atomic.Pointer [quic.Conn ] // QUIC连接
45- quicListener atomic.Pointer [quic.Listener ] // QUIC监听器
46- listenAddr atomic.Pointer [net.Addr ] // 监听器地址
47- index int // 分片索引
48- maxStreams int // 此分片的最大流数
41+ streams sync.Map // 存储流的映射表
42+ idChan chan string // 可用流ID通道
43+ first atomic.Bool // 首次标志
44+ quicConn atomic.Pointer [quic.Conn ] // QUIC连接
45+ listenAddr atomic.Pointer [net.Addr ] // 监听器地址(服务端)
46+ index int // 分片索引
47+ maxStreams int // 此分片的最大流数
4948}
5049
5150// Pool QUIC连接池结构体,用于管理QUIC流
5251type Pool struct {
53- shards []* Shard // 连接分片切片
54- numShards int // 分片数量
55- idChan chan string // 全局可用流ID通道
56- tlsCode string // TLS安全模式代码
57- hostname string // 主机名
58- clientIP string // 客户端IP
59- tlsConfig * tls.Config // TLS配置
60- addrResolver func () (string , error ) // 地址解析器
61- listenAddr string // 监听地址
62- errCount atomic.Int32 // 错误计数
63- capacity atomic.Int32 // 当前容量
64- minCap int // 最小容量
65- maxCap int // 最大容量
66- interval atomic.Int64 // 流创建间隔
67- minIvl time.Duration // 最小间隔
68- maxIvl time.Duration // 最大间隔
69- keepAlive time.Duration // 保活间隔
70- ctx context.Context // 上下文
71- cancel context.CancelFunc // 取消函数
52+ shards []* Shard // 连接分片切片
53+ numShards int // 分片数量
54+ idChan chan string // 全局可用流ID通道
55+ tlsCode string // TLS安全模式代码
56+ hostname string // 主机名
57+ clientIP string // 客户端IP
58+ tlsConfig * tls.Config // TLS配置
59+ addrResolver func () (string , error ) // 地址解析器
60+ listenAddr string // 监听地址
61+ sharedListener atomic.Pointer [quic.Listener ] // 共享监听器(服务端)
62+ errCount atomic.Int32 // 错误计数
63+ capacity atomic.Int32 // 当前容量
64+ minCap int // 最小容量
65+ maxCap int // 最大容量
66+ interval atomic.Int64 // 流创建间隔
67+ minIvl time.Duration // 最小间隔
68+ maxIvl time.Duration // 最大间隔
69+ keepAlive time.Duration // 保活间隔
70+ ctx context.Context // 上下文
71+ cancel context.CancelFunc // 取消函数
7272}
7373
7474// StreamConn 将QUIC流包装为接口
@@ -229,14 +229,11 @@ func (s *Shard) flushShard() {
229229 s .idChan = make (chan string , s .maxStreams )
230230}
231231
232- // closeShard 关闭分片的连接和监听器
232+ // closeShard 关闭分片的连接
233233func (s * Shard ) closeShard () {
234234 if conn := s .quicConn .Swap (nil ); conn != nil {
235235 conn .CloseWithError (0 , "pool closed" )
236236 }
237- if listener := s .quicListener .Swap (nil ); listener != nil {
238- listener .Close ()
239- }
240237}
241238
242239// NewClientPool 创建新的客户端QUIC池
@@ -440,31 +437,89 @@ func (s *Shard) establishConnection(ctx context.Context, addrResolver func() (st
440437 return nil
441438}
442439
443- // startListener 为分片启动QUIC监听器
444- func (s * Shard ) startListener ( listenAddr string , tlsConfig * tls. Config , keepAlive time. Duration ) error {
445- if s . quicListener .Load () != nil {
440+ // startSharedListener 启动共享监听器(服务端)
441+ func (p * Pool ) startSharedListener ( ) error {
442+ if p . sharedListener .Load () != nil {
446443 return nil
447444 }
448- if tlsConfig == nil {
449- return fmt .Errorf ("startListener : server mode requires TLS config" )
445+ if p . tlsConfig == nil {
446+ return fmt .Errorf ("startSharedListener : server mode requires TLS config" )
450447 }
451448
452- clonedTLS := tlsConfig .Clone ()
449+ clonedTLS := p . tlsConfig .Clone ()
453450 clonedTLS .NextProtos = []string {defaultALPN }
454451 clonedTLS .MinVersion = tls .VersionTLS13
455452
456- quicConfig := buildQUICConfig (keepAlive )
457- newListener , err := quic .ListenAddr (listenAddr , clonedTLS , quicConfig )
453+ quicConfig := buildQUICConfig (p . keepAlive )
454+ listener , err := quic .ListenAddr (p . listenAddr , clonedTLS , quicConfig )
458455 if err != nil {
459- return fmt .Errorf ("startListener[shard %d]: %w" , s .index , err )
456+ return fmt .Errorf ("startSharedListener: %w" , err )
457+ }
458+
459+ p .sharedListener .Store (listener )
460+
461+ // 保存监听器地址到所有分片
462+ addr := listener .Addr ()
463+ for _ , shard := range p .shards {
464+ shard .listenAddr .Store (& addr )
460465 }
461466
462- s .quicListener .Store (newListener )
463- addr := newListener .Addr ()
464- s .listenAddr .Store (& addr )
465467 return nil
466468}
467469
470+ // acceptAndDistribute 接受连接并轮询分配到分片
471+ func (p * Pool ) acceptAndDistribute () {
472+ var shardIndex int32 = - 1
473+
474+ for p .ctx .Err () == nil {
475+ listener := p .sharedListener .Load ()
476+ if listener == nil {
477+ return
478+ }
479+
480+ conn , err := listener .Accept (p .ctx )
481+ if err != nil {
482+ if p .ctx .Err () != nil {
483+ return
484+ }
485+ select {
486+ case <- p .ctx .Done ():
487+ return
488+ case <- time .After (acceptRetryInterval ):
489+ }
490+ continue
491+ }
492+
493+ // 验证客户端IP
494+ if p .clientIP != "" {
495+ remoteAddr := conn .RemoteAddr ().(* net.UDPAddr )
496+ if remoteAddr .IP .String () != p .clientIP {
497+ conn .CloseWithError (0 , "unauthorized IP" )
498+ continue
499+ }
500+ }
501+
502+ // 轮询分配到分片
503+ idx := int (atomic .AddInt32 (& shardIndex , 1 )) % p .numShards
504+ shard := p .shards [idx ]
505+ shard .quicConn .Store (conn )
506+
507+ // 为该连接启动流接受协程
508+ go p .acceptStreams (shard , conn )
509+ }
510+ }
511+
512+ // acceptStreams 为单个连接接受流
513+ func (p * Pool ) acceptStreams (shard * Shard , conn * quic.Conn ) {
514+ for p .ctx .Err () == nil {
515+ stream , err := (* conn ).AcceptStream (p .ctx )
516+ if err != nil {
517+ return
518+ }
519+ go shard .handleStream (stream , p .idChan , p .maxCap , p .Active )
520+ }
521+ }
522+
468523// ClientManager 客户端QUIC池管理器
469524func (p * Pool ) ClientManager () {
470525 if p .cancel != nil {
@@ -537,67 +592,16 @@ func (p *Pool) ServerManager() {
537592 }
538593 p .ctx , p .cancel = context .WithCancel (context .Background ())
539594
540- // 为每个分片启动独立的监听器
541- var wg sync.WaitGroup
542- for i := range p .shards {
543- wg .Add (1 )
544- go func (shard * Shard ) {
545- defer wg .Done ()
546- p .manageShardServer (shard )
547- }(p .shards [i ])
548- }
549- wg .Wait ()
550- }
551-
552- // manageShardServer 管理单个分片的服务端监听和流接收
553- func (p * Pool ) manageShardServer (shard * Shard ) {
554- // 启动分片的QUIC监听器
555- if err := shard .startListener (p .listenAddr , p .tlsConfig , p .keepAlive ); err != nil {
595+ // 启动共享监听器
596+ if err := p .startSharedListener (); err != nil {
556597 return
557598 }
558599
559- // 接受QUIC连接
560- for p .ctx .Err () == nil {
561- listener := shard .quicListener .Load ()
562- if listener == nil {
563- return
564- }
565-
566- conn , err := listener .Accept (p .ctx )
567- if err != nil {
568- if p .ctx .Err () != nil {
569- return
570- }
571- select {
572- case <- p .ctx .Done ():
573- return
574- case <- time .After (acceptRetryInterval ):
575- }
576- continue
577- }
578-
579- // 验证客户端IP
580- if p .clientIP != "" {
581- remoteAddr := conn .RemoteAddr ().(* net.UDPAddr )
582- if remoteAddr .IP .String () != p .clientIP {
583- conn .CloseWithError (0 , "unauthorized IP" )
584- continue
585- }
586- }
587-
588- // 存储连接并接受流
589- shard .quicConn .Store (conn )
600+ // 启动连接分配器
601+ go p .acceptAndDistribute ()
590602
591- go func (c * quic.Conn ) {
592- for p .ctx .Err () == nil {
593- stream , err := (* c ).AcceptStream (p .ctx )
594- if err != nil {
595- return
596- }
597- go shard .handleStream (stream , p .idChan , p .maxCap , p .Active )
598- }
599- }(conn )
600- }
603+ // 等待上下文取消
604+ <- p .ctx .Done ()
601605}
602606
603607// OutgoingGet 根据ID获取可用流
@@ -660,7 +664,12 @@ func (p *Pool) Close() {
660664 }
661665 p .Flush ()
662666
663- // 并行关闭所有分片的连接和监听器
667+ // 关闭共享监听器
668+ if listener := p .sharedListener .Swap (nil ); listener != nil {
669+ listener .Close ()
670+ }
671+
672+ // 并行关闭所有分片的连接
664673 var wg sync.WaitGroup
665674 for _ , shard := range p .shards {
666675 wg .Add (1 )
@@ -670,7 +679,9 @@ func (p *Pool) Close() {
670679 }(shard )
671680 }
672681 wg .Wait ()
673- } // Ready 检查连接池是否已初始化
682+ }
683+
684+ // Ready 检查连接池是否已初始化
674685func (p * Pool ) Ready () bool {
675686 return p .ctx != nil
676687}
0 commit comments