Skip to content

Commit

Permalink
Allow concurrent calls for Fake methods.
Browse files Browse the repository at this point in the history
Signed-off-by: Nadia Pinaeva <n.m.pinaeva@gmail.com>
  • Loading branch information
npinaeva committed Sep 9, 2024
1 parent 19fb4da commit 6c1e3c3
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,27 @@ import (
"regexp"
"sort"
"strings"
"sync"
)

// Fake is a fake implementation of Interface
type Fake struct {
nftContext
// mutex is used to protect Table and LastTransaction.
// When Table and LastTransaction are accessed directly, the caller must acquire Fake.RLock
// and release when finished.
sync.RWMutex

nextHandle int

// Table contains the Interface's table. This will be `nil` until you `tx.Add()`
// the table.
// Make sure to acquire Fake.RLock before accessing Table in a concurrent environment.
Table *FakeTable

// LastTransaction is the last transaction passed to Run(). It will remain set until the
// next time Run() is called. (It is not affected by Check().)
// Make sure to acquire Fake.RLock before accessing LastTransaction in a concurrent environment.
LastTransaction *Transaction
}

Expand Down Expand Up @@ -94,6 +101,8 @@ var _ Interface = &Fake{}

// List is part of Interface.
func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) {
fake.RLock()
defer fake.RUnlock()
if fake.Table == nil {
return nil, notFoundError("no such table %q", fake.table)
}
Expand Down Expand Up @@ -123,6 +132,8 @@ func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) {

// ListRules is part of Interface
func (fake *Fake) ListRules(_ context.Context, chain string) ([]*Rule, error) {
fake.RLock()
defer fake.RUnlock()
if fake.Table == nil {
return nil, notFoundError("no such table %q", fake.table)
}
Expand All @@ -145,6 +156,8 @@ func (fake *Fake) ListRules(_ context.Context, chain string) ([]*Rule, error) {

// ListElements is part of Interface
func (fake *Fake) ListElements(_ context.Context, objectType, name string) ([]*Element, error) {
fake.RLock()
defer fake.RUnlock()
if fake.Table == nil {
return nil, notFoundError("no such %s %q", objectType, name)
}
Expand All @@ -169,6 +182,8 @@ func (fake *Fake) NewTransaction() *Transaction {

// Run is part of Interface
func (fake *Fake) Run(_ context.Context, tx *Transaction) error {
fake.Lock()
defer fake.Unlock()
fake.LastTransaction = tx
updatedTable, err := fake.run(tx)
if err == nil {
Expand All @@ -179,10 +194,13 @@ func (fake *Fake) Run(_ context.Context, tx *Transaction) error {

// Check is part of Interface
func (fake *Fake) Check(_ context.Context, tx *Transaction) error {
fake.RLock()
defer fake.RUnlock()
_, err := fake.run(tx)
return err
}

// must be called with fake.lock held
func (fake *Fake) run(tx *Transaction) (*FakeTable, error) {
if tx.err != nil {
return nil, tx.err
Expand Down Expand Up @@ -480,6 +498,8 @@ func checkElementRefs(element *Element, table *FakeTable) error {

// Dump dumps the current contents of fake, in a way that looks like an nft transaction.
func (fake *Fake) Dump() string {
fake.RLock()
defer fake.RUnlock()
if fake.Table == nil {
return ""
}
Expand Down

0 comments on commit 6c1e3c3

Please sign in to comment.