From 46d5ddc8d17eb7ecec4f24cd4081b77a2150e0e0 Mon Sep 17 00:00:00 2001 From: Sid Kattoju Date: Tue, 2 Jul 2024 10:48:03 -0400 Subject: [PATCH] Adding a token getter to get service account tokens --- internal/authentication/tokengetter.go | 95 +++++++++++ internal/authentication/tokengetter_test.go | 167 ++++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 internal/authentication/tokengetter.go create mode 100644 internal/authentication/tokengetter_test.go diff --git a/internal/authentication/tokengetter.go b/internal/authentication/tokengetter.go new file mode 100644 index 000000000..e7d0fead4 --- /dev/null +++ b/internal/authentication/tokengetter.go @@ -0,0 +1,95 @@ +package authentication + +import ( + "context" + "sync" + "time" + + authv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/utils/ptr" +) + +type TokenGetter struct { + client corev1.ServiceAccountsGetter + expirationSeconds int64 + tokens map[types.NamespacedName]*authv1.TokenRequestStatus + tokenLocks keyLock[types.NamespacedName] + mu sync.RWMutex +} + +func NewTokenGetter(client corev1.ServiceAccountsGetter, expirationSeconds int64) *TokenGetter { + return &TokenGetter{ + client: client, + expirationSeconds: expirationSeconds, + tokens: map[types.NamespacedName]*authv1.TokenRequestStatus{}, + tokenLocks: newKeyLock[types.NamespacedName](), + } +} + +type keyLock[K comparable] struct { + locks map[K]*sync.Mutex + mu sync.Mutex +} + +func newKeyLock[K comparable]() keyLock[K] { + return keyLock[K]{locks: map[K]*sync.Mutex{}} +} + +func (k *keyLock[K]) Lock(key K) { + k.getLock(key).Lock() +} + +func (k *keyLock[K]) Unlock(key K) { + k.getLock(key).Unlock() +} + +func (k *keyLock[K]) getLock(key K) *sync.Mutex { + k.mu.Lock() + defer k.mu.Unlock() + + lock, ok := k.locks[key] + if !ok { + lock = &sync.Mutex{} + k.locks[key] = lock + } + return lock +} + +func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) { + t.tokenLocks.Lock(key) + defer t.tokenLocks.Unlock(key) + + t.mu.RLock() + token, ok := t.tokens[key] + t.mu.RUnlock() + + expireTime := time.Time{} + if ok { + expireTime = token.ExpirationTimestamp.Time + } + + fiveMinutesAfterNow := metav1.Now().Add(5 * time.Minute) + if expireTime.Before(fiveMinutesAfterNow) { + var err error + token, err = t.getToken(ctx, key) + if err != nil { + return "", err + } + t.mu.Lock() + t.tokens[key] = token + t.mu.Unlock() + } + + return token.Token, nil +} + +func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (*authv1.TokenRequestStatus, error) { + req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx, key.Name, &authv1.TokenRequest{Spec: authv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](3600)}}, metav1.CreateOptions{}) + if err != nil { + return nil, err + } + return &req.Status, nil +} diff --git a/internal/authentication/tokengetter_test.go b/internal/authentication/tokengetter_test.go new file mode 100644 index 000000000..d8debbf7d --- /dev/null +++ b/internal/authentication/tokengetter_test.go @@ -0,0 +1,167 @@ +package authentication + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + authv1 "k8s.io/api/authentication/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/watch" + corev1 "k8s.io/client-go/applyconfigurations/core/v1" + "k8s.io/client-go/kubernetes/fake" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" + ctest "k8s.io/client-go/testing" +) + +func TestNewTokenGetter(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + fakeClient.PrependReactor("create", "serviceaccounts/token", func(action ctest.Action) (handled bool, ret runtime.Object, err error) { + act, ok := action.(ctest.CreateActionImpl) + if !ok { + return false, nil, nil + } + tokenRequest := act.GetObject().(*authv1.TokenRequest) + if act.Name == "test-service-account-1" { + tokenRequest.Status = authv1.TokenRequestStatus{ + Token: "test-token-1", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(5 * time.Minute)), + } + } + if act.Name == "test-service-account-2" { + tokenRequest.Status = authv1.TokenRequestStatus{ + Token: "test-token-2", + ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(1 * time.Second)), + } + } + + return true, tokenRequest, nil + }) + tg := NewTokenGetter(fakeClient.CoreV1(), int64(5*time.Minute)) + t.Log("Testing NewTokenGetter with fake client") + token, err := tg.Get(context.Background(), types.NamespacedName{ + Namespace: "test-namespace-1", + Name: "test-service-account-1", + }) + if err != nil { + t.Fatalf("failed to get token: %v", err) + return + } + t.Log("token:", token) + if token != "test-token-1" { + t.Errorf("token does not match") + } + t.Log("Testing getting token from cache") + token, err = tg.Get(context.Background(), types.NamespacedName{ + Namespace: "test-namespace-1", + Name: "test-service-account-1", + }) + if err != nil { + t.Fatalf("failed to get token from cache: %v", err) + return + } + t.Log("token:", token) + if token != "test-token-1" { + t.Errorf("token does not match") + } + t.Log("Testing getting short lived token from fake client") + token, err = tg.Get(context.Background(), types.NamespacedName{ + Namespace: "test-namespace-2", + Name: "test-service-account-2", + }) + if err != nil { + t.Fatalf("failed to get token: %v", err) + return + } + t.Log("token:", token) + if token != "test-token-2" { + t.Errorf("token does not match") + } + //wait for token to expired + time.Sleep(1 * time.Second) + t.Log("Testing getting expired token from cache") + token, err = tg.Get(context.Background(), types.NamespacedName{ + Namespace: "test-namespace-2", + Name: "test-service-account-2", + }) + if err != nil { + t.Fatalf("failed to refresh token: %v", err) + return + } + t.Log("token:", token) + if token != "test-token-2" { + t.Errorf("token does not match") + } +} + +type ServiceAccountsGetterImpl struct{} + +func (ServiceAccountsGetterImpl) ServiceAccounts(namespace string) corev1client.ServiceAccountInterface { + return &ServiceAccountTokenInterfaceImpl{} +} + +type ServiceAccountTokenInterfaceImpl struct{} + +func (i ServiceAccountTokenInterfaceImpl) Apply(ctx context.Context, serviceAccount *corev1.ServiceAccountApplyConfiguration, opts metav1.ApplyOptions) (result *v1.ServiceAccount, err error) { + panic("placeholder, not implemented") +} + +func (i ServiceAccountTokenInterfaceImpl) Create(ctx context.Context, serviceAccount *v1.ServiceAccount, opts metav1.CreateOptions) (*v1.ServiceAccount, error) { + panic("placeholder, not implemented") +} + +func (i ServiceAccountTokenInterfaceImpl) Update(ctx context.Context, serviceAccount *v1.ServiceAccount, opts metav1.UpdateOptions) (*v1.ServiceAccount, error) { + panic("placeholder, not implemented") + +} + +func (i ServiceAccountTokenInterfaceImpl) Delete(ctx context.Context, name string, opts metav1.DeleteOptions) error { + panic("placeholder, not implemented") + +} + +func (i ServiceAccountTokenInterfaceImpl) DeleteCollection(ctx context.Context, opts metav1.DeleteOptions, listOpts metav1.ListOptions) error { + panic("placeholder, not implemented") + +} + +func (i ServiceAccountTokenInterfaceImpl) Get(ctx context.Context, name string, opts metav1.GetOptions) (*v1.ServiceAccount, error) { + panic("placeholder, not implemented") + +} + +func (i ServiceAccountTokenInterfaceImpl) List(ctx context.Context, opts metav1.ListOptions) (*v1.ServiceAccountList, error) { + panic("placeholder, not implemented") + +} + +func (i ServiceAccountTokenInterfaceImpl) Watch(ctx context.Context, opts metav1.ListOptions) (watch.Interface, error) { + panic("placeholder, not implemented") + +} + +func (i ServiceAccountTokenInterfaceImpl) Patch(ctx context.Context, name string, pt types.PatchType, data []byte, opts metav1.PatchOptions, subresources ...string) (result *v1.ServiceAccount, err error) { + panic("placeholder, not implemented") + +} + +func (ServiceAccountTokenInterfaceImpl) CreateToken(ctx context.Context, serviceAccountName string, tokenRequest *authv1.TokenRequest, opts metav1.CreateOptions) (*authv1.TokenRequest, error) { + err := fmt.Errorf("error when fetching token") + return nil, err +} + +func TestTokenGetter_GetToken(t *testing.T) { + t.Log("Testing NewTokenGetter with test service account getter implementation") + saGetter := &ServiceAccountsGetterImpl{} + tg := NewTokenGetter(saGetter, int64(5*time.Minute)) + _, err := tg.Get(context.Background(), types.NamespacedName{ + Namespace: "test-namespace", + Name: "test-service-account-3", + }) + assert.EqualError(t, err, "error when fetching token") +}