diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 70e11ff..be42186 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -24,13 +24,13 @@ func Authenticate(req *http.Request, s *terraform.State) (ok bool, err error) { var authenticator Authenticator switch backend { - case "basic": + case basic.Name: viper.SetDefault("auth_basic_enabled", true) if !viper.GetBool("auth_basic_enabled") { return false, fmt.Errorf("basic auth is not enabled") } authenticator = basic.NewBasicAuth() - case "jwt": + case jwt.Name: issuerURL := viper.GetString("auth_jwt_oidc_issuer_url") if addr := viper.GetString("vault_addr"); issuerURL != "" && addr != "" { issuerURL = fmt.Sprintf("%s/v1/identity/oidc", addr) diff --git a/pkg/auth/basic/basic.go b/pkg/auth/basic/basic.go index edff74c..7506e40 100644 --- a/pkg/auth/basic/basic.go +++ b/pkg/auth/basic/basic.go @@ -7,6 +7,8 @@ import ( "github.com/nimbolus/terraform-backend/pkg/terraform" ) +const Name = "basic" + type BasicAuth struct{} func NewBasicAuth() *BasicAuth { @@ -14,7 +16,7 @@ func NewBasicAuth() *BasicAuth { } func (l *BasicAuth) GetName() string { - return "basic" + return Name } func (b *BasicAuth) Authenticate(secret string, s *terraform.State) (bool, error) { diff --git a/pkg/auth/jwt/jwt.go b/pkg/auth/jwt/jwt.go index 399e467..cd785e7 100644 --- a/pkg/auth/jwt/jwt.go +++ b/pkg/auth/jwt/jwt.go @@ -8,6 +8,8 @@ import ( "github.com/nimbolus/terraform-backend/pkg/terraform" ) +const Name = "jwt" + type JWTAuth struct { issuerURL string } @@ -19,7 +21,7 @@ func NewJWTAuth(issuerURL string) *JWTAuth { } func (l *JWTAuth) GetName() string { - return "jwt" + return Name } func (b *JWTAuth) Authenticate(secret string, s *terraform.State) (bool, error) { diff --git a/pkg/kms/local/local.go b/pkg/kms/local/local.go index c50a8ef..b067ced 100644 --- a/pkg/kms/local/local.go +++ b/pkg/kms/local/local.go @@ -9,6 +9,8 @@ import ( "io" ) +const Name = "local" + type KMS struct { cipher cipher.AEAD } @@ -34,7 +36,7 @@ func GenerateKey() (string, error) { } func (v *KMS) GetName() string { - return "local" + return Name } func (s *KMS) Encrypt(d []byte) ([]byte, error) { diff --git a/pkg/kms/transit/transit.go b/pkg/kms/transit/transit.go index f6bc5b8..b22c5d1 100644 --- a/pkg/kms/transit/transit.go +++ b/pkg/kms/transit/transit.go @@ -7,6 +7,8 @@ import ( vaultclient "github.com/nimbolus/terraform-backend/pkg/client/vault" ) +const Name = "transit" + type VaultTransit struct { engine string key string @@ -25,7 +27,7 @@ func NewVaultTransit(engine string, key string) (*VaultTransit, error) { } func (v *VaultTransit) GetName() string { - return "transit" + return Name } func (v *VaultTransit) Encrypt(d []byte) ([]byte, error) { diff --git a/pkg/lock/local/local.go b/pkg/lock/local/local.go index 049a243..784681a 100644 --- a/pkg/lock/local/local.go +++ b/pkg/lock/local/local.go @@ -6,6 +6,8 @@ import ( "github.com/nimbolus/terraform-backend/pkg/terraform" ) +const Name = "local" + type Lock struct { mutex sync.Mutex db map[string][]byte @@ -18,7 +20,7 @@ func NewLock() *Lock { } func (l *Lock) GetName() string { - return "local" + return Name } func (l *Lock) Lock(s *terraform.State) (bool, error) { diff --git a/pkg/lock/postgres/postgres.go b/pkg/lock/postgres/postgres.go index 2847d01..d61a2f5 100644 --- a/pkg/lock/postgres/postgres.go +++ b/pkg/lock/postgres/postgres.go @@ -3,13 +3,14 @@ package postgres import ( "context" "database/sql" - "fmt" "time" pgclient "github.com/nimbolus/terraform-backend/pkg/client/postgres" "github.com/nimbolus/terraform-backend/pkg/terraform" ) +const Name = "postgres" + type Lock struct { db *pgclient.Client } @@ -26,7 +27,7 @@ func NewLock() (*Lock, error) { } func (l *Lock) GetName() string { - return "pg" + return Name } func (l *Lock) Lock(s *terraform.State) (bool, error) { @@ -40,9 +41,9 @@ func (l *Lock) Lock(s *terraform.State) (bool, error) { defer tx.Rollback() - var lockData []byte + var lock []byte - if err := tx.QueryRow(`SELECT lock_data FROM `+l.db.GetLocksTableName()+` WHERE state_id = $1`, s.ID).Scan(&lockData); err != nil { + if err := tx.QueryRow(`SELECT lock_data FROM `+l.db.GetLocksTableName()+` WHERE state_id = $1`, s.ID).Scan(&lock); err != nil { if err == sql.ErrNoRows { if _, err := tx.Exec(`INSERT INTO locks (state_id, lock_data) VALUES ($1, $2)`, s.ID, s.Lock); err != nil { return false, err @@ -58,12 +59,14 @@ func (l *Lock) Lock(s *terraform.State) (bool, error) { return false, err } - if string(lockData) == string(s.Lock) { + if string(lock) == string(s.Lock) { // you already have the lock return true, nil } - return false, fmt.Errorf("lock already taken for id %s: %s", s.ID, string(lockData)) + s.Lock = lock + + return false, nil } func (l *Lock) Unlock(s *terraform.State) (bool, error) { @@ -77,18 +80,18 @@ func (l *Lock) Unlock(s *terraform.State) (bool, error) { defer tx.Rollback() - var lockData []byte + var lock []byte - if err := tx.QueryRow(`SELECT lock_data FROM `+l.db.GetLocksTableName()+` WHERE state_id = $1`, s.ID).Scan(&lockData); err != nil { + if err := tx.QueryRow(`SELECT lock_data FROM `+l.db.GetLocksTableName()+` WHERE state_id = $1`, s.ID).Scan(&lock); err != nil { if err == sql.ErrNoRows { - return false, fmt.Errorf("no lock for id %s found", s.ID) + return false, nil } return false, err } - if string(lockData) != string(s.Lock) { - return false, fmt.Errorf("lock mismatch for id %s", s.ID) + if string(lock) != string(s.Lock) { + return false, nil } if _, err := tx.Exec(`DELETE FROM `+l.db.GetLocksTableName()+` WHERE state_id = $1 AND lock_data = $2`, s.ID, s.Lock); err != nil { diff --git a/pkg/lock/postgres/postgres_test.go b/pkg/lock/postgres/postgres_test.go index c180c75..c724b56 100644 --- a/pkg/lock/postgres/postgres_test.go +++ b/pkg/lock/postgres/postgres_test.go @@ -6,9 +6,15 @@ package postgres import ( "testing" + "github.com/spf13/viper" + "github.com/nimbolus/terraform-backend/pkg/lock/util" ) +func init() { + viper.AutomaticEnv() +} + func TestLock(t *testing.T) { l, err := NewLock() if err != nil { diff --git a/pkg/lock/redis/redis.go b/pkg/lock/redis/redis.go index 72b7eed..277ab8e 100644 --- a/pkg/lock/redis/redis.go +++ b/pkg/lock/redis/redis.go @@ -19,7 +19,10 @@ import ( "github.com/nimbolus/terraform-backend/pkg/terraform" ) -const lockKey = "terraform-backend-state-lock" +const ( + Name = "redis" + lockKey = "terraform-backend-state-lock" +) type Lock struct { pool *redigo.Pool @@ -39,7 +42,7 @@ func NewLock() *Lock { } func (r *Lock) GetName() string { - return "redis" + return Name } func (r *Lock) Lock(s *terraform.State) (locked bool, err error) { diff --git a/pkg/server/kms.go b/pkg/server/kms.go index 2fc251e..ac78662 100644 --- a/pkg/server/kms.go +++ b/pkg/server/kms.go @@ -12,11 +12,11 @@ import ( ) func GetKMS() (k kms.KMS, err error) { - viper.SetDefault("kms_backend", "local") + viper.SetDefault("kms_backend", local.Name) backend := viper.GetString("kms_backend") switch backend { - case "local": + case local.Name: key := viper.GetString("kms_key") if key == "" { key, _ = local.GenerateKey() @@ -38,7 +38,7 @@ func GetKMS() (k kms.KMS, err error) { } k, err = local.NewKMS(key) - case "transit": + case transit.Name: k, err = transit.NewVaultTransit(viper.GetString("kms_transit_engine"), viper.GetString("kms_transit_key")) default: return nil, fmt.Errorf("failed to initialize KMS backend %s: %v", backend, err) diff --git a/pkg/server/locker.go b/pkg/server/locker.go index dee0ccb..81c32c2 100644 --- a/pkg/server/locker.go +++ b/pkg/server/locker.go @@ -12,15 +12,15 @@ import ( ) func GetLocker() (l lock.Locker, err error) { - viper.SetDefault("lock_backend", "local") + viper.SetDefault("lock_backend", local.Name) backend := viper.GetString("lock_backend") switch backend { - case "local": + case local.Name: l = local.NewLock() - case "redis": + case redis.Name: l = redis.NewLock() - case "postgres": + case postgres.Name: l, err = postgres.NewLock() default: err = fmt.Errorf("backend is not implemented") diff --git a/pkg/server/storage.go b/pkg/server/storage.go index c4d02f0..75ba09d 100644 --- a/pkg/server/storage.go +++ b/pkg/server/storage.go @@ -11,14 +11,14 @@ import ( ) func GetStorage() (s storage.Storage, err error) { - viper.SetDefault("storage_backend", "fs") + viper.SetDefault("storage_backend", filesystem.Name) backend := viper.GetString("storage_backend") switch backend { - case "fs": + case filesystem.Name: viper.SetDefault("storage_fs_dir", "./states") s, err = filesystem.NewFileSystemStorage(viper.GetString("storage_fs_dir")) - case "s3": + case s3.Name: viper.SetDefault("storage_s3_endpoint", "s3.amazonaws.com") viper.SetDefault("storage_s3_use_ssl", true) viper.SetDefault("storage_s3_bucket", "terraform-state") diff --git a/pkg/storage/filesystem/filesystem.go b/pkg/storage/filesystem/filesystem.go index 14907f6..c0776f0 100644 --- a/pkg/storage/filesystem/filesystem.go +++ b/pkg/storage/filesystem/filesystem.go @@ -8,6 +8,8 @@ import ( "github.com/nimbolus/terraform-backend/pkg/terraform" ) +const Name = "fs" + type FileSystemStorage struct { directory string } @@ -24,7 +26,7 @@ func NewFileSystemStorage(directory string) (*FileSystemStorage, error) { } func (f *FileSystemStorage) GetName() string { - return "file" + return Name } func (f *FileSystemStorage) SaveState(s *terraform.State) error { diff --git a/pkg/storage/s3/s3.go b/pkg/storage/s3/s3.go index aaba9b8..5b2535c 100644 --- a/pkg/storage/s3/s3.go +++ b/pkg/storage/s3/s3.go @@ -11,6 +11,8 @@ import ( "github.com/nimbolus/terraform-backend/pkg/terraform" ) +const Name = "s3" + type S3Storage struct { client *minio.Client bucket string @@ -38,7 +40,7 @@ func NewS3Storage(endpoint, bucket, accessKey, secretKey string, useSSL bool) (* } func (s *S3Storage) GetName() string { - return "s3" + return Name } func (s *S3Storage) SaveState(state *terraform.State) error {