Skip to content

Commit

Permalink
🐛 Recreate watcher if the file unlinked and replaced (#2893)
Browse files Browse the repository at this point in the history
* Add certwatcher test for file rename

* Handle fsnotify.Chmod events as Removals
  • Loading branch information
m-messiah authored Aug 3, 2024
1 parent abb2d86 commit 5b943b9
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
10 changes: 7 additions & 3 deletions pkg/certwatcher/certwatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,14 @@ func (cw *CertWatcher) ReadCertificate() error {

func (cw *CertWatcher) handleEvent(event fsnotify.Event) {
// Only care about events which may modify the contents of the file.
if !(isWrite(event) || isRemove(event) || isCreate(event)) {
if !(isWrite(event) || isRemove(event) || isCreate(event) || isChmod(event)) {
return
}

log.V(1).Info("certificate event", "event", event)

// If the file was removed, re-add the watch.
if isRemove(event) {
// If the file was removed or renamed, re-add the watch to the previous name
if isRemove(event) || isChmod(event) {
if err := cw.watcher.Add(event.Name); err != nil {
log.Error(err, "error re-watching file")
}
Expand All @@ -202,3 +202,7 @@ func isCreate(event fsnotify.Event) bool {
func isRemove(event fsnotify.Event) bool {
return event.Op.Has(fsnotify.Remove)
}

func isChmod(event fsnotify.Event) bool {
return event.Op.Has(fsnotify.Chmod)
}
2 changes: 1 addition & 1 deletion pkg/certwatcher/certwatcher_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ var _ = BeforeSuite(func() {
})

var _ = AfterSuite(func() {
for _, file := range []string{certPath, keyPath} {
for _, file := range []string{certPath, keyPath, certPath + ".new", keyPath + ".new", certPath + ".old", keyPath + ".old"} {
_ = os.Remove(file)
}
})
39 changes: 35 additions & 4 deletions pkg/certwatcher/certwatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,36 @@ var _ = Describe("CertWatcher", func() {
Expect(called.Load()).To(BeNumerically(">=", 1))
})

It("should reload currentCert when changed with rename", func() {
doneCh := startWatcher()
called := atomic.Int64{}
watcher.RegisterCallback(func(crt tls.Certificate) {
called.Add(1)
Expect(crt.Certificate).ToNot(BeEmpty())
})

firstcert, _ := watcher.GetCertificate(nil)

err := writeCerts(certPath+".new", keyPath+".new", "192.168.0.2")
Expect(err).ToNot(HaveOccurred())

Expect(os.Link(certPath, certPath+".old")).To(Succeed())
Expect(os.Rename(certPath+".new", certPath)).To(Succeed())

Expect(os.Link(keyPath, keyPath+".old")).To(Succeed())
Expect(os.Rename(keyPath+".new", keyPath)).To(Succeed())

Eventually(func() bool {
secondcert, _ := watcher.GetCertificate(nil)
first := firstcert.PrivateKey.(*rsa.PrivateKey)
return first.Equal(secondcert.PrivateKey)
}).ShouldNot(BeTrue())

ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
Expect(called.Load()).To(BeNumerically(">=", 1))
})

Context("prometheus metric read_certificate_total", func() {
var readCertificateTotalBefore float64
var readCertificateErrorsBefore float64
Expand Down Expand Up @@ -159,17 +189,18 @@ var _ = Describe("CertWatcher", func() {

Expect(os.Remove(keyPath)).To(Succeed())

// Note, we are checking two errors here, because os.Remove generates two fsnotify events: Chmod + Remove
Eventually(func() error {
readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)
if readCertificateTotalAfter != readCertificateTotalBefore+1.0 {
return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
if readCertificateTotalAfter != readCertificateTotalBefore+2.0 {
return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+2.0, readCertificateTotalAfter)
}
return nil
}, "4s").Should(Succeed())
Eventually(func() error {
readCertificateErrorsAfter := testutil.ToFloat64(metrics.ReadCertificateErrors)
if readCertificateErrorsAfter != readCertificateErrorsBefore+1.0 {
return fmt.Errorf("metric read certificate errors expected: %v and got: %v", readCertificateErrorsBefore+1.0, readCertificateErrorsAfter)
if readCertificateErrorsAfter != readCertificateErrorsBefore+2.0 {
return fmt.Errorf("metric read certificate errors expected: %v and got: %v", readCertificateErrorsBefore+2.0, readCertificateErrorsAfter)
}
return nil
}, "4s").Should(Succeed())
Expand Down

0 comments on commit 5b943b9

Please sign in to comment.