From c7ddfa6a22113af3e95cb9e0bc457c7244133e07 Mon Sep 17 00:00:00 2001 From: Laurentiu Badea Date: Wed, 14 Oct 2020 12:16:00 -0700 Subject: [PATCH] Add Account cache (#1519) --- config/stored_requests.go | 78 +++++++++++++----- config/stored_requests_test.go | 60 +++++++++++--- stored_requests/caches/memory/cache.go | 3 +- stored_requests/config/config.go | 18 ++-- stored_requests/config/config_test.go | 110 +++++++++---------------- stored_requests/fetcher.go | 14 +++- stored_requests/fetcher_test.go | 54 +++++++++++- 7 files changed, 226 insertions(+), 111 deletions(-) diff --git a/config/stored_requests.go b/config/stored_requests.go index 61db7eb03d0..249569461bc 100644 --- a/config/stored_requests.go +++ b/config/stored_requests.go @@ -22,20 +22,31 @@ const ( ) // Section returns the config section this type is defined in -func (sr *StoredRequests) Section() string { +func (dataType DataType) Section() string { return map[DataType]string{ RequestDataType: "stored_requests", CategoryDataType: "categories", VideoDataType: "stored_video_req", AMPRequestDataType: "stored_amp_req", AccountDataType: "accounts", - }[sr.dataType] + }[dataType] +} + +// Section returns the config section +func (sr *StoredRequests) Section() string { + return sr.dataType.Section() } +// DataType returns the DataType associated with this config func (sr *StoredRequests) DataType() DataType { return sr.dataType } +// SetDataType sets the DataType on this config. Needed for tests. +func (sr *StoredRequests) SetDataType(dataType DataType) { + sr.dataType = dataType +} + // StoredRequests struct defines options for stored requests for each data type // including some amp stored_requests options type StoredRequests struct { @@ -132,7 +143,7 @@ func (cfg *StoredRequests) validate(errs configErrors) configErrors { if cfg.DataType() == AccountDataType && cfg.Postgres.ConnectionInfo.Database != "" { errs = append(errs, fmt.Errorf("%s.postgres: retrieving accounts via postgres not available, use accounts.files", cfg.Section())) } else { - errs = cfg.Postgres.validate(cfg.Section(), errs) + errs = cfg.Postgres.validate(cfg.DataType(), errs) } // Categories do not use cache so none of the following checks apply @@ -156,7 +167,7 @@ func (cfg *StoredRequests) validate(errs configErrors) configErrors { errs = append(errs, fmt.Errorf("%s: postgres.initialize_caches.query must be empty if in_memory_cache=none", cfg.Section())) } } - errs = cfg.InMemoryCache.validate(cfg.Section(), errs) + errs = cfg.InMemoryCache.validate(cfg.DataType(), errs) return errs } @@ -169,12 +180,12 @@ type PostgresConfig struct { PollUpdates PostgresUpdatePolling `mapstructure:"poll_for_updates"` } -func (cfg *PostgresConfig) validate(section string, errs configErrors) configErrors { +func (cfg *PostgresConfig) validate(dataType DataType, errs configErrors) configErrors { if cfg.ConnectionInfo.Database == "" { return errs } - return cfg.PollUpdates.validate(section, errs) + return cfg.PollUpdates.validate(dataType, errs) } // PostgresConnection has options which put types to the Postgres Connection string. See: @@ -269,7 +280,8 @@ type PostgresCacheInitializer struct { AmpQuery string `mapstructure:"amp_query"` } -func (cfg *PostgresCacheInitializer) validate(section string, errs configErrors) configErrors { +func (cfg *PostgresCacheInitializer) validate(dataType DataType, errs configErrors) configErrors { + section := dataType.Section() if cfg.Query == "" { return errs } @@ -305,7 +317,8 @@ type PostgresUpdatePolling struct { AmpQuery string `mapstructure:"amp_query"` } -func (cfg *PostgresUpdatePolling) validate(section string, errs configErrors) configErrors { +func (cfg *PostgresUpdatePolling) validate(dataType DataType, errs configErrors) configErrors { + section := dataType.Section() if cfg.Query == "" { return errs } @@ -384,32 +397,57 @@ type InMemoryCache struct { // TTL is the maximum number of seconds that an unused value will stay in the cache. // TTL <= 0 can be used for "no ttl". Elements will still be evicted based on the Size. TTL int `mapstructure:"ttl_seconds"` + // Size is the max total cache size allowed for single caches + Size int `mapstructure:"size_bytes"` // RequestCacheSize is the max number of bytes allowed in the cache for Stored Requests. Values <= 0 will have no limit RequestCacheSize int `mapstructure:"request_cache_size_bytes"` // ImpCacheSize is the max number of bytes allowed in the cache for Stored Imps. Values <= 0 will have no limit ImpCacheSize int `mapstructure:"imp_cache_size_bytes"` } -func (cfg *InMemoryCache) validate(section string, errs configErrors) configErrors { +func (cfg *InMemoryCache) validate(dataType DataType, errs configErrors) configErrors { + section := dataType.Section() switch cfg.Type { case "none": // No errors for no config options case "unbounded": if cfg.TTL != 0 { - errs = append(errs, fmt.Errorf("%s: in_memory_cache must be 0 for unbounded caches. Got %d", section, cfg.TTL)) + errs = append(errs, fmt.Errorf("%s: in_memory_cache.ttl_seconds is not supported for unbounded caches. Got %d", section, cfg.TTL)) } - if cfg.RequestCacheSize != 0 { - errs = append(errs, fmt.Errorf("%s: in_memory_cache.request_cache_size_bytes must be 0 for unbounded caches. Got %d", section, cfg.RequestCacheSize)) - } - if cfg.ImpCacheSize != 0 { - errs = append(errs, fmt.Errorf("%s: in_memory_cache.imp_cache_size_bytes must be 0 for unbounded caches. Got %d", section, cfg.ImpCacheSize)) + if dataType == AccountDataType { + // single cache + if cfg.Size != 0 { + errs = append(errs, fmt.Errorf("%s: in_memory_cache.size_bytes is not supported for unbounded caches. Got %d", section, cfg.Size)) + } + } else { + // dual (request and imp) caches + if cfg.RequestCacheSize != 0 { + errs = append(errs, fmt.Errorf("%s: in_memory_cache.request_cache_size_bytes is not supported for unbounded caches. Got %d", section, cfg.RequestCacheSize)) + } + if cfg.ImpCacheSize != 0 { + errs = append(errs, fmt.Errorf("%s: in_memory_cache.imp_cache_size_bytes is not supported for unbounded caches. Got %d", section, cfg.ImpCacheSize)) + } } case "lru": - if cfg.RequestCacheSize <= 0 { - errs = append(errs, fmt.Errorf("%s: in_memory_cache.request_cache_size_bytes must be >= 0 when in_memory_cache.type=lru. Got %d", section, cfg.RequestCacheSize)) - } - if cfg.ImpCacheSize <= 0 { - errs = append(errs, fmt.Errorf("%s: in_memory_cache.imp_cache_size_bytes must be >= 0 when in_memory_cache.type=lru. Got %d", section, cfg.ImpCacheSize)) + if dataType == AccountDataType { + // single cache + if cfg.Size <= 0 { + errs = append(errs, fmt.Errorf("%s: in_memory_cache.size_bytes must be >= 0 when in_memory_cache.type=lru. Got %d", section, cfg.Size)) + } + if cfg.RequestCacheSize > 0 || cfg.ImpCacheSize > 0 { + glog.Warningf("%s: in_memory_cache.request_cache_size_bytes and imp_cache_size_bytes do not apply to this section and will be ignored", section) + } + } else { + // dual (request and imp) caches + if cfg.RequestCacheSize <= 0 { + errs = append(errs, fmt.Errorf("%s: in_memory_cache.request_cache_size_bytes must be >= 0 when in_memory_cache.type=lru. Got %d", section, cfg.RequestCacheSize)) + } + if cfg.ImpCacheSize <= 0 { + errs = append(errs, fmt.Errorf("%s: in_memory_cache.imp_cache_size_bytes must be >= 0 when in_memory_cache.type=lru. Got %d", section, cfg.ImpCacheSize)) + } + if cfg.Size > 0 { + glog.Warningf("%s: in_memory_cache.size_bytes does not apply in this section and will be ignored", section) + } } default: errs = append(errs, fmt.Errorf("%s: in_memory_cache.type %s is invalid", section, cfg.Type)) diff --git a/config/stored_requests_test.go b/config/stored_requests_test.go index 36a5e3793ed..a3bd5a96820 100644 --- a/config/stored_requests_test.go +++ b/config/stored_requests_test.go @@ -75,43 +75,83 @@ func TestPostgressConnString(t *testing.T) { assertHasValue(t, params, "sslmode", "disable") } -func TestInMemoryCacheValidation(t *testing.T) { +func TestInMemoryCacheValidationStoredRequests(t *testing.T) { assertNoErrs(t, (&InMemoryCache{ Type: "unbounded", - }).validate("Test", nil)) + }).validate(RequestDataType, nil)) assertNoErrs(t, (&InMemoryCache{ Type: "none", - }).validate("Test", nil)) + }).validate(RequestDataType, nil)) assertNoErrs(t, (&InMemoryCache{ Type: "lru", RequestCacheSize: 1000, ImpCacheSize: 1000, - }).validate("Test", nil)) + }).validate(RequestDataType, nil)) assertErrsExist(t, (&InMemoryCache{ Type: "unrecognized", - }).validate("Test", nil)) + }).validate(RequestDataType, nil)) assertErrsExist(t, (&InMemoryCache{ Type: "unbounded", ImpCacheSize: 1000, - }).validate("Test", nil)) + }).validate(RequestDataType, nil)) assertErrsExist(t, (&InMemoryCache{ Type: "unbounded", RequestCacheSize: 1000, - }).validate("Test", nil)) + }).validate(RequestDataType, nil)) assertErrsExist(t, (&InMemoryCache{ Type: "unbounded", TTL: 500, - }).validate("Test", nil)) + }).validate(RequestDataType, nil)) assertErrsExist(t, (&InMemoryCache{ Type: "lru", RequestCacheSize: 0, ImpCacheSize: 1000, - }).validate("Test", nil)) + }).validate(RequestDataType, nil)) assertErrsExist(t, (&InMemoryCache{ Type: "lru", RequestCacheSize: 1000, ImpCacheSize: 0, - }).validate("Test", nil)) + }).validate(RequestDataType, nil)) + assertErrsExist(t, (&InMemoryCache{ + Type: "lru", + Size: 1000, + }).validate(RequestDataType, nil)) +} + +func TestInMemoryCacheValidationSingleCache(t *testing.T) { + assertNoErrs(t, (&InMemoryCache{ + Type: "unbounded", + }).validate(AccountDataType, nil)) + assertNoErrs(t, (&InMemoryCache{ + Type: "none", + }).validate(AccountDataType, nil)) + assertNoErrs(t, (&InMemoryCache{ + Type: "lru", + Size: 1000, + }).validate(AccountDataType, nil)) + assertErrsExist(t, (&InMemoryCache{ + Type: "unrecognized", + }).validate(AccountDataType, nil)) + assertErrsExist(t, (&InMemoryCache{ + Type: "unbounded", + Size: 1000, + }).validate(AccountDataType, nil)) + assertErrsExist(t, (&InMemoryCache{ + Type: "unbounded", + TTL: 500, + }).validate(AccountDataType, nil)) + assertErrsExist(t, (&InMemoryCache{ + Type: "lru", + Size: 0, + }).validate(AccountDataType, nil)) + assertErrsExist(t, (&InMemoryCache{ + Type: "lru", + RequestCacheSize: 1000, + }).validate(AccountDataType, nil)) + assertErrsExist(t, (&InMemoryCache{ + Type: "lru", + ImpCacheSize: 1000, + }).validate(AccountDataType, nil)) } func assertErrsExist(t *testing.T, err configErrors) { diff --git a/stored_requests/caches/memory/cache.go b/stored_requests/caches/memory/cache.go index aea087e6d19..5939c26ddec 100644 --- a/stored_requests/caches/memory/cache.go +++ b/stored_requests/caches/memory/cache.go @@ -18,7 +18,8 @@ import ( // For no TTL, use ttlSeconds <= 0 func NewCache(size int, ttl int, dataType string) stored_requests.CacheJSON { if ttl > 0 && size <= 0 { - glog.Fatalf("No in-memory %s caches defined with a finite TTL but unbounded size. Config validation should have caught this. Failing fast because something is buggy.", dataType) + // a positive ttl indicates "LRU" cache type, while unlimited size indicates an "unbounded" cache type + glog.Fatalf("unbounded in-memory %s cache with TTL not allowed. Config validation should have caught this. Failing fast because something is buggy.", dataType) } if size > 0 { glog.Infof("Using a Stored %s in-memory cache. Max size: %d bytes. TTL: %d seconds.", dataType, size, ttl) diff --git a/stored_requests/config/config.go b/stored_requests/config/config.go index c8231969d00..d32773a7a1d 100644 --- a/stored_requests/config/config.go +++ b/stored_requests/config/config.go @@ -177,15 +177,17 @@ func newFetcher(cfg *config.StoredRequests, client *http.Client, db *sql.DB) (fe } func newCache(cfg *config.StoredRequests) stored_requests.Cache { - if cfg.InMemoryCache.Type == "none" { - glog.Infof("No Stored %s cache configured. The %s Fetcher backend will be used for all data requests", cfg.DataType(), cfg.DataType()) - return stored_requests.Cache{&nil_cache.NilCache{}, &nil_cache.NilCache{}} - } - - return stored_requests.Cache{ - Requests: memory.NewCache(cfg.InMemoryCache.RequestCacheSize, cfg.InMemoryCache.TTL, "Requests"), - Imps: memory.NewCache(cfg.InMemoryCache.ImpCacheSize, cfg.InMemoryCache.TTL, "Imps"), + cache := stored_requests.Cache{&nil_cache.NilCache{}, &nil_cache.NilCache{}, &nil_cache.NilCache{}} + switch { + case cfg.InMemoryCache.Type == "none": + glog.Warningf("No %s cache configured. The %s Fetcher backend will be used for all data requests", cfg.DataType(), cfg.DataType()) + case cfg.DataType() == config.AccountDataType: + cache.Accounts = memory.NewCache(cfg.InMemoryCache.Size, cfg.InMemoryCache.TTL, "Accounts") + default: + cache.Requests = memory.NewCache(cfg.InMemoryCache.RequestCacheSize, cfg.InMemoryCache.TTL, "Requests") + cache.Imps = memory.NewCache(cfg.InMemoryCache.ImpCacheSize, cfg.InMemoryCache.TTL, "Imps") } + return cache } func newEventProducers(cfg *config.StoredRequests, client *http.Client, db *sql.DB, router *httprouter.Router) (eventProducers []events.EventProducer) { diff --git a/stored_requests/config/config_test.go b/stored_requests/config/config_test.go index 3332103649e..f225f74bad0 100644 --- a/stored_requests/config/config_test.go +++ b/stored_requests/config/config_test.go @@ -9,27 +9,43 @@ import ( "regexp" "testing" + "github.com/stretchr/testify/assert" + sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/julienschmidt/httprouter" "github.com/prebid/prebid-server/config" + "github.com/prebid/prebid-server/stored_requests" "github.com/prebid/prebid-server/stored_requests/backends/empty_fetcher" "github.com/prebid/prebid-server/stored_requests/backends/http_fetcher" "github.com/prebid/prebid-server/stored_requests/events" httpEvents "github.com/prebid/prebid-server/stored_requests/events/http" ) +func typedConfig(dataType config.DataType, sr *config.StoredRequests) *config.StoredRequests { + sr.SetDataType(dataType) + return sr +} + +func isEmptyCacheType(cache stored_requests.CacheJSON) bool { + cache.Save(context.Background(), map[string]json.RawMessage{"foo": json.RawMessage("true")}) + objs := cache.Get(context.Background(), []string{"foo"}) + return len(objs) == 0 +} + +func isMemoryCacheType(cache stored_requests.CacheJSON) bool { + cache.Save(context.Background(), map[string]json.RawMessage{"foo": json.RawMessage("true")}) + objs := cache.Get(context.Background(), []string{"foo"}) + return len(objs) == 1 +} + func TestNewEmptyFetcher(t *testing.T) { fetcher := newFetcher(&config.StoredRequests{}, nil, nil) - ampFetcher := newFetcher(&config.StoredRequests{}, nil, nil) - if fetcher == nil || ampFetcher == nil { - t.Errorf("The fetchers should be non-nil, even with an empty config.") + if fetcher == nil { + t.Errorf("The fetcher should be non-nil, even with an empty config.") } if _, ok := fetcher.(empty_fetcher.EmptyFetcher); !ok { t.Errorf("If the config is empty, and EmptyFetcher should be returned") } - if _, ok := ampFetcher.(empty_fetcher.EmptyFetcher); !ok { - t.Errorf("If the config is empty, and EmptyFetcher should be returned for AMP") - } } func TestNewHTTPFetcher(t *testing.T) { @@ -38,47 +54,12 @@ func TestNewHTTPFetcher(t *testing.T) { Endpoint: "stored-requests.prebid.com", }, }, nil, nil) - ampFetcher := newFetcher(&config.StoredRequests{ - HTTP: config.HTTPFetcherConfig{ - Endpoint: "stored-requests.prebid.com?type=amp", - }, - }, nil, nil) - if httpFetcher, ok := fetcher.(*http_fetcher.HttpFetcher); ok { - if httpFetcher.Endpoint != "stored-requests.prebid.com?" { - t.Errorf("The HTTP fetcher is using the wrong endpoint. Expected %s, got %s", "stored-requests.prebid.com?", httpFetcher.Endpoint) - } - } else { - t.Errorf("An HTTP Fetching config should return an HTTPFetcher. Got %v", ampFetcher) - } - if httpFetcher, ok := ampFetcher.(*http_fetcher.HttpFetcher); ok { - if httpFetcher.Endpoint != "stored-requests.prebid.com?type=amp&" { - t.Errorf("The AMP HTTP fetcher is using the wrong endpoint. Expected %s, got %s", "stored-requests.prebid.com?type=amp&", httpFetcher.Endpoint) - } - } else { - t.Errorf("An HTTP Fetching config should return an HTTPFetcher. Got %v", ampFetcher) - } -} - -func TestNewHTTPFetcherNoAmp(t *testing.T) { - fetcher := newFetcher(&config.StoredRequests{ - HTTP: config.HTTPFetcherConfig{ - Endpoint: "stored-requests.prebid.com", - }, - }, nil, nil) - ampFetcher := newFetcher(&config.StoredRequests{ - HTTP: config.HTTPFetcherConfig{ - Endpoint: "", - }, - }, nil, nil) if httpFetcher, ok := fetcher.(*http_fetcher.HttpFetcher); ok { if httpFetcher.Endpoint != "stored-requests.prebid.com?" { t.Errorf("The HTTP fetcher is using the wrong endpoint. Expected %s, got %s", "stored-requests.prebid.com?", httpFetcher.Endpoint) } } else { - t.Errorf("An HTTP Fetching config should return an HTTPFetcher. Got %v", ampFetcher) - } - if httpAmpFetcher, ok := ampFetcher.(*http_fetcher.HttpFetcher); ok && httpAmpFetcher == nil { - t.Errorf("An HTTP Fetching config should not return an Amp HTTP fetcher in this case. Got %v (%v)", ampFetcher, httpAmpFetcher) + t.Errorf("An HTTP Fetching config should return an HTTPFetcher. Got %v", fetcher) } } @@ -102,11 +83,9 @@ func TestNewHTTPEvents(t *testing.T) { func TestNewEmptyCache(t *testing.T) { cache := newCache(&config.StoredRequests{InMemoryCache: config.InMemoryCache{Type: "none"}}) - cache.Requests.Save(context.Background(), map[string]json.RawMessage{"foo": json.RawMessage("true")}) - reqs := cache.Requests.Get(context.Background(), []string{"foo"}) - if len(reqs) != 0 { - t.Errorf("The newCache method should return an empty cache if the config asks for it.") - } + assert.True(t, isEmptyCacheType(cache.Requests), "The newCache method should return an empty Request cache") + assert.True(t, isEmptyCacheType(cache.Imps), "The newCache method should return an empty Imp cache") + assert.True(t, isEmptyCacheType(cache.Accounts), "The newCache method should return an empty Account cache") } func TestNewInMemoryCache(t *testing.T) { @@ -117,11 +96,21 @@ func TestNewInMemoryCache(t *testing.T) { ImpCacheSize: 100, }, }) - cache.Requests.Save(context.Background(), map[string]json.RawMessage{"foo": json.RawMessage("true")}) - reqs := cache.Requests.Get(context.Background(), []string{"foo"}) - if len(reqs) != 1 { - t.Errorf("The newCache method should return an in-memory cache if the config asks for it.") - } + assert.True(t, isMemoryCacheType(cache.Requests), "The newCache method should return an in-memory Request cache for StoredRequests config") + assert.True(t, isMemoryCacheType(cache.Imps), "The newCache method should return an in-memory Imp cache for StoredRequests config") + assert.True(t, isEmptyCacheType(cache.Accounts), "The newCache method should return an empty Account cache for StoredRequests config") +} + +func TestNewInMemoryAccountCache(t *testing.T) { + cache := newCache(typedConfig(config.AccountDataType, &config.StoredRequests{ + InMemoryCache: config.InMemoryCache{ + TTL: 60, + Size: 100, + }, + })) + assert.True(t, isMemoryCacheType(cache.Accounts), "The newCache method should return an in-memory Account cache for Accounts config") + assert.True(t, isEmptyCacheType(cache.Requests), "The newCache method should return an empty Request cache for Accounts config") + assert.True(t, isEmptyCacheType(cache.Imps), "The newCache method should return an empty Imp cache for Accounts config") } func TestNewPostgresEventProducers(t *testing.T) { @@ -138,33 +127,16 @@ func TestNewPostgresEventProducers(t *testing.T) { }, }, } - ampCfg := &config.StoredRequests{ - Postgres: config.PostgresConfig{ - CacheInitialization: config.PostgresCacheInitializer{ - Timeout: 50, - Query: "SELECT id, requestData, type FROM stored_amp_data", - }, - PollUpdates: config.PostgresUpdatePolling{ - RefreshRate: 20, - Timeout: 50, - Query: "SELECT id, requestData, type FROM stored_amp_data WHERE last_updated > $1", - }, - }, - } client := &http.Client{} db, mock, err := sqlmock.New() if err != nil { t.Fatalf("Failed to create mock: %v", err) } mock.ExpectQuery("^" + regexp.QuoteMeta(cfg.Postgres.CacheInitialization.Query) + "$").WillReturnError(errors.New("Query failed")) - mock.ExpectQuery("^" + regexp.QuoteMeta(ampCfg.Postgres.CacheInitialization.Query) + "$").WillReturnError(errors.New("Query failed")) evProducers := newEventProducers(cfg, client, db, nil) assertProducerLength(t, evProducers, 1) - ampEvProducers := newEventProducers(ampCfg, client, db, nil) - assertProducerLength(t, ampEvProducers, 1) - assertExpectationsMet(t, mock) } diff --git a/stored_requests/fetcher.go b/stored_requests/fetcher.go index e9716e08a23..1773c966f32 100644 --- a/stored_requests/fetcher.go +++ b/stored_requests/fetcher.go @@ -65,6 +65,7 @@ func (e NotFoundError) Error() string { type Cache struct { Requests CacheJSON Imps CacheJSON + Accounts CacheJSON } type CacheJSON interface { // Get works much like Fetcher.FetchRequests, with a few exceptions: @@ -190,8 +191,17 @@ func (f *fetcherWithCache) FetchRequests(ctx context.Context, requestIDs []strin return } -func (f *fetcherWithCache) FetchAccount(ctx context.Context, accountID string) (json.RawMessage, []error) { - return f.fetcher.FetchAccount(ctx, accountID) +func (f *fetcherWithCache) FetchAccount(ctx context.Context, accountID string) (account json.RawMessage, errs []error) { + accountData := f.cache.Accounts.Get(ctx, []string{accountID}) + // TODO: add metrics + if account, ok := accountData[accountID]; ok { + return account, errs + } + account, errs = f.fetcher.FetchAccount(ctx, accountID) + if len(errs) == 0 { + f.cache.Accounts.Save(ctx, map[string]json.RawMessage{accountID: account}) + } + return account, errs } func (f *fetcherWithCache) FetchCategories(ctx context.Context, primaryAdServer, publisherId, iabCategory string) (string, error) { diff --git a/stored_requests/fetcher_test.go b/stored_requests/fetcher_test.go index 396ba3d04b2..7a6a06fb923 100644 --- a/stored_requests/fetcher_test.go +++ b/stored_requests/fetcher_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/prebid/prebid-server/pbsmetrics" + "github.com/prebid/prebid-server/stored_requests/caches/nil_cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -17,7 +18,7 @@ func setupFetcherWithCacheDeps() (*mockCache, *mockCache, *mockFetcher, AllFetch impCache := &mockCache{} metricsEngine := &pbsmetrics.MetricsEngineMock{} fetcher := &mockFetcher{} - afetcherWithCache := WithCache(fetcher, Cache{reqCache, impCache}, metricsEngine) + afetcherWithCache := WithCache(fetcher, Cache{reqCache, impCache, &nil_cache.NilCache{}}, metricsEngine) return reqCache, impCache, fetcher, afetcherWithCache, metricsEngine } @@ -158,6 +159,57 @@ func TestCacheSaves(t *testing.T) { assert.Len(t, errs, 0, "FetchRequests with duplicate IDs shouldn't return an error") } +func setupAccountFetcherWithCacheDeps() (*mockCache, *mockFetcher, AllFetcher, *pbsmetrics.MetricsEngineMock) { + accCache := &mockCache{} + metricsEngine := &pbsmetrics.MetricsEngineMock{} + fetcher := &mockFetcher{} + afetcherWithCache := WithCache(fetcher, Cache{&nil_cache.NilCache{}, &nil_cache.NilCache{}, accCache}, metricsEngine) + + return accCache, fetcher, afetcherWithCache, metricsEngine +} + +func TestAccountCacheHit(t *testing.T) { + accCache, fetcher, aFetcherWithCache, metricsEngine := setupAccountFetcherWithCacheDeps() + cachedAccounts := []string{"known"} + ctx := context.Background() + + // Test read from cache + accCache.On("Get", ctx, cachedAccounts).Return( + map[string]json.RawMessage{ + "known": json.RawMessage(`true`), + }) + + account, errs := aFetcherWithCache.FetchAccount(ctx, "known") + + accCache.AssertExpectations(t) + fetcher.AssertExpectations(t) + metricsEngine.AssertExpectations(t) + assert.JSONEq(t, `true`, string(account), "FetchAccount should fetch the right account data") + assert.Len(t, errs, 0, "FetchAccount shouldn't return any errors") +} + +func TestAccountCacheMiss(t *testing.T) { + accCache, fetcher, aFetcherWithCache, metricsEngine := setupAccountFetcherWithCacheDeps() + uncachedAccounts := []string{"uncached"} + uncachedAccountsData := map[string]json.RawMessage{ + "uncached": json.RawMessage(`true`), + } + ctx := context.Background() + + // Test read from cache + accCache.On("Get", ctx, uncachedAccounts).Return(map[string]json.RawMessage{}) + accCache.On("Save", ctx, uncachedAccountsData) + fetcher.On("FetchAccount", ctx, "uncached").Return(uncachedAccountsData["uncached"], []error{}) + + account, errs := aFetcherWithCache.FetchAccount(ctx, "uncached") + + accCache.AssertExpectations(t) + fetcher.AssertExpectations(t) + metricsEngine.AssertExpectations(t) + assert.JSONEq(t, `true`, string(account), "FetchAccount should fetch the right account data") + assert.Len(t, errs, 0, "FetchAccount shouldn't return any errors") +} + func TestComposedCache(t *testing.T) { c1 := &mockCache{} c2 := &mockCache{}