Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package main

import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"syscall"
"time"
)

const (
cleanupInterval = 1 * time.Hour
)

// CacheManager handles the storage and cleanup of cached Docker images
type CacheManager struct {
dir string
maxCacheAge time.Duration
}

// NewCacheManager creates a new CacheManager instance
func NewCacheManager(dir string, maxCacheAge time.Duration) (*CacheManager, error) {
if dir == "" {
tmpDir, err := os.MkdirTemp("", "docker-image-cache-*")
if err != nil {
return nil, fmt.Errorf("failed to create temporary cache directory: %w", err)
}
dir = tmpDir
} else if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create cache directory: %w", err)
}

return &CacheManager{dir: dir, maxCacheAge: maxCacheAge}, nil
}

// StartCleanup starts a background goroutine that periodically removes old files
func (c *CacheManager) StartCleanup(ctx context.Context) {
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()

// Run initial cleanup
c.PerformCleanup()

for {
select {
case <-ticker.C:
c.PerformCleanup()
case <-ctx.Done():
log.Println("Stopping cache cleanup background task")
return
}
}
}

// PerformCleanup removes files from the cache directory that are older than maxCacheAge
func (c *CacheManager) PerformCleanup() {
files, err := os.ReadDir(c.dir)
if err != nil {
log.Printf("Failed to read cache directory during cleanup: %v\n", err)
return
}

now := time.Now()
for _, file := range files {
if file.IsDir() {
continue
}

info, err := file.Info()
if err != nil {
log.Printf("Failed to get info for file %s during cleanup: %v\n", file.Name(), err)
continue
}

stat := info.Sys().(*syscall.Stat_t)
atime := time.Unix(int64(stat.Atim.Sec), int64(stat.Atim.Nsec))

if now.Sub(atime) > c.maxCacheAge {
path := filepath.Join(c.dir, file.Name())
log.Printf("Removing old cached file: %s (age: %v)\n", file.Name(), now.Sub(atime))
if err := os.Remove(path); err != nil {
log.Printf("Failed to remove old cached file %s: %v\n", file.Name(), err)
}
}
}
}

// GetCachePath returns the full path for a cached image
func (c *CacheManager) GetCachePath(imageName string) string {
return filepath.Join(c.dir, c.GetCacheFilename(imageName))
}

// GetCacheFilename generates a safe filename for caching
func (c *CacheManager) GetCacheFilename(imageName string) string {
ref := ParseImageReference(imageName)
safeImageName := sanitizeFilenameComponent(ref.Repository)
safeTag := sanitizeFilenameComponent(ref.Tag)
return fmt.Sprintf("%s_%s.tar.gz", safeImageName, safeTag)
}

// Dir returns the cache directory path
func (c *CacheManager) Dir() string {
return c.dir
}
86 changes: 86 additions & 0 deletions cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package main

import (
"os"
"path/filepath"
"testing"
"time"
)

func TestPerformCleanup(t *testing.T) {
tempDir, err := os.MkdirTemp("", "test-cleanup-*")
if err != nil {
t.Fatal(err)
}
defer cleanupTempDir(t, tempDir)

maxAge := 2 * time.Hour

// Create a new file (should NOT be removed)
newFile := filepath.Join(tempDir, "new_file.tar.gz")
if err := os.WriteFile(newFile, []byte("new"), 0644); err != nil {
t.Fatal(err)
}

// Create an old file (should be removed)
oldFile := filepath.Join(tempDir, "old_file.tar.gz")
if err := os.WriteFile(oldFile, []byte("old"), 0644); err != nil {
t.Fatal(err)
}

// Manually set the modification time to be older than maxAge
oldTime := time.Now().Add(-maxAge - time.Hour)
if err := os.Chtimes(oldFile, oldTime, oldTime); err != nil {
t.Fatal(err)
}

cache, _ := NewCacheManager(tempDir, maxAge)
cache.PerformCleanup()

// Check if the new file still exists
if _, err := os.Stat(newFile); os.IsNotExist(err) {
t.Errorf("new file was incorrectly removed")
}

// Check if the old file was removed
if _, err := os.Stat(oldFile); !os.IsNotExist(err) {
t.Errorf("old file was not removed")
}
}

