forked from hashicorp/vault
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathactivity_log_util_common_test.go
343 lines (306 loc) · 13.1 KB
/
activity_log_util_common_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package vault
import (
"context"
"errors"
"fmt"
"io"
"testing"
"time"
"github.com/axiomhq/hyperloglog"
"github.com/hashicorp/vault/helper/timeutil"
"github.com/hashicorp/vault/vault/activity"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
// Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal creates 3 months of hyperloglogs and fills them with
// overlapping clients. The test calls computeCurrentMonthForBillingPeriodInternal with the current month map having
// some overlap with the previous months. The test then verifies that the results have the correct number of entity and
// non-entity clients. The test also calls computeCurrentMonthForBillingPeriodInternal with an empty current month map,
// and verifies that the results are all 0.
func Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal(t *testing.T) {
// populate the first month with clients 1-10
monthOneHLL := hyperloglog.New()
// populate the second month with clients 5-15
monthTwoHLL := hyperloglog.New()
// populate the third month with clients 10-20
monthThreeHLL := hyperloglog.New()
for i := 0; i < 20; i++ {
clientID := []byte(fmt.Sprintf("client_%d", i))
if i < 10 {
monthOneHLL.Insert(clientID)
}
if 5 <= i && i < 15 {
monthTwoHLL.Insert(clientID)
}
if 10 <= i && i < 20 {
monthThreeHLL.Insert(clientID)
}
}
mockHLLGetFunc := func(ctx context.Context, startTime time.Time) (*hyperloglog.Sketch, error) {
currMonthStart := timeutil.StartOfMonth(time.Now())
if startTime.Equal(timeutil.MonthsPreviousTo(3, currMonthStart)) {
return monthThreeHLL, nil
}
if startTime.Equal(timeutil.MonthsPreviousTo(2, currMonthStart)) {
return monthTwoHLL, nil
}
if startTime.Equal(timeutil.MonthsPreviousTo(1, currMonthStart)) {
return monthOneHLL, nil
}
return nil, fmt.Errorf("bad start time")
}
// Let's add 2 entities exclusive to month 1 (clients 0,1),
// 2 entities shared by month 1 and 2 (clients 5,6),
// 2 entities shared by month 2 and 3 (clients 10,11), and
// 2 entities exclusive to month 3 (15,16). Furthermore, we can add
// 3 new entities (clients 20,21, and 22).
entitiesStruct := make(map[string]struct{}, 0)
entitiesStruct["client_0"] = struct{}{}
entitiesStruct["client_1"] = struct{}{}
entitiesStruct["client_5"] = struct{}{}
entitiesStruct["client_6"] = struct{}{}
entitiesStruct["client_10"] = struct{}{}
entitiesStruct["client_11"] = struct{}{}
entitiesStruct["client_15"] = struct{}{}
entitiesStruct["client_16"] = struct{}{}
entitiesStruct["client_20"] = struct{}{}
entitiesStruct["client_21"] = struct{}{}
entitiesStruct["client_22"] = struct{}{}
// We will add 3 nonentity clients from month 1 (clients 2,3,4),
// 3 shared by months 1 and 2 (7,8,9),
// 3 shared by months 2 and 3 (12,13,14), and
// 3 exclusive to month 3 (17,18,19). We will also
// add 4 new nonentity clients.
nonEntitiesStruct := make(map[string]struct{}, 0)
nonEntitiesStruct["client_2"] = struct{}{}
nonEntitiesStruct["client_3"] = struct{}{}
nonEntitiesStruct["client_4"] = struct{}{}
nonEntitiesStruct["client_7"] = struct{}{}
nonEntitiesStruct["client_8"] = struct{}{}
nonEntitiesStruct["client_9"] = struct{}{}
nonEntitiesStruct["client_12"] = struct{}{}
nonEntitiesStruct["client_13"] = struct{}{}
nonEntitiesStruct["client_14"] = struct{}{}
nonEntitiesStruct["client_17"] = struct{}{}
nonEntitiesStruct["client_18"] = struct{}{}
nonEntitiesStruct["client_19"] = struct{}{}
nonEntitiesStruct["client_23"] = struct{}{}
nonEntitiesStruct["client_24"] = struct{}{}
nonEntitiesStruct["client_25"] = struct{}{}
nonEntitiesStruct["client_26"] = struct{}{}
counts := &processCounts{
Entities: entitiesStruct,
NonEntities: nonEntitiesStruct,
}
currentMonthClientsMap := make(map[int64]*processMonth, 1)
currentMonthClients := &processMonth{
Counts: counts,
NewClients: &processNewClients{Counts: counts},
}
// Technially I think currentMonthClientsMap should have the keys as
// unix timestamps, but for the purposes of the unit test it doesn't
// matter what the values actually are.
currentMonthClientsMap[0] = currentMonthClients
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
endTime := timeutil.StartOfMonth(time.Now())
startTime := timeutil.MonthsPreviousTo(3, endTime)
monthRecord, err := a.computeCurrentMonthForBillingPeriodInternal(context.Background(), currentMonthClientsMap, mockHLLGetFunc, startTime, endTime)
if err != nil {
t.Fatal(err)
}
// We should have 11 entity clients and 16 nonentity clients, and 3 new entity clients
// and 4 new nonentity clients
if monthRecord.Counts.EntityClients != 11 {
t.Fatalf("wrong number of entity clients. Expected 11, got %d", monthRecord.Counts.EntityClients)
}
if monthRecord.Counts.NonEntityClients != 16 {
t.Fatalf("wrong number of non entity clients. Expected 16, got %d", monthRecord.Counts.NonEntityClients)
}
if monthRecord.NewClients.Counts.EntityClients != 3 {
t.Fatalf("wrong number of new entity clients. Expected 3, got %d", monthRecord.NewClients.Counts.EntityClients)
}
if monthRecord.NewClients.Counts.NonEntityClients != 4 {
t.Fatalf("wrong number of new non entity clients. Expected 4, got %d", monthRecord.NewClients.Counts.NonEntityClients)
}
// Attempt to compute current month when no records exist
endTime = time.Now().UTC()
startTime = timeutil.StartOfMonth(endTime)
emptyClientsMap := make(map[int64]*processMonth, 0)
monthRecord, err = a.computeCurrentMonthForBillingPeriodInternal(context.Background(), emptyClientsMap, mockHLLGetFunc, startTime, endTime)
if err != nil {
t.Fatalf("failed to compute empty current month, err: %v", err)
}
// We should have 0 entity clients, nonentity clients,new entity clients
// and new nonentity clients
if monthRecord.Counts.EntityClients != 0 {
t.Fatalf("wrong number of entity clients. Expected 0, got %d", monthRecord.Counts.EntityClients)
}
if monthRecord.Counts.NonEntityClients != 0 {
t.Fatalf("wrong number of non entity clients. Expected 0, got %d", monthRecord.Counts.NonEntityClients)
}
if monthRecord.NewClients.Counts.EntityClients != 0 {
t.Fatalf("wrong number of new entity clients. Expected 0, got %d", monthRecord.NewClients.Counts.EntityClients)
}
if monthRecord.NewClients.Counts.NonEntityClients != 0 {
t.Fatalf("wrong number of new non entity clients. Expected 0, got %d", monthRecord.NewClients.Counts.NonEntityClients)
}
}
// writeEntitySegment writes a single segment file with the given time and index for an entity
func writeEntitySegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.EntityActivityLog) {
t.Helper()
protoItem, err := proto.Marshal(item)
require.NoError(t, err)
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, ts, index), protoItem)
}
// writeTokenSegment writes a single segment file with the given time and index for a token
func writeTokenSegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.TokenCount) {
t.Helper()
protoItem, err := proto.Marshal(item)
require.NoError(t, err)
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, ts, index), protoItem)
}
// makeSegmentPath formats the path for a segment at a particular time and index
func makeSegmentPath(t *testing.T, typ string, ts time.Time, index int) string {
t.Helper()
return fmt.Sprintf("%s%s%d/%d", ActivityPrefix, typ, ts.Unix(), index)
}
// TestSegmentFileReader_BadData verifies that the reader returns errors when the data is unable to be parsed
// However, the next time that Read*() is called, the reader should still progress and be able to then return any
// valid data without errors
func TestSegmentFileReader_BadData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
// write bad data that won't be able to be unmarshaled at index 0
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, 0), []byte("fake data"))
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, 0), []byte("fake data"))
// write entity at index 1
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: "id",
},
}}
writeEntitySegment(t, core, now, 1, entity)
// write token at index 1
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
"ns": 1,
}}
writeTokenSegment(t, core, now, 1, token)
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
// first the bad entity is read, which returns an error
_, err = reader.ReadEntity(context.Background())
require.Error(t, err)
// then, the reader can read the good entity at index 1
gotEntity, err := reader.ReadEntity(context.Background())
require.True(t, proto.Equal(gotEntity, entity))
require.Nil(t, err)
// the bad token causes an error
_, err = reader.ReadToken(context.Background())
require.Error(t, err)
// but the good token is able to be read
gotToken, err := reader.ReadToken(context.Background())
require.True(t, proto.Equal(gotToken, token))
require.Nil(t, err)
}
// TestSegmentFileReader_MissingData verifies that the segment file reader will skip over missing segment paths without
// errorring until it is able to find a valid segment path
func TestSegmentFileReader_MissingData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
// write entities and tokens at indexes 0, 1, 2
for i := 0; i < 3; i++ {
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, i), []byte("fake data"))
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, i), []byte("fake data"))
}
// write entity at index 3
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: "id",
},
}}
writeEntitySegment(t, core, now, 3, entity)
// write token at index 3
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
"ns": 1,
}}
writeTokenSegment(t, core, now, 3, token)
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
// delete the indexes 0, 1, 2
for i := 0; i < 3; i++ {
require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityTokenBasePath, now, i)))
require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityEntityBasePath, now, i)))
}
// we expect the reader to only return the data at index 3, and then be done
gotEntity, err := reader.ReadEntity(context.Background())
require.NoError(t, err)
require.True(t, proto.Equal(gotEntity, entity))
_, err = reader.ReadEntity(context.Background())
require.Equal(t, err, io.EOF)
gotToken, err := reader.ReadToken(context.Background())
require.NoError(t, err)
require.True(t, proto.Equal(gotToken, token))
_, err = reader.ReadToken(context.Background())
require.Equal(t, err, io.EOF)
}
// TestSegmentFileReader_NoData verifies that the reader return io.EOF when there is no data
func TestSegmentFileReader_NoData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
entity, err := reader.ReadEntity(context.Background())
require.Nil(t, entity)
require.Equal(t, err, io.EOF)
token, err := reader.ReadToken(context.Background())
require.Nil(t, token)
require.Equal(t, err, io.EOF)
}
// TestSegmentFileReader verifies that the reader iterates through all segments paths in ascending order and returns
// io.EOF when it's done
func TestSegmentFileReader(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
entities := make([]*activity.EntityActivityLog, 0, 3)
tokens := make([]*activity.TokenCount, 0, 3)
// write 3 entity segment pieces and 3 token segment pieces
for i := 0; i < 3; i++ {
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: fmt.Sprintf("id-%d", i),
},
}}
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
fmt.Sprintf("ns-%d", i): uint64(i),
}}
writeEntitySegment(t, core, now, i, entity)
writeTokenSegment(t, core, now, i, token)
entities = append(entities, entity)
tokens = append(tokens, token)
}
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
gotEntities := make([]*activity.EntityActivityLog, 0, 3)
gotTokens := make([]*activity.TokenCount, 0, 3)
// read the entities from the reader
for entity, err := reader.ReadEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadEntity(context.Background()) {
require.NoError(t, err)
gotEntities = append(gotEntities, entity)
}
// read the tokens from the reader
for token, err := reader.ReadToken(context.Background()); !errors.Is(err, io.EOF); token, err = reader.ReadToken(context.Background()) {
require.NoError(t, err)
gotTokens = append(gotTokens, token)
}
require.Len(t, gotEntities, 3)
require.Len(t, gotTokens, 3)
// verify that the entities and tokens we got from the reader are correct
// we can't use require.Equals() here because there are protobuf differences in unexported fields
for i := 0; i < 3; i++ {
require.True(t, proto.Equal(gotEntities[i], entities[i]))
require.True(t, proto.Equal(gotTokens[i], tokens[i]))
}
}