Skip to content

Commit cd27d9a

Browse files
rvilgalyscopybara-github
authored andcommitted
Refactor ListUpdate args.
PiperOrigin-RevId: 640194024
1 parent d5412b6 commit cd27d9a

File tree

4 files changed

+18
-16
lines changed

4 files changed

+18
-16
lines changed

api.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ const (
4141

4242
// The api interface specifies wrappers around the Web Risk API.
4343
type api interface {
44-
ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
45-
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error)
44+
ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error)
4645
HashLookup(ctx context.Context, hashPrefix []byte,
4746
threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error)
4847
}
@@ -123,17 +122,16 @@ func (a *netAPI) parseError(httpResp *http.Response) error {
123122
}
124123

125124
// ListUpdate issues a ComputeThreatListDiff API call and returns the response.
126-
func (a *netAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
127-
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) {
125+
func (a *netAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) {
128126
resp := new(pb.ComputeThreatListDiffResponse)
129127
u := *a.url // Make a copy of URL
130128
// Add fields from ComputeThreatListDiffRequest to URL request
131129
q := u.Query()
132-
q.Set(threatTypeString, threatType.String())
133-
if len(versionToken) != 0 {
134-
q.Set(versionTokenString, base64.StdEncoding.EncodeToString(versionToken))
130+
q.Set(threatTypeString, req.GetThreatType().String())
131+
if len(req.GetVersionToken()) != 0 {
132+
q.Set(versionTokenString, base64.StdEncoding.EncodeToString(req.GetVersionToken()))
135133
}
136-
for _, compressionType := range compressionTypes {
134+
for _, compressionType := range req.GetConstraints().GetSupportedCompressions() {
137135
q.Add(supportedCompressionsString, compressionType.String())
138136
}
139137
u.RawQuery = q.Encode()

api_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ type mockAPI struct {
3939
threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error)
4040
}
4141

42-
func (m *mockAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
43-
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) {
44-
return m.listUpdate(ctx, threatType, versionToken, compressionTypes)
42+
func (m *mockAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) {
43+
return m.listUpdate(ctx, req.GetThreatType(), req.GetVersionToken(), req.GetConstraints().GetSupportedCompressions())
4544
}
4645

4746
func (m *mockAPI) HashLookup(ctx context.Context, hashPrefix []byte,
@@ -118,8 +117,14 @@ func TestNetAPI(t *testing.T) {
118117
RawIndices: &pb.RawIndices{Indices: []int32{1, 2, 3}},
119118
},
120119
}
121-
resp1, err := api.ListUpdate(context.Background(), wantReqThreatType, []byte{},
122-
wantReqCompressionTypes)
120+
req := &pb.ComputeThreatListDiffRequest{
121+
ThreatType: wantReqThreatType,
122+
Constraints: &pb.ComputeThreatListDiffRequest_Constraints{
123+
SupportedCompressions: wantReqCompressionTypes,
124+
},
125+
VersionToken: []byte{},
126+
}
127+
resp1, err := api.ListUpdate(context.Background(), req)
123128
gotResp = resp1
124129
if err != nil {
125130
t.Errorf("unexpected ListUpdate error: %v", err)

database.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) {
212212
last := db.config.now()
213213
for _, req := range s {
214214
// Query the API for the threat list and update the database.
215-
resp, err := api.ListUpdate(ctx, req.ThreatType, req.VersionToken, req.Constraints.SupportedCompressions)
215+
resp, err := api.ListUpdate(ctx, req)
216216
if err != nil {
217217
db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err)
218218
db.setError(err)

webrisk_client_system_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ func TestNetworkAPIUpdate(t *testing.T) {
4242
ThreatType: pb.ThreatType_MALWARE,
4343
}
4444

45-
dat, err := nm.ListUpdate(context.Background(), req.ThreatType,
46-
req.VersionToken, []pb.CompressionType{})
45+
dat, err := nm.ListUpdate(context.Background(), req)
4746
if err != nil {
4847
t.Fatal(err)
4948
}

0 commit comments

Comments
 (0)