Skip to content

Commit

Permalink
Merge pull request #7 from nimbolus/verify-lock-on-push
Browse files Browse the repository at this point in the history
only allow state push with correct lock id
  • Loading branch information
lu1as committed Oct 11, 2022
2 parents e50f8d0 + 4319333 commit b0ca4b0
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 72 deletions.
21 changes: 17 additions & 4 deletions pkg/lock/local/local.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package local

import (
"fmt"
"sync"

"github.com/nimbolus/terraform-backend/pkg/terraform"
Expand All @@ -10,12 +11,12 @@ const Name = "local"

type Lock struct {
mutex sync.Mutex
db map[string][]byte
db map[string]terraform.LockInfo
}

func NewLock() *Lock {
return &Lock{
db: make(map[string][]byte),
db: make(map[string]terraform.LockInfo),
}
}

Expand All @@ -29,7 +30,7 @@ func (l *Lock) Lock(s *terraform.State) (bool, error) {

lock, ok := l.db[s.ID]
if ok {
if string(lock) == string(s.Lock) {
if lock.Equal(s.Lock) {
// you already have the lock
return true, nil
}
Expand All @@ -53,7 +54,7 @@ func (l *Lock) Unlock(s *terraform.State) (bool, error) {
return false, nil
}

if string(lock) != string(s.Lock) {
if !lock.Equal(s.Lock) {
s.Lock = lock

return false, nil
Expand All @@ -63,3 +64,15 @@ func (l *Lock) Unlock(s *terraform.State) (bool, error) {

return true, nil
}

func (l *Lock) GetLock(s *terraform.State) (terraform.LockInfo, error) {
l.mutex.Lock()
defer l.mutex.Unlock()

lock, ok := l.db[s.ID]
if !ok {
return terraform.LockInfo{}, fmt.Errorf("no lock found for state %s", s.ID)
}

return lock, nil
}
1 change: 1 addition & 0 deletions pkg/lock/locker.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ type Locker interface {
GetName() string
Lock(s *terraform.State) (ok bool, err error)
Unlock(s *terraform.State) (ok bool, err error)
GetLock(s *terraform.State) (terraform.LockInfo, error)
}
53 changes: 45 additions & 8 deletions pkg/lock/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"

Expand Down Expand Up @@ -67,11 +68,16 @@ func (l *Lock) Lock(s *terraform.State) (bool, error) {

defer tx.Rollback()

var lock []byte
var rawLock []byte

if err := tx.QueryRow(`SELECT lock_data FROM `+l.table+` WHERE state_id = $1`, s.ID).Scan(&lock); err != nil {
if err := tx.QueryRow(`SELECT lock_data FROM `+l.table+` WHERE state_id = $1`, s.ID).Scan(&rawLock); err != nil {
if err == sql.ErrNoRows {
if _, err := tx.Exec(`INSERT INTO `+l.table+` (state_id, lock_data) VALUES ($1, $2)`, s.ID, s.Lock); err != nil {
lockBytes, err := json.Marshal(s.Lock)
if err != nil {
return false, err
}

if _, err := tx.Exec(`INSERT INTO `+l.table+` (state_id, lock_data) VALUES ($1, $2)`, s.ID, lockBytes); err != nil {
return false, err
}

Expand All @@ -85,7 +91,13 @@ func (l *Lock) Lock(s *terraform.State) (bool, error) {
return false, err
}

if string(lock) == string(s.Lock) {
var lock terraform.LockInfo

if err := json.Unmarshal(rawLock, &lock); err != nil {
return false, err
}

if lock.Equal(s.Lock) {
// you already have the lock
return true, nil
}
Expand All @@ -106,21 +118,27 @@ func (l *Lock) Unlock(s *terraform.State) (bool, error) {

defer tx.Rollback()

var lock []byte
var rawLock []byte

if err := tx.QueryRow(`SELECT lock_data FROM `+l.table+` WHERE state_id = $1`, s.ID).Scan(&lock); err != nil {
if err := tx.QueryRow(`SELECT lock_data FROM `+l.table+` WHERE state_id = $1`, s.ID).Scan(&rawLock); err != nil {
if err == sql.ErrNoRows {
return false, nil
}

return false, err
}

if string(lock) != string(s.Lock) {
var lock terraform.LockInfo

if err := json.Unmarshal(rawLock, &lock); err != nil {
return false, err
}

if !lock.Equal(s.Lock) {
return false, nil
}

if _, err := tx.Exec(`DELETE FROM `+l.table+` WHERE state_id = $1 AND lock_data = $2`, s.ID, s.Lock); err != nil {
if _, err := tx.Exec(`DELETE FROM `+l.table+` WHERE state_id = $1 AND lock_data = $2`, s.ID, rawLock); err != nil {
return false, err
}

Expand All @@ -130,3 +148,22 @@ func (l *Lock) Unlock(s *terraform.State) (bool, error) {

return true, nil
}

func (l *Lock) GetLock(s *terraform.State) (terraform.LockInfo, error) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

var rawLock []byte

if err := l.db.QueryRowContext(ctx, `SELECT lock_data FROM `+l.table+` WHERE state_id = $1`, s.ID).Scan(&rawLock); err != nil {
return terraform.LockInfo{}, err
}

var lock terraform.LockInfo

if err := json.Unmarshal(rawLock, &lock); err != nil {
return terraform.LockInfo{}, err
}

return lock, nil
}
56 changes: 48 additions & 8 deletions pkg/lock/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package redis
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"time"
Expand Down Expand Up @@ -86,7 +87,7 @@ func (r *Lock) Lock(s *terraform.State) (locked bool, err error) {
}

// state is locked
if string(lock) == string(s.Lock) {
if lock.Equal(s.Lock) {
return true, nil
}

Expand Down Expand Up @@ -123,7 +124,7 @@ func (r *Lock) Unlock(s *terraform.State) (unlocked bool, err error) {
return false, nil
}

if string(lock) != string(s.Lock) {
if !lock.Equal(s.Lock) {
return false, nil
}

Expand All @@ -134,6 +135,32 @@ func (r *Lock) Unlock(s *terraform.State) (unlocked bool, err error) {
return true, nil
}

func (r *Lock) GetLock(s *terraform.State) (lock terraform.LockInfo, err error) {
mutex := r.client.NewMutex(lockKey, redsync.WithExpiry(12*time.Hour), redsync.WithTries(1), redsync.WithGenValueFunc(func() (string, error) {
return uuid.New().String(), nil
}))

// lock the global redis mutex
if err := mutex.Lock(); err != nil {
log.Errorf("failed to lock redsync mutex: %v", err)

return terraform.LockInfo{}, err
}

defer func() {
// unlock the global redis mutex
if _, mutErr := mutex.Unlock(); mutErr != nil {
log.Errorf("failed to unlock redsync mutex: %v", mutErr)

if err != nil {
err = multierr.Append(err, mutErr)
}
}
}()

return r.getLock(s)
}

func (r *Lock) setLock(s *terraform.State) error {
ctx := context.Background()

Expand All @@ -144,7 +171,14 @@ func (r *Lock) setLock(s *terraform.State) error {

defer conn.Close()

reply, err := redigo.String(conn.Do("SET", s.ID, base64.StdEncoding.EncodeToString(s.Lock), "NX", "PX", int(12*time.Hour/time.Millisecond)))
rawLock, err := json.Marshal(s.Lock)
if err != nil {
return err
}

lockString := base64.StdEncoding.EncodeToString(rawLock)

reply, err := redigo.String(conn.Do("SET", s.ID, lockString, "NX", "PX", int(12*time.Hour/time.Millisecond)))
if err != nil {
return err
}
Expand All @@ -156,24 +190,30 @@ func (r *Lock) setLock(s *terraform.State) error {
return nil
}

func (r *Lock) getLock(s *terraform.State) ([]byte, error) {
func (r *Lock) getLock(s *terraform.State) (terraform.LockInfo, error) {
ctx := context.Background()

conn, err := r.pool.GetContext(ctx)
if err != nil {
return nil, err
return terraform.LockInfo{}, err
}

defer conn.Close()

value, err := redigo.String(conn.Do("GET", s.ID))
if err != nil {
return nil, err
return terraform.LockInfo{}, err
}

lock, err := base64.StdEncoding.DecodeString(value)
rawLock, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return nil, err
return terraform.LockInfo{}, err
}

var lock terraform.LockInfo

if err := json.Unmarshal(rawLock, &lock); err != nil {
return terraform.LockInfo{}, err
}

return lock, nil
Expand Down
6 changes: 3 additions & 3 deletions pkg/lock/redis/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestGetLock(t *testing.T) {
ID: terraform.GetStateID("test", "test"),
Project: "test",
Name: "test",
Lock: []byte(expectedLock),
Lock: terraform.LockInfo{ID: expectedLock},
}

{
Expand All @@ -49,8 +49,8 @@ func TestGetLock(t *testing.T) {
t.Error(err)
}

if string(lock) != string(expectedLock) {
t.Errorf("lock mismatch: %s != %s", string(lock), string(expectedLock))
if lock.ID != expectedLock {
t.Errorf("lock mismatch: %s != %s", lock.ID, expectedLock)
}
}

Expand Down
29 changes: 26 additions & 3 deletions pkg/lock/util/locktest.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package util

import (
"testing"
"time"

"github.com/google/uuid"
"github.com/spf13/viper"
Expand All @@ -19,15 +20,31 @@ func LockTest(t *testing.T, l lock.Locker) {
ID: terraform.GetStateID("test", "test"),
Project: "test",
Name: "test",
Lock: []byte(uuid.New().String()),
Lock: terraform.LockInfo{
ID: uuid.New().String(),
Path: "",
Operation: "LockTest",
Who: "test",
Version: "0.0.0",
Created: time.Now().String(),
Info: "",
},
}
t.Logf("s1: %s", s1.Lock)

s2 := terraform.State{
ID: terraform.GetStateID("test", "test"),
Project: "test",
Name: "test",
Lock: []byte(uuid.New().String()),
Lock: terraform.LockInfo{
ID: uuid.New().String(),
Path: "",
Operation: "LockTest",
Who: "test",
Version: "0.0.0",
Created: time.Now().String(),
Info: "",
},
}
t.Logf("s2: %s", s2.Lock)

Expand All @@ -43,6 +60,12 @@ func LockTest(t *testing.T, l lock.Locker) {
t.Error(err)
}

if lock, err := l.GetLock(&s1); err != nil {
t.Error(err)
} else if !lock.Equal(s1.Lock) {
t.Errorf("lock is not equal: %s != %s", lock, s1.Lock)
}

if locked, err := l.Lock(&s1); err != nil || !locked {
t.Error("should be able to lock twice from the same process")
}
Expand All @@ -51,7 +74,7 @@ func LockTest(t *testing.T, l lock.Locker) {
t.Error("should not be able to lock twice from different processes")
}

if string(s2.Lock) != string(s1.Lock) {
if !s2.Lock.Equal(s1.Lock) {
t.Error("failed Lock() should return the lock information of the current lock")
}

Expand Down
Loading

0 comments on commit b0ca4b0

Please sign in to comment.