func TestGetCacheFilename(t *testing.T) {
cache, _ := NewCacheManager("", 1*time.Hour)
defer func(path string) {
err := os.RemoveAll(path)
if err != nil {
t.Errorf("failed to remove temporary directory: %v", err)
}
}(cache.Dir())

tests := []struct {
imageName string
expected string
}{
{
imageName: "alpine:latest",
expected: "library_alpine_latest.tar.gz",
},
{
imageName: "library/ubuntu:20.04",
expected: "library_ubuntu_20.04.tar.gz",
},
{
imageName: "ghcr.io/username/repo:v1.2.3",
expected: "username_repo_v1.2.3.tar.gz",
},
}

for _, tt := range tests {
t.Run(tt.imageName, func(t *testing.T) {
got := cache.GetCacheFilename(tt.imageName)
if got != tt.expected {
t.Errorf("GetCacheFilename(%q) = %q, want %q", tt.imageName, got, tt.expected)
}
})
}
}
4 changes: 4 additions & 0 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ port: 8080
# If not specified, uses a temporary directory
cache_dir: /tmp/docker-images

# Maximum age for cached images before they are considered stale and eligible for cleanup.
# Supports duration formats like "24h", "30m".
max_cache_age: 48h

# Per-registry credentials
# Use registry hostname as the key
registries:
Expand Down
11 changes: 8 additions & 3 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ package main
import (
"fmt"
"os"
"time"

"gopkg.in/yaml.v3"
)

// Config represents the application configuration
type Config struct {
Port int `yaml:"port"`
CacheDir string `yaml:"cache_dir"`
Registries map[string]RegistryConfig `yaml:"registries"`
Port int `yaml:"port"`
CacheDir string `yaml:"cache_dir"`
MaxCacheAge time.Duration `yaml:"max_cache_age"`
Registries map[string]RegistryConfig `yaml:"registries"`
}

// RegistryConfig holds credentials for a specific registry
Expand Down Expand Up @@ -46,6 +48,9 @@ func (c *Config) ApplyDefaults() {
if c.Port == 0 {
c.Port = 8080
}
if c.MaxCacheAge == 0 {
c.MaxCacheAge = 48 * time.Hour
}
}

// Validate checks if the configuration is valid
Expand Down
13 changes: 2 additions & 11 deletions files.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package main
import (
"archive/tar"
"compress/gzip"
"crypto/sha256"
"encoding/hex"
"io"
"log"
"os"
Expand Down Expand Up @@ -48,15 +46,8 @@ func decompressGzip(src, dst string) error {
}
defer closeWithLog(dstFile, "destination file")

hasher := sha256.New()
writer := io.MultiWriter(dstFile, hasher)
_, err = io.Copy(writer, gzReader)
if err != nil {
return err
}

_ = hex.EncodeToString(hasher.Sum(nil))
return nil
_, err = io.Copy(dstFile, gzReader)
return err
}

// createTar creates a gzip-compressed tar archive from a source directory
Expand Down
35 changes: 21 additions & 14 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,29 @@ func main() {

var addr string
var cacheDir string
var maxCacheAge time.Duration

if *configPath != "" {
config, err := LoadConfig(*configPath)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}

config, err := LoadConfig(*configPath)
if err != nil {
log.Printf("No config file loaded, using defaults: %v", err)
addr = ":8080"
cacheDir = ""
} else {
addr = fmt.Sprintf(":%d", config.Port)
cacheDir = config.CacheDir
config.ApplyCredentials()
maxCacheAge = config.MaxCacheAge

log.Printf("Loaded configuration from %s", *configPath)
} else {
addr = ":8080"
cacheDir = ""
log.Printf("Using cache directory: %s and maximum age %s", cacheDir, maxCacheAge)
}

server := NewServer(addr, cacheDir)
srv, err := server.Start()
server := NewServer(addr, cacheDir, maxCacheAge)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

srv, err := server.Start(ctx)
if err != nil {
log.Fatalf("Failed to start server: %v", err)
}
Expand All @@ -58,9 +62,12 @@ func main() {
sig := <-quit
log.Printf("Received %s, shutting down gracefully...", sig)

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
// Cancel the context for background tasks
cancel()

shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Fatalf("Server forced to shutdown: %v", err)
}

Expand Down
Loading