Skip to content

Commit

Permalink
[management] Add transaction to addPeer (netbirdio#2469)
Browse files Browse the repository at this point in the history
This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer.
  • Loading branch information
pascal-fischer authored Sep 16, 2024
1 parent 730dd17 commit 6c50b0c
Show file tree
Hide file tree
Showing 13 changed files with 1,085 additions and 206 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/golang-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
run: git --no-pager diff --exit-code

- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...

test_client_on_docker:
runs-on: ubuntu-20.04
Expand Down
42 changes: 25 additions & 17 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}

// Subclass used in gorm to only load network and not whole account
type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
}

type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
Expand Down Expand Up @@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps
}

func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID)
if err != nil {
return nil, err
}
return user.AutoGroups, nil
}

func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID)
enabled := true
Expand All @@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
return groupList
}

func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
key, err := a.FindSetupKey(setupKey)
if err != nil {
return nil, err
}
return key.AutoGroups, nil
}

func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
Expand Down Expand Up @@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
}

func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil {
return false, err
}
Expand All @@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
return false, nil
}

func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}

labelMap := ConvertSliceToMap(existingLabels)
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
if err != nil {
return "", fmt.Errorf("failed to get new host label: %w", err)
}

if newLabel == "" {
return "", fmt.Errorf("failed to get new host label: %w", err)
}

return newLabel, nil
}

// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
Expand Down
3 changes: 2 additions & 1 deletion management/server/ephemeral_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
)

type MockStore struct {
Expand All @@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
return s.account, nil
}

return nil, fmt.Errorf("account not found")
return nil, status.NewPeerNotFoundError(peerId)
}

type MocAccountManager struct {
Expand Down
168 changes: 161 additions & 7 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package server

import (
"context"
"errors"
"net"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -46,6 +48,158 @@ type FileStore struct {
metrics telemetry.AppMetrics `json:"-"`
}

func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
return f(s)
}

func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
s.mux.Lock()
defer s.mux.Unlock()

accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
if !ok {
return status.NewSetupKeyNotFoundError()
}

account, err := s.getAccount(accountID)
if err != nil {
return err
}

account.SetupKeys[setupKeyID].UsedTimes++

return s.SaveAccount(ctx, account)
}

func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
s.mux.Lock()
defer s.mux.Unlock()

account, err := s.getAccount(accountID)
if err != nil {
return err
}

allGroup, err := account.GetGroupAll()
if err != nil || allGroup == nil {
return errors.New("all group not found")
}

allGroup.Peers = append(allGroup.Peers, peerID)

return nil
}

func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
s.mux.Lock()
defer s.mux.Unlock()

account, err := s.getAccount(accountId)
if err != nil {
return err
}

account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)

return nil
}

func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()

account, ok := s.Accounts[peer.AccountID]
if !ok {
return status.NewAccountNotFoundError(peer.AccountID)
}

account.Peers[peer.ID] = peer
return s.SaveAccount(ctx, account)
}

func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
s.mux.Lock()
defer s.mux.Unlock()

account, ok := s.Accounts[accountId]
if !ok {
return status.NewAccountNotFoundError(accountId)
}

account.Network.Serial++

return s.SaveAccount(ctx, account)
}

func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
s.mux.Lock()
defer s.mux.Unlock()

accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
if !ok {
return nil, status.NewSetupKeyNotFoundError()
}

account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}

setupKey, ok := account.SetupKeys[key]
if !ok {
return nil, status.Errorf(status.NotFound, "setup key not found")
}

return setupKey, nil
}

func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
s.mux.Lock()
defer s.mux.Unlock()

account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}

var takenIps []net.IP
for _, existingPeer := range account.Peers {
takenIps = append(takenIps, existingPeer.IP)
}

return takenIps, nil
}

func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
s.mux.Lock()
defer s.mux.Unlock()

account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}

existingLabels := []string{}
for _, peer := range account.Peers {
if peer.DNSLabel != "" {
existingLabels = append(existingLabels, peer.DNSLabel)
}
}
return existingLabels, nil
}

func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
s.mux.Lock()
defer s.mux.Unlock()

account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}

return account.Network, nil
}

type StoredAccount struct{}

// NewFileStore restores a store from the file located in the datadir
Expand Down Expand Up @@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A

accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
return nil, status.NewSetupKeyNotFoundError()
}

account, err := s.getAccount(accountID)
Expand Down Expand Up @@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil
}

func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
Expand Down Expand Up @@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, ok := s.Accounts[accountID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
return nil, status.NewAccountNotFoundError(accountID)
}

return account, nil
Expand Down Expand Up @@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (

accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
return "", status.NewSetupKeyNotFoundError()
}

return accountID, nil
}

func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()

Expand All @@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
return nil, status.NewPeerNotFoundError(peerKey)
}

func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()

Expand Down Expand Up @@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
}

// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
s.mux.Lock()
defer s.mux.Unlock()

Expand Down
Loading

0 comments on commit 6c50b0c

Please sign in to comment.