Skip to content

Commit 298d84f

Browse files
committed
add file size sending and improve tests for sigv4
Signed-off-by: Ashley Davis <ashley.davis@cyberark.com>
1 parent f8ee4b1 commit 298d84f

File tree

3 files changed

+156
-36
lines changed

3 files changed

+156
-36
lines changed

internal/cyberark/dataupload/dataupload.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ func (c *CyberArkClient) PutSnapshot(ctx context.Context, snapshot Snapshot) err
124124
checksumHex := hex.EncodeToString(checksum)
125125
checksumBase64 := base64.StdEncoding.EncodeToString(checksum)
126126

127-
presignedUploadURL, username, err := c.retrievePresignedUploadURL(ctx, checksumHex, snapshot.ClusterID)
127+
presignedUploadURL, username, err := c.retrievePresignedUploadURL(ctx, checksumHex, snapshot.ClusterID, int64(encodedBody.Len()))
128128
if err != nil {
129129
return fmt.Errorf("while retrieving snapshot upload URL: %s", err)
130130
}
@@ -168,20 +168,30 @@ func (c *CyberArkClient) PutSnapshot(ctx context.Context, snapshot Snapshot) err
168168
return nil
169169
}
170170

171-
func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksum string, clusterID string) (string, string, error) {
171+
// RetrievePresignedUploadURLRequest is the JSON body sent to the inventory API to request a presigned upload URL.
172+
type RetrievePresignedUploadURLRequest struct {
173+
ClusterID string `json:"cluster_id"`
174+
Checksum string `json:"checksum_sha256"`
175+
176+
// AgentVersion is the v-prefixed version of the agent uploading the snapshot.
177+
// Note that the backend relies on this version being v-prefixed semver.
178+
AgentVersion string `json:"agent_version"`
179+
180+
// FileSize is the size of the data we'll upload in bytes
181+
FileSize int64 `json:"file_size"`
182+
}
183+
184+
func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksum string, clusterID string, fileSize int64) (string, string, error) {
172185
uploadURL, err := url.JoinPath(c.baseURL, apiPathSnapshotLinks)
173186
if err != nil {
174187
return "", "", err
175188
}
176189

177-
request := struct {
178-
ClusterID string `json:"cluster_id"`
179-
Checksum string `json:"checksum_sha256"`
180-
AgentVersion string `json:"agent_version"`
181-
}{
190+
request := RetrievePresignedUploadURLRequest{
182191
ClusterID: clusterID,
183192
Checksum: checksum,
184193
AgentVersion: version.PreflightVersion,
194+
FileSize: fileSize,
185195
}
186196

187197
encodedBody := &bytes.Buffer{}

internal/cyberark/dataupload/mock.go

Lines changed: 138 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@ package dataupload
22

33
import (
44
"bytes"
5+
"crypto/rand"
56
"crypto/sha256"
67
"encoding/base64"
8+
"encoding/hex"
79
"encoding/json"
810
"fmt"
911
"io"
1012
"net/http"
1113
"net/http/httptest"
14+
"net/url"
15+
"sync"
1216
"testing"
1317

1418
"github.com/stretchr/testify/assert"
@@ -25,9 +29,19 @@ const (
2529
successClusterID = "ffffffff-ffff-ffff-ffff-ffffffffffff"
2630
)
2731

32+
type uploadValues struct {
33+
ClusterID string
34+
FileSize int64
35+
}
36+
2837
type mockDataUploadServer struct {
2938
t testing.TB
3039
serverURL string
40+
41+
mux *http.ServeMux
42+
43+
expectedUploadValues map[string]uploadValues
44+
expectedUploadValuesMutex sync.Mutex
3145
}
3246

3347
// MockDataUploadServer starts a server which mocks the CyberArk
@@ -45,39 +59,48 @@ type mockDataUploadServer struct {
4559
// responses.
4660
func MockDataUploadServer(t testing.TB) (string, *http.Client) {
4761
mux := http.NewServeMux()
48-
server := httptest.NewTLSServer(mux)
49-
t.Cleanup(server.Close)
5062
mds := &mockDataUploadServer{
51-
t: t,
52-
serverURL: server.URL,
63+
t: t,
64+
65+
expectedUploadValues: make(map[string]uploadValues),
5366
}
54-
mux.Handle("/", mds)
67+
68+
mux.HandleFunc("POST "+apiPathSnapshotLinks, mds.handleSnapshotLinks)
69+
70+
// The path includes random data to ensure that each request is treated separately by the mock server, allowing us to track data across calls.
71+
// It also ensures that the client isn't using some pre-saved path and is actually using the presigned URL returned by the mock server in the previous step, which is important for test validity.
72+
mux.HandleFunc("PUT /presigned-upload/{randData}", mds.handlePresignedUpload)
73+
74+
server := httptest.NewTLSServer(mds)
75+
t.Cleanup(server.Close)
76+
77+
mds.mux = mux
78+
mds.serverURL = server.URL
79+
5580
httpClient := server.Client()
5681
httpClient.Transport = transport.NewDebuggingRoundTripper(httpClient.Transport, transport.DebugByContext)
5782
return server.URL, httpClient
5883
}
5984

6085
func (mds *mockDataUploadServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
6186
mds.t.Log(r.Method, r.RequestURI)
62-
switch r.URL.Path {
63-
case apiPathSnapshotLinks:
64-
mds.handleSnapshotLinks(w, r)
65-
return
66-
case "/presigned-upload":
67-
mds.handlePresignedUpload(w, r)
68-
return
69-
default:
70-
w.WriteHeader(http.StatusNotFound)
71-
}
87+
88+
mds.mux.ServeHTTP(w, r)
7289
}
7390

74-
func (mds *mockDataUploadServer) handleSnapshotLinks(w http.ResponseWriter, r *http.Request) {
75-
if r.Method != http.MethodPost {
76-
w.WriteHeader(http.StatusMethodNotAllowed)
77-
_, _ = w.Write([]byte(`{"message":"method not allowed"}`))
78-
return
91+
// randHex reads 8 random bytes and returns them as a hex string. It is used to generate
92+
// unique paths per-request to ensure that file size is tracked across calls.
93+
func randHex() string {
94+
b := make([]byte, 8)
95+
_, err := rand.Read(b)
96+
if err != nil {
97+
panic("failed to read random bytes: " + err.Error())
7998
}
8099

100+
return hex.EncodeToString(b)
101+
}
102+
103+
func (mds *mockDataUploadServer) handleSnapshotLinks(w http.ResponseWriter, r *http.Request) {
81104
if r.Header.Get("User-Agent") != version.UserAgent() {
82105
http.Error(w, "should set user agent on all requests", http.StatusInternalServerError)
83106
return
@@ -99,13 +122,11 @@ func (mds *mockDataUploadServer) handleSnapshotLinks(w http.ResponseWriter, r *h
99122
return
100123
}
101124

125+
var req RetrievePresignedUploadURLRequest
126+
102127
decoder := json.NewDecoder(r.Body)
103-
var req struct {
104-
ClusterID string `json:"cluster_id"`
105-
Checksum string `json:"checksum_sha256"`
106-
AgentVersion string `json:"agent_version"`
107-
}
108128
decoder.DisallowUnknownFields()
129+
109130
if err := decoder.Decode(&req); err != nil {
110131
http.Error(w, `{"error": "Invalid request format"}`, http.StatusBadRequest)
111132
return
@@ -135,10 +156,33 @@ func (mds *mockDataUploadServer) handleSnapshotLinks(w http.ResponseWriter, r *h
135156
return
136157
}
137158

159+
if req.FileSize <= 0 {
160+
http.Error(w, "file size must be greater than 0", http.StatusInternalServerError)
161+
return
162+
}
163+
164+
randomData := randHex()
165+
166+
mds.expectedUploadValuesMutex.Lock()
167+
defer mds.expectedUploadValuesMutex.Unlock()
168+
169+
uploadValues := uploadValues{
170+
ClusterID: req.ClusterID,
171+
FileSize: req.FileSize,
172+
}
173+
174+
mds.expectedUploadValues[randomData] = uploadValues
175+
176+
presignedURL, err := url.JoinPath(mds.serverURL, "presigned-upload", randomData)
177+
if err != nil {
178+
http.Error(w, "failed to generate presigned URL", http.StatusInternalServerError)
179+
mds.t.Logf("failed to generate presigned URL: %v", err)
180+
return
181+
}
182+
138183
// Write response body
139184
w.WriteHeader(http.StatusOK)
140185
w.Header().Set("Content-Type", "application/json")
141-
presignedURL := mds.serverURL + "/presigned-upload"
142186
_ = json.NewEncoder(w).Encode(struct {
143187
URL string `json:"url"`
144188
}{presignedURL})
@@ -155,9 +199,18 @@ const amzExampleChecksumError = `<?xml version="1.0" encoding="UTF-8"?>
155199
</Error>`
156200

157201
func (mds *mockDataUploadServer) handlePresignedUpload(w http.ResponseWriter, r *http.Request) {
158-
if r.Method != http.MethodPut {
159-
w.WriteHeader(http.StatusMethodNotAllowed)
160-
_, _ = w.Write([]byte(`{"message":"method not allowed"}`))
202+
randData := r.PathValue("randData")
203+
if randData == "" {
204+
http.Error(w, "missing randData in path; should match that returned in presigned url", http.StatusInternalServerError)
205+
return
206+
}
207+
208+
mds.expectedUploadValuesMutex.Lock()
209+
uploadValues, ok := mds.expectedUploadValues[randData]
210+
mds.expectedUploadValuesMutex.Unlock()
211+
212+
if !ok {
213+
http.Error(w, "didn't find a prior call to generate presigned URL", http.StatusInternalServerError)
161214
return
162215
}
163216

@@ -178,9 +231,65 @@ func (mds *mockDataUploadServer) handlePresignedUpload(w http.ResponseWriter, r
178231
return
179232
}
180233

234+
sseHeader := r.Header.Get("X-Amz-Server-Side-Encryption")
235+
if sseHeader != "AES256" {
236+
http.Error(w, "should set x-amz-server-side-encryption header to AES256 on all requests", http.StatusInternalServerError)
237+
return
238+
}
239+
240+
taggingHeader := r.Header.Get("X-Amz-Tagging")
241+
if taggingHeader == "" {
242+
http.Error(w, "should set x-amz-tagging header on all requests", http.StatusInternalServerError)
243+
return
244+
}
245+
246+
tags, err := url.ParseQuery(taggingHeader)
247+
if err != nil {
248+
http.Error(w, "x-amz-tagging header should be encoded as a valid query string", http.StatusInternalServerError)
249+
return
250+
}
251+
252+
if tags.Get("agent_version") != version.PreflightVersion {
253+
http.Error(w, fmt.Sprintf("x-amz-tagging should contain an agent_version tag with value %s", version.PreflightVersion), http.StatusInternalServerError)
254+
return
255+
}
256+
257+
if tags.Get("tenant_id") == "" {
258+
// TODO: if we change setup a bit, we can check the tenant_id matches the expected tenant_id from the test config, but for now, just check it's set
259+
http.Error(w, "x-amz-tagging should contain a tenant_id tag", http.StatusInternalServerError)
260+
return
261+
}
262+
263+
if tags.Get("upload_type") != "k8s_snapshot" {
264+
http.Error(w, "x-amz-tagging should contain an upload_type tag with value k8s_snapshot", http.StatusInternalServerError)
265+
return
266+
}
267+
268+
if tags.Get("uploader_id") != uploadValues.ClusterID {
269+
http.Error(w, "x-amz-tagging should contain an uploader_id tag which matches the cluster ID sent in the RetrievePresignedUploadURL request", http.StatusInternalServerError)
270+
return
271+
}
272+
273+
if tags.Get("username") == "" {
274+
// TODO: if we change setup a bit, we can check the username matches the expected username from the test config
275+
// but for now, just check it's set
276+
http.Error(w, "x-amz-tagging should contain a username tag", http.StatusInternalServerError)
277+
return
278+
}
279+
280+
if tags.Get("vendor") != "k8s" {
281+
http.Error(w, "x-amz-tagging should contain a vendor tag with value k8s", http.StatusInternalServerError)
282+
return
283+
}
284+
181285
body, err := io.ReadAll(r.Body)
182286
require.NoError(mds.t, err)
183287

288+
if uploadValues.FileSize != int64(len(body)) {
289+
http.Error(w, fmt.Sprintf("file size in request body should match that sent in RetrievePresignedUploadURL request; expected %d, got %d", uploadValues.FileSize, len(body)), http.StatusInternalServerError)
290+
return
291+
}
292+
184293
hash := sha256.New()
185294
_, err = hash.Write(body)
186295
require.NoError(mds.t, err)

internal/cyberark/servicediscovery/testdata/discovery_success.json.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"dr_region": "us-east-2",
44
"subdomain": "venafi-test",
55
"platform_id": "platform-123",
6+
"tenant_id": "tenant-123",
67
"identity_id": "identity-456",
78
"default_url": "https://venafi-test.integration-cyberark.cloud",
89
"tenant_flags": {

0 commit comments

Comments
 (0)