Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions auth/instance_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,20 @@ type InstanceJWTClaims struct {
jwt.RegisteredClaims
}

func NewInstanceJWTToken(instance params.Instance, secret, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error) {
func NewInstanceTokenGetter(jwtSecret string) (InstanceTokenGetter, error) {
if jwtSecret == "" {
return nil, fmt.Errorf("jwt secret is required")
}
return &instanceToken{
jwtSecret: jwtSecret,
}, nil
}

type instanceToken struct {
jwtSecret string
}

func (i *instanceToken) NewInstanceJWTToken(instance params.Instance, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error) {
// Token expiration is equal to the bootstrap timeout set on the pool plus the polling
// interval garm uses to check for timed out runners. Runners that have not sent their info
// by the end of this interval are most likely failed and will be reaped by garm anyway.
Expand All @@ -67,7 +80,7 @@ func NewInstanceJWTToken(instance params.Instance, secret, entity string, poolTy
CreateAttempt: instance.CreateAttempt,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(secret))
tokenString, err := token.SignedString([]byte(i.jwtSecret))
if err != nil {
return "", errors.Wrap(err, "signing token")
}
Expand Down
10 changes: 9 additions & 1 deletion auth/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@

package auth

import "net/http"
import (
"net/http"

"github.com/cloudbase/garm/params"
)

// Middleware defines an authentication middleware
type Middleware interface {
Middleware(next http.Handler) http.Handler
}

type InstanceTokenGetter interface {
NewInstanceJWTToken(instance params.Instance, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error)
}
7 changes: 7 additions & 0 deletions cmd/garm-cli/cmd/github_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ func parseCredentialsAddParams() (ret params.CreateGithubCredentialsParams, err
func parseCredentialsUpdateParams() (params.UpdateGithubCredentialsParams, error) {
var updateParams params.UpdateGithubCredentialsParams

if credentialsAppInstallationID != 0 || credentialsAppID != 0 || credentialsPrivateKeyPath != "" {
updateParams.App = &params.GithubApp{}
}

if credentialsName != "" {
updateParams.Name = &credentialsName
}
Expand All @@ -312,6 +316,9 @@ func parseCredentialsUpdateParams() (params.UpdateGithubCredentialsParams, error
}

if credentialsOAuthToken != "" {
if updateParams.PAT == nil {
updateParams.PAT = &params.GithubPAT{}
}
updateParams.PAT.OAuth2Token = credentialsOAuthToken
}

Expand Down
8 changes: 2 additions & 6 deletions database/sql/enterprise.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e
}

func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error {
enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials")
enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return errors.Wrap(err, "fetching enterprise")
}
Expand Down Expand Up @@ -206,17 +206,13 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string,
return errors.Wrap(q.Error, "saving enterprise")
}

if creds.ID != 0 {
enterprise.Credentials = creds
}

return nil
})
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}

enterprise, err = s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials")
enterprise, err = s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}
Expand Down
8 changes: 2 additions & 6 deletions database/sql/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio
}

func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) (err error) {
org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials")
org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return errors.Wrap(err, "fetching org")
}
Expand Down Expand Up @@ -198,17 +198,13 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para
return errors.Wrap(q.Error, "saving org")
}

if creds.ID != 0 {
org.Credentials = creds
}

return nil
})
if err != nil {
return params.Organization{}, errors.Wrap(err, "saving org")
}

org, err = s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials")
org, err = s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return params.Organization{}, errors.Wrap(err, "updating enterprise")
}
Expand Down
7 changes: 2 additions & 5 deletions database/sql/repositories.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository,
}

func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err error) {
repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials")
repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return errors.Wrap(err, "fetching repo")
}
Expand Down Expand Up @@ -197,16 +197,13 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param
return errors.Wrap(q.Error, "saving repo")
}

if creds.ID != 0 {
repo.Credentials = creds
}
return nil
})
if err != nil {
return params.Repository{}, errors.Wrap(err, "saving repo")
}

repo, err = s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials")
repo, err = s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return params.Repository{}, errors.Wrap(err, "updating enterprise")
}
Expand Down
26 changes: 26 additions & 0 deletions database/watcher/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ func WithAny(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc {
}
}

// WithAll returns a filter function that returns true if all of the provided filters return true.
func WithAll(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
for _, filter := range filters {
if !filter(payload) {
return false
}
}
return true
}
}

// WithEntityTypeFilter returns a filter function that filters payloads by entity type.
// The filter function returns true if the payload's entity type matches the provided entity type.
func WithEntityTypeFilter(entityType dbCommon.DatabaseEntityType) dbCommon.PayloadFilterFunc {
Expand Down Expand Up @@ -139,3 +151,17 @@ func WithEntityJobFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFun
}
}
}

// WithGithubCredentialsFilter returns a filter function that filters payloads by Github credentials.
func WithGithubCredentialsFilter(creds params.GithubCredentials) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
if payload.EntityType != dbCommon.GithubCredentialsEntityType {
return false
}
credsPayload, ok := payload.Payload.(params.GithubCredentials)
if !ok {
return false
}
return credsPayload.ID == creds.ID
}
}
54 changes: 32 additions & 22 deletions params/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,13 @@ func (r Repository) GetEntity() (GithubEntity, error) {
return GithubEntity{}, fmt.Errorf("repository has no ID")
}
return GithubEntity{
ID: r.ID,
EntityType: GithubEntityTypeRepository,
Owner: r.Owner,
Name: r.Name,
ID: r.ID,
EntityType: GithubEntityTypeRepository,
Owner: r.Owner,
Name: r.Name,
PoolBalancerType: r.PoolBalancerType,
Credentials: r.Credentials,
WebhookSecret: r.WebhookSecret,
}, nil
}

