Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 223 additions & 4 deletions cmd/litestream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,20 +294,39 @@ func (c *Config) Validate() error {
}

// Validate database configs
for _, db := range c.DBs {
for idx, db := range c.DBs {
// Validate that either path or dir is specified, but not both
if db.Path != "" && db.Dir != "" {
return fmt.Errorf("database config #%d: cannot specify both 'path' and 'dir'", idx+1)
}
if db.Path == "" && db.Dir == "" {
return fmt.Errorf("database config #%d: must specify either 'path' or 'dir'", idx+1)
}

// When using dir, pattern must be specified
if db.Dir != "" && db.Pattern == "" {
return fmt.Errorf("database config #%d: 'pattern' is required when using 'dir'", idx+1)
}

// Use path or dir for identifying the config in error messages
dbIdentifier := db.Path
if dbIdentifier == "" {
dbIdentifier = db.Dir
}

// Validate sync intervals for replicas
if db.Replica != nil && db.Replica.SyncInterval != nil && *db.Replica.SyncInterval <= 0 {
return &ConfigValidationError{
Err: ErrInvalidSyncInterval,
Field: fmt.Sprintf("dbs[%s].replica.sync-interval", db.Path),
Field: fmt.Sprintf("dbs[%s].replica.sync-interval", dbIdentifier),
Value: *db.Replica.SyncInterval,
}
}
for i, replica := range db.Replicas {
if replica.SyncInterval != nil && *replica.SyncInterval <= 0 {
return &ConfigValidationError{
Err: ErrInvalidSyncInterval,
Field: fmt.Sprintf("dbs[%s].replicas[%d].sync-interval", db.Path, i),
Field: fmt.Sprintf("dbs[%s].replicas[%d].sync-interval", dbIdentifier, i),
Value: *replica.SyncInterval,
}
}
Expand Down Expand Up @@ -409,6 +428,9 @@ func ParseConfig(r io.Reader, expandEnv bool) (_ Config, err error) {

// Normalize paths.
for _, dbConfig := range config.DBs {
if dbConfig.Path == "" {
continue
}
if dbConfig.Path, err = expand(dbConfig.Path); err != nil {
return config, err
}
Expand Down Expand Up @@ -440,9 +462,12 @@ type CompactionLevelConfig struct {
Interval time.Duration `yaml:"interval"`
}

// DBConfig represents the configuration for a single database.
// DBConfig represents the configuration for a single database or directory of databases.
type DBConfig struct {
Path string `yaml:"path"`
Dir string `yaml:"dir"` // Directory to scan for databases
Pattern string `yaml:"pattern"` // File pattern to match (e.g., "*.db", "*.sqlite")
Recursive bool `yaml:"recursive"` // Scan subdirectories recursively
MetaPath *string `yaml:"meta-path"`
MonitorInterval *time.Duration `yaml:"monitor-interval"`
CheckpointInterval *time.Duration `yaml:"checkpoint-interval"`
Expand Down Expand Up @@ -512,6 +537,200 @@ func NewDBFromConfig(dbc *DBConfig) (*litestream.DB, error) {
return db, nil
}

// NewDBsFromDirectoryConfig scans a directory and creates DB instances for all SQLite databases found.
func NewDBsFromDirectoryConfig(dbc *DBConfig) ([]*litestream.DB, error) {
if dbc.Dir == "" {
return nil, fmt.Errorf("directory path is required for directory replication")
}

if dbc.Pattern == "" {
return nil, fmt.Errorf("pattern is required for directory replication")
}

dirPath, err := expand(dbc.Dir)
if err != nil {
return nil, err
}

// Find all SQLite databases in the directory
dbPaths, err := FindSQLiteDatabases(dirPath, dbc.Pattern, dbc.Recursive)
if err != nil {
return nil, fmt.Errorf("failed to scan directory %s: %w", dirPath, err)
}

if len(dbPaths) == 0 {
return nil, fmt.Errorf("no SQLite databases found in directory %s with pattern %s", dirPath, dbc.Pattern)
}

// Create DB instances for each found database
var dbs []*litestream.DB
for _, dbPath := range dbPaths {
// Calculate relative path from directory root
relPath, err := filepath.Rel(dirPath, dbPath)
if err != nil {
return nil, fmt.Errorf("failed to calculate relative path for %s: %w", dbPath, err)
}

// Create a copy of the config for each database
dbConfigCopy := *dbc
dbConfigCopy.Path = dbPath
dbConfigCopy.Dir = "" // Clear dir field for individual DB
dbConfigCopy.Pattern = "" // Clear pattern field
dbConfigCopy.Recursive = false // Clear recursive flag

// Deep copy replica config and make path unique per database.
// This prevents all databases from writing to the same replica path.
if dbc.Replica != nil {
replicaCopy, err := cloneReplicaConfigWithRelativePath(dbc.Replica, relPath)
if err != nil {
return nil, fmt.Errorf("failed to configure replica for %s: %w", dbPath, err)
}
dbConfigCopy.Replica = replicaCopy
}

// Also handle deprecated 'replicas' array field.
if len(dbc.Replicas) > 0 {
dbConfigCopy.Replicas = make([]*ReplicaConfig, len(dbc.Replicas))
for i, replica := range dbc.Replicas {
replicaCopy, err := cloneReplicaConfigWithRelativePath(replica, relPath)
if err != nil {
return nil, fmt.Errorf("failed to configure replica %d for %s: %w", i, dbPath, err)
}
dbConfigCopy.Replicas[i] = replicaCopy
}
}

db, err := NewDBFromConfig(&dbConfigCopy)
if err != nil {
return nil, fmt.Errorf("failed to create DB for %s: %w", dbPath, err)
}
dbs = append(dbs, db)
}

return dbs, nil
}

// cloneReplicaConfigWithRelativePath returns a copy of the replica configuration with the
// database-relative path appended to either the replica path or URL, depending on how the
// replica was configured.
func cloneReplicaConfigWithRelativePath(base *ReplicaConfig, relPath string) (*ReplicaConfig, error) {
if base == nil {
return nil, nil
}

replicaCopy := *base
relPath = filepath.ToSlash(relPath)
if relPath == "" || relPath == "." {
return &replicaCopy, nil
}

if replicaCopy.URL != "" {
u, err := url.Parse(replicaCopy.URL)
if err != nil {
return nil, fmt.Errorf("parse replica url: %w", err)
}
appendRelativePathToURL(u, relPath)
replicaCopy.URL = u.String()
return &replicaCopy, nil
}

switch base.ReplicaType() {
case "file":
relOSPath := filepath.FromSlash(relPath)
if replicaCopy.Path != "" {
replicaCopy.Path = filepath.Join(replicaCopy.Path, relOSPath)
} else {
replicaCopy.Path = relOSPath
}
default:
// Normalize to forward slashes for cloud/object storage backends.
basePath := filepath.ToSlash(replicaCopy.Path)
if basePath != "" {
replicaCopy.Path = path.Join(basePath, relPath)
} else {
replicaCopy.Path = relPath
}
}

return &replicaCopy, nil
}

// appendRelativePathToURL appends relPath to the URL's path component, ensuring
// the result remains rooted and uses forward slashes.
func appendRelativePathToURL(u *url.URL, relPath string) {
cleanRel := strings.TrimPrefix(relPath, "/")
if cleanRel == "" || cleanRel == "." {
return
}

basePath := u.Path
var joined string
if basePath == "" {
joined = cleanRel
} else {
joined = path.Join(basePath, cleanRel)
}

joined = "/" + strings.TrimPrefix(joined, "/")
u.Path = joined
}

// FindSQLiteDatabases recursively finds all SQLite database files in a directory.
// Exported for testing.
func FindSQLiteDatabases(dir string, pattern string, recursive bool) ([]string, error) {
var dbPaths []string

err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

// Skip directories unless recursive
if info.IsDir() {
if !recursive && path != dir {
return filepath.SkipDir
}
return nil
}

// Check if file matches pattern
matched, err := filepath.Match(pattern, filepath.Base(path))
if err != nil {
return err
}
if !matched {
return nil
}

// Check if it's a SQLite database
if IsSQLiteDatabase(path) {
dbPaths = append(dbPaths, path)
}

return nil
})

return dbPaths, err
}

// IsSQLiteDatabase checks if a file is a SQLite database by reading its header.
// Exported for testing.
func IsSQLiteDatabase(path string) bool {
file, err := os.Open(path)
if err != nil {
return false
}
defer file.Close()

// SQLite files start with "SQLite format 3\x00"
header := make([]byte, 16)
if _, err := file.Read(header); err != nil {
return false
}

return string(header) == "SQLite format 3\x00"
}

// ByteSize is a custom type for parsing byte sizes from YAML.
// It supports both SI units (KB, MB, GB using base 1000) and IEC units
// (KiB, MiB, GiB using base 1024) as well as short forms (K, M, G).
Expand Down
Loading