diff --git a/gateway/api_definition.go b/gateway/api_definition.go index 2a4149f58a4..0b6c54bab5f 100644 --- a/gateway/api_definition.go +++ b/gateway/api_definition.go @@ -47,6 +47,7 @@ import ( "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/config" "github.com/TykTechnologies/tyk/header" + "github.com/TykTechnologies/tyk/internal/model" "github.com/TykTechnologies/tyk/regexp" "github.com/TykTechnologies/tyk/rpc" "github.com/TykTechnologies/tyk/storage" @@ -522,7 +523,11 @@ func (a APIDefinitionLoader) FromDashboardService(endpoint string) ([]*APISpec, } // Extract tagged APIs# +<<<<<<< HEAD list := &nestedApiDefinitionList{} +======= + list := model.NewMergedAPIList() +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) inBytes, err := io.ReadAll(resp.Body) if err != nil { log.Error("Couldn't read api definition list") @@ -680,15 +685,23 @@ func (a APIDefinitionLoader) FromRPC(store RPCDataLoader, orgId string, gw *Gate } func (a APIDefinitionLoader) processRPCDefinitions(apiCollection string, gw *Gateway) ([]*APISpec, error) { +<<<<<<< HEAD var payload []nestedApiDefinition +======= + var payload []model.MergedAPI +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) if err := json.Unmarshal([]byte(apiCollection), &payload); err != nil { return nil, err } +<<<<<<< HEAD list := &nestedApiDefinitionList{ Message: payload, } +======= + list := model.NewMergedAPIList(payload...) +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) gwConfig := a.Gw.GetConfig() diff --git a/gateway/api_definition_test.go b/gateway/api_definition_test.go index cece0631fb9..373b519bd50 100644 --- a/gateway/api_definition_test.go +++ b/gateway/api_definition_test.go @@ -21,6 +21,11 @@ import ( "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/apidef/oas" "github.com/TykTechnologies/tyk/config" +<<<<<<< HEAD +======= + "github.com/TykTechnologies/tyk/internal/model" + "github.com/TykTechnologies/tyk/internal/policy" +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) "github.com/TykTechnologies/tyk/rpc" "github.com/TykTechnologies/tyk/test" "github.com/TykTechnologies/tyk/user" @@ -1448,7 +1453,7 @@ func Test_LoadAPIsFromRPC(t *testing.T) { loader := APIDefinitionLoader{Gw: ts.Gw} t.Run("load APIs from RPC - success", func(t *testing.T) { - mockedStorage := &RPCDataLoaderMock{ + mockedStorage := &policy.RPCDataLoaderMock{ ShouldConnect: true, Apis: []nestedApiDefinition{ {APIDefinition: &apidef.APIDefinition{Id: objectID, OrgID: "org1", APIID: "api1"}}, @@ -1462,7 +1467,7 @@ func Test_LoadAPIsFromRPC(t *testing.T) { }) t.Run("load APIs from RPC - success - then fail", func(t *testing.T) { - mockedStorage := &RPCDataLoaderMock{ + mockedStorage := &policy.RPCDataLoaderMock{ ShouldConnect: true, Apis: []nestedApiDefinition{ {APIDefinition: &apidef.APIDefinition{Id: objectID, OrgID: "org1", APIID: "api1"}}, diff --git a/gateway/gateway.go b/gateway/gateway.go new file mode 100644 index 00000000000..da5d538ae43 --- /dev/null +++ b/gateway/gateway.go @@ -0,0 +1,62 @@ +package gateway + +import ( + "github.com/TykTechnologies/tyk/internal/policy" + "github.com/TykTechnologies/tyk/user" +) + +// Repository is a description of our Gateway API promises. +type Repository interface { + policy.Repository +} + +// Gateway implements the Repository interface. +var _ Repository = &Gateway{} + +// PolicyIDs returns a list of IDs for each policy loaded in the gateway. +func (gw *Gateway) PolicyIDs() []string { + gw.policiesMu.RLock() + defer gw.policiesMu.RUnlock() + + result := make([]string, 0, len(gw.policiesByID)) + for id := range gw.policiesByID { + result = append(result, id) + } + return result +} + +// PolicyByID will return a Policy matching the passed Policy ID. +func (gw *Gateway) PolicyByID(id string) (user.Policy, bool) { + gw.policiesMu.RLock() + defer gw.policiesMu.RUnlock() + + pol, ok := gw.policiesByID[id] + return pol, ok +} + +// PolicyCount will return the number of policies loaded in the gateway. +func (gw *Gateway) PolicyCount() int { + gw.policiesMu.RLock() + defer gw.policiesMu.RUnlock() + + return len(gw.policiesByID) +} + +// SetPolicies updates the internal policy map with a new policy map. +func (gw *Gateway) SetPolicies(pols map[string]user.Policy) { + gw.policiesMu.Lock() + defer gw.policiesMu.Unlock() + + gw.policiesByID = pols +} + +// SetPoliciesByID will update the internal policiesByID map with new policies. +// The key used will be the policy ID. +func (gw *Gateway) SetPoliciesByID(pols ...user.Policy) { + gw.policiesMu.Lock() + defer gw.policiesMu.Unlock() + + for _, pol := range pols { + gw.policiesByID[pol.ID] = pol + } +} diff --git a/gateway/health_check.go b/gateway/health_check.go index c31ec345c63..b7cb46b241e 100644 --- a/gateway/health_check.go +++ b/gateway/health_check.go @@ -8,15 +8,21 @@ import ( "sync" "time" +<<<<<<< HEAD "github.com/TykTechnologies/tyk/rpc" +======= +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) "github.com/sirupsen/logrus" "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/header" + "github.com/TykTechnologies/tyk/internal/model" + "github.com/TykTechnologies/tyk/rpc" "github.com/TykTechnologies/tyk/storage" ) +<<<<<<< HEAD func (gw *Gateway) setCurrentHealthCheckInfo(h map[string]apidef.HealthCheckItem) { gw.healthCheckInfo.Store(h) } @@ -25,12 +31,40 @@ func (gw *Gateway) getHealthCheckInfo() map[string]apidef.HealthCheckItem { ret, ok := gw.healthCheckInfo.Load().(map[string]apidef.HealthCheckItem) if !ok { return make(map[string]apidef.HealthCheckItem, 0) +======= +type ( + HealthCheckItem = model.HealthCheckItem + HealthCheckStatus = model.HealthCheckStatus + HealthCheckResponse = model.HealthCheckResponse +) + +const ( + Pass = model.Pass + Fail = model.Fail + Warn = model.Warn + Datastore = model.Datastore + System = model.System +) + +func (gw *Gateway) setCurrentHealthCheckInfo(h map[string]model.HealthCheckItem) { + gw.healthCheckInfo.Store(h) +} + +func (gw *Gateway) getHealthCheckInfo() map[string]HealthCheckItem { + ret, ok := gw.healthCheckInfo.Load().(map[string]HealthCheckItem) + if !ok { + return make(map[string]HealthCheckItem, 0) +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) } return ret } func (gw *Gateway) initHealthCheck(ctx context.Context) { +<<<<<<< HEAD gw.setCurrentHealthCheckInfo(make(map[string]apidef.HealthCheckItem, 3)) +======= + gw.setCurrentHealthCheckInfo(make(map[string]HealthCheckItem, 3)) +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) go func(ctx context.Context) { var n = gw.GetConfig().LivenessCheck.CheckDuration @@ -59,12 +93,20 @@ func (gw *Gateway) initHealthCheck(ctx context.Context) { } type SafeHealthCheck struct { +<<<<<<< HEAD info map[string]apidef.HealthCheckItem +======= + info map[string]HealthCheckItem +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) mux sync.Mutex } func (gw *Gateway) gatherHealthChecks() { +<<<<<<< HEAD allInfos := SafeHealthCheck{info: make(map[string]apidef.HealthCheckItem, 3)} +======= + allInfos := SafeHealthCheck{info: make(map[string]HealthCheckItem, 3)} +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) redisStore := storage.RedisCluster{KeyPrefix: "livenesscheck-", ConnectionHandler: gw.StorageConnectionHandler} @@ -76,9 +118,15 @@ func (gw *Gateway) gatherHealthChecks() { go func() { defer wg.Done() +<<<<<<< HEAD var checkItem = apidef.HealthCheckItem{ Status: apidef.Pass, ComponentType: apidef.Datastore, +======= + var checkItem = HealthCheckItem{ + Status: Pass, + ComponentType: Datastore, +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) Time: time.Now().Format(time.RFC3339), } @@ -86,7 +134,11 @@ func (gw *Gateway) gatherHealthChecks() { if err != nil { mainLog.WithField("liveness-check", true).WithError(err).Error("Redis health check failed") checkItem.Output = err.Error() +<<<<<<< HEAD checkItem.Status = apidef.Fail +======= + checkItem.Status = Fail +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) } allInfos.mux.Lock() @@ -100,9 +152,15 @@ func (gw *Gateway) gatherHealthChecks() { go func() { defer wg.Done() +<<<<<<< HEAD var checkItem = apidef.HealthCheckItem{ Status: apidef.Pass, ComponentType: apidef.Datastore, +======= + var checkItem = HealthCheckItem{ + Status: Pass, + ComponentType: Datastore, +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) Time: time.Now().Format(time.RFC3339), } @@ -110,6 +168,7 @@ func (gw *Gateway) gatherHealthChecks() { err := errors.New("Dashboard service not initialized") mainLog.WithField("liveness-check", true).Error(err) checkItem.Output = err.Error() +<<<<<<< HEAD checkItem.Status = apidef.Fail } else if err := gw.DashService.Ping(); err != nil { mainLog.WithField("liveness-check", true).Error(err) @@ -118,6 +177,16 @@ func (gw *Gateway) gatherHealthChecks() { } checkItem.ComponentType = apidef.System +======= + checkItem.Status = Fail + } else if err := gw.DashService.Ping(); err != nil { + mainLog.WithField("liveness-check", true).Error(err) + checkItem.Output = err.Error() + checkItem.Status = Fail + } + + checkItem.ComponentType = System +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) allInfos.mux.Lock() allInfos.info["dashboard"] = checkItem @@ -132,18 +201,31 @@ func (gw *Gateway) gatherHealthChecks() { go func() { defer wg.Done() +<<<<<<< HEAD var checkItem = apidef.HealthCheckItem{ Status: apidef.Pass, ComponentType: apidef.Datastore, +======= + var checkItem = HealthCheckItem{ + Status: Pass, + ComponentType: Datastore, +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) Time: time.Now().Format(time.RFC3339), } if !rpc.Login() { checkItem.Output = "Could not connect to RPC" +<<<<<<< HEAD checkItem.Status = apidef.Fail } checkItem.ComponentType = apidef.System +======= + checkItem.Status = Fail + } + + checkItem.ComponentType = System +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) allInfos.mux.Lock() allInfos.info["rpc"] = checkItem @@ -166,8 +248,13 @@ func (gw *Gateway) liveCheckHandler(w http.ResponseWriter, r *http.Request) { checks := gw.getHealthCheckInfo() +<<<<<<< HEAD res := apidef.HealthCheckResponse{ Status: apidef.Pass, +======= + res := HealthCheckResponse{ + Status: Pass, +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) Version: VERSION, Description: "Tyk GW", Details: checks, @@ -176,11 +263,16 @@ func (gw *Gateway) liveCheckHandler(w http.ResponseWriter, r *http.Request) { var failCount int for _, v := range checks { +<<<<<<< HEAD if v.Status == apidef.Fail { +======= + if v.Status == Fail { +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) failCount++ } } +<<<<<<< HEAD var status apidef.HealthCheckStatus switch failCount { @@ -192,6 +284,19 @@ func (gw *Gateway) liveCheckHandler(w http.ResponseWriter, r *http.Request) { default: status = apidef.Warn +======= + var status HealthCheckStatus + + switch failCount { + case 0: + status = Pass + + case len(checks): + status = Fail + + default: + status = Warn +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) } res.Status = status diff --git a/gateway/policy_test.go b/gateway/policy_test.go index 1481798adad..4d5aaa8726d 100644 --- a/gateway/policy_test.go +++ b/gateway/policy_test.go @@ -20,7 +20,6 @@ import ( "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/config" "github.com/TykTechnologies/tyk/header" - "github.com/TykTechnologies/tyk/rpc" "github.com/TykTechnologies/tyk/test" "github.com/TykTechnologies/tyk/user" @@ -1596,6 +1595,7 @@ func TestParsePoliciesFromRPC(t *testing.T) { } } +<<<<<<< HEAD type RPCDataLoaderMock struct { ShouldConnect bool @@ -1666,3 +1666,5 @@ func Test_LoadPoliciesFromRPC(t *testing.T) { assert.Equal(t, 1, len(polMap), "expected 0 policies to be loaded from RPC") }) } +======= +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) diff --git a/internal/policy/Taskfile.yml b/internal/policy/Taskfile.yml new file mode 100644 index 00000000000..97c1c82735d --- /dev/null +++ b/internal/policy/Taskfile.yml @@ -0,0 +1,56 @@ +--- +version: "3" + +includes: + services: + taskfile: ../../docker/services/Taskfile.yml + dir: ../../docker/services + +vars: + run: . + +tasks: + default: + desc: "Run tests" + deps: [ services:up ] + requires: + vars: [run] + cmds: + - defer: { task: services:down } + - goimports -w . + - go fmt ./... + - task: test + vars: + run: '{{.run}}' + + test: + desc: "Run tests" + requires: + vars: [run] + cmds: + - go test -count=1 -run='({{.run}})' -cover -coverprofile=pkg.cov -v . + + stress: + desc: "Run stress tests" + requires: + vars: [run] + cmds: + - go test -count=2000 -run='({{.run}})' -cover -coverprofile=pkg.cov . + + cover: + desc: "Show source coverage" + aliases: [coverage, cov] + cmds: + - go tool cover -func=pkg.cov + + uncover: + desc: "Show uncovered source" + cmds: + - uncover pkg.cov + + install:uncover: + desc: "Install uncover" + env: + GOBIN: /usr/local/bin + cmds: + - go install github.com/gregoryv/uncover/...@latest diff --git a/internal/policy/apply.go b/internal/policy/apply.go new file mode 100644 index 00000000000..8a24ce7ab3a --- /dev/null +++ b/internal/policy/apply.go @@ -0,0 +1,600 @@ +package policy + +import ( + "errors" + "fmt" + + "github.com/sirupsen/logrus" + + "github.com/TykTechnologies/tyk/user" +) + +var ( + // ErrMixedPartitionAndPerAPIPolicies is the error to return when a mix of per api and partitioned policies are to be applied in a session. + ErrMixedPartitionAndPerAPIPolicies = errors.New("cannot apply multiple policies when some have per_api set and some are partitioned") +) + +// Repository is a storage encapsulating policy retrieval. +// Gateway implements this object to decouple this package. +type Repository interface { + PolicyCount() int + PolicyIDs() []string + PolicyByID(string) (user.Policy, bool) +} + +// Service represents the policy service for gateway. +type Service struct { + storage Repository + logger *logrus.Logger + + // used for validation if not empty + orgID *string +} + +// New creates a new policy.Service object. +func New(orgID *string, storage Repository, logger *logrus.Logger) *Service { + return &Service{ + orgID: orgID, + storage: storage, + logger: logger, + } +} + +// ClearSession clears the quota, rate limit and complexity values so that partitioned policies can apply their values. +// Otherwise, if the session has already a higher value, an applied policy will not win, and its values will be ignored. +func (t *Service) ClearSession(session *user.SessionState) error { + policies := session.PolicyIDs() + + for _, polID := range policies { + policy, ok := t.storage.PolicyByID(polID) + if !ok { + return fmt.Errorf("policy not found: %s", polID) + } + + all := !(policy.Partitions.Quota || policy.Partitions.RateLimit || policy.Partitions.Acl || policy.Partitions.Complexity) + + if policy.Partitions.Quota || all { + session.QuotaMax = 0 + session.QuotaRemaining = 0 + } + + if policy.Partitions.RateLimit || all { + session.Rate = 0 + session.Per = 0 + session.Smoothing = nil + session.ThrottleRetryLimit = 0 + session.ThrottleInterval = 0 + } + + if policy.Partitions.Complexity || all { + session.MaxQueryDepth = 0 + } + } + + return nil +} + +type applyStatus struct { + didQuota map[string]bool + didRateLimit map[string]bool + didAcl map[string]bool + didComplexity map[string]bool + didPerAPI bool + didPartition bool +} + +// Apply will check if any policies are loaded. If any are, it +// will overwrite the session state to use the policy values. +func (t *Service) Apply(session *user.SessionState) error { + rights := make(map[string]user.AccessDefinition) + tags := make(map[string]bool) + if session.MetaData == nil { + session.MetaData = make(map[string]interface{}) + } + + if err := t.ClearSession(session); err != nil { + t.logger.WithError(err).Warn("error clearing session") + } + + applyState := applyStatus{ + didQuota: make(map[string]bool), + didRateLimit: make(map[string]bool), + didAcl: make(map[string]bool), + didComplexity: make(map[string]bool), + } + + var ( + err error + policyIDs []string + ) + + storage := t.storage + + customPolicies, err := session.GetCustomPolicies() + if err != nil { + policyIDs = session.PolicyIDs() + } else { + storage = NewStore(customPolicies) + policyIDs = storage.PolicyIDs() + } + + for _, polID := range policyIDs { + policy, ok := storage.PolicyByID(polID) + if !ok { + err := fmt.Errorf("policy not found: %q", polID) + t.Logger().Error(err) + if len(policyIDs) > 1 { + continue + } + + return err + } + // Check ownership, policy org owner must be the same as API, + // otherwise you could overwrite a session key with a policy from a different org! + if t.orgID != nil && policy.OrgID != *t.orgID { + err := errors.New("attempting to apply policy from different organisation to key, skipping") + t.Logger().Error(err) + return err + } + + if policy.Partitions.PerAPI && policy.Partitions.Enabled() { + err := fmt.Errorf("cannot apply policy %s which has per_api and any of partitions set", policy.ID) + t.logger.Error(err) + return err + } + + if policy.Partitions.PerAPI { + if err := t.applyPerAPI(policy, session, rights, &applyState); err != nil { + return err + } + } else { + if err := t.applyPartitions(policy, session, rights, &applyState); err != nil { + return err + } + } + + session.IsInactive = session.IsInactive || policy.IsInactive + + for _, tag := range policy.Tags { + tags[tag] = true + } + + for k, v := range policy.MetaData { + session.MetaData[k] = v + } + + if policy.LastUpdated > session.LastUpdated { + session.LastUpdated = policy.LastUpdated + } + } + + for _, tag := range session.Tags { + tags[tag] = true + } + + // set tags + session.Tags = []string{} + for tag := range tags { + session.Tags = appendIfMissing(session.Tags, tag) + } + + if len(policyIDs) == 0 { + for apiID, accessRight := range session.AccessRights { + // check if the api in the session has per api limit + if !accessRight.Limit.IsEmpty() { + accessRight.AllowanceScope = apiID + session.AccessRights[apiID] = accessRight + } + } + } + + distinctACL := make(map[string]bool) + + for _, v := range rights { + if v.Limit.SetBy != "" { + distinctACL[v.Limit.SetBy] = true + } + } + + // If some APIs had only ACL partitions, inherit rest from session level + for k, v := range rights { + if !applyState.didAcl[k] { + delete(rights, k) + continue + } + + if !applyState.didRateLimit[k] { + v.Limit.Rate = session.Rate + v.Limit.Per = session.Per + v.Limit.Smoothing = session.Smoothing + v.Limit.ThrottleInterval = session.ThrottleInterval + v.Limit.ThrottleRetryLimit = session.ThrottleRetryLimit + v.Endpoints = nil + } + + if !applyState.didComplexity[k] { + v.Limit.MaxQueryDepth = session.MaxQueryDepth + } + + if !applyState.didQuota[k] { + v.Limit.QuotaMax = session.QuotaMax + v.Limit.QuotaRenewalRate = session.QuotaRenewalRate + v.Limit.QuotaRenews = session.QuotaRenews + } + + // If multime ACL + if len(distinctACL) > 1 { + if v.AllowanceScope == "" && v.Limit.SetBy != "" { + v.AllowanceScope = v.Limit.SetBy + } + } + + v.Limit.SetBy = "" + + rights[k] = v + } + + // If we have policies defining rules for one single API, update session root vars (legacy) + t.updateSessionRootVars(session, rights, applyState) + + // Override session ACL if at least one policy define it + if len(applyState.didAcl) > 0 { + session.AccessRights = rights + } + + return nil +} + +// Logger implements a typical logger signature with service context. +func (t *Service) Logger() *logrus.Entry { + return logrus.NewEntry(t.logger) +} + +// ApplyRateLimits will write policy limits to session and apiLimits. +// The limits get written if either are empty. +// The limits get written if filled and policyLimits allows a higher request rate. +func (t *Service) ApplyRateLimits(session *user.SessionState, policy user.Policy, apiLimits *user.APILimit) { + policyLimits := policy.APILimit() + if t.emptyRateLimit(policyLimits) { + return + } + + // duration is time between requests, e.g.: + // + // apiLimits: 500ms for 2 requests / second + // policyLimits: 100ms for 10 requests / second + // + // if apiLimits > policyLimits (500ms > 100ms) then + // we apply the higher rate from the policy. + // + // the policy-defined rate limits are enforced as + // a minimum possible api rate limit setting, + // raising apiLimits. + + if t.emptyRateLimit(*apiLimits) || apiLimits.Duration() > policyLimits.Duration() { + apiLimits.Rate = policyLimits.Rate + apiLimits.Per = policyLimits.Per + apiLimits.Smoothing = policyLimits.Smoothing + } + + // sessionLimits, similar to apiLimits, get policy + // rate applied if the policy allows more requests. + sessionLimits := session.APILimit() + if t.emptyRateLimit(sessionLimits) || sessionLimits.Duration() > policyLimits.Duration() { + session.Rate = policyLimits.Rate + session.Per = policyLimits.Per + session.Smoothing = policyLimits.Smoothing + } +} + +func (t *Service) emptyRateLimit(m user.APILimit) bool { + return m.Rate == 0 || m.Per == 0 +} + +func (t *Service) applyPerAPI(policy user.Policy, session *user.SessionState, rights map[string]user.AccessDefinition, + applyState *applyStatus) error { + + if applyState.didPartition { + t.logger.Error(ErrMixedPartitionAndPerAPIPolicies) + return ErrMixedPartitionAndPerAPIPolicies + } + + for apiID, accessRights := range policy.AccessRights { + idForScope := apiID + // check if we don't have limit on API level specified when policy was created + if accessRights.Limit.IsEmpty() { + // limit was not specified on API level so we will populate it from policy + idForScope = policy.ID + accessRights.Limit = policy.APILimit() + } + accessRights.AllowanceScope = idForScope + accessRights.Limit.SetBy = idForScope + + // respect current quota renews (on API limit level) + if r, ok := session.AccessRights[apiID]; ok && !r.Limit.IsEmpty() { + accessRights.Limit.QuotaRenews = r.Limit.QuotaRenews + } + + if r, ok := session.AccessRights[apiID]; ok { + // If GQL introspection is disabled, keep that configuration. + if r.DisableIntrospection { + accessRights.DisableIntrospection = r.DisableIntrospection + } + } + + if currAD, ok := rights[apiID]; ok { + accessRights = t.applyAPILevelLimits(accessRights, currAD) + } + + // overwrite session access right for this API + rights[apiID] = accessRights + + // identify that limit for that API is set (to allow set it only once) + applyState.didAcl[apiID] = true + applyState.didQuota[apiID] = true + applyState.didRateLimit[apiID] = true + applyState.didComplexity[apiID] = true + } + + if len(policy.AccessRights) > 0 { + applyState.didPerAPI = true + } + + return nil +} + +func (t *Service) applyPartitions(policy user.Policy, session *user.SessionState, rights map[string]user.AccessDefinition, + applyState *applyStatus) error { + + usePartitions := policy.Partitions.Enabled() + + if usePartitions && applyState.didPerAPI { + t.logger.Error(ErrMixedPartitionAndPerAPIPolicies) + return ErrMixedPartitionAndPerAPIPolicies + } + + for k, v := range policy.AccessRights { + ar := v + + if !usePartitions || policy.Partitions.Acl { + applyState.didAcl[k] = true + + ar.AllowedURLs = MergeAllowedURLs(ar.AllowedURLs, v.AllowedURLs) + + // Merge ACLs for the same API + if r, ok := rights[k]; ok { + // If GQL introspection is disabled, keep that configuration. + if v.DisableIntrospection { + r.DisableIntrospection = v.DisableIntrospection + } + r.Versions = appendIfMissing(rights[k].Versions, v.Versions...) + + r.AllowedURLs = MergeAllowedURLs(r.AllowedURLs, v.AllowedURLs) + + for _, t := range v.RestrictedTypes { + for ri, rt := range r.RestrictedTypes { + if t.Name == rt.Name { + r.RestrictedTypes[ri].Fields = intersection(rt.Fields, t.Fields) + } + } + } + + for _, t := range v.AllowedTypes { + for ri, rt := range r.AllowedTypes { + if t.Name == rt.Name { + r.AllowedTypes[ri].Fields = intersection(rt.Fields, t.Fields) + } + } + } + + mergeFieldLimits := func(res *user.FieldLimits, new user.FieldLimits) { + if greaterThanInt(new.MaxQueryDepth, res.MaxQueryDepth) { + res.MaxQueryDepth = new.MaxQueryDepth + } + } + + for _, far := range v.FieldAccessRights { + exists := false + for i, rfar := range r.FieldAccessRights { + if far.TypeName == rfar.TypeName && far.FieldName == rfar.FieldName { + exists = true + mergeFieldLimits(&r.FieldAccessRights[i].Limits, far.Limits) + } + } + + if !exists { + r.FieldAccessRights = append(r.FieldAccessRights, far) + } + } + + ar = r + } + + ar.Limit.SetBy = policy.ID + } + + if !usePartitions || policy.Partitions.Quota { + applyState.didQuota[k] = true + if greaterThanInt64(policy.QuotaMax, ar.Limit.QuotaMax) { + + ar.Limit.QuotaMax = policy.QuotaMax + if greaterThanInt64(policy.QuotaMax, session.QuotaMax) { + session.QuotaMax = policy.QuotaMax + } + } + + if policy.QuotaRenewalRate > ar.Limit.QuotaRenewalRate { + ar.Limit.QuotaRenewalRate = policy.QuotaRenewalRate + if policy.QuotaRenewalRate > session.QuotaRenewalRate { + session.QuotaRenewalRate = policy.QuotaRenewalRate + } + } + } + + if !usePartitions || policy.Partitions.RateLimit { + applyState.didRateLimit[k] = true + + t.ApplyRateLimits(session, policy, &ar.Limit) + + if rightsAR, ok := rights[k]; ok { + ar.Endpoints = t.ApplyEndpointLevelLimits(v.Endpoints, rightsAR.Endpoints) + } + + if policy.ThrottleRetryLimit > ar.Limit.ThrottleRetryLimit { + ar.Limit.ThrottleRetryLimit = policy.ThrottleRetryLimit + if policy.ThrottleRetryLimit > session.ThrottleRetryLimit { + session.ThrottleRetryLimit = policy.ThrottleRetryLimit + } + } + + if policy.ThrottleInterval > ar.Limit.ThrottleInterval { + ar.Limit.ThrottleInterval = policy.ThrottleInterval + if policy.ThrottleInterval > session.ThrottleInterval { + session.ThrottleInterval = policy.ThrottleInterval + } + } + } + + if !usePartitions || policy.Partitions.Complexity { + applyState.didComplexity[k] = true + + if greaterThanInt(policy.MaxQueryDepth, ar.Limit.MaxQueryDepth) { + ar.Limit.MaxQueryDepth = policy.MaxQueryDepth + if greaterThanInt(policy.MaxQueryDepth, session.MaxQueryDepth) { + session.MaxQueryDepth = policy.MaxQueryDepth + } + } + } + + // Respect existing QuotaRenews + if r, ok := session.AccessRights[k]; ok && !r.Limit.IsEmpty() { + ar.Limit.QuotaRenews = r.Limit.QuotaRenews + } + + rights[k] = ar + } + + // Master policy case + if len(policy.AccessRights) == 0 { + if !usePartitions || policy.Partitions.RateLimit { + session.Rate = policy.Rate + session.Per = policy.Per + session.Smoothing = policy.Smoothing + session.ThrottleInterval = policy.ThrottleInterval + session.ThrottleRetryLimit = policy.ThrottleRetryLimit + } + + if !usePartitions || policy.Partitions.Complexity { + session.MaxQueryDepth = policy.MaxQueryDepth + } + + if !usePartitions || policy.Partitions.Quota { + session.QuotaMax = policy.QuotaMax + session.QuotaRenewalRate = policy.QuotaRenewalRate + } + } + + if !session.HMACEnabled { + session.HMACEnabled = policy.HMACEnabled + } + + if !session.EnableHTTPSignatureValidation { + session.EnableHTTPSignatureValidation = policy.EnableHTTPSignatureValidation + } + + applyState.didPartition = usePartitions + + return nil +} + +func (t *Service) updateSessionRootVars(session *user.SessionState, rights map[string]user.AccessDefinition, applyState applyStatus) { + if len(applyState.didQuota) == 1 && len(applyState.didRateLimit) == 1 && len(applyState.didComplexity) == 1 { + for _, v := range rights { + if len(applyState.didRateLimit) == 1 { + session.Rate = v.Limit.Rate + session.Per = v.Limit.Per + session.Smoothing = v.Limit.Smoothing + } + + if len(applyState.didQuota) == 1 { + session.QuotaMax = v.Limit.QuotaMax + session.QuotaRenews = v.Limit.QuotaRenews + session.QuotaRenewalRate = v.Limit.QuotaRenewalRate + } + + if len(applyState.didComplexity) == 1 { + session.MaxQueryDepth = v.Limit.MaxQueryDepth + } + } + } +} + +func (t *Service) applyAPILevelLimits(policyAD user.AccessDefinition, currAD user.AccessDefinition) user.AccessDefinition { + var updated bool + if policyAD.Limit.Duration() > currAD.Limit.Duration() { + policyAD.Limit.Per = currAD.Limit.Per + policyAD.Limit.Rate = currAD.Limit.Rate + policyAD.Limit.Smoothing = currAD.Limit.Smoothing + updated = true + } + + if currAD.Limit.QuotaMax != policyAD.Limit.QuotaMax && greaterThanInt64(currAD.Limit.QuotaMax, policyAD.Limit.QuotaMax) { + policyAD.Limit.QuotaMax = currAD.Limit.QuotaMax + updated = true + } + + if greaterThanInt64(currAD.Limit.QuotaRenewalRate, policyAD.Limit.QuotaRenewalRate) { + policyAD.Limit.QuotaRenewalRate = currAD.Limit.QuotaRenewalRate + } + + if policyAD.Limit.QuotaMax == -1 { + policyAD.Limit.QuotaRenewalRate = 0 + } + + if updated { + policyAD.Limit.SetBy = currAD.Limit.SetBy + policyAD.AllowanceScope = currAD.AllowanceScope + } + + policyAD.Endpoints = t.ApplyEndpointLevelLimits(policyAD.Endpoints, currAD.Endpoints) + + return policyAD +} + +// ApplyEndpointLevelLimits combines policyEndpoints and currEndpoints and returns the combined value. +// The returned endpoints would have the highest request rate from policyEndpoints and currEndpoints. +func (t *Service) ApplyEndpointLevelLimits(policyEndpoints user.Endpoints, currEndpoints user.Endpoints) user.Endpoints { + currEPMap := currEndpoints.Map() + if len(currEPMap) == 0 { + return policyEndpoints + } + + result := policyEndpoints.Map() + if len(result) == 0 { + return currEPMap.Endpoints() + } + + for currEP, currRL := range currEPMap { + policyRL, ok := result[currEP] + if !ok { + // merge missing endpoints + result[currEP] = currRL + continue + } + + policyDur, currDur := policyRL.Duration(), currRL.Duration() + if policyDur > currDur { + result[currEP] = currRL + continue + } + + // when duration is equal, use higher rate and per + // eg. when 10 per 60 and 5 per 30 comes in + // Duration would be 6s each, in such a case higher rate of 10 per 60 would be picked up. + if policyDur == currDur && currRL.Rate > policyRL.Rate { + result[currEP] = currRL + } + } + + return result.Endpoints() +} diff --git a/internal/policy/apply_test.go b/internal/policy/apply_test.go new file mode 100644 index 00000000000..37b77e4a4a5 --- /dev/null +++ b/internal/policy/apply_test.go @@ -0,0 +1,1265 @@ +package policy_test + +import ( + "embed" + "encoding/json" + "fmt" + "slices" + "sort" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/TykTechnologies/graphql-go-tools/pkg/graphql" + "github.com/TykTechnologies/tyk/internal/policy" + "github.com/TykTechnologies/tyk/user" +) + +//go:embed testdata/*.json +var testDataFS embed.FS + +func TestApplyRateLimits_PolicyLimits(t *testing.T) { + t.Run("policy limits unset", func(t *testing.T) { + svc := &policy.Service{} + + session := &user.SessionState{ + Rate: 5, + Per: 10, + } + apiLimits := user.APILimit{ + RateLimit: user.RateLimit{ + Rate: 10, + Per: 10, + }, + } + policy := user.Policy{} + + svc.ApplyRateLimits(session, policy, &apiLimits) + + assert.Equal(t, 10, int(apiLimits.Rate)) + assert.Equal(t, 5, int(session.Rate)) + }) + + t.Run("policy limits apply all", func(t *testing.T) { + svc := &policy.Service{} + + session := &user.SessionState{ + Rate: 5, + Per: 10, + } + apiLimits := user.APILimit{ + RateLimit: user.RateLimit{ + Rate: 5, + Per: 10, + }, + } + policy := user.Policy{ + Rate: 10, + Per: 10, + } + + svc.ApplyRateLimits(session, policy, &apiLimits) + + assert.Equal(t, 10, int(apiLimits.Rate)) + assert.Equal(t, 10, int(session.Rate)) + }) + + // As the policy defined a higher rate than apiLimits, + // changes are applied to api limits, but skipped on + // the session as the session has a higher allowance. + t.Run("policy limits apply per-api", func(t *testing.T) { + svc := &policy.Service{} + + session := &user.SessionState{ + Rate: 15, + Per: 10, + } + apiLimits := user.APILimit{ + RateLimit: user.RateLimit{ + Rate: 5, + Per: 10, + }, + } + policy := user.Policy{ + Rate: 10, + Per: 10, + } + + svc.ApplyRateLimits(session, policy, &apiLimits) + + assert.Equal(t, 10, int(apiLimits.Rate)) + assert.Equal(t, 15, int(session.Rate)) + }) + + // As the policy defined a lower rate than apiLimits, + // no changes to api limits are applied. + t.Run("policy limits skip", func(t *testing.T) { + svc := &policy.Service{} + + session := &user.SessionState{ + Rate: 5, + Per: 10, + } + apiLimits := user.APILimit{ + RateLimit: user.RateLimit{Rate: 15, + Per: 10, + }, + } + policy := user.Policy{ + Rate: 10, + Per: 10, + } + + svc.ApplyRateLimits(session, policy, &apiLimits) + + assert.Equal(t, 15, int(apiLimits.Rate)) + assert.Equal(t, 10, int(session.Rate)) + }) +} + +func TestApplyRateLimits_FromCustomPolicies(t *testing.T) { + svc := &policy.Service{} + + session := &user.SessionState{} + session.SetCustomPolicies([]user.Policy{ + { + ID: "pol1", + Partitions: user.PolicyPartitions{RateLimit: true}, + Rate: 8, + Per: 1, + AccessRights: map[string]user.AccessDefinition{"a": {}}, + }, + { + ID: "pol2", + Partitions: user.PolicyPartitions{RateLimit: true}, + Rate: 10, + Per: 1, + AccessRights: map[string]user.AccessDefinition{"a": {}}, + }, + }) + + assert.NoError(t, svc.Apply(session)) + assert.Equal(t, 10, int(session.Rate)) +} + +func TestApplyEndpointLevelLimits(t *testing.T) { + f, err := testDataFS.ReadFile("testdata/apply_endpoint_rl.json") + assert.NoError(t, err) + + var testCases []struct { + Name string `json:"name"` + PolicyEP user.Endpoints `json:"policyEP"` + CurrEP user.Endpoints `json:"currEP"` + Expected user.Endpoints `json:"expected"` + } + err = json.Unmarshal(f, &testCases) + assert.NoError(t, err) + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + service := policy.Service{} + result := service.ApplyEndpointLevelLimits(tc.PolicyEP, tc.CurrEP) + assert.ElementsMatch(t, tc.Expected, result) + }) + } + +} + +type testApplyPoliciesData struct { + name string + policies []string + errMatch string // substring + sessMatch func(*testing.T, *user.SessionState) // ignored if nil + session *user.SessionState + // reverseOrder executes the tests in reversed order of policies, + // in addition to the order specified in policies + reverseOrder bool +} + +func testPrepareApplyPolicies(tb testing.TB) (*policy.Service, []testApplyPoliciesData) { + tb.Helper() + + f, err := testDataFS.ReadFile("testdata/policies.json") + assert.NoError(tb, err) + + var policies = make(map[string]user.Policy) + err = json.Unmarshal(f, &policies) + assert.NoError(tb, err) + + var repoPols = make(map[string]user.Policy) + err = json.Unmarshal(f, &repoPols) + assert.NoError(tb, err) + + store := policy.NewStoreMap(repoPols) + orgID := "" + service := policy.New(&orgID, store, logrus.StandardLogger()) + + // splitting tests for readability + var tests []testApplyPoliciesData + + nilSessionTCs := []testApplyPoliciesData{ + { + "Empty", nil, + "", nil, nil, false, + }, + { + "Single", []string{"nonpart1"}, + "", nil, nil, false, + }, + { + "Missing", []string{"nonexistent"}, + "not found", nil, nil, false, + }, + { + "DiffOrg", []string{"difforg"}, + "different org", nil, nil, false, + }, + } + tests = append(tests, nilSessionTCs...) + + nonPartitionedTCs := []testApplyPoliciesData{ + { + name: "MultiNonPart", + policies: []string{"nonpart1", "nonpart2", "nonexistent"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + + want := map[string]user.AccessDefinition{ + "a": { + Limit: user.APILimit{}, + AllowanceScope: "p1", + }, + "b": { + Limit: user.APILimit{}, + AllowanceScope: "p2", + }, + } + + assert.Equal(t, want, s.AccessRights) + }, + }, + { + name: "MultiACLPolicy", + policies: []string{"nonpart3"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + + want := map[string]user.AccessDefinition{ + "a": { + Limit: user.APILimit{}, + }, + "b": { + Limit: user.APILimit{}, + }, + } + + assert.Equal(t, want, s.AccessRights) + }, + }, + } + tests = append(tests, nonPartitionedTCs...) + + quotaPartitionTCs := []testApplyPoliciesData{ + { + "QuotaPart with unlimited", []string{"unlimited-quota"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.QuotaMax != -1 { + t.Fatalf("want unlimited quota to be -1") + } + }, nil, false, + }, + { + "QuotaPart", []string{"quota1"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.QuotaMax != 2 { + t.Fatalf("want QuotaMax to be 2") + } + }, nil, false, + }, + { + "QuotaParts", []string{"quota1", "quota2"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.QuotaMax != 3 { + t.Fatalf("Should pick bigger value") + } + }, nil, false, + }, + { + "QuotaParts with acl", []string{"quota5", "quota4"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.Equal(t, int64(4), s.QuotaMax) + }, nil, false, + }, + { + "QuotaPart with access rights", []string{"quota3"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.QuotaMax != 3 { + t.Fatalf("quota should be the same as policy quota") + } + }, nil, false, + }, + { + "QuotaPart with access rights in multi-policy", []string{"quota4", "nonpart1"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.QuotaMax != 3 { + t.Fatalf("quota should be the same as policy quota") + } + + // Don't apply api 'b' coming from quota4 policy + want := map[string]user.AccessDefinition{"a": {Limit: user.APILimit{}}} + assert.Equal(t, want, s.AccessRights) + }, nil, false, + }, + } + tests = append(tests, quotaPartitionTCs...) + + rateLimitPartitionTCs := []testApplyPoliciesData{ + { + "RatePart with unlimited", []string{"unlimited-rate"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.True(t, s.Rate <= 0, "want unlimited rate to be <= 0") + }, nil, false, + }, + { + "RatePart", []string{"rate1"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.Rate != 3 { + t.Fatalf("want Rate to be 3") + } + }, nil, false, + }, + { + "RateParts", []string{"rate1", "rate2"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.Rate != 4 { + t.Fatalf("Should pick bigger value") + } + }, nil, false, + }, + { + "RateParts with acl", []string{"rate5", "rate4"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.Equal(t, float64(10), s.Rate) + }, nil, false, + }, + { + "RateParts with acl respected by session", []string{"rate4", "rate5"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.Equal(t, float64(10), s.Rate) + }, &user.SessionState{Rate: 20}, false, + }, + { + "Rate with no partition respected by session", []string{"rate-no-partition"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.Equal(t, float64(12), s.Rate) + }, &user.SessionState{Rate: 20}, false, + }, + } + tests = append(tests, rateLimitPartitionTCs...) + + complexityPartitionTCs := []testApplyPoliciesData{ + { + "ComplexityPart with unlimited", []string{"unlimitedComplexity"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.MaxQueryDepth != -1 { + t.Fatalf("unlimitied query depth should be -1") + } + }, nil, false, + }, + { + "ComplexityPart", []string{"complexity1"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.MaxQueryDepth != 2 { + t.Fatalf("want MaxQueryDepth to be 2") + } + }, nil, false, + }, + { + "ComplexityParts", []string{"complexity1", "complexity2"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.MaxQueryDepth != 3 { + t.Fatalf("Should pick bigger value") + } + }, nil, false, + }, + } + tests = append(tests, complexityPartitionTCs...) + + aclPartitionTCs := []testApplyPoliciesData{ + { + "AclPart", []string{"acl1"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + want := map[string]user.AccessDefinition{"a": {Limit: user.APILimit{}}} + + assert.Equal(t, want, s.AccessRights) + }, nil, false, + }, + { + "AclPart", []string{"acl1", "acl2"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + want := map[string]user.AccessDefinition{"a": {Limit: user.APILimit{}}, "b": {Limit: user.APILimit{}}} + assert.Equal(t, want, s.AccessRights) + }, nil, false, + }, + { + "Acl for a and rate for a,b", []string{"acl1", "rate-for-a-b"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + want := map[string]user.AccessDefinition{"a": {Limit: user.APILimit{RateLimit: user.RateLimit{Rate: 4, Per: 1}}}} + assert.Equal(t, want, s.AccessRights) + }, nil, false, + }, + { + "Acl for a,b and individual rate for a,b", []string{"acl-for-a-b", "rate-for-a", "rate-for-b"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + want := map[string]user.AccessDefinition{ + "a": {Limit: user.APILimit{RateLimit: user.RateLimit{Rate: 4, Per: 1}}}, + "b": {Limit: user.APILimit{RateLimit: user.RateLimit{Rate: 2, Per: 1}}}, + } + assert.Equal(t, want, s.AccessRights) + }, nil, false, + }, + { + "RightsUpdate", []string{"acl-for-a-b"}, + "", func(t *testing.T, ses *user.SessionState) { + t.Helper() + expectedAccessRights := map[string]user.AccessDefinition{"a": {Limit: user.APILimit{}}, "b": {Limit: user.APILimit{}}} + assert.Equal(t, expectedAccessRights, ses.AccessRights) + }, &user.SessionState{ + AccessRights: map[string]user.AccessDefinition{ + "c": {Limit: user.APILimit{}}, + }, + }, false, + }, + } + tests = append(tests, aclPartitionTCs...) + + inactiveTCs := []testApplyPoliciesData{ + { + "InactiveMergeOne", []string{"tags1", "inactive1"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if !s.IsInactive { + t.Fatalf("want IsInactive to be true") + } + }, nil, false, + }, + { + "InactiveMergeAll", []string{"inactive1", "inactive2"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if !s.IsInactive { + t.Fatalf("want IsInactive to be true") + } + }, nil, false, + }, + { + "InactiveWithSession", []string{"tags1", "tags2"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if !s.IsInactive { + t.Fatalf("want IsInactive to be true") + } + }, &user.SessionState{ + IsInactive: true, + }, false, + }, + } + tests = append(tests, inactiveTCs...) + + perAPITCs := []testApplyPoliciesData{ + { + name: "Per API is set with other partitions to true", + policies: []string{"per_api_and_partitions"}, + errMatch: "cannot apply policy per_api_and_partitions which has per_api and any of partitions set", + }, + { + name: "Per API is set to true with some partitions set to true", + policies: []string{"per_api_and_some_partitions"}, + errMatch: "cannot apply policy per_api_and_some_partitions which has per_api and any of partitions set", + }, + { + name: "Per API is set to true with no other partitions set to true", + policies: []string{"per_api_and_no_other_partitions"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + + want := map[string]user.AccessDefinition{ + "c": { + Limit: user.APILimit{ + RateLimit: user.RateLimit{ + Rate: 2000, + Per: 60, + }, + QuotaMax: -1, + }, + AllowanceScope: "c", + }, + "d": { + Limit: user.APILimit{ + RateLimit: user.RateLimit{ + Rate: 20, + Per: 1, + }, + QuotaMax: 1000, + QuotaRenewalRate: 3600, + }, + AllowanceScope: "d", + }, + } + assert.Equal(t, want, s.AccessRights) + }, + }, + { + name: "several policies with Per API set to true specifying limit for the same API", + policies: []string{"per_api_and_no_other_partitions", "per_api_with_api_d"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + want := map[string]user.AccessDefinition{ + "c": { + Limit: user.APILimit{ + RateLimit: user.RateLimit{ + Rate: 2000, + Per: 60, + }, + QuotaMax: -1, + }, + AllowanceScope: "c", + }, + "d": { + Limit: user.APILimit{ + RateLimit: user.RateLimit{ + Rate: 200, + Per: 10, + }, + QuotaMax: 5000, + QuotaRenewalRate: 3600, + }, + AllowanceScope: "d", + }, + } + assert.Equal(t, want, s.AccessRights) + }, + }, + { + name: "several policies with Per API set to true specifying limit for the same APIs", + policies: []string{"per_api_and_no_other_partitions", "per_api_with_api_d", "per_api_with_api_c"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + want := map[string]user.AccessDefinition{ + "c": { + Limit: user.APILimit{ + RateLimit: user.RateLimit{ + Rate: 3000, + Per: 10, + }, + QuotaMax: -1, + }, + AllowanceScope: "c", + }, + "d": { + Limit: user.APILimit{ + RateLimit: user.RateLimit{ + Rate: 200, + Per: 10, + }, + QuotaMax: 5000, + QuotaRenewalRate: 3600, + }, + AllowanceScope: "d", + }, + } + assert.Equal(t, want, s.AccessRights) + }, + }, + { + name: "several policies, mixed the one which has Per API set to true and partitioned ones", + policies: []string{"per_api_with_api_d", "quota1"}, + errMatch: "cannot apply multiple policies when some have per_api set and some are partitioned", + }, + { + name: "several policies, mixed the one which has Per API set to true and partitioned ones (different order)", + policies: []string{"rate1", "per_api_with_api_d"}, + errMatch: "cannot apply multiple policies when some have per_api set and some are partitioned", + }, + { + name: "Per API is set to true and some API gets limit set from policy's fields", + policies: []string{"per_api_with_limit_set_from_policy"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + want := map[string]user.AccessDefinition{ + "e": { + Limit: user.APILimit{ + QuotaMax: -1, + RateLimit: user.RateLimit{ + Rate: 300, + Per: 1, + }, + }, + AllowanceScope: "per_api_with_limit_set_from_policy", + }, + "d": { + Limit: user.APILimit{ + QuotaMax: 5000, + QuotaRenewalRate: 3600, + RateLimit: user.RateLimit{ + Rate: 200, + Per: 10, + }, + }, + AllowanceScope: "d", + }, + } + assert.Equal(t, want, s.AccessRights) + }, + }, + { + name: "Per API with limits override", + policies: []string{ + "per_api_with_limit_set_from_policy", + "per_api_with_api_d", + "per_api_with_higher_rate_on_api_d", + }, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + want := map[string]user.AccessDefinition{ + "e": { + Limit: user.APILimit{ + QuotaMax: -1, + RateLimit: user.RateLimit{ + Rate: 300, + Per: 1, + }, + }, + AllowanceScope: "per_api_with_limit_set_from_policy", + }, + "d": { + Limit: user.APILimit{ + QuotaMax: 5000, + QuotaRenewalRate: 3600, + RateLimit: user.RateLimit{ + Rate: 200, + Per: 10, + }, + }, + AllowanceScope: "d", + }, + } + assert.Equal(t, want, s.AccessRights) + }, + }, + } + tests = append(tests, perAPITCs...) + + graphQLTCs := []testApplyPoliciesData{ + { + name: "Merge per path rules for the same API", + policies: []string{"per-path2", "per-path1"}, + sessMatch: func(t *testing.T, sess *user.SessionState) { + t.Helper() + want := map[string]user.AccessDefinition{ + "a": { + AllowedURLs: []user.AccessSpec{ + {URL: "/user", Methods: []string{"GET", "POST"}}, + {URL: "/companies", Methods: []string{"GET", "POST"}}, + }, + Limit: user.APILimit{}, + }, + "b": { + AllowedURLs: []user.AccessSpec{ + {URL: "/", Methods: []string{"PUT"}}, + }, + Limit: user.APILimit{}, + }, + } + + gotPolicy, ok := store.PolicyByID("per-path2") + + assert.True(t, ok) + assert.Equal(t, user.AccessSpec{ + URL: "/user", Methods: []string{"GET"}, + }, gotPolicy.AccessRights["a"].AllowedURLs[0]) + + assert.Equal(t, want, sess.AccessRights) + }, + }, + { + name: "Merge restricted fields for the same GraphQL API", + policies: []string{"restricted-types1", "restricted-types2"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + + want := map[string]user.AccessDefinition{ + "a": { // It should get intersection of restricted types. + RestrictedTypes: []graphql.Type{ + {Name: "Country", Fields: []string{"code"}}, + {Name: "Person", Fields: []string{"name"}}, + }, + Limit: user.APILimit{}, + }, + } + + assert.Equal(t, want, s.AccessRights) + }, + }, + { + name: "Merge allowed fields for the same GraphQL API", + policies: []string{"allowed-types1", "allowed-types2"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + + want := map[string]user.AccessDefinition{ + "a": { // It should get intersection of restricted types. + AllowedTypes: []graphql.Type{ + {Name: "Country", Fields: []string{"code"}}, + {Name: "Person", Fields: []string{"name"}}, + }, + Limit: user.APILimit{}, + }, + } + + assert.Equal(t, want, s.AccessRights) + }, + }, + { + name: "If GQL introspection is disabled, it remains disabled after merging", + policies: []string{"introspection-disabled", "introspection-enabled"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + + want := map[string]user.AccessDefinition{ + "a": { + DisableIntrospection: true, // If GQL introspection is disabled, it remains disabled after merging. + Limit: user.APILimit{}, + }, + } + + assert.Equal(t, want, s.AccessRights) + }, + }, + { + name: "Merge field level depth limit for the same GraphQL API", + policies: []string{"field-level-depth-limit1", "field-level-depth-limit2"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + + want := map[string]user.AccessDefinition{ + "graphql-api": { + Limit: user.APILimit{}, + FieldAccessRights: []user.FieldAccessDefinition{ + {TypeName: "Query", FieldName: "people", Limits: user.FieldLimits{MaxQueryDepth: 4}}, + {TypeName: "Mutation", FieldName: "putPerson", Limits: user.FieldLimits{MaxQueryDepth: -1}}, + {TypeName: "Query", FieldName: "countries", Limits: user.FieldLimits{MaxQueryDepth: 3}}, + {TypeName: "Query", FieldName: "continents", Limits: user.FieldLimits{MaxQueryDepth: 4}}, + }, + }, + } + + assert.Equal(t, want, s.AccessRights) + }, + }, + } + tests = append(tests, graphQLTCs...) + + throttleTCs := []testApplyPoliciesData{ + { + "Throttle interval from policy", []string{"throttle1"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + if s.ThrottleInterval != 9 { + t.Fatalf("Throttle interval should be 9 inherited from policy") + } + }, nil, false, + }, + { + name: "Throttle retry limit from policy", + policies: []string{"throttle1"}, + errMatch: "", + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + + if s.ThrottleRetryLimit != 99 { + t.Fatalf("Throttle interval should be 9 inherited from policy") + } + }, + session: nil, + }, + } + tests = append(tests, throttleTCs...) + + tagsTCs := []testApplyPoliciesData{ + { + "TagMerge", []string{"tags1", "tags2"}, + "", func(t *testing.T, s *user.SessionState) { + t.Helper() + want := []string{"key-tag", "tagA", "tagX", "tagY"} + sort.Strings(s.Tags) + + assert.Equal(t, want, s.Tags) + }, &user.SessionState{ + Tags: []string{"key-tag"}, + }, false, + }, + } + tests = append(tests, tagsTCs...) + + partitionTCs := []testApplyPoliciesData{ + { + "NonpartAndPart", []string{"nonpart1", "quota1"}, + "", nil, nil, false, + }, + { + name: "inherit quota and rate from partitioned policies", + policies: []string{"quota1", "rate3"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + + if s.QuotaMax != 2 { + t.Fatalf("quota should be the same as quota policy") + } + if s.Rate != 4 { + t.Fatalf("rate should be the same as rate policy") + } + if s.Per != 4 { + t.Fatalf("Rate per seconds should be the same as rate policy") + } + }, + }, + { + name: "inherit quota and rate from partitioned policies applied in different order", + policies: []string{"rate3", "quota1"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + + if s.QuotaMax != 2 { + t.Fatalf("quota should be the same as quota policy") + } + if s.Rate != 4 { + t.Fatalf("rate should be the same as rate policy") + } + if s.Per != 4 { + t.Fatalf("Rate per seconds should be the same as rate policy") + } + }, + }, + } + tests = append(tests, partitionTCs...) + + endpointRLTCs := []testApplyPoliciesData{ + { + name: "Per API and per endpoint policies", + policies: []string{"per_api_with_limit_set_from_policy", "per_api_with_endpoint_limits_on_d_and_e"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + endpointsConfig := user.Endpoints{ + { + Path: "/get", + Methods: user.EndpointMethods{ + { + Name: "GET", + Limit: user.RateLimit{ + Rate: -1, + }, + }, + }, + }, + { + Path: "/post", + Methods: user.EndpointMethods{ + { + Name: "POST", + Limit: user.RateLimit{ + Rate: 300, + Per: 10, + }, + }, + }, + }, + } + want := map[string]user.AccessDefinition{ + "e": { + Limit: user.APILimit{ + QuotaMax: -1, + RateLimit: user.RateLimit{ + Rate: 500, + Per: 1, + }, + }, + AllowanceScope: "per_api_with_endpoint_limits_on_d_and_e", + Endpoints: endpointsConfig, + }, + "d": { + Limit: user.APILimit{ + QuotaMax: 5000, + QuotaRenewalRate: 3600, + RateLimit: user.RateLimit{ + Rate: 200, + Per: 10, + }, + }, + AllowanceScope: "d", + Endpoints: endpointsConfig, + }, + } + assert.Equal(t, want, s.AccessRights) + }, + reverseOrder: true, + }, + { + name: "Endpoint level limits overlapping", + policies: []string{ + "per_api_with_limit_set_from_policy", + "per_api_with_endpoint_limits_on_d_and_e", + "per_endpoint_limits_different_on_api_d", + }, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + apiEEndpoints := user.Endpoints{ + { + Path: "/get", + Methods: user.EndpointMethods{ + { + Name: "GET", + Limit: user.RateLimit{ + Rate: -1, + }, + }, + }, + }, + { + Path: "/post", + Methods: user.EndpointMethods{ + { + Name: "POST", + Limit: user.RateLimit{ + Rate: 300, + Per: 10, + }, + }, + }, + }, + } + + assert.ElementsMatch(t, apiEEndpoints, s.AccessRights["e"].Endpoints) + + apiDEndpoints := user.Endpoints{ + { + Path: "/get", + Methods: user.EndpointMethods{ + { + Name: "GET", + Limit: user.RateLimit{ + Rate: -1, + }, + }, + }, + }, + { + Path: "/post", + Methods: user.EndpointMethods{ + { + Name: "POST", + Limit: user.RateLimit{ + Rate: 400, + Per: 11, + }, + }, + }, + }, + { + Path: "/anything", + Methods: user.EndpointMethods{ + { + Name: "PUT", + Limit: user.RateLimit{ + Rate: 500, + Per: 10, + }, + }, + }, + }, + } + + assert.ElementsMatch(t, apiDEndpoints, s.AccessRights["d"].Endpoints) + + apiELimits := user.APILimit{ + QuotaMax: -1, + RateLimit: user.RateLimit{ + Rate: 500, + Per: 1, + }, + } + assert.Equal(t, apiELimits, s.AccessRights["e"].Limit) + + apiDLimits := user.APILimit{ + QuotaMax: 5000, + QuotaRenewalRate: 3600, + RateLimit: user.RateLimit{ + Rate: 200, + Per: 10, + }, + } + assert.Equal(t, apiDLimits, s.AccessRights["d"].Limit) + }, + reverseOrder: true, + }, + { + name: "endpoint_rate_limits_on_acl_partition_only", + policies: []string{"endpoint_rate_limits_on_acl_partition_only"}, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.NotEmpty(t, s.AccessRights) + assert.Empty(t, s.AccessRights["d"].Endpoints) + }, + }, + { + name: "endpoint_rate_limits_when_acl_and_quota_partitions_combined", + policies: []string{ + "endpoint_rate_limits_on_acl_partition_only", + "endpoint_rate_limits_on_quota_partition_only", + }, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.NotEmpty(t, s.AccessRights) + assert.Empty(t, s.AccessRights["d"].Endpoints) + }, + reverseOrder: true, + }, + } + + tests = append(tests, endpointRLTCs...) + + combinedEndpointRLTCs := []testApplyPoliciesData{ + { + name: "combine_non_partitioned_policies_with_endpoint_rate_limits_configured_on_api_d", + policies: []string{ + "api_d_get_endpoint_rl_1_configure_on_non_partitioned_policy", + "api_d_get_endpoint_rl_2_configure_on_non_partitioned_policy", + }, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.NotEmpty(t, s.AccessRights) + apiDEndpoints := user.Endpoints{ + { + Path: "/get", + Methods: user.EndpointMethods{ + { + Name: "GET", + Limit: user.RateLimit{ + Rate: 20, + Per: 60, + }, + }, + }, + }, + } + + assert.ElementsMatch(t, apiDEndpoints, s.AccessRights["d"].Endpoints) + }, + reverseOrder: true, + }, + { + name: "combine_non_partitioned_policies_with_endpoint_rate_limits_no_bound_configured_on_api_d", + policies: []string{ + "api_d_get_endpoint_rl_1_configure_on_non_partitioned_policy", + "api_d_get_endpoint_rl_2_configure_on_non_partitioned_policy", + "api_d_get_endpoint_rl_3_configure_on_non_partitioned_policy", + }, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.NotEmpty(t, s.AccessRights) + apiDEndpoints := user.Endpoints{ + { + Path: "/get", + Methods: user.EndpointMethods{ + { + Name: "GET", + Limit: user.RateLimit{ + Rate: -1, + }, + }, + }, + }, + } + + assert.ElementsMatch(t, apiDEndpoints, s.AccessRights["d"].Endpoints) + }, + reverseOrder: true, + }, + { + name: "combine_non_partitioned_policies_with_multiple_endpoint_rate_limits_configured_on_api_d", + policies: []string{ + "api_d_get_endpoint_rl_1_configure_on_non_partitioned_policy", + "api_d_get_endpoint_rl_2_configure_on_non_partitioned_policy", + "api_d_get_endpoint_rl_3_configure_on_non_partitioned_policy", + "api_d_post_endpoint_rl_1_configure_on_non_partitioned_policy", + }, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.NotEmpty(t, s.AccessRights) + apiDEndpoints := user.Endpoints{ + { + Path: "/get", + Methods: user.EndpointMethods{ + { + Name: "GET", + Limit: user.RateLimit{ + Rate: -1, + }, + }, + }, + }, + { + Path: "/post", + Methods: user.EndpointMethods{ + { + Name: "POST", + Limit: user.RateLimit{ + Rate: 20, + Per: 60, + }, + }, + }, + }, + } + + assert.ElementsMatch(t, apiDEndpoints, s.AccessRights["d"].Endpoints) + }, + reverseOrder: true, + }, + { + name: "combine_non_partitioned_policies_with_endpoint_rate_limits_configured_on_api_d_and_e", + policies: []string{ + "api_d_get_endpoint_rl_1_configure_on_non_partitioned_policy", + "api_d_get_endpoint_rl_2_configure_on_non_partitioned_policy", + "api_d_get_endpoint_rl_3_configure_on_non_partitioned_policy", + "api_d_post_endpoint_rl_1_configure_on_non_partitioned_policy", + "api_e_get_endpoint_rl_1_configure_on_non_partitioned_policy", + }, + sessMatch: func(t *testing.T, s *user.SessionState) { + t.Helper() + assert.NotEmpty(t, s.AccessRights) + apiDEndpoints := user.Endpoints{ + { + Path: "/get", + Methods: user.EndpointMethods{ + { + Name: "GET", + Limit: user.RateLimit{ + Rate: -1, + }, + }, + }, + }, + { + Path: "/post", + Methods: user.EndpointMethods{ + { + Name: "POST", + Limit: user.RateLimit{ + Rate: 20, + Per: 60, + }, + }, + }, + }, + } + + assert.ElementsMatch(t, apiDEndpoints, s.AccessRights["d"].Endpoints) + + apiEEndpoints := user.Endpoints{ + { + Path: "/get", + Methods: user.EndpointMethods{ + { + Name: "GET", + Limit: user.RateLimit{ + Rate: 100, + Per: 60, + }, + }, + }, + }, + } + + assert.ElementsMatch(t, apiEEndpoints, s.AccessRights["e"].Endpoints) + }, + reverseOrder: true, + }, + } + + tests = append(tests, combinedEndpointRLTCs...) + + return service, tests +} + +func TestService_Apply(t *testing.T) { + service, tests := testPrepareApplyPolicies(t) + + for _, tc := range tests { + pols := [][]string{tc.policies} + if tc.reverseOrder { + var copyPols = make([]string, len(tc.policies)) + copy(copyPols, tc.policies) + slices.Reverse(copyPols) + pols = append(pols, copyPols) + } + + for i, policies := range pols { + name := tc.name + if i == 1 { + name = fmt.Sprintf("%s, reversed=%t", name, true) + } + + t.Run(name, func(t *testing.T) { + sess := tc.session + if sess == nil { + sess = &user.SessionState{} + } + sess.SetPolicies(policies...) + if err := service.Apply(sess); err != nil { + assert.ErrorContains(t, err, tc.errMatch) + return + } + + if tc.sessMatch != nil { + tc.sessMatch(t, sess) + } + }) + } + } +} + +func BenchmarkService_Apply(b *testing.B) { + b.ReportAllocs() + + service, tests := testPrepareApplyPolicies(b) + + for i := 0; i < b.N; i++ { + for _, tc := range tests { + sess := &user.SessionState{} + sess.SetPolicies(tc.policies...) + err := service.Apply(sess) + assert.NoError(b, err) + } + } +} diff --git a/internal/policy/rpc.go b/internal/policy/rpc.go new file mode 100644 index 00000000000..31c1db8b9ab --- /dev/null +++ b/internal/policy/rpc.go @@ -0,0 +1,42 @@ +package policy + +import ( + "encoding/json" + + "github.com/TykTechnologies/tyk/internal/model" + "github.com/TykTechnologies/tyk/user" +) + +// RPCDataLoaderMock is a policy-related test utility. +type RPCDataLoaderMock struct { + ShouldConnect bool + Policies []user.Policy + Apis []model.MergedAPI +} + +// Connect will return the connection status. +func (s *RPCDataLoaderMock) Connect() bool { + return s.ShouldConnect +} + +// GetApiDefinitions returns the internal Apis as a json string. +func (s *RPCDataLoaderMock) GetApiDefinitions(_ string, tags []string) string { + if len(tags) > 1 { + panic("not implemented") + } + + apiList, err := json.Marshal(s.Apis) + if err != nil { + return "" + } + return string(apiList) +} + +// GetPolicies returns the internal Policies as a json string. +func (s *RPCDataLoaderMock) GetPolicies(_ string) string { + policyList, err := json.Marshal(s.Policies) + if err != nil { + return "" + } + return string(policyList) +} diff --git a/internal/policy/store.go b/internal/policy/store.go new file mode 100644 index 00000000000..7829659db89 --- /dev/null +++ b/internal/policy/store.go @@ -0,0 +1,48 @@ +package policy + +import ( + "github.com/TykTechnologies/tyk/user" +) + +// Store is an in-memory policy storage object that implements the +// repository for policy access. We do not implement concurrency +// protections here. Where order is important, use this. +type Store struct { + policies []user.Policy +} + +// NewStore returns a new policy.Store. +func NewStore(policies []user.Policy) *Store { + return &Store{ + policies: policies, + } +} + +// PolicyIDs returns a list policy IDs in the store. +// It will return nil if no policies exist. +func (s *Store) PolicyIDs() []string { + if len(s.policies) == 0 { + return nil + } + + policyIDs := make([]string, 0, len(s.policies)) + for _, val := range s.policies { + policyIDs = append(policyIDs, val.ID) + } + return policyIDs +} + +// PolicyByID returns a policy by ID. +func (s *Store) PolicyByID(id string) (user.Policy, bool) { + for _, pol := range s.policies { + if pol.ID == id { + return pol, true + } + } + return user.Policy{}, false +} + +// PolicyCount returns the number of policies in the store. +func (s *Store) PolicyCount() int { + return len(s.policies) +} diff --git a/internal/policy/store_map.go b/internal/policy/store_map.go new file mode 100644 index 00000000000..a035c320a4a --- /dev/null +++ b/internal/policy/store_map.go @@ -0,0 +1,46 @@ +package policy + +import ( + "github.com/TykTechnologies/tyk/user" +) + +// StoreMap is same as Store, but doesn't preserve order. +type StoreMap struct { + policies map[string]user.Policy +} + +// NewStoreMap returns a new policy.StoreMap. +func NewStoreMap(policies map[string]user.Policy) *StoreMap { + if len(policies) == 0 { + policies = make(map[string]user.Policy) + } + + return &StoreMap{ + policies: policies, + } +} + +// PolicyIDs returns a list policy IDs in the store. +// It will return nil if no policies exist. +func (s *StoreMap) PolicyIDs() []string { + if len(s.policies) == 0 { + return nil + } + + policyIDs := make([]string, 0, len(s.policies)) + for _, val := range s.policies { + policyIDs = append(policyIDs, val.ID) + } + return policyIDs +} + +// PolicyByID returns a policy by ID. +func (s *StoreMap) PolicyByID(id string) (user.Policy, bool) { + v, ok := s.policies[id] + return v, ok +} + +// PolicyCount returns the number of policies in the store. +func (s *StoreMap) PolicyCount() int { + return len(s.policies) +} diff --git a/internal/policy/util.go b/internal/policy/util.go new file mode 100644 index 00000000000..8558fed0800 --- /dev/null +++ b/internal/policy/util.go @@ -0,0 +1,105 @@ +package policy + +import ( + "slices" + + "github.com/TykTechnologies/tyk/user" +) + +// MergeAllowedURLs will merge s1 and s2 to produce a merged result. +// It maintains order of keys in s1 and s2 as they are seen. +// If the result is an empty set, nil is returned. +func MergeAllowedURLs(s1, s2 []user.AccessSpec) []user.AccessSpec { + order := []string{} + merged := map[string][]string{} + + // Loop input sets and merge through a map. + for _, src := range [][]user.AccessSpec{s1, s2} { + for _, r := range src { + url := r.URL + v, ok := merged[url] + if !ok { + // First time we see the spec + merged[url] = r.Methods + + // Maintain order + order = append(order, url) + + continue + } + merged[url] = appendIfMissing(v, r.Methods...) + } + } + + // Early exit without allocating. + if len(order) == 0 { + return nil + } + + // Provide results in desired order. + result := make([]user.AccessSpec, 0, len(order)) + for _, key := range order { + spec := user.AccessSpec{ + Methods: merged[key], + URL: key, + } + result = append(result, spec) + } + return result +} + +// appendIfMissing ensures dest slice is unique with new items. +func appendIfMissing(dest []string, in ...string) []string { + for _, v := range in { + if slices.Contains(dest, v) { + continue + } + dest = append(dest, v) + } + return dest +} + +// intersection gets intersection of the given two slices. +func intersection(a []string, b []string) (inter []string) { + m := make(map[string]bool) + + for _, item := range a { + m[item] = true + } + + for _, item := range b { + if _, ok := m[item]; ok { + inter = append(inter, item) + } + } + + return +} + +// greaterThanInt64 checks whether first int64 value is bigger than second int64 value. +// -1 means infinite and the biggest value. +func greaterThanInt64(first, second int64) bool { + if first == -1 { + return true + } + + if second == -1 { + return false + } + + return first > second +} + +// greaterThanInt checks whether first int value is bigger than second int value. +// -1 means infinite and the biggest value. +func greaterThanInt(first, second int) bool { + if first == -1 { + return true + } + + if second == -1 { + return false + } + + return first > second +} diff --git a/internal/policy/util_test.go b/internal/policy/util_test.go new file mode 100644 index 00000000000..460d0cfb119 --- /dev/null +++ b/internal/policy/util_test.go @@ -0,0 +1,64 @@ +package policy_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/TykTechnologies/tyk/internal/policy" + "github.com/TykTechnologies/tyk/user" +) + +func TestMergeAllowedURLs(t *testing.T) { + svc := &policy.Service{} + + session := &user.SessionState{} + policies := []user.Policy{ + { + ID: "pol1", + AccessRights: map[string]user.AccessDefinition{ + "a": { + AllowedURLs: []user.AccessSpec{ + {URL: "/user", Methods: []string{"GET"}}, + {URL: "/companies", Methods: []string{"GET"}}, + }, + }, + }, + }, + { + ID: "pol2", + AccessRights: map[string]user.AccessDefinition{ + "a": { + AllowedURLs: []user.AccessSpec{ + {URL: "/user", Methods: []string{"POST", "PATCH", "PUT"}}, + {URL: "/companies", Methods: []string{"POST"}}, + {URL: "/admin", Methods: []string{"GET", "POST"}}, + }, + }, + }, + }, + { + ID: "pol3", + AccessRights: map[string]user.AccessDefinition{ + "a": { + AllowedURLs: []user.AccessSpec{ + {URL: "/admin/cache", Methods: []string{"DELETE"}}, + }, + }, + }, + }, + } + + session.SetCustomPolicies(policies) + + assert.NoError(t, svc.Apply(session)) + + want := []user.AccessSpec{ + {URL: "/user", Methods: []string{"GET", "POST", "PATCH", "PUT"}}, + {URL: "/companies", Methods: []string{"GET", "POST"}}, + {URL: "/admin", Methods: []string{"GET", "POST"}}, + {URL: "/admin/cache", Methods: []string{"DELETE"}}, + } + + assert.Equal(t, want, session.AccessRights["a"].AllowedURLs) +} diff --git a/tests/policy/Taskfile.yml b/tests/policy/Taskfile.yml new file mode 100644 index 00000000000..f5ab84d0aa9 --- /dev/null +++ b/tests/policy/Taskfile.yml @@ -0,0 +1,53 @@ +--- +version: "3" + +includes: + services: + taskfile: ../../docker/services/Taskfile.yml + dir: ../../docker/services + +vars: + coverage: policy.cov + testArgs: -v + +tasks: + test: + desc: "Run tests (requires redis)" + deps: [ services:up ] + cmds: + - defer: + task: services:down + - task: fmt + - go test {{.testArgs}} -count=1 -cover -coverprofile={{.coverage}} -coverpkg=./... ./... + + bench: + desc: "Run benchmarks" + cmds: + - task: fmt + - go test {{.testArgs}} -count=1 -tags integration -run=^$ -bench=. -benchtime=10s -benchmem ./... + + fmt: + internal: true + desc: "Invoke fmt" + cmds: + - goimports -w . + - go fmt ./... + + cover: + desc: "Show source coverage" + aliases: [coverage, cov] + cmds: + - go tool cover -func={{.coverage}} + + uncover: + desc: "Show uncovered source" + cmds: + - uncover {{.coverage}} + + install:uncover: + desc: "Install uncover" + internal: true + env: + GOBIN: /usr/local/bin + cmds: + - go install github.com/gregoryv/uncover/...@latest diff --git a/tests/policy/allowed_urls_test.go b/tests/policy/allowed_urls_test.go new file mode 100644 index 00000000000..a81cc5c3ed2 --- /dev/null +++ b/tests/policy/allowed_urls_test.go @@ -0,0 +1,172 @@ +package policy + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/TykTechnologies/tyk/internal/policy" + "github.com/TykTechnologies/tyk/internal/uuid" + "github.com/TykTechnologies/tyk/test" + "github.com/TykTechnologies/tyk/user" +) + +// The integration test. +func TestAllowedURLs(t *testing.T) { + ts := StartTest(nil) + t.Cleanup(ts.Close) + + policyBase := user.Policy{ + ID: uuid.New(), + Per: 1, + Rate: 1000, + QuotaMax: 50, + QuotaRenewalRate: 3600, + OrgID: DefaultOrg, + AccessRights: map[string]user.AccessDefinition{ + "api1": { + Versions: []string{"v1"}, + Limit: user.APILimit{ + QuotaMax: 100, + QuotaRenewalRate: 3600, + RateLimit: user.RateLimit{ + Rate: 1000, + Per: 1, + }, + }, + AllowedURLs: []user.AccessSpec{ + {URL: "/user", Methods: []string{"GET"}}, + {URL: "/companies", Methods: []string{"GET"}}, + }, + }, + "api2": { + Versions: []string{"v1"}, + Limit: user.APILimit{ + QuotaMax: 200, + QuotaRenewalRate: 3600, + RateLimit: user.RateLimit{ + Rate: 1000, + Per: 1, + }, + }, + AllowedURLs: []user.AccessSpec{ + {URL: "/user", Methods: []string{"POST", "PATCH", "PUT"}}, + {URL: "/companies", Methods: []string{"POST"}}, + {URL: "/admin", Methods: []string{"GET", "POST"}}, + }, + }, + "api3": { + Versions: []string{"v1"}, + AllowedURLs: []user.AccessSpec{ + {URL: "/admin/cache", Methods: []string{"DELETE"}}, + }, + }, + }, + } + + policyWithPaths := user.Policy{ + ID: uuid.New(), + Per: 1, + Rate: 1000, + QuotaMax: 50, + QuotaRenewalRate: 3600, + OrgID: DefaultOrg, + AccessRights: map[string]user.AccessDefinition{ + "api1": { + Versions: []string{"v1"}, + AllowedURLs: []user.AccessSpec{ + {URL: "/appended", Methods: []string{"GET"}}, + }, + }, + "api2": { + Versions: []string{"v1"}, + AllowedURLs: []user.AccessSpec{ + {URL: "/appended", Methods: []string{"GET"}}, + }, + }, + "api3": { + Versions: []string{"v1"}, + AllowedURLs: []user.AccessSpec{ + {URL: "/appended", Methods: []string{"GET"}}, + }, + }, + }, + } + + ts.Gw.SetPoliciesByID(policyBase, policyWithPaths) + + // load APIs + ts.Gw.BuildAndLoadAPI( + func(spec *APISpec) { + spec.Name = "api 1" + spec.APIID = "api1" + spec.UseKeylessAccess = false + spec.Proxy.ListenPath = "/api1" + spec.OrgID = DefaultOrg + }, + func(spec *APISpec) { + spec.Name = "api 2" + spec.APIID = "api2" + spec.UseKeylessAccess = false + spec.Proxy.ListenPath = "/api2" + spec.OrgID = DefaultOrg + }, + func(spec *APISpec) { + spec.Name = "api 3" + spec.APIID = "api3" + spec.UseKeylessAccess = false + spec.Proxy.ListenPath = "/api3" + spec.OrgID = DefaultOrg + }, + ) + + // create test session + session := &user.SessionState{ + ApplyPolicies: []string{policyBase.ID, policyWithPaths.ID}, + OrgID: DefaultOrg, + AccessRights: map[string]user.AccessDefinition{ + "api1": { + APIID: "api1", + Versions: []string{"v1"}, + }, + "api2": { + APIID: "api2", + Versions: []string{"v1"}, + }, + "api3": { + APIID: "api3", + Versions: []string{"v1"}, + }, + }, + } + + // create key + key := uuid.New() + ts.Run(t, test.TestCase{Method: http.MethodPost, Path: "/tyk/keys/" + key, Data: session, AdminAuth: true, Code: 200}) + + // check key session + t.Run("Check key session", func(t *testing.T) { + ts.Run(t, []test.TestCase{ + { + Method: http.MethodGet, + Path: fmt.Sprintf("/tyk/keys/%v?org_id=%v", key, DefaultOrg), + AdminAuth: true, + Code: http.StatusOK, + BodyMatchFunc: func(data []byte) bool { + session := user.SessionState{} + assert.NoError(t, json.Unmarshal(data, &session)) + + for _, apiName := range []string{"api1", "api2", "api3"} { + want := policy.MergeAllowedURLs(policyBase.AccessRights[apiName].AllowedURLs, policyWithPaths.AccessRights[apiName].AllowedURLs) + assert.Equal(t, want, session.AccessRights[apiName].AllowedURLs, fmt.Sprintf("api %q allowed urls don't match", apiName)) + } + + return true + }, + }, + }...) + }) +} diff --git a/tests/policy/shim.go b/tests/policy/shim.go new file mode 100644 index 00000000000..cf6fe5f61af --- /dev/null +++ b/tests/policy/shim.go @@ -0,0 +1,9 @@ +package policy + +import "github.com/TykTechnologies/tyk/gateway" + +const DefaultOrg = "default-org-id" + +type APISpec = gateway.APISpec + +var StartTest = gateway.StartTest diff --git a/tests/quota/Taskfile.yml b/tests/quota/Taskfile.yml index d65f4e35699..a0c0fa22df7 100644 --- a/tests/quota/Taskfile.yml +++ b/tests/quota/Taskfile.yml @@ -2,7 +2,9 @@ version: "3" includes: - services: ../../docker/services/Taskfile.yml + services: + taskfile: ../../docker/services/Taskfile.yml + dir: ../../docker/services vars: coverage: quota.cov diff --git a/user/custom_policies.go b/user/custom_policies.go index bdbb7f3d12a..3ac8c852b92 100644 --- a/user/custom_policies.go +++ b/user/custom_policies.go @@ -6,10 +6,26 @@ import ( "fmt" ) +// CustomPolicies returns a map of custom policies on the session. +// To preserve policy order, use GetCustomPolicies instead. func (s *SessionState) CustomPolicies() (map[string]Policy, error) { + customPolicies, err := s.GetCustomPolicies() + if err != nil { + return nil, err + } + + result := make(map[string]Policy, len(customPolicies)) + for i := 0; i < len(customPolicies); i++ { + result[customPolicies[i].ID] = customPolicies[i] + } + + return result, nil +} + +// GetCustomPolicies is like CustomPolicies but returns the list, preserving order. +func (s *SessionState) GetCustomPolicies() ([]Policy, error) { var ( customPolicies []Policy - ret map[string]Policy ) metadataPolicies, found := s.MetaData["policies"].([]interface{}) @@ -22,16 +38,14 @@ func (s *SessionState) CustomPolicies() (map[string]Policy, error) { return nil, fmt.Errorf("failed to marshal metadata policies: %w", err) } - _ = json.Unmarshal(polJSON, &customPolicies) - - ret = make(map[string]Policy, len(customPolicies)) - for i := 0; i < len(customPolicies); i++ { - ret[customPolicies[i].ID] = customPolicies[i] + if err := json.Unmarshal(polJSON, &customPolicies); err != nil { + return nil, fmt.Errorf("failed to unmarshal metadata policies: %w", err) } - return ret, nil + return customPolicies, err } +// SetCustomPolicies sets custom policies into session metadata. func (s *SessionState) SetCustomPolicies(list []Policy) { if s.MetaData == nil { s.MetaData = make(map[string]interface{}) diff --git a/user/session.go b/user/session.go index 611cb287e1f..8a23358927d 100644 --- a/user/session.go +++ b/user/session.go @@ -7,11 +7,13 @@ import ( "github.com/TykTechnologies/graphql-go-tools/pkg/graphql" +<<<<<<< HEAD logger "github.com/TykTechnologies/tyk/log" +======= + "github.com/TykTechnologies/tyk/apidef" +>>>>>>> e31a08f08... [TT-12897] Merge path based permissions when combining policies (#6597) ) -var log = logger.Get() - type HashType string const (