Expand Down Expand Up @@ -470,10 +473,12 @@ func (o Organization) GetEntity() (GithubEntity, error) {
return GithubEntity{}, fmt.Errorf("organization has no ID")
}
return GithubEntity{
ID: o.ID,
EntityType: GithubEntityTypeOrganization,
Owner: o.Name,
WebhookSecret: o.WebhookSecret,
ID: o.ID,
EntityType: GithubEntityTypeOrganization,
Owner: o.Name,
WebhookSecret: o.WebhookSecret,
PoolBalancerType: o.PoolBalancerType,
Credentials: o.Credentials,
}, nil
}

Expand Down Expand Up @@ -517,10 +522,12 @@ func (e Enterprise) GetEntity() (GithubEntity, error) {
return GithubEntity{}, fmt.Errorf("enterprise has no ID")
}
return GithubEntity{
ID: e.ID,
EntityType: GithubEntityTypeEnterprise,
Owner: e.Name,
WebhookSecret: e.WebhookSecret,
ID: e.ID,
EntityType: GithubEntityTypeEnterprise,
Owner: e.Name,
WebhookSecret: e.WebhookSecret,
PoolBalancerType: e.PoolBalancerType,
Credentials: e.Credentials,
}, nil
}

Expand Down Expand Up @@ -685,11 +692,6 @@ type Provider struct {
// used by swagger client generated code
type Providers []Provider

type UpdatePoolStateParams struct {
WebhookSecret string
InternalConfig *Internal
}

type PoolManagerStatus struct {
IsRunning bool `json:"running"`
FailureReason string `json:"failure_reason,omitempty"`
Expand Down Expand Up @@ -788,15 +790,23 @@ type UpdateSystemInfoParams struct {
}

type GithubEntity struct {
Owner string `json:"owner"`
Name string `json:"name"`
ID string `json:"id"`
EntityType GithubEntityType `json:"entity_type"`
Credentials GithubCredentials `json:"credentials"`
Owner string `json:"owner"`
Name string `json:"name"`
ID string `json:"id"`
EntityType GithubEntityType `json:"entity_type"`
Credentials GithubCredentials `json:"credentials"`
PoolBalancerType PoolBalancerType `json:"pool_balancing_type"`

WebhookSecret string `json:"-"`
}

func (g GithubEntity) GetPoolBalancerType() PoolBalancerType {
if g.PoolBalancerType == "" {
return PoolBalancerTypeRoundRobin
}
return g.PoolBalancerType
}

func (g GithubEntity) LabelScope() string {
switch g.EntityType {
case GithubEntityTypeRepository:
Expand Down
18 changes: 0 additions & 18 deletions runner/common/mocks/PoolManager.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions runner/common/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ type PoolManager interface {
// a repo, org or enterprise, we determine the destination of that webhook, retrieve the pool manager
// for it and call this function with the WorkflowJob as a parameter.
HandleWorkflowJob(job params.WorkflowJob) error
// RefreshState allows us to update webhook secrets and configuration for a pool manager.
RefreshState(param params.UpdatePoolStateParams) error

// DeleteRunner will attempt to remove a runner from the pool. If forceRemove is true, any error
// received from the provider will be ignored and we will proceed to remove the runner from the database.
Expand Down
6 changes: 2 additions & 4 deletions runner/enterprises.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,9 @@ func (r *Runner) UpdateEnterprise(ctx context.Context, enterpriseID string, para
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}

// Use the admin context in the pool manager. Any access control is already done above when
// updating the store.
poolMgr, err := r.poolManagerCtrl.UpdateEnterprisePoolManager(r.ctx, enterprise)
poolMgr, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise)
if err != nil {
return params.Enterprise{}, fmt.Errorf("failed to update enterprise pool manager: %w", err)
return params.Enterprise{}, fmt.Errorf("failed to get enterprise pool manager: %w", err)
}

enterprise.PoolManagerStatus = poolMgr.Status()
Expand Down
14 changes: 5 additions & 9 deletions runner/enterprises_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ type EnterpriseTestFixtures struct {
CreateInstanceParams params.CreateInstanceParams
UpdateRepoParams params.UpdateEntityParams
UpdatePoolParams params.UpdatePoolParams
UpdatePoolStateParams params.UpdatePoolStateParams
ErrMock error
ProviderMock *runnerCommonMocks.Provider
PoolMgrMock *runnerCommonMocks.PoolManager
Expand Down Expand Up @@ -138,9 +137,6 @@ func (s *EnterpriseTestSuite) SetupTest() {
Image: "test-images-updated",
Flavor: "test-flavor-updated",
},
UpdatePoolStateParams: params.UpdatePoolStateParams{
WebhookSecret: "test-update-repo-webhook-secret",
},
ErrMock: fmt.Errorf("mock error"),
ProviderMock: providerMock,
PoolMgrMock: runnerCommonMocks.NewPoolManager(s.T()),
Expand Down Expand Up @@ -298,7 +294,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolMgrFailed() {
}

func (s *EnterpriseTestSuite) TestUpdateEnterprise() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil)

param := s.Fixtures.UpdateRepoParams
Expand Down Expand Up @@ -330,21 +326,21 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseInvalidCreds() {
}

func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMgrFailed() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)

_, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams)

s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Equal(fmt.Sprintf("failed to update enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
s.Require().Equal(fmt.Sprintf("failed to get enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
}

func (s *EnterpriseTestSuite) TestUpdateEnterpriseCreateEnterprisePoolMgrFailed() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)

_, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams)

s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Equal(fmt.Sprintf("failed to update enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
s.Require().Equal(fmt.Sprintf("failed to get enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
}

func (s *EnterpriseTestSuite) TestCreateEnterprisePool() {
Expand Down
Loading