diff --git a/cmd/query/app/mocks/Watcher.go b/cmd/query/app/mocks/Watcher.go deleted file mode 100644 index 50a3e361005..00000000000 --- a/cmd/query/app/mocks/Watcher.go +++ /dev/null @@ -1,87 +0,0 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. - -// Copyright (c) 2021 The Jaeger Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package mocks - -import ( - fsnotify "github.com/fsnotify/fsnotify" - mock "github.com/stretchr/testify/mock" -) - -// Watcher is an autogenerated mock type for the Watcher type -type Watcher struct { - mock.Mock -} - -// Add provides a mock function with given fields: name -func (_m *Watcher) Add(name string) error { - ret := _m.Called(name) - - var r0 error - if rf, ok := ret.Get(0).(func(string) error); ok { - r0 = rf(name) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Close provides a mock function with given fields: name -func (_m *Watcher) Close() error { - ret := _m.Called() - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Errors provides a mock function with given fields: -func (_m *Watcher) Errors() chan error { - ret := _m.Called() - - var r0 chan error - if rf, ok := ret.Get(0).(func() chan error); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(chan error) - } - } - - return r0 -} - -// Events provides a mock function with given fields: -func (_m *Watcher) Events() chan fsnotify.Event { - ret := _m.Called() - - var r0 chan fsnotify.Event - if rf, ok := ret.Get(0).(func() chan fsnotify.Event); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(chan fsnotify.Event) - } - } - - return r0 -} diff --git a/cmd/query/app/static_handler.go b/cmd/query/app/static_handler.go index f53a28a3524..9e01c8ff3f4 100644 --- a/cmd/query/app/static_handler.go +++ b/cmd/query/app/static_handler.go @@ -27,7 +27,6 @@ import ( "strings" "sync/atomic" - "github.com/fsnotify/fsnotify" "github.com/gorilla/mux" "go.uber.org/zap" @@ -61,10 +60,10 @@ func RegisterStaticHandler(r *mux.Router, logger *zap.Logger, qOpts *QueryOption // StaticAssetsHandler handles static assets type StaticAssetsHandler struct { - options StaticAssetsHandlerOptions - indexHTML atomic.Value // stores []byte - assetsFS http.FileSystem - newWatcher func() (fswatcher.Watcher, error) + options StaticAssetsHandlerOptions + indexHTML atomic.Value // stores []byte + assetsFS http.FileSystem + watcher *fswatcher.FSWatcher } // StaticAssetsHandlerOptions defines options for NewStaticAssetsHandler @@ -73,7 +72,6 @@ type StaticAssetsHandlerOptions struct { UIConfigPath string LogAccess bool Logger *zap.Logger - NewWatcher func() (fswatcher.Watcher, error) } type loadedConfig struct { @@ -92,23 +90,23 @@ func NewStaticAssetsHandler(staticAssetsRoot string, options StaticAssetsHandler options.Logger = zap.NewNop() } - if options.NewWatcher == nil { - options.NewWatcher = fswatcher.NewWatcher - } - indexHTML, err := loadAndEnrichIndexHTML(assetsFS.Open, options) if err != nil { return nil, err } h := &StaticAssetsHandler{ - options: options, - assetsFS: assetsFS, - newWatcher: options.NewWatcher, + options: options, + assetsFS: assetsFS, } + watcher, err := fswatcher.New([]string{options.UIConfigPath}, h.reloadUIConfig, h.options.Logger) + if err != nil { + return nil, err + } + h.watcher = watcher + h.indexHTML.Store(indexHTML) - h.watch() return h, nil } @@ -142,67 +140,14 @@ func loadAndEnrichIndexHTML(open func(string) (http.File, error), options Static return indexBytes, nil } -func (sH *StaticAssetsHandler) configListener(watcher fswatcher.Watcher) { - for { - select { - case event := <-watcher.Events(): - // ignore if the event filename is not the UI configuration - if filepath.Base(event.Name) != filepath.Base(sH.options.UIConfigPath) { - continue - } - // ignore if the event is a chmod event (permission or owner changes) - if event.Op&fsnotify.Chmod == fsnotify.Chmod { - continue - } - if event.Op&fsnotify.Remove == fsnotify.Remove { - sH.options.Logger.Warn("the UI config file has been removed, using the last known version") - continue - } - // this will catch events for all files inside the same directory, which is OK if we don't have many changes - sH.options.Logger.Info("reloading UI config", zap.String("filename", sH.options.UIConfigPath)) - content, err := loadAndEnrichIndexHTML(sH.assetsFS.Open, sH.options) - if err != nil { - sH.options.Logger.Error("error while reloading the UI config", zap.Error(err)) - } - sH.indexHTML.Store(content) - sH.options.Logger.Info("reloaded UI config", zap.String("filename", sH.options.UIConfigPath)) - case err, ok := <-watcher.Errors(): - if !ok { - return - } - sH.options.Logger.Error("event", zap.Error(err)) - } - } -} - -func (sH *StaticAssetsHandler) watch() { - if sH.options.UIConfigPath == "" { - sH.options.Logger.Info("UI config path not provided, config file will not be watched") - return - } - - watcher, err := sH.newWatcher() +func (sH *StaticAssetsHandler) reloadUIConfig() { + sH.options.Logger.Info("reloading UI config", zap.String("filename", sH.options.UIConfigPath)) + content, err := loadAndEnrichIndexHTML(sH.assetsFS.Open, sH.options) if err != nil { - sH.options.Logger.Error("failed to create a new watcher for the UI config", zap.Error(err)) - return - } - - go func() { - sH.configListener(watcher) - }() - - if err := watcher.Add(sH.options.UIConfigPath); err != nil { - sH.options.Logger.Error("error adding watcher to file", zap.String("file", sH.options.UIConfigPath), zap.Error(err)) - } else { - sH.options.Logger.Info("watching", zap.String("file", sH.options.UIConfigPath)) - } - - dir := filepath.Dir(sH.options.UIConfigPath) - if err := watcher.Add(dir); err != nil { - sH.options.Logger.Error("error adding watcher to dir", zap.String("dir", dir), zap.Error(err)) - } else { - sH.options.Logger.Info("watching", zap.String("dir", dir)) + sH.options.Logger.Error("error while reloading the UI config", zap.Error(err)) } + sH.indexHTML.Store(content) + sH.options.Logger.Info("reloaded UI config", zap.String("filename", sH.options.UIConfigPath)) } func loadIndexHTML(open func(string) (http.File, error)) ([]byte, error) { diff --git a/cmd/query/app/static_handler_test.go b/cmd/query/app/static_handler_test.go index c8d342e5bb8..b9e0a121b0b 100644 --- a/cmd/query/app/static_handler_test.go +++ b/cmd/query/app/static_handler_test.go @@ -26,17 +26,13 @@ import ( "testing" "time" - "github.com/fsnotify/fsnotify" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest/observer" - "github.com/jaegertracing/jaeger/cmd/query/app/mocks" - "github.com/jaegertracing/jaeger/pkg/fswatcher" "github.com/jaegertracing/jaeger/pkg/testutils" ) @@ -151,100 +147,6 @@ func TestNewStaticAssetsHandlerErrors(t *testing.T) { } } -func TestWatcherError(t *testing.T) { - const totalWatcherAddCalls = 2 - - for _, tc := range []struct { - name string - errorOnNthAdd int - newWatcherErr error - watcherAddErr error - wantWatcherAddCalls int - }{ - { - name: "NewWatcher error", - newWatcherErr: fmt.Errorf("new watcher error"), - }, - { - name: "Watcher.Add first call error", - errorOnNthAdd: 0, - watcherAddErr: fmt.Errorf("add first error"), - wantWatcherAddCalls: 2, - }, - { - name: "Watcher.Add second call error", - errorOnNthAdd: 1, - watcherAddErr: fmt.Errorf("add second error"), - wantWatcherAddCalls: 2, - }, - } { - t.Run(tc.name, func(t *testing.T) { - // Prepare - zcore, logObserver := observer.New(zapcore.InfoLevel) - logger := zap.New(zcore) - defer func() { - if r := recover(); r != nil { - // Select loop exits without logging error, only containing previous error log. - assert.Equal(t, logObserver.FilterMessage("event").Len(), 1) - assert.Equal(t, "send on closed channel", fmt.Sprint(r)) - } - }() - - watcher := &mocks.Watcher{} - for i := 0; i < totalWatcherAddCalls; i++ { - var err error - if i == tc.errorOnNthAdd { - err = tc.watcherAddErr - } - watcher.On("Add", mock.Anything).Return(err).Once() - } - watcher.On("Events").Return(make(chan fsnotify.Event)) - errChan := make(chan error) - watcher.On("Errors").Return(errChan) - - // Test - _, err := NewStaticAssetsHandler("fixture", StaticAssetsHandlerOptions{ - UIConfigPath: "fixture/ui-config-hotreload.json", - NewWatcher: func() (fswatcher.Watcher, error) { - return watcher, tc.newWatcherErr - }, - Logger: logger, - }) - - // Validate - - // Error logged but not returned - assert.NoError(t, err) - if tc.newWatcherErr != nil { - assert.Equal(t, logObserver.FilterField(zap.Error(tc.newWatcherErr)).Len(), 1) - } else { - assert.Zero(t, logObserver.FilterField(zap.Error(tc.newWatcherErr)).Len()) - } - - if tc.watcherAddErr != nil { - assert.Equal(t, logObserver.FilterField(zap.Error(tc.watcherAddErr)).Len(), 1) - } else { - assert.Zero(t, logObserver.FilterField(zap.Error(tc.watcherAddErr)).Len()) - } - - watcher.AssertNumberOfCalls(t, "Add", tc.wantWatcherAddCalls) - - // Validate Events and Errors channels - if tc.newWatcherErr == nil { - errChan <- fmt.Errorf("first error") - - waitUntil(t, func() bool { - return logObserver.FilterMessage("event").Len() > 0 - }, 100, 10*time.Millisecond, "timed out waiting for error") - assert.Equal(t, logObserver.FilterMessage("event").Len(), 1) - - close(errChan) - errChan <- fmt.Errorf("second error on closed chan") - } - }) - } -} - func TestHotReloadUIConfig(t *testing.T) { dir, err := os.MkdirTemp("", "ui-config-hotreload-*") require.NoError(t, err) diff --git a/pkg/config/tlscfg/cert_watcher.go b/pkg/config/tlscfg/cert_watcher.go index c03d87acb37..106249c3647 100644 --- a/pkg/config/tlscfg/cert_watcher.go +++ b/pkg/config/tlscfg/cert_watcher.go @@ -15,41 +15,41 @@ package tlscfg import ( - "crypto/sha256" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" - "os" - "path" "path/filepath" "sync" - "github.com/fsnotify/fsnotify" "go.uber.org/zap" "github.com/jaegertracing/jaeger/pkg/fswatcher" ) +const ( + logMsgPairReloaded = "Reloaded modified key pair" + logMsgCertReloaded = "Reloaded modified certificate" + logMsgPairNotReloaded = "Failed to reload key pair, using previous versions" + logMsgCertNotReloaded = "Failed to reload certificate, using previous version" +) + // certWatcher watches filesystem changes on certificates supplied via Options // The changed RootCAs and ClientCAs certificates are added to x509.CertPool without invalidating the previously used certificate. // The certificate and key can be obtained via certWatcher.certificate. // The consumers of this API should use GetCertificate or GetClientCertificate from tls.Config to supply the certificate to the config. type certWatcher struct { - mu sync.RWMutex - opts Options - logger *zap.Logger - watcher fswatcher.Watcher - cert *tls.Certificate - caHash string - clientCAHash string - certHash string - keyHash string + mu sync.RWMutex + opts Options + logger *zap.Logger + watchers []*fswatcher.FSWatcher + cert *tls.Certificate } var _ io.Closer = (*certWatcher)(nil) -func newCertWatcher(opts Options, logger *zap.Logger) (*certWatcher, error) { +func newCertWatcher(opts Options, logger *zap.Logger, rootCAs, clientCAs *x509.CertPool) (*certWatcher, error) { var cert *tls.Certificate if opts.CertPath != "" && opts.KeyPath != "" { // load certs at startup to catch missing certs error early @@ -60,27 +60,31 @@ func newCertWatcher(opts Options, logger *zap.Logger) (*certWatcher, error) { cert = &c } - watcher, err := fswatcher.NewWatcher() - if err != nil { - return nil, err - } - w := &certWatcher{ - opts: opts, - logger: logger, - cert: cert, - watcher: watcher, + opts: opts, + logger: logger, + cert: cert, } - if err := w.setupWatchedPaths(); err != nil { - watcher.Close() + if err := w.watchCertPair(); err != nil { return nil, err } + if err := w.watchCert(w.opts.CAPath, rootCAs); err != nil { + return nil, err + } + if err := w.watchCert(w.opts.ClientCAPath, clientCAs); err != nil { + return nil, err + } + return w, nil } func (w *certWatcher) Close() error { - return w.watcher.Close() + var errs []error + for _, w := range w.watchers { + errs = append(errs, w.Close()) + } + return errors.Join(errs...) } func (w *certWatcher) certificate() *tls.Certificate { @@ -89,152 +93,59 @@ func (w *certWatcher) certificate() *tls.Certificate { return w.cert } -// setupWatchedPaths retrieves hashes of all configured certificates -// and adds their parent directories to the watcher. -func (w *certWatcher) setupWatchedPaths() error { - uniqueDirs := make(map[string]bool) - addPath := func(certPath string, hashPtr *string) error { - if certPath == "" { - return nil - } - if h, err := hashFile(certPath); err == nil { - *hashPtr = h - } else { - return err - } - dir := path.Dir(certPath) - if _, ok := uniqueDirs[dir]; !ok { - w.watcher.Add(dir) - uniqueDirs[dir] = true - } +func (w *certWatcher) watchCertPair() error { + watcher, err := fswatcher.New( + []string{w.opts.CertPath, w.opts.KeyPath}, + w.onCertPairChange, + w.logger, + ) + if err == nil { + w.watchers = append(w.watchers, watcher) return nil } - - if err := addPath(w.opts.CAPath, &w.caHash); err != nil { - return err - } - if err := addPath(w.opts.ClientCAPath, &w.clientCAHash); err != nil { - return err - } - if err := addPath(w.opts.CertPath, &w.certHash); err != nil { - return err - } - if err := addPath(w.opts.KeyPath, &w.keyHash); err != nil { - return err - } - return nil + w.Close() + return fmt.Errorf("failed to watch key pair %s and %s: %w", w.opts.KeyPath, w.opts.CertPath, err) } -// watchChangesLoop waits for notifications of changes in the watched directories -// and attempts to reload all certificates that changed. -// -// Write and Rename events indicate that some files might have changed and reload might be necessary. -// Remove event indicates that the file was deleted and we should write an error to log. -// -// Reasoning: -// -// Write event is sent if the file content is rewritten. -// -// Usually files are not rewritten, but they are updated by swapping them with new -// ones by calling Rename. That avoids files being read while they are not yet -// completely written but it also means that inotify on file level will not work: -// watch is invalidated when the old file is deleted. -// -// If reading from Kubernetes Secret volumes the target files are symbolic links -// to files in a different directory. That directory is swapped with a new one, -// while the symbolic links remain the same. This guarantees atomic swap for all -// files at once, but it also means any Rename event in the directory might -// indicate that the files were replaced, even if event.Name is not any of the -// files we are monitoring. We check the hashes of the files to detect if they -// were really changed. -func (w *certWatcher) watchChangesLoop(rootCAs, clientCAs *x509.CertPool) { - for { - select { - case event, ok := <-w.watcher.Events(): - if !ok { - return // channel closed means the watcher is closed - } - w.logger.Debug("Received event", zap.String("event", event.String())) - if event.Op&fsnotify.Write == fsnotify.Write || - event.Op&fsnotify.Rename == fsnotify.Rename || - event.Op&fsnotify.Remove == fsnotify.Remove { - w.attemptReload(rootCAs, clientCAs) - } - case err, ok := <-w.watcher.Errors(): - if !ok { - return // channel closed means the watcher is closed - } - w.logger.Error("Watcher got error", zap.Error(err)) - } - } -} - -// attemptReload checks if the watched files have been modified and reloads them if necessary. -func (w *certWatcher) attemptReload(rootCAs, clientCAs *x509.CertPool) { - w.reloadIfModified(w.opts.CAPath, &w.caHash, rootCAs) - w.reloadIfModified(w.opts.ClientCAPath, &w.clientCAHash, clientCAs) - - isCertModified, newCertHash := w.isModified(w.opts.CertPath, w.certHash) - isKeyModified, newKeyHash := w.isModified(w.opts.KeyPath, w.keyHash) - if isCertModified || isKeyModified { - c, err := tls.LoadX509KeyPair(filepath.Clean(w.opts.CertPath), filepath.Clean(w.opts.KeyPath)) - if err == nil { - w.mu.Lock() - w.cert = &c - w.certHash = newCertHash - w.keyHash = newKeyHash - w.mu.Unlock() - w.logger.Info("Loaded modified certificate", zap.String("certificate", w.opts.CertPath)) - w.logger.Info("Loaded modified certificate", zap.String("certificate", w.opts.KeyPath)) - } else { - w.logger.Error( - "Failed to load certificate pair", - zap.String("certificate", w.opts.CertPath), - zap.String("key", w.opts.KeyPath), - zap.Error(err), - ) - } - } -} +func (w *certWatcher) watchCert(certPath string, certPool *x509.CertPool) error { + onCertChange := func() { w.onCertChange(certPath, certPool) } -func (w *certWatcher) reloadIfModified(certPath string, certHash *string, certPool *x509.CertPool) { - if mod, newHash := w.isModified(certPath, *certHash); mod { - if err := addCertToPool(certPath, certPool); err == nil { - w.mu.Lock() - *certHash = newHash - w.mu.Unlock() - w.logger.Info("Loaded modified certificate", zap.String("certificate", certPath)) - } else { - w.logger.Error("Failed to load certificate", zap.String("certificate", certPath), zap.Error(err)) - } + watcher, err := fswatcher.New([]string{certPath}, onCertChange, w.logger) + if err == nil { + w.watchers = append(w.watchers, watcher) + return nil } + w.Close() + return fmt.Errorf("failed to watch cert %s: %w", certPath, err) } -// isModified returns true if the file has been modified since the last check. -func (w *certWatcher) isModified(file string, previousHash string) (bool, string) { - if file == "" { - return false, "" +func (w *certWatcher) onCertPairChange() { + cert, err := tls.LoadX509KeyPair(filepath.Clean(w.opts.CertPath), filepath.Clean(w.opts.KeyPath)) + if err == nil { + w.mu.Lock() + w.cert = &cert + w.mu.Unlock() + w.logger.Info( + logMsgPairReloaded, + zap.String("key", w.opts.KeyPath), + zap.String("cert", w.opts.CertPath), + ) + } else { + w.logger.Error( + logMsgPairNotReloaded, + zap.String("key", w.opts.KeyPath), + zap.String("cert", w.opts.CertPath), + zap.Error(err), + ) } - hash, err := hashFile(file) - if err != nil { - w.logger.Warn("Certificate has been removed, using the last known version", zap.String("certificate", file)) - return false, "" - } - return previousHash != hash, hash } -// hashFile returns the SHA256 hash of the file. -func hashFile(file string) (string, error) { - f, err := os.Open(filepath.Clean(file)) - if err != nil { - return "", err - } - defer f.Close() - - h := sha256.New() - if _, err := io.Copy(h, f); err != nil { - return "", err +func (w *certWatcher) onCertChange(certPath string, certPool *x509.CertPool) { + w.mu.Lock() // prevent concurrent updates to the same certPool + if err := addCertToPool(certPath, certPool); err == nil { + w.logger.Info(logMsgCertReloaded, zap.String("cert", certPath)) + } else { + w.logger.Error(logMsgCertNotReloaded, zap.String("cert", certPath), zap.Error(err)) } - - return fmt.Sprintf("%x", h.Sum(nil)), nil + w.mu.Unlock() } diff --git a/pkg/config/tlscfg/cert_watcher_test.go b/pkg/config/tlscfg/cert_watcher_test.go index cd36b4d46bd..9c69136066e 100644 --- a/pkg/config/tlscfg/cert_watcher_test.go +++ b/pkg/config/tlscfg/cert_watcher_test.go @@ -17,7 +17,6 @@ package tlscfg import ( "crypto/tls" "crypto/x509" - "fmt" "os" "path/filepath" "testing" @@ -28,8 +27,6 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest/observer" - - "github.com/jaegertracing/jaeger/pkg/fswatcher" ) const ( @@ -44,7 +41,7 @@ const ( ) func copyToTempFile(t *testing.T, pattern string, filename string) (file *os.File, closeFn func()) { - tempFile, err := os.CreateTemp("", pattern) + tempFile, err := os.CreateTemp("", pattern+"_") require.NoError(t, err) data, err := os.ReadFile(filename) @@ -67,7 +64,7 @@ func copyFile(t *testing.T, dest string, src string) { require.NoError(t, err) } -func TestReload(t *testing.T) { +func TestReloadKeyPair(t *testing.T) { // copy certs to temp so we can modify them certFile, certFileCloseFn := copyToTempFile(t, "cert.crt", serverCert) defer certFileCloseFn() @@ -83,40 +80,29 @@ func TestReload(t *testing.T) { CertPath: certFile.Name(), KeyPath: keyFile.Name(), } - watcher, err := newCertWatcher(opts, logger) + certPool := x509.NewCertPool() + watcher, err := newCertWatcher(opts, logger, certPool, certPool) require.NoError(t, err) assert.NotNil(t, watcher.certificate()) defer watcher.Close() - certPool := x509.NewCertPool() require.NoError(t, err) - go watcher.watchChangesLoop(certPool, certPool) cert, err := tls.LoadX509KeyPair(serverCert, serverKey) require.NoError(t, err) assert.Equal(t, &cert, watcher.certificate()) - // Write the client's public key. + // Replace certificate part of the pair with client's cert, which should fail to load. copyFile(t, certFile.Name(), clientCert) - assertLogs(t, - func() bool { - // Logged when the cert is reloaded with mismatching client public key and existing server private key. - return logObserver.FilterMessage("Failed to load certificate pair"). - FilterField(zap.String("certificate", certFile.Name())).Len() > 0 - }, - "Unable to locate 'Failed to load certificate pair' in log. All logs: %v", logObserver) + assertLogs(t, logObserver, logMsgPairNotReloaded, [][2]string{{"key", keyFile.Name()}, {"cert", certFile.Name()}}) + assert.Equal(t, &cert, watcher.certificate(), "key pair unchanged") + logObserver.TakeAll() // clean up logs - // Write the client's private key. + // Replace key part with client's private key. Valid pair, should reload. copyFile(t, keyFile.Name(), clientKey) - assertLogs(t, - func() bool { - // Logged when the client private key is modified in the cert which enables successful reloading of - // the cert as both private and public keys now match. - return logObserver.FilterMessage("Loaded modified certificate"). - FilterField(zap.String("certificate", keyFile.Name())).Len() > 0 - }, - "Unable to locate 'Loaded modified certificate' in log. All logs: %v", logObserver) + assertLogs(t, logObserver, logMsgPairReloaded, [][2]string{{"key", keyFile.Name()}, {"cert", certFile.Name()}}) + logObserver.TakeAll() // clean up logs cert, err = tls.LoadX509KeyPair(filepath.Clean(clientCert), clientKey) require.NoError(t, err) @@ -125,9 +111,9 @@ func TestReload(t *testing.T) { func TestReload_ca_certs(t *testing.T) { // copy certs to temp so we can modify them - caFile, caFileCloseFn := copyToTempFile(t, "cert.crt", caCert) + caFile, caFileCloseFn := copyToTempFile(t, "ca.crt", caCert) defer caFileCloseFn() - clientCaFile, clientCaFileClostFn := copyToTempFile(t, "key.crt", caCert) + clientCaFile, clientCaFileClostFn := copyToTempFile(t, "client-ca.crt", caCert) defer clientCaFileClostFn() zcore, logObserver := observer.New(zapcore.InfoLevel) @@ -136,36 +122,34 @@ func TestReload_ca_certs(t *testing.T) { CAPath: caFile.Name(), ClientCAPath: clientCaFile.Name(), } - watcher, err := newCertWatcher(opts, logger) + certPool := x509.NewCertPool() + watcher, err := newCertWatcher(opts, logger, certPool, certPool) require.NoError(t, err) defer watcher.Close() - certPool := x509.NewCertPool() require.NoError(t, err) - go watcher.watchChangesLoop(certPool, certPool) // update the content with different certs to trigger reload. copyFile(t, caFile.Name(), wrongCaCert) copyFile(t, clientCaFile.Name(), wrongCaCert) - assertLogs(t, - func() bool { - return logObserver.FilterField(zap.String("certificate", caFile.Name())).Len() > 0 - }, - "Unable to locate 'certificate' in log. All logs: %v", logObserver) + assertLogs(t, logObserver, logMsgCertReloaded, [][2]string{{"cert", caFile.Name()}}) + assertLogs(t, logObserver, logMsgCertReloaded, [][2]string{{"cert", clientCaFile.Name()}}) + logObserver.TakeAll() // clean up logs - assertLogs(t, - func() bool { - return logObserver.FilterField(zap.String("certificate", clientCaFile.Name())).Len() > 0 - }, - "Unable to locate 'certificate' in log. All logs: %v", logObserver) + // update the content with invalid certs to trigger failed reload. + copyFile(t, caFile.Name(), badCaCert) + copyFile(t, clientCaFile.Name(), badCaCert) + + assertLogs(t, logObserver, logMsgCertNotReloaded, [][2]string{{"cert", caFile.Name()}}) + assertLogs(t, logObserver, logMsgCertNotReloaded, [][2]string{{"cert", clientCaFile.Name()}}) } func TestReload_err_cert_update(t *testing.T) { // copy certs to temp so we can modify them certFile, certFileCloseFn := copyToTempFile(t, "cert.crt", serverCert) defer certFileCloseFn() - keyFile, keyFileCloseFn := copyToTempFile(t, "cert.crt", serverKey) + keyFile, keyFileCloseFn := copyToTempFile(t, "key.crt", serverKey) defer keyFileCloseFn() zcore, logObserver := observer.New(zapcore.InfoLevel) @@ -176,14 +160,13 @@ func TestReload_err_cert_update(t *testing.T) { CertPath: certFile.Name(), KeyPath: keyFile.Name(), } - watcher, err := newCertWatcher(opts, logger) + certPool := x509.NewCertPool() + watcher, err := newCertWatcher(opts, logger, certPool, certPool) require.NoError(t, err) assert.NotNil(t, watcher.certificate()) defer watcher.Close() - certPool := x509.NewCertPool() require.NoError(t, err) - go watcher.watchChangesLoop(certPool, certPool) serverCert, err := tls.LoadX509KeyPair(filepath.Clean(serverCert), filepath.Clean(serverKey)) require.NoError(t, err) assert.Equal(t, &serverCert, watcher.certificate()) @@ -192,26 +175,12 @@ func TestReload_err_cert_update(t *testing.T) { copyFile(t, certFile.Name(), badCaCert) copyFile(t, keyFile.Name(), clientKey) - assertLogs(t, - func() bool { - return logObserver.FilterMessage("Failed to load certificate pair"). - FilterField(zap.String("certificate", certFile.Name())).Len() > 0 - }, "Unable to locate 'Failed to load certificate pair' in log. All logs: %v", logObserver) - assert.Equal(t, &serverCert, watcher.certificate()) -} - -func TestReload_err_watch(t *testing.T) { - opts := Options{ - CAPath: "doesnotexists", - } - watcher, err := newCertWatcher(opts, zap.NewNop()) - require.Error(t, err) - assert.Contains(t, err.Error(), "no such file or directory") - assert.Nil(t, watcher) + assertLogs(t, logObserver, logMsgPairNotReloaded, [][2]string{{"key", opts.KeyPath}, {"cert", opts.CertPath}}) + assert.Equal(t, &serverCert, watcher.certificate(), "values unchanged") } func TestReload_kubernetes_secret_update(t *testing.T) { - mountDir, err := os.MkdirTemp("", "secret-mountpoint") + mountDir, err := os.MkdirTemp("", "secret-mountpoint_") require.NoError(t, err) defer os.RemoveAll(mountDir) @@ -247,20 +216,27 @@ func TestReload_kubernetes_secret_update(t *testing.T) { zcore, logObserver := observer.New(zapcore.InfoLevel) logger := zap.New(zcore) - watcher, err := newCertWatcher(opts, logger) + + certPool := x509.NewCertPool() + + watcher, err := newCertWatcher(opts, logger, certPool, certPool) require.NoError(t, err) defer watcher.Close() - certPool := x509.NewCertPool() require.NoError(t, err) - go watcher.watchChangesLoop(certPool, certPool) expectedCert, err := tls.LoadX509KeyPair(serverCert, serverKey) require.NoError(t, err) - assert.Equal(t, expectedCert.Certificate, watcher.certificate().Certificate, - "certificate should be updated: %v", logObserver.All()) + assert.Equal(t, + expectedCert.Certificate, + watcher.certificate().Certificate, + "certificate should be updated: %v", logObserver.All(), + ) + + logObserver.TakeAll() // clean logs + // Create second dir with different key pair and CA. // After the update, the directory looks like following: // // /secret-mountpoint/ca.crt # symbolic link to ..data/ca.crt @@ -271,10 +247,9 @@ func TestReload_kubernetes_secret_update(t *testing.T) { // /secret-mountpoint/..timestamp-2/ca.crt # new version of ca.crt // /secret-mountpoint/..timestamp-2/tls.crt # new version of tls.crt // /secret-mountpoint/..timestamp-2/tls.key # new version of tls.key - logObserver.TakeAll() timestamp2Dir := filepath.Join(mountDir, "..timestamp-2") - createTimestampDir(t, timestamp2Dir, caCert, clientCert, clientKey) + createTimestampDir(t, timestamp2Dir, serverCert, clientCert, clientKey) err = os.Symlink("..timestamp-2", filepath.Join(mountDir, "..data_tmp")) require.NoError(t, err) @@ -284,21 +259,17 @@ func TestReload_kubernetes_secret_update(t *testing.T) { err = os.RemoveAll(timestamp1Dir) require.NoError(t, err) - assertLogs(t, - func() bool { - return logObserver.FilterMessage("Loaded modified certificate"). - FilterField(zap.String("certificate", opts.CertPath)).Len() > 0 - }, - "Unable to locate 'Loaded modified certificate' in log. All logs: %v", logObserver) + assertLogs(t, logObserver, logMsgPairReloaded, [][2]string{{"key", opts.KeyPath}, {"cert", opts.CertPath}}) + assertLogs(t, logObserver, logMsgCertReloaded, [][2]string{{"cert", opts.CAPath}}) expectedCert, err = tls.LoadX509KeyPair(clientCert, clientKey) require.NoError(t, err) assert.Equal(t, expectedCert.Certificate, watcher.certificate().Certificate, "certificate should be updated: %v", logObserver.All()) - // Make third update to make sure that the watcher is still working. - logObserver.TakeAll() + logObserver.TakeAll() // clean logs + // Make third update to make sure that the watcher is still working. timestamp3Dir := filepath.Join(mountDir, "..timestamp-3") createTimestampDir(t, timestamp3Dir, caCert, serverCert, serverKey) err = os.Symlink("..timestamp-3", filepath.Join(mountDir, "..data_tmp")) @@ -308,17 +279,16 @@ func TestReload_kubernetes_secret_update(t *testing.T) { err = os.RemoveAll(timestamp2Dir) require.NoError(t, err) - assertLogs(t, - func() bool { - return logObserver.FilterMessage("Loaded modified certificate"). - FilterField(zap.String("certificate", opts.CertPath)).Len() > 0 - }, - "Unable to locate 'Loaded modified certificate' in log. All logs: %v", logObserver) + assertLogs(t, logObserver, logMsgPairReloaded, [][2]string{{"key", opts.KeyPath}, {"cert", opts.CertPath}}) + assertLogs(t, logObserver, logMsgCertReloaded, [][2]string{{"cert", opts.CAPath}}) expectedCert, err = tls.LoadX509KeyPair(serverCert, serverKey) require.NoError(t, err) - assert.Equal(t, expectedCert.Certificate, watcher.certificate().Certificate, - "certificate should be updated: %v", logObserver.All()) + assert.Equal(t, + expectedCert.Certificate, + watcher.certificate().Certificate, + "certificate should be updated", + ) } func createTimestampDir(t *testing.T, dir string, ca, cert, key string) { @@ -341,13 +311,6 @@ func createTimestampDir(t *testing.T, dir string, ca, cert, key string) { } func TestAddCertsToWatch_err(t *testing.T) { - watcher, err := fswatcher.NewWatcher() - require.NoError(t, err) - defer watcher.Close() - w := &certWatcher{ - watcher: watcher, - } - tests := []struct { opts Options }{ @@ -379,60 +342,38 @@ func TestAddCertsToWatch_err(t *testing.T) { }, } for _, test := range tests { - w.opts = test.opts - err := w.setupWatchedPaths() + watcher, err := newCertWatcher(test.opts, nil, nil, nil) require.Error(t, err) assert.Contains(t, err.Error(), "no such file or directory") + assert.Nil(t, watcher) } } -func TestAddCertsToWatch_remove_ca(t *testing.T) { - caFile, caFileCloseFn := copyToTempFile(t, "cert.crt", caCert) - defer caFileCloseFn() - clientCaFile, clientCaFileClostFn := copyToTempFile(t, "key.crt", caCert) - defer clientCaFileClostFn() - - zcore, logObserver := observer.New(zapcore.InfoLevel) - logger := zap.New(zcore) - opts := Options{ - CAPath: caFile.Name(), - ClientCAPath: clientCaFile.Name(), +func assertLogs(t *testing.T, + logs *observer.ObservedLogs, + logMsg string, + fields [][2]string, +) { + errMsg := "Expecting log '" + logMsg + "'" + for _, field := range fields { + errMsg = errMsg + " " + field[0] + "=" + field[1] + } + fn := func() bool { + l := logs + if logMsg != "" { + l = l.FilterMessageSnippet(logMsg) + } + for _, field := range fields { + l = l.FilterField(zap.String(field[0], field[1])) + } + return l.Len() > 0 + } + ok := assert.Eventuallyf(t, fn, 5*time.Second, 10*time.Millisecond, errMsg) + if !ok { + for _, log := range logs.All() { + t.Log(log) + } } - watcher, err := newCertWatcher(opts, logger) - require.NoError(t, err) - defer watcher.Close() - - certPool := x509.NewCertPool() - require.NoError(t, err) - go watcher.watchChangesLoop(certPool, certPool) - - require.NoError(t, os.Remove(caFile.Name())) - require.NoError(t, os.Remove(clientCaFile.Name())) - assertLogs(t, - func() bool { - return logObserver.FilterMessage("Certificate has been removed, using the last known version").Len() >= 2 - }, - "Unable to locate 'Certificate has been removed' in log. All logs: %v", logObserver) - assert.True(t, logObserver.FilterMessage("Certificate has been removed, using the last known version").FilterField(zap.String("certificate", caFile.Name())).Len() > 0) - assert.True(t, logObserver.FilterMessage("Certificate has been removed, using the last known version").FilterField(zap.String("certificate", clientCaFile.Name())).Len() > 0) -} - -type delayedFormat struct { - fn func() interface{} -} - -func (df delayedFormat) String() string { - return fmt.Sprintf("%v", df.fn()) -} - -func assertLogs(t *testing.T, f func() bool, errorMsg string, logObserver *observer.ObservedLogs) { - assert.Eventuallyf(t, f, - 10*time.Second, 10*time.Millisecond, - errorMsg, - delayedFormat{ - fn: func() interface{} { return logObserver.All() }, - }, - ) } // syncWrite ensures data is written to the given filename and flushed to disk. @@ -448,42 +389,3 @@ func syncWrite(filename string, data []byte, perm os.FileMode) error { } return f.Sync() } - -func TestReload_err_ca_cert_update(t *testing.T) { - // copy certs to temp so we can modify them - caFile, caFileCloseFn := copyToTempFile(t, "cert.crt", caCert) - defer caFileCloseFn() - clientCaFile, clientCaFileClostFn := copyToTempFile(t, "key.crt", caCert) - defer clientCaFileClostFn() - - zcore, logObserver := observer.New(zapcore.InfoLevel) - logger := zap.New(zcore) - opts := Options{ - CAPath: caFile.Name(), - ClientCAPath: clientCaFile.Name(), - } - watcher, err := newCertWatcher(opts, logger) - require.NoError(t, err) - defer watcher.Close() - - certPool := x509.NewCertPool() - require.NoError(t, err) - go watcher.watchChangesLoop(certPool, certPool) - - // update the content with bad certs. - copyFile(t, caFile.Name(), badCaCert) - assertLogs(t, - func() bool { - return logObserver.FilterMessage("Failed to load certificate"). - FilterField(zap.String("certificate", caFile.Name())).Len() > 0 - }, - "Unable to locate 'certificate' in log. All logs: %v", logObserver) - - copyFile(t, clientCaFile.Name(), badCaCert) - assertLogs(t, - func() bool { - return logObserver.FilterMessage("Failed to load certificate"). - FilterField(zap.String("certificate", clientCaFile.Name())).Len() > 0 - }, - "Unable to locate 'Failed to load certificate' in log. All logs: %v", logObserver) -} diff --git a/pkg/config/tlscfg/options.go b/pkg/config/tlscfg/options.go index d70964c834f..fc0422be942 100644 --- a/pkg/config/tlscfg/options.go +++ b/pkg/config/tlscfg/options.go @@ -87,6 +87,7 @@ func (p *Options) Config(logger *zap.Logger) (*tls.Config, error) { } if p.ClientCAPath != "" { + // TODO this should be moved to certWatcher, since it already loads key pair certPool := x509.NewCertPool() if err := addCertToPool(p.ClientCAPath, certPool); err != nil { return nil, err @@ -95,11 +96,11 @@ func (p *Options) Config(logger *zap.Logger) (*tls.Config, error) { tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert } - w, err := newCertWatcher(*p, logger) + certWatcher, err := newCertWatcher(*p, logger, tlsCfg.RootCAs, tlsCfg.ClientCAs) if err != nil { return nil, err } - p.certWatcher = w + p.certWatcher = certWatcher if (p.CertPath == "" && p.KeyPath != "") || (p.CertPath != "" && p.KeyPath == "") { return nil, fmt.Errorf("for client auth via TLS, either both client certificate and key must be supplied, or neither") @@ -114,7 +115,6 @@ func (p *Options) Config(logger *zap.Logger) (*tls.Config, error) { } } - go p.certWatcher.watchChangesLoop(tlsCfg.RootCAs, tlsCfg.ClientCAs) return tlsCfg, nil } @@ -148,7 +148,7 @@ func addCertToPool(caPath string, certPool *x509.CertPool) error { var _ io.Closer = (*Options)(nil) -// Close closes Options. +// Close shuts down the embedded certificate watcher. func (p *Options) Close() error { if p.certWatcher != nil { return p.certWatcher.Close() diff --git a/pkg/fswatcher/fs_watcher.go b/pkg/fswatcher/fs_watcher.go deleted file mode 100644 index ea61b967920..00000000000 --- a/pkg/fswatcher/fs_watcher.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) 2021 The Jaeger Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fswatcher - -import "github.com/fsnotify/fsnotify" - -// Watcher watches for Events and Errors once a resource is added to the watch list. -// Primarily used for mocking the fsnotify lib. -type Watcher interface { - Add(name string) error - Close() error - Events() chan fsnotify.Event - Errors() chan error -} - -// fsnotifyWatcherWrapper wraps the fsnotify.Watcher and implements Watcher. -type fsnotifyWatcherWrapper struct { - fsnotifyWatcher *fsnotify.Watcher -} - -// Add adds the filename to watch. -func (f *fsnotifyWatcherWrapper) Add(name string) error { - return f.fsnotifyWatcher.Add(name) -} - -// Close closes the watcher. -func (f *fsnotifyWatcherWrapper) Close() error { - return f.fsnotifyWatcher.Close() -} - -// Events returns the fsnotify.Watcher's Events chan. -func (f *fsnotifyWatcherWrapper) Events() chan fsnotify.Event { - return f.fsnotifyWatcher.Events -} - -// Errors returns the fsnotify.Watcher's Errors chan. -func (f *fsnotifyWatcherWrapper) Errors() chan error { - return f.fsnotifyWatcher.Errors -} - -// NewWatcher creates a new fsnotifyWatcherWrapper, wrapping the fsnotify.Watcher. -func NewWatcher() (Watcher, error) { - w, err := fsnotify.NewWatcher() - return &fsnotifyWatcherWrapper{fsnotifyWatcher: w}, err -} diff --git a/pkg/fswatcher/fs_watcher_test.go b/pkg/fswatcher/fs_watcher_test.go deleted file mode 100644 index b2ce42bc004..00000000000 --- a/pkg/fswatcher/fs_watcher_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2021 The Jaeger Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fswatcher - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestFsWatcher(t *testing.T) { - w, err := NewWatcher() - require.NoError(t, err) - assert.IsType(t, &fsnotifyWatcherWrapper{}, w) - - err = w.Add("foo") - assert.Error(t, err) - - err = w.Add("../../cmd/query/app/fixture/ui-config.json") - assert.NoError(t, err) - - err = w.Close() - assert.NoError(t, err) - - events := w.Events() - assert.NotZero(t, events) - - errs := w.Errors() - assert.NotZero(t, errs) -} diff --git a/pkg/fswatcher/fswatcher.go b/pkg/fswatcher/fswatcher.go new file mode 100644 index 00000000000..d3c929d1d9b --- /dev/null +++ b/pkg/fswatcher/fswatcher.go @@ -0,0 +1,169 @@ +// Copyright (c) 2021 The Jaeger Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fswatcher + +import ( + "crypto/sha256" + "fmt" + "io" + "os" + "path" + "path/filepath" + "sync" + + "github.com/fsnotify/fsnotify" + "go.uber.org/zap" +) + +type FSWatcher struct { + watcher *fsnotify.Watcher + logger *zap.Logger + fileHashContentMap map[string]string + onChange func() + mu sync.RWMutex +} + +// FSWatcher waits for notifications of changes in the watched directories +// and attempts to reload all files that changed. +// +// Write and Rename events indicate that some files might have changed and reload might be necessary. +// Remove event indicates that the file was deleted and we should write a warn to log. +// +// Reasoning: +// +// Write event is sent if the file content is rewritten. +// +// Usually files are not rewritten, but they are updated by swapping them with new +// ones by calling Rename. That avoids files being read while they are not yet +// completely written but it also means that inotify on file level will not work: +// watch is invalidated when the old file is deleted. +// +// If reading from Kubernetes Secret volumes the target files are symbolic links +// to files in a different directory. That directory is swapped with a new one, +// while the symbolic links remain the same. This guarantees atomic swap for all +// files at once, but it also means any Rename event in the directory might +// indicate that the files were replaced, even if event.Name is not any of the +// files we are monitoring. We check the hashes of the files to detect if they +// were really changed. +func New(filepaths []string, onChange func(), logger *zap.Logger) (*FSWatcher, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + w := &FSWatcher{ + watcher: watcher, + logger: logger, + fileHashContentMap: make(map[string]string), + onChange: onChange, + } + + if err = w.setupWatchedPaths(filepaths); err != nil { + w.Close() + return nil, err + } + + go w.watch() + + return w, nil +} + +func (w *FSWatcher) setupWatchedPaths(filepaths []string) error { + uniqueDirs := make(map[string]bool) + + for _, p := range filepaths { + if p == "" { + continue + } + if h, err := hashFile(p); err == nil { + w.fileHashContentMap[p] = h + } else { + return err + } + dir := path.Dir(p) + if _, ok := uniqueDirs[dir]; !ok { + if err := w.watcher.Add(dir); err != nil { + return err + } + uniqueDirs[dir] = true + } + } + + return nil +} + +// Watch watches for Events and Errors of files. +// Each time an Event happen, all the files are checked for content change. +// If a file's content changes, its hashed content is updated and +// onChange is invoked after all file checks. +func (w *FSWatcher) watch() { + for { + select { + case event, ok := <-w.watcher.Events: + if !ok { + return + } + w.logger.Info("Received event", zap.String("event", event.String())) + var changed bool + w.mu.Lock() + for file, hash := range w.fileHashContentMap { + fileChanged, newHash := w.isModified(file, hash) + if fileChanged { + changed = fileChanged + w.fileHashContentMap[file] = newHash + } + } + w.mu.Unlock() + if changed { + w.onChange() + } + case err, ok := <-w.watcher.Errors: + if !ok { + return + } + w.logger.Error("fsnotifier reported an error", zap.Error(err)) + } + } +} + +// Close closes the watcher. +func (w *FSWatcher) Close() error { + return w.watcher.Close() +} + +// isModified returns true if the file has been modified since the last check. +func (w *FSWatcher) isModified(filepath string, previousHash string) (bool, string) { + hash, err := hashFile(filepath) + if err != nil { + w.logger.Warn("Unable to read the file", zap.String("file", filepath), zap.Error(err)) + return true, "" + } + return previousHash != hash, hash +} + +// hashFile returns the SHA256 hash of the file. +func hashFile(file string) (string, error) { + f, err := os.Open(filepath.Clean(file)) + if err != nil { + return "", err + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + + return fmt.Sprintf("%x", h.Sum(nil)), nil +} diff --git a/pkg/fswatcher/fswatcher_test.go b/pkg/fswatcher/fswatcher_test.go new file mode 100644 index 00000000000..cc2488bfcf6 --- /dev/null +++ b/pkg/fswatcher/fswatcher_test.go @@ -0,0 +1,250 @@ +// Copyright (c) 2021 The Jaeger Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fswatcher + +import ( + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" +) + +func createTestFiles(t *testing.T) (file1 string, file2 string, file3 string, close func()) { + testDir1, err := os.MkdirTemp("", "test1_") + require.NoError(t, err) + + file1 = filepath.Join(testDir1, "test1.doc") + err = os.WriteFile(file1, []byte("test data"), 0o600) + require.NoError(t, err) + + file2 = filepath.Join(testDir1, "test2.doc") + err = os.WriteFile(file2, []byte("test data"), 0o600) + require.NoError(t, err) + + testDir2, err := os.MkdirTemp("", "test2_") + require.NoError(t, err) + + file3 = filepath.Join(testDir2, "test3.doc") + err = os.WriteFile(file3, []byte("test data"), 0o600) + require.NoError(t, err) + + close = func() { + assert.NoError(t, os.RemoveAll(testDir1)) + assert.NoError(t, os.RemoveAll(testDir2)) + } + return +} + +func TestFSWatcherAddFiles(t *testing.T) { + file1, file2, file3, close := createTestFiles(t) + defer close() + + // Add one unreadable file + _, err := New([]string{"invalid-file-name"}, nil, nil) + require.Error(t, err) + + // Add one readable file + w, err := New([]string{file1}, nil, nil) + require.NoError(t, err) + assert.IsType(t, &FSWatcher{}, w) + assert.NoError(t, w.Close()) + + // Add one empty-name file and one readable file + w, err = New([]string{"", file1}, nil, nil) + require.NoError(t, err) + assert.IsType(t, &FSWatcher{}, w) + assert.NoError(t, w.Close()) + + // Add one readable file and one unreadable file + _, err = New([]string{file1, "invalid-file-name"}, nil, nil) + require.Error(t, err) + + // Add two readable files from one dir + w, err = New([]string{file1, file2}, nil, nil) + require.NoError(t, err) + assert.IsType(t, &FSWatcher{}, w) + assert.NoError(t, w.Close()) + + // Add two readable files from two different repos + w, err = New([]string{file1, file3}, nil, nil) + require.NoError(t, err) + assert.IsType(t, &FSWatcher{}, w) + assert.NoError(t, w.Close()) +} + +func TestFSWatcherWithMultipleFiles(t *testing.T) { + testFile1, err := os.CreateTemp("", "") + require.NoError(t, err) + defer testFile1.Close() + + testFile2, err := os.CreateTemp("", "") + require.NoError(t, err) + defer testFile2.Close() + + _, err = testFile1.WriteString("test content 1") + require.NoError(t, err) + + _, err = testFile2.WriteString("test content 2") + require.NoError(t, err) + + zcore, logObserver := observer.New(zapcore.InfoLevel) + logger := zap.New(zcore) + + onChange := func() { + logger.Info("Change happens") + } + + w, err := New([]string{testFile1.Name(), testFile2.Name()}, onChange, logger) + require.NoError(t, err) + require.IsType(t, &FSWatcher{}, w) + defer w.Close() + + // Test Write event + testFile1.WriteString(" changed") + testFile2.WriteString(" changed") + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Received event").Len() > 0 + }, + "Unable to locate 'Received event' in log. All logs: %v", logObserver) + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Change happens").Len() > 0 + }, + "Unable to locate 'Change happens' in log. All logs: %v", logObserver) + newHash1, err := hashFile(testFile1.Name()) + require.NoError(t, err) + newHash2, err := hashFile(testFile2.Name()) + require.NoError(t, err) + w.mu.RLock() + assert.Equal(t, newHash1, w.fileHashContentMap[testFile1.Name()]) + assert.Equal(t, newHash2, w.fileHashContentMap[testFile2.Name()]) + w.mu.RUnlock() + + // Test Remove event + os.Remove(testFile1.Name()) + os.Remove(testFile2.Name()) + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Received event").Len() > 0 + }, + "Unable to locate 'Received event' in log. All logs: %v", logObserver) + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Unable to read the file").FilterField(zap.String("file", testFile1.Name())).Len() > 0 + }, + "Unable to locate 'Unable to read the file' in log. All logs: %v", logObserver) + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Unable to read the file").FilterField(zap.String("file", testFile2.Name())).Len() > 0 + }, + "Unable to locate 'Unable to read the file' in log. All logs: %v", logObserver) +} + +func TestFSWatcherWithSymlinkAndRepoChanges(t *testing.T) { + testDir, err := os.MkdirTemp("", "test") + require.NoError(t, err) + defer os.RemoveAll(testDir) + + err = os.Symlink("..timestamp-1", filepath.Join(testDir, "..data")) + require.NoError(t, err) + err = os.Symlink(filepath.Join("..data", "test.doc"), filepath.Join(testDir, "test.doc")) + require.NoError(t, err) + + timestamp1Dir := filepath.Join(testDir, "..timestamp-1") + createTimestampDir(t, timestamp1Dir) + + zcore, logObserver := observer.New(zapcore.InfoLevel) + logger := zap.New(zcore) + + onChange := func() {} + + w, err := New([]string{filepath.Join(testDir, "test.doc")}, onChange, logger) + require.NoError(t, err) + require.IsType(t, &FSWatcher{}, w) + defer w.Close() + + timestamp2Dir := filepath.Join(testDir, "..timestamp-2") + createTimestampDir(t, timestamp2Dir) + + err = os.Symlink("..timestamp-2", filepath.Join(testDir, "..data_tmp")) + require.NoError(t, err) + + os.Rename(filepath.Join(testDir, "..data_tmp"), filepath.Join(testDir, "..data")) + require.NoError(t, err) + err = os.RemoveAll(timestamp1Dir) + require.NoError(t, err) + + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Received event").Len() > 0 + }, + "Unable to locate 'Received event' in log. All logs: %v", logObserver) + byteData, err := os.ReadFile(filepath.Join(testDir, "test.doc")) + require.NoError(t, err) + assert.Equal(t, byteData, []byte("test data")) + + timestamp3Dir := filepath.Join(testDir, "..timestamp-3") + createTimestampDir(t, timestamp3Dir) + err = os.Symlink("..timestamp-3", filepath.Join(testDir, "..data_tmp")) + require.NoError(t, err) + os.Rename(filepath.Join(testDir, "..data_tmp"), filepath.Join(testDir, "..data")) + require.NoError(t, err) + err = os.RemoveAll(timestamp2Dir) + require.NoError(t, err) + + assertLogs(t, + func() bool { + return logObserver.FilterMessage("Received event").Len() > 0 + }, + "Unable to locate 'Received event' in log. All logs: %v", logObserver) + byteData, err = os.ReadFile(filepath.Join(testDir, "test.doc")) + require.NoError(t, err) + assert.Equal(t, byteData, []byte("test data")) +} + +func createTimestampDir(t *testing.T, dir string) { + t.Helper() + err := os.MkdirAll(dir, 0o700) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(dir, "test.doc"), []byte("test data"), 0o600) + require.NoError(t, err) +} + +type delayedFormat struct { + fn func() interface{} +} + +func (df delayedFormat) String() string { + return fmt.Sprintf("%v", df.fn()) +} + +func assertLogs(t *testing.T, f func() bool, errorMsg string, logObserver *observer.ObservedLogs) { + assert.Eventuallyf(t, f, + 10*time.Second, 10*time.Millisecond, + errorMsg, + delayedFormat{ + fn: func() interface{} { return logObserver.All() }, + }, + ) +}