Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pkg/credentials/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ func (s *StoreFactory) NewStore(credCtxs []string) (CredentialStore, error) {
return nil, err
}
if s.file {
return withOverride{
target: Store{
return &withOverride{
target: &Store{
credCtxs: credCtxs,
cfg: s.cfg,
},
overrides: s.overrides,
credContext: credCtxs,
}, nil
}
return withOverride{
target: Store{
return &withOverride{
target: &Store{
credCtxs: credCtxs,
cfg: s.cfg,
program: s.program,
Expand Down
4 changes: 4 additions & 0 deletions pkg/credentials/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ func (s NoopStore) Remove(context.Context, string) error {
func (s NoopStore) List(context.Context) ([]Credential, error) {
return nil, nil
}

func (s NoopStore) RecreateAll(context.Context) error {
return nil
}
4 changes: 4 additions & 0 deletions pkg/credentials/overrides.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,7 @@ func (w withOverride) List(ctx context.Context) ([]Credential, error) {

return creds, nil
}

func (w withOverride) RecreateAll(ctx context.Context) error {
return w.target.RecreateAll(ctx)
}
82 changes: 74 additions & 8 deletions pkg/credentials/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"regexp"
"slices"
"sync"

"github.com/docker/cli/cli/config/credentials"
"github.com/docker/cli/cli/config/types"
Expand All @@ -24,15 +25,20 @@ type CredentialStore interface {
Refresh(ctx context.Context, cred Credential) error
Remove(ctx context.Context, toolName string) error
List(ctx context.Context) ([]Credential, error)
RecreateAll(ctx context.Context) error
}

type Store struct {
credCtxs []string
cfg *config.CLIConfig
program client.ProgramFunc
credCtxs []string
cfg *config.CLIConfig
program client.ProgramFunc
recreateAllLock sync.RWMutex
}

func (s Store) Get(_ context.Context, toolName string) (*Credential, bool, error) {
func (s *Store) Get(_ context.Context, toolName string) (*Credential, bool, error) {
s.recreateAllLock.RLock()
defer s.recreateAllLock.RUnlock()

if len(s.credCtxs) > 0 && s.credCtxs[0] == AllCredentialContexts {
return nil, false, fmt.Errorf("cannot get a credential with context %q", AllCredentialContexts)
}
Expand Down Expand Up @@ -80,7 +86,10 @@ func (s Store) Get(_ context.Context, toolName string) (*Credential, bool, error

// Add adds a new credential to the credential store.
// Any context set on the credential object will be overwritten with the first context of the credential store.
func (s Store) Add(_ context.Context, cred Credential) error {
func (s *Store) Add(_ context.Context, cred Credential) error {
s.recreateAllLock.RLock()
defer s.recreateAllLock.RUnlock()

first := first(s.credCtxs)
if first == AllCredentialContexts {
return fmt.Errorf("cannot add a credential with context %q", AllCredentialContexts)
Expand All @@ -99,7 +108,10 @@ func (s Store) Add(_ context.Context, cred Credential) error {
}

// Refresh updates an existing credential in the credential store.
func (s Store) Refresh(_ context.Context, cred Credential) error {
func (s *Store) Refresh(_ context.Context, cred Credential) error {
s.recreateAllLock.RLock()
defer s.recreateAllLock.RUnlock()

if !slices.Contains(s.credCtxs, cred.Context) {
return fmt.Errorf("context %q not in list of valid contexts for this credential store", cred.Context)
}
Expand All @@ -115,7 +127,10 @@ func (s Store) Refresh(_ context.Context, cred Credential) error {
return store.Store(auth)
}

func (s Store) Remove(_ context.Context, toolName string) error {
func (s *Store) Remove(_ context.Context, toolName string) error {
s.recreateAllLock.RLock()
defer s.recreateAllLock.RUnlock()

first := first(s.credCtxs)
if len(s.credCtxs) > 1 || first == AllCredentialContexts {
return fmt.Errorf("error: credential deletion is not supported when multiple credential contexts are provided")
Expand All @@ -129,7 +144,10 @@ func (s Store) Remove(_ context.Context, toolName string) error {
return store.Erase(toolNameWithCtx(toolName, first))
}

func (s Store) List(_ context.Context) ([]Credential, error) {
func (s *Store) List(_ context.Context) ([]Credential, error) {
s.recreateAllLock.RLock()
defer s.recreateAllLock.RUnlock()

store, err := s.getStore()
if err != nil {
return nil, err
Expand Down Expand Up @@ -199,6 +217,54 @@ func (s Store) List(_ context.Context) ([]Credential, error) {
return maps.Values(credsByName), nil
}

func (s *Store) RecreateAll(_ context.Context) error {
store, err := s.getStore()
if err != nil {
return err
}

// We repeatedly lock and unlock the mutex in this function to give other threads a chance to talk to the credential store.
// It can take several minutes to recreate the credentials if there are hundreds of them, and we don't want to
// block all other threads while we do that.
// New credentials might be created after our GetAll, but they will be created with the current encryption configuration,
// so it's okay that they are skipped by this function.

s.recreateAllLock.Lock()
all, err := store.GetAll()
s.recreateAllLock.Unlock()
if err != nil {
return err
}

// Loop through and recreate each individual credential.
for serverAddress := range all {
s.recreateAllLock.Lock()
authConfig, err := store.Get(serverAddress)
if err != nil {
s.recreateAllLock.Unlock()

if IsCredentialsNotFoundError(err) {
// This can happen if the credential was deleted between the GetAll and the Get by another thread.
continue
}
return err
}

if err := store.Erase(serverAddress); err != nil {
s.recreateAllLock.Unlock()
return err
}

if err := store.Store(authConfig); err != nil {
s.recreateAllLock.Unlock()
return err
}
s.recreateAllLock.Unlock()
}

return nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Writing a helper function for this will allow you to use defer s.recreateAllLock.Unlock().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good call

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}

func (s *Store) getStore() (credentials.Store, error) {
if s.program != nil {
return &toolCredentialStore{
Expand Down
17 changes: 17 additions & 0 deletions pkg/sdkserver/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@ func (s *server) initializeCredentialStore(_ context.Context, credCtxs []string)
return store, nil
}

func (s *server) recreateAllCredentials(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())

store, err := s.initializeCredentialStore(r.Context(), []string{credentials.AllCredentialContexts})
if err != nil {
writeError(logger, w, http.StatusInternalServerError, err)
return
}

if err := store.RecreateAll(r.Context()); err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to recreate all credentials: %w", err))
return
}

writeResponse(logger, w, map[string]any{"stdout": "All credentials recreated successfully"})
}

func (s *server) listCredentials(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())
req := new(credentialsRequest)
Expand Down
1 change: 1 addition & 0 deletions pkg/sdkserver/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (s *server) addRoutes(mux *http.ServeMux) {
mux.HandleFunc("POST /credentials/create", s.createCredential)
mux.HandleFunc("POST /credentials/reveal", s.revealCredential)
mux.HandleFunc("POST /credentials/delete", s.deleteCredential)
mux.HandleFunc("POST /credentials/recreate-all", s.recreateAllCredentials)

mux.HandleFunc("POST /datasets", s.listDatasets)
mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements)
Expand Down