diff --git a/gateway/middleware.go b/gateway/middleware.go index a56cdff9231e..ec8d6cf265e4 100644 --- a/gateway/middleware.go +++ b/gateway/middleware.go @@ -489,14 +489,21 @@ func (t *BaseMiddleware) ApplyPolicies(session *user.SessionState) error { } else { usePartitions := policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl || policy.Partitions.Complexity + // Ensure `rights` is filled with known APIs to ensure that + // a policy with acl rights gets honored even if not first. + for k := range policy.AccessRights { + if _, ok := rights[k]; ok { + continue + } + rights[k] = user.AccessDefinition{} + } + for k, v := range policy.AccessRights { - ar := v + ar := rights[k] if !usePartitions || policy.Partitions.Acl { didACL[k] = true - ar.AllowedURLs = mergeAllowedURLs(ar.AllowedURLs, v.AllowedURLs) - // Merge ACLs for the same API if r, ok := rights[k]; ok { // If GQL introspection is disabled, keep that configuration. @@ -507,18 +514,26 @@ func (t *BaseMiddleware) ApplyPolicies(session *user.SessionState) error { r.AllowedURLs = mergeAllowedURLs(r.AllowedURLs, v.AllowedURLs) - for _, t := range v.RestrictedTypes { - for ri, rt := range r.RestrictedTypes { - if t.Name == rt.Name { - r.RestrictedTypes[ri].Fields = intersection(rt.Fields, t.Fields) + if len(r.RestrictedTypes) == 0 { + r.RestrictedTypes = v.RestrictedTypes + } else { + for _, t := range v.RestrictedTypes { + for ri, rt := range r.RestrictedTypes { + if t.Name == rt.Name { + r.RestrictedTypes[ri].Fields = intersection(rt.Fields, t.Fields) + } } } } - for _, t := range v.AllowedTypes { - for ri, rt := range r.AllowedTypes { - if t.Name == rt.Name { - r.AllowedTypes[ri].Fields = intersection(rt.Fields, t.Fields) + if len(r.AllowedTypes) == 0 { + r.AllowedTypes = v.AllowedTypes + } else { + for _, t := range v.AllowedTypes { + for ri, rt := range r.AllowedTypes { + if t.Name == rt.Name { + r.AllowedTypes[ri].Fields = intersection(rt.Fields, t.Fields) + } } } } @@ -529,17 +544,21 @@ func (t *BaseMiddleware) ApplyPolicies(session *user.SessionState) error { } } - for _, far := range v.FieldAccessRights { - exists := false - for i, rfar := range r.FieldAccessRights { - if far.TypeName == rfar.TypeName && far.FieldName == rfar.FieldName { - exists = true - mergeFieldLimits(&r.FieldAccessRights[i].Limits, far.Limits) + if len(r.FieldAccessRights) == 0 { + r.FieldAccessRights = v.FieldAccessRights + } else { + for _, far := range v.FieldAccessRights { + exists := false + for i, rfar := range r.FieldAccessRights { + if far.TypeName == rfar.TypeName && far.FieldName == rfar.FieldName { + exists = true + mergeFieldLimits(&r.FieldAccessRights[i].Limits, far.Limits) + } } - } - if !exists { - r.FieldAccessRights = append(r.FieldAccessRights, far) + if !exists { + r.FieldAccessRights = append(r.FieldAccessRights, far) + } } } @@ -551,8 +570,8 @@ func (t *BaseMiddleware) ApplyPolicies(session *user.SessionState) error { if !usePartitions || policy.Partitions.Quota { didQuota[k] = true - if greaterThanInt64(policy.QuotaMax, ar.Limit.QuotaMax) { + if greaterThanInt64(policy.QuotaMax, ar.Limit.QuotaMax) { ar.Limit.QuotaMax = policy.QuotaMax if greaterThanInt64(policy.QuotaMax, session.QuotaMax) { session.QuotaMax = policy.QuotaMax diff --git a/internal/policy/apply.go b/internal/policy/apply.go deleted file mode 100644 index aa811f05ba8a..000000000000 --- a/internal/policy/apply.go +++ /dev/null @@ -1,620 +0,0 @@ -package policy - -import ( - "errors" - "fmt" - - "github.com/sirupsen/logrus" - - "github.com/TykTechnologies/tyk/user" -) - -var ( - // ErrMixedPartitionAndPerAPIPolicies is the error to return when a mix of per api and partitioned policies are to be applied in a session. - ErrMixedPartitionAndPerAPIPolicies = errors.New("cannot apply multiple policies when some have per_api set and some are partitioned") -) - -// Repository is a storage encapsulating policy retrieval. -// Gateway implements this object to decouple this package. -type Repository interface { - PolicyCount() int - PolicyIDs() []string - PolicyByID(string) (user.Policy, bool) -} - -// Service represents the policy service for gateway. -type Service struct { - storage Repository - logger *logrus.Logger - - // used for validation if not empty - orgID *string -} - -// New creates a new policy.Service object. -func New(orgID *string, storage Repository, logger *logrus.Logger) *Service { - return &Service{ - orgID: orgID, - storage: storage, - logger: logger, - } -} - -// ClearSession clears the quota, rate limit and complexity values so that partitioned policies can apply their values. -// Otherwise, if the session has already a higher value, an applied policy will not win, and its values will be ignored. -func (t *Service) ClearSession(session *user.SessionState) error { - policies := session.PolicyIDs() - - for _, polID := range policies { - policy, ok := t.storage.PolicyByID(polID) - if !ok { - return fmt.Errorf("policy not found: %s", polID) - } - - all := !(policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl || policy.Partitions.Complexity) - - if policy.Partitions.Quota || all { - session.QuotaMax = 0 - session.QuotaRemaining = 0 - } - - if policy.Partitions.RateLimit || all { - session.Rate = 0 - session.Per = 0 - session.Smoothing = nil - session.ThrottleRetryLimit = 0 - session.ThrottleInterval = 0 - } - - if policy.Partitions.Complexity || all { - session.MaxQueryDepth = 0 - } - } - - return nil -} - -type applyStatus struct { - didQuota map[string]bool - didRateLimit map[string]bool - didAcl map[string]bool - didComplexity map[string]bool - didPerAPI bool - didPartition bool -} - -// Apply will check if any policies are loaded. If any are, it -// will overwrite the session state to use the policy values. -func (t *Service) Apply(session *user.SessionState) error { - rights := make(map[string]user.AccessDefinition) - tags := make(map[string]bool) - if session.MetaData == nil { - session.MetaData = make(map[string]interface{}) - } - - if err := t.ClearSession(session); err != nil { - t.logger.WithError(err).Warn("error clearing session") - } - - applyState := applyStatus{ - didQuota: make(map[string]bool), - didRateLimit: make(map[string]bool), - didAcl: make(map[string]bool), - didComplexity: make(map[string]bool), - } - - var ( - err error - policyIDs []string - ) - - storage := t.storage - - customPolicies, err := session.GetCustomPolicies() - if err != nil { - policyIDs = session.PolicyIDs() - } else { - storage = NewStore(customPolicies) - policyIDs = storage.PolicyIDs() - } - - for _, polID := range policyIDs { - policy, ok := storage.PolicyByID(polID) - if !ok { - err := fmt.Errorf("policy not found: %q", polID) - t.Logger().Error(err) - if len(policyIDs) > 1 { - continue - } - - return err - } - // Check ownership, policy org owner must be the same as API, - // otherwise you could overwrite a session key with a policy from a different org! - if t.orgID != nil && policy.OrgID != *t.orgID { - err := errors.New("attempting to apply policy from different organisation to key, skipping") - t.Logger().Error(err) - return err - } - - if policy.Partitions.PerAPI && policy.Partitions.Enabled() { - err := fmt.Errorf("cannot apply policy %s which has per_api and any of partitions set", policy.ID) - t.logger.Error(err) - return err - } - - if policy.Partitions.PerAPI { - if err := t.applyPerAPI(policy, session, rights, &applyState); err != nil { - return err - } - } else { - if err := t.applyPartitions(policy, session, rights, &applyState); err != nil { - return err - } - } - - session.IsInactive = session.IsInactive || policy.IsInactive - - for _, tag := range policy.Tags { - tags[tag] = true - } - - for k, v := range policy.MetaData { - session.MetaData[k] = v - } - - if policy.LastUpdated > session.LastUpdated { - session.LastUpdated = policy.LastUpdated - } - } - - for _, tag := range session.Tags { - tags[tag] = true - } - - // set tags - session.Tags = []string{} - for tag := range tags { - session.Tags = appendIfMissing(session.Tags, tag) - } - - if len(policyIDs) == 0 { - for apiID, accessRight := range session.AccessRights { - // check if the api in the session has per api limit - if !accessRight.Limit.IsEmpty() { - accessRight.AllowanceScope = apiID - session.AccessRights[apiID] = accessRight - } - } - } - - distinctACL := make(map[string]bool) - - for _, v := range rights { - if v.Limit.SetBy != "" { - distinctACL[v.Limit.SetBy] = true - } - } - - // If some APIs had only ACL partitions, inherit rest from session level - for k, v := range rights { - if !applyState.didAcl[k] { - delete(rights, k) - continue - } - - if !applyState.didRateLimit[k] { - v.Limit.Rate = session.Rate - v.Limit.Per = session.Per - v.Limit.Smoothing = session.Smoothing - v.Limit.ThrottleInterval = session.ThrottleInterval - v.Limit.ThrottleRetryLimit = session.ThrottleRetryLimit - v.Endpoints = nil - } - - if !applyState.didComplexity[k] { - v.Limit.MaxQueryDepth = session.MaxQueryDepth - } - - if !applyState.didQuota[k] { - v.Limit.QuotaMax = session.QuotaMax - v.Limit.QuotaRenewalRate = session.QuotaRenewalRate - v.Limit.QuotaRenews = session.QuotaRenews - } - - // If multime ACL - if len(distinctACL) > 1 { - if v.AllowanceScope == "" && v.Limit.SetBy != "" { - v.AllowanceScope = v.Limit.SetBy - } - } - - v.Limit.SetBy = "" - - rights[k] = v - } - - // If we have policies defining rules for one single API, update session root vars (legacy) - t.updateSessionRootVars(session, rights, applyState) - - // Override session ACL if at least one policy define it - if len(applyState.didAcl) > 0 { - session.AccessRights = rights - } - - return nil -} - -// Logger implements a typical logger signature with service context. -func (t *Service) Logger() *logrus.Entry { - return logrus.NewEntry(t.logger) -} - -// ApplyRateLimits will write policy limits to session and apiLimits. -// The limits get written if either are empty. -// The limits get written if filled and policyLimits allows a higher request rate. -func (t *Service) ApplyRateLimits(session *user.SessionState, policy user.Policy, apiLimits *user.APILimit) { - policyLimits := policy.APILimit() - if t.emptyRateLimit(policyLimits) { - return - } - - // duration is time between requests, e.g.: - // - // apiLimits: 500ms for 2 requests / second - // policyLimits: 100ms for 10 requests / second - // - // if apiLimits > policyLimits (500ms > 100ms) then - // we apply the higher rate from the policy. - // - // the policy-defined rate limits are enforced as - // a minimum possible api rate limit setting, - // raising apiLimits. - - if t.emptyRateLimit(*apiLimits) || apiLimits.Duration() > policyLimits.Duration() { - apiLimits.Rate = policyLimits.Rate - apiLimits.Per = policyLimits.Per - apiLimits.Smoothing = policyLimits.Smoothing - } - - // sessionLimits, similar to apiLimits, get policy - // rate applied if the policy allows more requests. - sessionLimits := session.APILimit() - if t.emptyRateLimit(sessionLimits) || sessionLimits.Duration() > policyLimits.Duration() { - session.Rate = policyLimits.Rate - session.Per = policyLimits.Per - session.Smoothing = policyLimits.Smoothing - } -} - -func (t *Service) emptyRateLimit(m user.APILimit) bool { - return m.Rate == 0 || m.Per == 0 -} - -func (t *Service) applyPerAPI(policy user.Policy, session *user.SessionState, rights map[string]user.AccessDefinition, - applyState *applyStatus) error { - - if applyState.didPartition { - t.logger.Error(ErrMixedPartitionAndPerAPIPolicies) - return ErrMixedPartitionAndPerAPIPolicies - } - - for apiID, accessRights := range policy.AccessRights { - idForScope := apiID - // check if we don't have limit on API level specified when policy was created - if accessRights.Limit.IsEmpty() { - // limit was not specified on API level so we will populate it from policy - idForScope = policy.ID - accessRights.Limit = policy.APILimit() - } - accessRights.AllowanceScope = idForScope - accessRights.Limit.SetBy = idForScope - - // respect current quota renews (on API limit level) - if r, ok := session.AccessRights[apiID]; ok && !r.Limit.IsEmpty() { - accessRights.Limit.QuotaRenews = r.Limit.QuotaRenews - } - - if r, ok := session.AccessRights[apiID]; ok { - // If GQL introspection is disabled, keep that configuration. - if r.DisableIntrospection { - accessRights.DisableIntrospection = r.DisableIntrospection - } - } - - if currAD, ok := rights[apiID]; ok { - accessRights = t.applyAPILevelLimits(accessRights, currAD) - } - - // overwrite session access right for this API - rights[apiID] = accessRights - - // identify that limit for that API is set (to allow set it only once) - applyState.didAcl[apiID] = true - applyState.didQuota[apiID] = true - applyState.didRateLimit[apiID] = true - applyState.didComplexity[apiID] = true - } - - if len(policy.AccessRights) > 0 { - applyState.didPerAPI = true - } - - return nil -} - -func (t *Service) applyPartitions(policy user.Policy, session *user.SessionState, rights map[string]user.AccessDefinition, - applyState *applyStatus) error { - - usePartitions := policy.Partitions.Enabled() - - if usePartitions && applyState.didPerAPI { - t.logger.Error(ErrMixedPartitionAndPerAPIPolicies) - return ErrMixedPartitionAndPerAPIPolicies - } - - // Ensure `rights` is filled with known APIs to ensure that - // a policy with acl rights gets honored even if not first. - for k := range policy.AccessRights { - if _, ok := rights[k]; ok { - continue - } - rights[k] = user.AccessDefinition{} - } - - for k, v := range policy.AccessRights { - // Use rights[k], which holds previously seen/merged policy access rights. - ar := rights[k] - - if !usePartitions || policy.Partitions.Acl { - applyState.didAcl[k] = true - - // Merge ACLs for the same API - if r, ok := rights[k]; ok { - // If GQL introspection is disabled, keep that configuration. - if v.DisableIntrospection { - r.DisableIntrospection = v.DisableIntrospection - } - r.Versions = appendIfMissing(rights[k].Versions, v.Versions...) - - r.AllowedURLs = MergeAllowedURLs(r.AllowedURLs, v.AllowedURLs) - - if len(r.RestrictedTypes) == 0 { - r.RestrictedTypes = v.RestrictedTypes - } else { - for _, t := range v.RestrictedTypes { - for ri, rt := range r.RestrictedTypes { - if t.Name == rt.Name { - r.RestrictedTypes[ri].Fields = intersection(rt.Fields, t.Fields) - } - } - } - } - - if len(r.AllowedTypes) == 0 { - r.AllowedTypes = v.AllowedTypes - } else { - for _, t := range v.AllowedTypes { - for ri, rt := range r.AllowedTypes { - if t.Name == rt.Name { - r.AllowedTypes[ri].Fields = intersection(rt.Fields, t.Fields) - } - } - } - } - - mergeFieldLimits := func(res *user.FieldLimits, new user.FieldLimits) { - if greaterThanInt(new.MaxQueryDepth, res.MaxQueryDepth) { - res.MaxQueryDepth = new.MaxQueryDepth - } - } - - if len(r.FieldAccessRights) == 0 { - r.FieldAccessRights = v.FieldAccessRights - } else { - for _, far := range v.FieldAccessRights { - exists := false - for i, rfar := range r.FieldAccessRights { - if far.TypeName == rfar.TypeName && far.FieldName == rfar.FieldName { - exists = true - mergeFieldLimits(&r.FieldAccessRights[i].Limits, far.Limits) - } - } - - if !exists { - r.FieldAccessRights = append(r.FieldAccessRights, far) - } - } - } - - ar = r - } - - ar.Limit.SetBy = policy.ID - } - - if !usePartitions || policy.Partitions.Quota { - applyState.didQuota[k] = true - - if greaterThanInt64(policy.QuotaMax, ar.Limit.QuotaMax) { - ar.Limit.QuotaMax = policy.QuotaMax - if greaterThanInt64(policy.QuotaMax, session.QuotaMax) { - session.QuotaMax = policy.QuotaMax - } - } - - if policy.QuotaRenewalRate > ar.Limit.QuotaRenewalRate { - ar.Limit.QuotaRenewalRate = policy.QuotaRenewalRate - if policy.QuotaRenewalRate > session.QuotaRenewalRate { - session.QuotaRenewalRate = policy.QuotaRenewalRate - } - } - } - - if !usePartitions || policy.Partitions.RateLimit { - applyState.didRateLimit[k] = true - - t.ApplyRateLimits(session, policy, &ar.Limit) - - if rightsAR, ok := rights[k]; ok { - ar.Endpoints = t.ApplyEndpointLevelLimits(v.Endpoints, rightsAR.Endpoints) - } - - if policy.ThrottleRetryLimit > ar.Limit.ThrottleRetryLimit { - ar.Limit.ThrottleRetryLimit = policy.ThrottleRetryLimit - if policy.ThrottleRetryLimit > session.ThrottleRetryLimit { - session.ThrottleRetryLimit = policy.ThrottleRetryLimit - } - } - - if policy.ThrottleInterval > ar.Limit.ThrottleInterval { - ar.Limit.ThrottleInterval = policy.ThrottleInterval - if policy.ThrottleInterval > session.ThrottleInterval { - session.ThrottleInterval = policy.ThrottleInterval - } - } - } - - if !usePartitions || policy.Partitions.Complexity { - applyState.didComplexity[k] = true - - if greaterThanInt(policy.MaxQueryDepth, ar.Limit.MaxQueryDepth) { - ar.Limit.MaxQueryDepth = policy.MaxQueryDepth - if greaterThanInt(policy.MaxQueryDepth, session.MaxQueryDepth) { - session.MaxQueryDepth = policy.MaxQueryDepth - } - } - } - - // Respect existing QuotaRenews - if r, ok := session.AccessRights[k]; ok && !r.Limit.IsEmpty() { - ar.Limit.QuotaRenews = r.Limit.QuotaRenews - } - - rights[k] = ar - } - - // Master policy case - if len(policy.AccessRights) == 0 { - if !usePartitions || policy.Partitions.RateLimit { - session.Rate = policy.Rate - session.Per = policy.Per - session.Smoothing = policy.Smoothing - session.ThrottleInterval = policy.ThrottleInterval - session.ThrottleRetryLimit = policy.ThrottleRetryLimit - } - - if !usePartitions || policy.Partitions.Complexity { - session.MaxQueryDepth = policy.MaxQueryDepth - } - - if !usePartitions || policy.Partitions.Quota { - session.QuotaMax = policy.QuotaMax - session.QuotaRenewalRate = policy.QuotaRenewalRate - } - } - - if !session.HMACEnabled { - session.HMACEnabled = policy.HMACEnabled - } - - if !session.EnableHTTPSignatureValidation { - session.EnableHTTPSignatureValidation = policy.EnableHTTPSignatureValidation - } - - applyState.didPartition = usePartitions - - return nil -} - -func (t *Service) updateSessionRootVars(session *user.SessionState, rights map[string]user.AccessDefinition, applyState applyStatus) { - if len(applyState.didQuota) == 1 && len(applyState.didRateLimit) == 1 && len(applyState.didComplexity) == 1 { - for _, v := range rights { - if len(applyState.didRateLimit) == 1 { - session.Rate = v.Limit.Rate - session.Per = v.Limit.Per - session.Smoothing = v.Limit.Smoothing - } - - if len(applyState.didQuota) == 1 { - session.QuotaMax = v.Limit.QuotaMax - session.QuotaRenews = v.Limit.QuotaRenews - session.QuotaRenewalRate = v.Limit.QuotaRenewalRate - } - - if len(applyState.didComplexity) == 1 { - session.MaxQueryDepth = v.Limit.MaxQueryDepth - } - } - } -} - -func (t *Service) applyAPILevelLimits(policyAD user.AccessDefinition, currAD user.AccessDefinition) user.AccessDefinition { - var updated bool - if policyAD.Limit.Duration() > currAD.Limit.Duration() { - policyAD.Limit.Per = currAD.Limit.Per - policyAD.Limit.Rate = currAD.Limit.Rate - policyAD.Limit.Smoothing = currAD.Limit.Smoothing - updated = true - } - - if currAD.Limit.QuotaMax != policyAD.Limit.QuotaMax && greaterThanInt64(currAD.Limit.QuotaMax, policyAD.Limit.QuotaMax) { - policyAD.Limit.QuotaMax = currAD.Limit.QuotaMax - updated = true - } - - if greaterThanInt64(currAD.Limit.QuotaRenewalRate, policyAD.Limit.QuotaRenewalRate) { - policyAD.Limit.QuotaRenewalRate = currAD.Limit.QuotaRenewalRate - } - - if policyAD.Limit.QuotaMax == -1 { - policyAD.Limit.QuotaRenewalRate = 0 - } - - if updated { - policyAD.Limit.SetBy = currAD.Limit.SetBy - policyAD.AllowanceScope = currAD.AllowanceScope - } - - policyAD.Endpoints = t.ApplyEndpointLevelLimits(policyAD.Endpoints, currAD.Endpoints) - - return policyAD -} - -// ApplyEndpointLevelLimits combines policyEndpoints and currEndpoints and returns the combined value. -// The returned endpoints would have the highest request rate from policyEndpoints and currEndpoints. -func (t *Service) ApplyEndpointLevelLimits(policyEndpoints user.Endpoints, currEndpoints user.Endpoints) user.Endpoints { - currEPMap := currEndpoints.Map() - if len(currEPMap) == 0 { - return policyEndpoints - } - - result := policyEndpoints.Map() - if len(result) == 0 { - return currEPMap.Endpoints() - } - - for currEP, currRL := range currEPMap { - policyRL, ok := result[currEP] - if !ok { - // merge missing endpoints - result[currEP] = currRL - continue - } - - policyDur, currDur := policyRL.Duration(), currRL.Duration() - if policyDur > currDur { - result[currEP] = currRL - continue - } - - // when duration is equal, use higher rate and per - // eg. when 10 per 60 and 5 per 30 comes in - // Duration would be 6s each, in such a case higher rate of 10 per 60 would be picked up. - if policyDur == currDur && currRL.Rate > policyRL.Rate { - result[currEP] = currRL - } - } - - return result.Endpoints() -} diff --git a/internal/policy/apply_test.go b/internal/policy/apply_test.go deleted file mode 100644 index e22da5c5b052..000000000000 --- a/internal/policy/apply_test.go +++ /dev/null @@ -1,1312 +0,0 @@ -package policy_test - -import ( - "embed" - "encoding/json" - "fmt" - "slices" - "sort" - "testing" - - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - - "github.com/TykTechnologies/graphql-go-tools/pkg/graphql" - "github.com/TykTechnologies/tyk/internal/policy" - "github.com/TykTechnologies/tyk/user" -) - -//go:embed testdata/*.json -var testDataFS embed.FS - -func TestApplyRateLimits_PolicyLimits(t *testing.T) { - t.Run("policy limits unset", func(t *testing.T) { - svc := &policy.Service{} - - session := &user.SessionState{ - Rate: 5, - Per: 10, - } - apiLimits := user.APILimit{ - RateLimit: user.RateLimit{ - Rate: 10, - Per: 10, - }, - } - policy := user.Policy{} - - svc.ApplyRateLimits(session, policy, &apiLimits) - - assert.Equal(t, 10, int(apiLimits.Rate)) - assert.Equal(t, 5, int(session.Rate)) - }) - - t.Run("policy limits apply all", func(t *testing.T) { - svc := &policy.Service{} - - session := &user.SessionState{ - Rate: 5, - Per: 10, - } - apiLimits := user.APILimit{ - RateLimit: user.RateLimit{ - Rate: 5, - Per: 10, - }, - } - policy := user.Policy{ - Rate: 10, - Per: 10, - } - - svc.ApplyRateLimits(session, policy, &apiLimits) - - assert.Equal(t, 10, int(apiLimits.Rate)) - assert.Equal(t, 10, int(session.Rate)) - }) - - // As the policy defined a higher rate than apiLimits, - // changes are applied to api limits, but skipped on - // the session as the session has a higher allowance. - t.Run("policy limits apply per-api", func(t *testing.T) { - svc := &policy.Service{} - - session := &user.SessionState{ - Rate: 15, - Per: 10, - } - apiLimits := user.APILimit{ - RateLimit: user.RateLimit{ - Rate: 5, - Per: 10, - }, - } - policy := user.Policy{ - Rate: 10, - Per: 10, - } - - svc.ApplyRateLimits(session, policy, &apiLimits) - - assert.Equal(t, 10, int(apiLimits.Rate)) - assert.Equal(t, 15, int(session.Rate)) - }) - - // As the policy defined a lower rate than apiLimits, - // no changes to api limits are applied. - t.Run("policy limits skip", func(t *testing.T) { - svc := &policy.Service{} - - session := &user.SessionState{ - Rate: 5, - Per: 10, - } - apiLimits := user.APILimit{ - RateLimit: user.RateLimit{Rate: 15, - Per: 10, - }, - } - policy := user.Policy{ - Rate: 10, - Per: 10, - } - - svc.ApplyRateLimits(session, policy, &apiLimits) - - assert.Equal(t, 15, int(apiLimits.Rate)) - assert.Equal(t, 10, int(session.Rate)) - }) -} - -func TestApplyRateLimits_FromCustomPolicies(t *testing.T) { - svc := &policy.Service{} - - session := &user.SessionState{} - session.SetCustomPolicies([]user.Policy{ - { - ID: "pol1", - Partitions: user.PolicyPartitions{RateLimit: true}, - Rate: 8, - Per: 1, - AccessRights: map[string]user.AccessDefinition{"a": {}}, - }, - { - ID: "pol2", - Partitions: user.PolicyPartitions{RateLimit: true}, - Rate: 10, - Per: 1, - AccessRights: map[string]user.AccessDefinition{"a": {}}, - }, - }) - - assert.NoError(t, svc.Apply(session)) - assert.Equal(t, 10, int(session.Rate)) -} - -func TestApplyACL_FromCustomPolicies(t *testing.T) { - svc := &policy.Service{} - - pol1 := user.Policy{ - ID: "pol1", - Partitions: user.PolicyPartitions{RateLimit: true}, - Rate: 8, - Per: 1, - AccessRights: map[string]user.AccessDefinition{ - "a": {}, - }, - } - - pol2 := user.Policy{ - ID: "pol2", - Partitions: user.PolicyPartitions{Acl: true}, - Rate: 10, - Per: 1, - AccessRights: map[string]user.AccessDefinition{ - "a": { - AllowedURLs: []user.AccessSpec{ - {URL: "/user", Methods: []string{"GET", "POST"}}, - {URL: "/companies", Methods: []string{"GET", "POST"}}, - }, - }, - }, - } - - t.Run("RateLimit first", func(t *testing.T) { - session := &user.SessionState{} - session.SetCustomPolicies([]user.Policy{pol1, pol2}) - - assert.NoError(t, svc.Apply(session)) - assert.Equal(t, pol2.AccessRights["a"].AllowedURLs, session.AccessRights["a"].AllowedURLs) - assert.Equal(t, 8, int(session.Rate)) - }) - - t.Run("ACL first", func(t *testing.T) { - session := &user.SessionState{} - session.SetCustomPolicies([]user.Policy{pol2, pol1}) - - assert.NoError(t, svc.Apply(session)) - assert.Equal(t, pol2.AccessRights["a"].AllowedURLs, session.AccessRights["a"].AllowedURLs) - assert.Equal(t, 8, int(session.Rate)) - }) -} - -func TestApplyEndpointLevelLimits(t *testing.T) { - f, err := testDataFS.ReadFile("testdata/apply_endpoint_rl.json") - assert.NoError(t, err) - - var testCases []struct { - Name string `json:"name"` - PolicyEP user.Endpoints `json:"policyEP"` - CurrEP user.Endpoints `json:"currEP"` - Expected user.Endpoints `json:"expected"` - } - err = json.Unmarshal(f, &testCases) - assert.NoError(t, err) - - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - service := policy.Service{} - result := service.ApplyEndpointLevelLimits(tc.PolicyEP, tc.CurrEP) - assert.ElementsMatch(t, tc.Expected, result) - }) - } - -} - -type testApplyPoliciesData struct { - name string - policies []string - errMatch string // substring - sessMatch func(*testing.T, *user.SessionState) // ignored if nil - session *user.SessionState - // reverseOrder executes the tests in reversed order of policies, - // in addition to the order specified in policies - reverseOrder bool -} - -func testPrepareApplyPolicies(tb testing.TB) (*policy.Service, []testApplyPoliciesData) { - tb.Helper() - - f, err := testDataFS.ReadFile("testdata/policies.json") - assert.NoError(tb, err) - - var policies = make(map[string]user.Policy) - err = json.Unmarshal(f, &policies) - assert.NoError(tb, err) - - var repoPols = make(map[string]user.Policy) - err = json.Unmarshal(f, &repoPols) - assert.NoError(tb, err) - - store := policy.NewStoreMap(repoPols) - orgID := "" - service := policy.New(&orgID, store, logrus.StandardLogger()) - - // splitting tests for readability - var tests []testApplyPoliciesData - - nilSessionTCs := []testApplyPoliciesData{ - { - "Empty", nil, - "", nil, nil, false, - }, - { - "Single", []string{"nonpart1"}, - "", nil, nil, false, - }, - { - "Missing", []string{"nonexistent"}, - "not found", nil, nil, false, - }, - { - "DiffOrg", []string{"difforg"}, - "different org", nil, nil, false, - }, - } - tests = append(tests, nilSessionTCs...) - - nonPartitionedTCs := []testApplyPoliciesData{ - { - name: "MultiNonPart", - policies: []string{"nonpart1", "nonpart2", "nonexistent"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - - want := map[string]user.AccessDefinition{ - "a": { - Limit: user.APILimit{}, - AllowanceScope: "p1", - }, - "b": { - Limit: user.APILimit{}, - AllowanceScope: "p2", - }, - } - - assert.Equal(t, want, s.AccessRights) - }, - }, - { - name: "MultiACLPolicy", - policies: []string{"nonpart3"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - - want := map[string]user.AccessDefinition{ - "a": { - Limit: user.APILimit{}, - }, - "b": { - Limit: user.APILimit{}, - }, - } - - assert.Equal(t, want, s.AccessRights) - }, - }, - } - tests = append(tests, nonPartitionedTCs...) - - quotaPartitionTCs := []testApplyPoliciesData{ - { - "QuotaPart with unlimited", []string{"unlimited-quota"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.QuotaMax != -1 { - t.Fatalf("want unlimited quota to be -1") - } - }, nil, false, - }, - { - "QuotaPart", []string{"quota1"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.QuotaMax != 2 { - t.Fatalf("want QuotaMax to be 2") - } - }, nil, false, - }, - { - "QuotaParts", []string{"quota1", "quota2"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.QuotaMax != 3 { - t.Fatalf("Should pick bigger value") - } - }, nil, false, - }, - { - "QuotaParts with acl", []string{"quota5", "quota4"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.Equal(t, int64(4), s.QuotaMax) - }, nil, false, - }, - { - "QuotaPart with access rights", []string{"quota3"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.QuotaMax != 3 { - t.Fatalf("quota should be the same as policy quota") - } - }, nil, false, - }, - { - "QuotaPart with access rights in multi-policy", []string{"quota4", "nonpart1"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.QuotaMax != 3 { - t.Fatalf("quota should be the same as policy quota") - } - - // Don't apply api 'b' coming from quota4 policy - want := map[string]user.AccessDefinition{"a": {Limit: user.APILimit{}}} - assert.Equal(t, want, s.AccessRights) - }, nil, false, - }, - } - tests = append(tests, quotaPartitionTCs...) - - rateLimitPartitionTCs := []testApplyPoliciesData{ - { - "RatePart with unlimited", []string{"unlimited-rate"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.True(t, s.Rate <= 0, "want unlimited rate to be <= 0") - }, nil, false, - }, - { - "RatePart", []string{"rate1"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.Rate != 3 { - t.Fatalf("want Rate to be 3") - } - }, nil, false, - }, - { - "RateParts", []string{"rate1", "rate2"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.Rate != 4 { - t.Fatalf("Should pick bigger value") - } - }, nil, false, - }, - { - "RateParts with acl", []string{"rate5", "rate4"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.Equal(t, float64(10), s.Rate) - }, nil, false, - }, - { - "RateParts with acl respected by session", []string{"rate4", "rate5"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.Equal(t, float64(10), s.Rate) - }, &user.SessionState{Rate: 20}, false, - }, - { - "Rate with no partition respected by session", []string{"rate-no-partition"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.Equal(t, float64(12), s.Rate) - }, &user.SessionState{Rate: 20}, false, - }, - } - tests = append(tests, rateLimitPartitionTCs...) - - complexityPartitionTCs := []testApplyPoliciesData{ - { - "ComplexityPart with unlimited", []string{"unlimitedComplexity"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.MaxQueryDepth != -1 { - t.Fatalf("unlimitied query depth should be -1") - } - }, nil, false, - }, - { - "ComplexityPart", []string{"complexity1"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.MaxQueryDepth != 2 { - t.Fatalf("want MaxQueryDepth to be 2") - } - }, nil, false, - }, - { - "ComplexityParts", []string{"complexity1", "complexity2"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.MaxQueryDepth != 3 { - t.Fatalf("Should pick bigger value") - } - }, nil, false, - }, - } - tests = append(tests, complexityPartitionTCs...) - - aclPartitionTCs := []testApplyPoliciesData{ - { - "AclPart", []string{"acl1"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - want := map[string]user.AccessDefinition{"a": {Limit: user.APILimit{}}} - - assert.Equal(t, want, s.AccessRights) - }, nil, false, - }, - { - "AclPart", []string{"acl1", "acl2"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - want := map[string]user.AccessDefinition{"a": {Limit: user.APILimit{}}, "b": {Limit: user.APILimit{}}} - assert.Equal(t, want, s.AccessRights) - }, nil, false, - }, - { - "Acl for a and rate for a,b", []string{"acl1", "rate-for-a-b"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - want := map[string]user.AccessDefinition{"a": {Limit: user.APILimit{RateLimit: user.RateLimit{Rate: 4, Per: 1}}}} - assert.Equal(t, want, s.AccessRights) - }, nil, false, - }, - { - "Acl for a,b and individual rate for a,b", []string{"acl-for-a-b", "rate-for-a", "rate-for-b"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - want := map[string]user.AccessDefinition{ - "a": {Limit: user.APILimit{RateLimit: user.RateLimit{Rate: 4, Per: 1}}}, - "b": {Limit: user.APILimit{RateLimit: user.RateLimit{Rate: 2, Per: 1}}}, - } - assert.Equal(t, want, s.AccessRights) - }, nil, false, - }, - { - "RightsUpdate", []string{"acl-for-a-b"}, - "", func(t *testing.T, ses *user.SessionState) { - t.Helper() - expectedAccessRights := map[string]user.AccessDefinition{"a": {Limit: user.APILimit{}}, "b": {Limit: user.APILimit{}}} - assert.Equal(t, expectedAccessRights, ses.AccessRights) - }, &user.SessionState{ - AccessRights: map[string]user.AccessDefinition{ - "c": {Limit: user.APILimit{}}, - }, - }, false, - }, - } - tests = append(tests, aclPartitionTCs...) - - inactiveTCs := []testApplyPoliciesData{ - { - "InactiveMergeOne", []string{"tags1", "inactive1"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if !s.IsInactive { - t.Fatalf("want IsInactive to be true") - } - }, nil, false, - }, - { - "InactiveMergeAll", []string{"inactive1", "inactive2"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if !s.IsInactive { - t.Fatalf("want IsInactive to be true") - } - }, nil, false, - }, - { - "InactiveWithSession", []string{"tags1", "tags2"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if !s.IsInactive { - t.Fatalf("want IsInactive to be true") - } - }, &user.SessionState{ - IsInactive: true, - }, false, - }, - } - tests = append(tests, inactiveTCs...) - - perAPITCs := []testApplyPoliciesData{ - { - name: "Per API is set with other partitions to true", - policies: []string{"per_api_and_partitions"}, - errMatch: "cannot apply policy per_api_and_partitions which has per_api and any of partitions set", - }, - { - name: "Per API is set to true with some partitions set to true", - policies: []string{"per_api_and_some_partitions"}, - errMatch: "cannot apply policy per_api_and_some_partitions which has per_api and any of partitions set", - }, - { - name: "Per API is set to true with no other partitions set to true", - policies: []string{"per_api_and_no_other_partitions"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - - want := map[string]user.AccessDefinition{ - "c": { - Limit: user.APILimit{ - RateLimit: user.RateLimit{ - Rate: 2000, - Per: 60, - }, - QuotaMax: -1, - }, - AllowanceScope: "c", - }, - "d": { - Limit: user.APILimit{ - RateLimit: user.RateLimit{ - Rate: 20, - Per: 1, - }, - QuotaMax: 1000, - QuotaRenewalRate: 3600, - }, - AllowanceScope: "d", - }, - } - assert.Equal(t, want, s.AccessRights) - }, - }, - { - name: "several policies with Per API set to true specifying limit for the same API", - policies: []string{"per_api_and_no_other_partitions", "per_api_with_api_d"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - want := map[string]user.AccessDefinition{ - "c": { - Limit: user.APILimit{ - RateLimit: user.RateLimit{ - Rate: 2000, - Per: 60, - }, - QuotaMax: -1, - }, - AllowanceScope: "c", - }, - "d": { - Limit: user.APILimit{ - RateLimit: user.RateLimit{ - Rate: 200, - Per: 10, - }, - QuotaMax: 5000, - QuotaRenewalRate: 3600, - }, - AllowanceScope: "d", - }, - } - assert.Equal(t, want, s.AccessRights) - }, - }, - { - name: "several policies with Per API set to true specifying limit for the same APIs", - policies: []string{"per_api_and_no_other_partitions", "per_api_with_api_d", "per_api_with_api_c"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - want := map[string]user.AccessDefinition{ - "c": { - Limit: user.APILimit{ - RateLimit: user.RateLimit{ - Rate: 3000, - Per: 10, - }, - QuotaMax: -1, - }, - AllowanceScope: "c", - }, - "d": { - Limit: user.APILimit{ - RateLimit: user.RateLimit{ - Rate: 200, - Per: 10, - }, - QuotaMax: 5000, - QuotaRenewalRate: 3600, - }, - AllowanceScope: "d", - }, - } - assert.Equal(t, want, s.AccessRights) - }, - }, - { - name: "several policies, mixed the one which has Per API set to true and partitioned ones", - policies: []string{"per_api_with_api_d", "quota1"}, - errMatch: "cannot apply multiple policies when some have per_api set and some are partitioned", - }, - { - name: "several policies, mixed the one which has Per API set to true and partitioned ones (different order)", - policies: []string{"rate1", "per_api_with_api_d"}, - errMatch: "cannot apply multiple policies when some have per_api set and some are partitioned", - }, - { - name: "Per API is set to true and some API gets limit set from policy's fields", - policies: []string{"per_api_with_limit_set_from_policy"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - want := map[string]user.AccessDefinition{ - "e": { - Limit: user.APILimit{ - QuotaMax: -1, - RateLimit: user.RateLimit{ - Rate: 300, - Per: 1, - }, - }, - AllowanceScope: "per_api_with_limit_set_from_policy", - }, - "d": { - Limit: user.APILimit{ - QuotaMax: 5000, - QuotaRenewalRate: 3600, - RateLimit: user.RateLimit{ - Rate: 200, - Per: 10, - }, - }, - AllowanceScope: "d", - }, - } - assert.Equal(t, want, s.AccessRights) - }, - }, - { - name: "Per API with limits override", - policies: []string{ - "per_api_with_limit_set_from_policy", - "per_api_with_api_d", - "per_api_with_higher_rate_on_api_d", - }, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - want := map[string]user.AccessDefinition{ - "e": { - Limit: user.APILimit{ - QuotaMax: -1, - RateLimit: user.RateLimit{ - Rate: 300, - Per: 1, - }, - }, - AllowanceScope: "per_api_with_limit_set_from_policy", - }, - "d": { - Limit: user.APILimit{ - QuotaMax: 5000, - QuotaRenewalRate: 3600, - RateLimit: user.RateLimit{ - Rate: 200, - Per: 10, - }, - }, - AllowanceScope: "d", - }, - } - assert.Equal(t, want, s.AccessRights) - }, - }, - } - tests = append(tests, perAPITCs...) - - graphQLTCs := []testApplyPoliciesData{ - { - name: "Merge per path rules for the same API", - policies: []string{"per-path2", "per-path1"}, - sessMatch: func(t *testing.T, sess *user.SessionState) { - t.Helper() - want := map[string]user.AccessDefinition{ - "a": { - AllowedURLs: []user.AccessSpec{ - {URL: "/user", Methods: []string{"GET", "POST"}}, - {URL: "/companies", Methods: []string{"GET", "POST"}}, - }, - Limit: user.APILimit{}, - }, - "b": { - AllowedURLs: []user.AccessSpec{ - {URL: "/", Methods: []string{"PUT"}}, - }, - Limit: user.APILimit{}, - }, - } - - gotPolicy, ok := store.PolicyByID("per-path2") - - assert.True(t, ok) - assert.Equal(t, user.AccessSpec{ - URL: "/user", Methods: []string{"GET"}, - }, gotPolicy.AccessRights["a"].AllowedURLs[0]) - - assert.Equal(t, want, sess.AccessRights) - }, - }, - { - name: "Merge restricted fields for the same GraphQL API", - policies: []string{"restricted-types1", "restricted-types2"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - - want := map[string]user.AccessDefinition{ - "a": { // It should get intersection of restricted types. - RestrictedTypes: []graphql.Type{ - {Name: "Country", Fields: []string{"code"}}, - {Name: "Person", Fields: []string{"name"}}, - }, - Limit: user.APILimit{}, - }, - } - - assert.Equal(t, want, s.AccessRights) - }, - }, - { - name: "Merge allowed fields for the same GraphQL API", - policies: []string{"allowed-types1", "allowed-types2"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - - want := map[string]user.AccessDefinition{ - "a": { // It should get intersection of restricted types. - AllowedTypes: []graphql.Type{ - {Name: "Country", Fields: []string{"code"}}, - {Name: "Person", Fields: []string{"name"}}, - }, - Limit: user.APILimit{}, - }, - } - - assert.Equal(t, want, s.AccessRights) - }, - }, - { - name: "If GQL introspection is disabled, it remains disabled after merging", - policies: []string{"introspection-disabled", "introspection-enabled"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - - want := map[string]user.AccessDefinition{ - "a": { - DisableIntrospection: true, // If GQL introspection is disabled, it remains disabled after merging. - Limit: user.APILimit{}, - }, - } - - assert.Equal(t, want, s.AccessRights) - }, - }, - { - name: "Merge field level depth limit for the same GraphQL API", - policies: []string{"field-level-depth-limit1", "field-level-depth-limit2"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - - want := map[string]user.AccessDefinition{ - "graphql-api": { - Limit: user.APILimit{}, - FieldAccessRights: []user.FieldAccessDefinition{ - {TypeName: "Query", FieldName: "people", Limits: user.FieldLimits{MaxQueryDepth: 4}}, - {TypeName: "Mutation", FieldName: "putPerson", Limits: user.FieldLimits{MaxQueryDepth: -1}}, - {TypeName: "Query", FieldName: "countries", Limits: user.FieldLimits{MaxQueryDepth: 3}}, - {TypeName: "Query", FieldName: "continents", Limits: user.FieldLimits{MaxQueryDepth: 4}}, - }, - }, - } - - assert.Equal(t, want, s.AccessRights) - }, - }, - } - tests = append(tests, graphQLTCs...) - - throttleTCs := []testApplyPoliciesData{ - { - "Throttle interval from policy", []string{"throttle1"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - if s.ThrottleInterval != 9 { - t.Fatalf("Throttle interval should be 9 inherited from policy") - } - }, nil, false, - }, - { - name: "Throttle retry limit from policy", - policies: []string{"throttle1"}, - errMatch: "", - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - - if s.ThrottleRetryLimit != 99 { - t.Fatalf("Throttle interval should be 9 inherited from policy") - } - }, - session: nil, - }, - } - tests = append(tests, throttleTCs...) - - tagsTCs := []testApplyPoliciesData{ - { - "TagMerge", []string{"tags1", "tags2"}, - "", func(t *testing.T, s *user.SessionState) { - t.Helper() - want := []string{"key-tag", "tagA", "tagX", "tagY"} - sort.Strings(s.Tags) - - assert.Equal(t, want, s.Tags) - }, &user.SessionState{ - Tags: []string{"key-tag"}, - }, false, - }, - } - tests = append(tests, tagsTCs...) - - partitionTCs := []testApplyPoliciesData{ - { - "NonpartAndPart", []string{"nonpart1", "quota1"}, - "", nil, nil, false, - }, - { - name: "inherit quota and rate from partitioned policies", - policies: []string{"quota1", "rate3"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - - if s.QuotaMax != 2 { - t.Fatalf("quota should be the same as quota policy") - } - if s.Rate != 4 { - t.Fatalf("rate should be the same as rate policy") - } - if s.Per != 4 { - t.Fatalf("Rate per seconds should be the same as rate policy") - } - }, - }, - { - name: "inherit quota and rate from partitioned policies applied in different order", - policies: []string{"rate3", "quota1"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - - if s.QuotaMax != 2 { - t.Fatalf("quota should be the same as quota policy") - } - if s.Rate != 4 { - t.Fatalf("rate should be the same as rate policy") - } - if s.Per != 4 { - t.Fatalf("Rate per seconds should be the same as rate policy") - } - }, - }, - } - tests = append(tests, partitionTCs...) - - endpointRLTCs := []testApplyPoliciesData{ - { - name: "Per API and per endpoint policies", - policies: []string{"per_api_with_limit_set_from_policy", "per_api_with_endpoint_limits_on_d_and_e"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - endpointsConfig := user.Endpoints{ - { - Path: "/get", - Methods: user.EndpointMethods{ - { - Name: "GET", - Limit: user.RateLimit{ - Rate: -1, - }, - }, - }, - }, - { - Path: "/post", - Methods: user.EndpointMethods{ - { - Name: "POST", - Limit: user.RateLimit{ - Rate: 300, - Per: 10, - }, - }, - }, - }, - } - want := map[string]user.AccessDefinition{ - "e": { - Limit: user.APILimit{ - QuotaMax: -1, - RateLimit: user.RateLimit{ - Rate: 500, - Per: 1, - }, - }, - AllowanceScope: "per_api_with_endpoint_limits_on_d_and_e", - Endpoints: endpointsConfig, - }, - "d": { - Limit: user.APILimit{ - QuotaMax: 5000, - QuotaRenewalRate: 3600, - RateLimit: user.RateLimit{ - Rate: 200, - Per: 10, - }, - }, - AllowanceScope: "d", - Endpoints: endpointsConfig, - }, - } - assert.Equal(t, want, s.AccessRights) - }, - reverseOrder: true, - }, - { - name: "Endpoint level limits overlapping", - policies: []string{ - "per_api_with_limit_set_from_policy", - "per_api_with_endpoint_limits_on_d_and_e", - "per_endpoint_limits_different_on_api_d", - }, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - apiEEndpoints := user.Endpoints{ - { - Path: "/get", - Methods: user.EndpointMethods{ - { - Name: "GET", - Limit: user.RateLimit{ - Rate: -1, - }, - }, - }, - }, - { - Path: "/post", - Methods: user.EndpointMethods{ - { - Name: "POST", - Limit: user.RateLimit{ - Rate: 300, - Per: 10, - }, - }, - }, - }, - } - - assert.ElementsMatch(t, apiEEndpoints, s.AccessRights["e"].Endpoints) - - apiDEndpoints := user.Endpoints{ - { - Path: "/get", - Methods: user.EndpointMethods{ - { - Name: "GET", - Limit: user.RateLimit{ - Rate: -1, - }, - }, - }, - }, - { - Path: "/post", - Methods: user.EndpointMethods{ - { - Name: "POST", - Limit: user.RateLimit{ - Rate: 400, - Per: 11, - }, - }, - }, - }, - { - Path: "/anything", - Methods: user.EndpointMethods{ - { - Name: "PUT", - Limit: user.RateLimit{ - Rate: 500, - Per: 10, - }, - }, - }, - }, - } - - assert.ElementsMatch(t, apiDEndpoints, s.AccessRights["d"].Endpoints) - - apiELimits := user.APILimit{ - QuotaMax: -1, - RateLimit: user.RateLimit{ - Rate: 500, - Per: 1, - }, - } - assert.Equal(t, apiELimits, s.AccessRights["e"].Limit) - - apiDLimits := user.APILimit{ - QuotaMax: 5000, - QuotaRenewalRate: 3600, - RateLimit: user.RateLimit{ - Rate: 200, - Per: 10, - }, - } - assert.Equal(t, apiDLimits, s.AccessRights["d"].Limit) - }, - reverseOrder: true, - }, - { - name: "endpoint_rate_limits_on_acl_partition_only", - policies: []string{"endpoint_rate_limits_on_acl_partition_only"}, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.NotEmpty(t, s.AccessRights) - assert.Empty(t, s.AccessRights["d"].Endpoints) - }, - }, - { - name: "endpoint_rate_limits_when_acl_and_quota_partitions_combined", - policies: []string{ - "endpoint_rate_limits_on_acl_partition_only", - "endpoint_rate_limits_on_quota_partition_only", - }, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.NotEmpty(t, s.AccessRights) - assert.Empty(t, s.AccessRights["d"].Endpoints) - }, - reverseOrder: true, - }, - } - - tests = append(tests, endpointRLTCs...) - - combinedEndpointRLTCs := []testApplyPoliciesData{ - { - name: "combine_non_partitioned_policies_with_endpoint_rate_limits_configured_on_api_d", - policies: []string{ - "api_d_get_endpoint_rl_1_configure_on_non_partitioned_policy", - "api_d_get_endpoint_rl_2_configure_on_non_partitioned_policy", - }, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.NotEmpty(t, s.AccessRights) - apiDEndpoints := user.Endpoints{ - { - Path: "/get", - Methods: user.EndpointMethods{ - { - Name: "GET", - Limit: user.RateLimit{ - Rate: 20, - Per: 60, - }, - }, - }, - }, - } - - assert.ElementsMatch(t, apiDEndpoints, s.AccessRights["d"].Endpoints) - }, - reverseOrder: true, - }, - { - name: "combine_non_partitioned_policies_with_endpoint_rate_limits_no_bound_configured_on_api_d", - policies: []string{ - "api_d_get_endpoint_rl_1_configure_on_non_partitioned_policy", - "api_d_get_endpoint_rl_2_configure_on_non_partitioned_policy", - "api_d_get_endpoint_rl_3_configure_on_non_partitioned_policy", - }, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.NotEmpty(t, s.AccessRights) - apiDEndpoints := user.Endpoints{ - { - Path: "/get", - Methods: user.EndpointMethods{ - { - Name: "GET", - Limit: user.RateLimit{ - Rate: -1, - }, - }, - }, - }, - } - - assert.ElementsMatch(t, apiDEndpoints, s.AccessRights["d"].Endpoints) - }, - reverseOrder: true, - }, - { - name: "combine_non_partitioned_policies_with_multiple_endpoint_rate_limits_configured_on_api_d", - policies: []string{ - "api_d_get_endpoint_rl_1_configure_on_non_partitioned_policy", - "api_d_get_endpoint_rl_2_configure_on_non_partitioned_policy", - "api_d_get_endpoint_rl_3_configure_on_non_partitioned_policy", - "api_d_post_endpoint_rl_1_configure_on_non_partitioned_policy", - }, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.NotEmpty(t, s.AccessRights) - apiDEndpoints := user.Endpoints{ - { - Path: "/get", - Methods: user.EndpointMethods{ - { - Name: "GET", - Limit: user.RateLimit{ - Rate: -1, - }, - }, - }, - }, - { - Path: "/post", - Methods: user.EndpointMethods{ - { - Name: "POST", - Limit: user.RateLimit{ - Rate: 20, - Per: 60, - }, - }, - }, - }, - } - - assert.ElementsMatch(t, apiDEndpoints, s.AccessRights["d"].Endpoints) - }, - reverseOrder: true, - }, - { - name: "combine_non_partitioned_policies_with_endpoint_rate_limits_configured_on_api_d_and_e", - policies: []string{ - "api_d_get_endpoint_rl_1_configure_on_non_partitioned_policy", - "api_d_get_endpoint_rl_2_configure_on_non_partitioned_policy", - "api_d_get_endpoint_rl_3_configure_on_non_partitioned_policy", - "api_d_post_endpoint_rl_1_configure_on_non_partitioned_policy", - "api_e_get_endpoint_rl_1_configure_on_non_partitioned_policy", - }, - sessMatch: func(t *testing.T, s *user.SessionState) { - t.Helper() - assert.NotEmpty(t, s.AccessRights) - apiDEndpoints := user.Endpoints{ - { - Path: "/get", - Methods: user.EndpointMethods{ - { - Name: "GET", - Limit: user.RateLimit{ - Rate: -1, - }, - }, - }, - }, - { - Path: "/post", - Methods: user.EndpointMethods{ - { - Name: "POST", - Limit: user.RateLimit{ - Rate: 20, - Per: 60, - }, - }, - }, - }, - } - - assert.ElementsMatch(t, apiDEndpoints, s.AccessRights["d"].Endpoints) - - apiEEndpoints := user.Endpoints{ - { - Path: "/get", - Methods: user.EndpointMethods{ - { - Name: "GET", - Limit: user.RateLimit{ - Rate: 100, - Per: 60, - }, - }, - }, - }, - } - - assert.ElementsMatch(t, apiEEndpoints, s.AccessRights["e"].Endpoints) - }, - reverseOrder: true, - }, - } - - tests = append(tests, combinedEndpointRLTCs...) - - return service, tests -} - -func TestService_Apply(t *testing.T) { - service, tests := testPrepareApplyPolicies(t) - - for _, tc := range tests { - pols := [][]string{tc.policies} - if tc.reverseOrder { - var copyPols = make([]string, len(tc.policies)) - copy(copyPols, tc.policies) - slices.Reverse(copyPols) - pols = append(pols, copyPols) - } - - for i, policies := range pols { - name := tc.name - if i == 1 { - name = fmt.Sprintf("%s, reversed=%t", name, true) - } - - t.Run(name, func(t *testing.T) { - sess := tc.session - if sess == nil { - sess = &user.SessionState{} - } - sess.SetPolicies(policies...) - if err := service.Apply(sess); err != nil { - assert.ErrorContains(t, err, tc.errMatch) - return - } - - if tc.sessMatch != nil { - tc.sessMatch(t, sess) - } - }) - } - } -} - -func BenchmarkService_Apply(b *testing.B) { - b.ReportAllocs() - - service, tests := testPrepareApplyPolicies(b) - - for i := 0; i < b.N; i++ { - for _, tc := range tests { - sess := &user.SessionState{} - sess.SetPolicies(tc.policies...) - err := service.Apply(sess) - assert.NoError(b, err) - } - } -} diff --git a/tests/policy/apply_acl_test.go b/tests/policy/apply_acl_test.go new file mode 100644 index 000000000000..b7cea4871278 --- /dev/null +++ b/tests/policy/apply_acl_test.go @@ -0,0 +1,68 @@ +package policy + +import ( + "testing" + + "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/user" + "github.com/stretchr/testify/assert" +) + +type ApplyPolicyFunc func(*user.SessionState) error + +func testApplyPolicyFn(t *testing.T) ApplyPolicyFunc { + bmid := &BaseMiddleware{ + Spec: &APISpec{ + APIDefinition: &apidef.APIDefinition{}, + }, + Gw: &Gateway{}, + } + return bmid.ApplyPolicies +} + +func TestApplyACL_FromCustomPolicies(t *testing.T) { + applyPolicy := testApplyPolicyFn(t) + + pol1 := user.Policy{ + ID: "pol1", + Partitions: user.PolicyPartitions{RateLimit: true}, + Rate: 8, + Per: 1, + AccessRights: map[string]user.AccessDefinition{ + "a": {}, + }, + } + + pol2 := user.Policy{ + ID: "pol2", + Partitions: user.PolicyPartitions{Acl: true}, + Rate: 10, + Per: 1, + AccessRights: map[string]user.AccessDefinition{ + "a": { + AllowedURLs: []user.AccessSpec{ + {URL: "/user", Methods: []string{"GET", "POST"}}, + {URL: "/companies", Methods: []string{"GET", "POST"}}, + }, + }, + }, + } + + t.Run("RateLimit first", func(t *testing.T) { + session := &user.SessionState{} + session.SetCustomPolicies([]user.Policy{pol1, pol2}) + + assert.NoError(t, applyPolicy(session)) + assert.Equal(t, pol2.AccessRights["a"].AllowedURLs, session.AccessRights["a"].AllowedURLs) + assert.Equal(t, 8, int(session.Rate)) + }) + + t.Run("ACL first", func(t *testing.T) { + session := &user.SessionState{} + session.SetCustomPolicies([]user.Policy{pol2, pol1}) + + assert.NoError(t, applyPolicy(session)) + assert.Equal(t, pol2.AccessRights["a"].AllowedURLs, session.AccessRights["a"].AllowedURLs) + assert.Equal(t, 8, int(session.Rate)) + }) +} diff --git a/tests/policy/shim.go b/tests/policy/shim.go index cf6fe5f61afc..bd6ae20fd47f 100644 --- a/tests/policy/shim.go +++ b/tests/policy/shim.go @@ -4,6 +4,10 @@ import "github.com/TykTechnologies/tyk/gateway" const DefaultOrg = "default-org-id" -type APISpec = gateway.APISpec +type ( + Gateway = gateway.Gateway + APISpec = gateway.APISpec + BaseMiddleware = gateway.BaseMiddleware +) var StartTest = gateway.StartTest