diff --git a/ra_sender.go b/ra_sender.go index 9709223..07f7b06 100644 --- a/ra_sender.go +++ b/ra_sender.go @@ -7,7 +7,6 @@ import ( "log/slog" "net/netip" "reflect" - "sync" "time" "github.com/mdlayher/ndp" @@ -20,14 +19,9 @@ type raSender struct { initialConfig *InterfaceConfig reloadCh chan *InterfaceConfig - stopCh any + stopCh chan any sock rAdvSocket - - childWg *sync.WaitGroup - childReloadCh []chan *InterfaceConfig - childStopCh []chan any - - socketCtor rAdvSocketCtor + socketCtor rAdvSocketCtor } func newRASender(initialConfig *InterfaceConfig, ctor rAdvSocketCtor, logger *slog.Logger) *raSender { @@ -36,19 +30,19 @@ func newRASender(initialConfig *InterfaceConfig, ctor rAdvSocketCtor, logger *sl initialConfig: initialConfig, reloadCh: make(chan *InterfaceConfig), stopCh: make(chan any), - childWg: &sync.WaitGroup{}, - childReloadCh: []chan *InterfaceConfig{}, - childStopCh: []chan any{}, socketCtor: ctor, } } func (s *raSender) run(ctx context.Context) { + // The current desired configuration + config := s.initialConfig + // Create the socket err := retry.Constant(ctx, time.Second, func(ctx context.Context) error { var err error - s.sock, err = s.socketCtor(s.initialConfig.Name) + s.sock, err = s.socketCtor(config.Name) if err != nil { // These are the unrecoverable errors we're aware of now. if errors.Is(err, unix.EPERM) || errors.Is(err, unix.EINVAL) { @@ -64,43 +58,6 @@ func (s *raSender) run(ctx context.Context) { return } - s.spawnChild(ctx, s.runUnsolicitedRASender) - s.childWg.Wait() - s.sock.close() -} - -func (s *raSender) reload(ctx context.Context, newConfig *InterfaceConfig) error { - for _, ch := range s.childReloadCh { - select { - case ch <- newConfig: - case <-ctx.Done(): - return ctx.Err() - } - } - return nil -} - -func (s *raSender) stop() { - for _, ch := range s.childStopCh { - close(ch) - } -} - -func (s *raSender) spawnChild(ctx context.Context, f func(context.Context, chan *InterfaceConfig, chan any)) { - s.childWg.Add(1) - reloadCh := make(chan *InterfaceConfig) - stopCh := make(chan any) - s.childReloadCh = append(s.childReloadCh, reloadCh) - s.childStopCh = append(s.childStopCh, stopCh) - go f(ctx, reloadCh, stopCh) -} - -func (s *raSender) runUnsolicitedRASender(ctx context.Context, reloadCh chan *InterfaceConfig, stopCh chan any) { - defer s.childWg.Done() - - // The current desired configuration - config := s.initialConfig - reload: for { msg := &ndp.RouterAdvertisement{ @@ -108,7 +65,9 @@ reload: RouterLifetime: 1800 * time.Second, } + // For unsolicited RA ticker := time.NewTicker(time.Duration(config.RAIntervalMilliseconds) * time.Millisecond) + for { select { case <-ticker.C: @@ -116,7 +75,7 @@ reload: if err != nil { continue } - case newConfig := <-reloadCh: + case newConfig := <-s.reloadCh: if reflect.DeepEqual(config, newConfig) { s.logger.Info("No configuration change. Skip reloading.") continue @@ -126,11 +85,26 @@ reload: continue reload case <-ctx.Done(): s.logger.Info("Context is done. Stopping.") - return - case <-stopCh: + break reload + case <-s.stopCh: s.logger.Info("Stop event received. Stopping.") - return + break reload } } } + + s.sock.close() +} + +func (s *raSender) reload(ctx context.Context, newConfig *InterfaceConfig) error { + select { + case s.reloadCh <- newConfig: + case <-ctx.Done(): + return ctx.Err() + } + return nil +} + +func (s *raSender) stop() { + close(s.stopCh) }