Skip to content

Commit

Permalink
make postgres queries composable into transaction
Browse files Browse the repository at this point in the history
Create sub functions which requires sql.Tx as parameters, so they will
run queries in transaction.
The queries can be combined and reused in functions

That's a start to make possible to roll back if we detect that client
can't read the secret (bad decrypt key for example)
  • Loading branch information
Ajnasz committed May 4, 2022
1 parent 794273d commit e217df9
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 256 deletions.
38 changes: 22 additions & 16 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func TestCreateEntryForm(t *testing.T) {
value := "Foo"
connection := postgresql.NewStorage(testhelper.GetPSQLTestConn())
t.Cleanup(func() {
connection.Close()
defer connection.Close()
})

data, multi, err := createMultipart(map[string]io.Reader{
Expand Down Expand Up @@ -233,13 +233,13 @@ func TestRequestPathsCreateEntry(t *testing.T) {
{Name: "/ path", Path: "/", StatusCode: 200},
{Name: "Longer path", Path: "/other", StatusCode: 404},
}
connection := postgresql.NewStorage(testhelper.GetPSQLTestConn())
t.Cleanup(func() {
connection.Close()
})

for _, testCase := range testCases {
t.Run(testCase.Name, func(t *testing.T) {
connection := postgresql.NewStorage(testhelper.GetPSQLTestConn())
t.Cleanup(func() {
connection.Close()
})
req := httptest.NewRequest("POST", fmt.Sprintf("http://example.com%s", testCase.Path), bytes.NewReader([]byte("ASDF")))
w := httptest.NewRecorder()
NewSecretHandler(NewHandlerConfig(connection)).ServeHTTP(w, req)
Expand All @@ -264,17 +264,17 @@ func TestGetEntry(t *testing.T) {
{
"first",
"foo",
"3f356f6c-c8b1-4b48-8243-aa04d07b8873",
uuid.NewUUIDString(),
},
}

connection := postgresql.NewStorage(testhelper.GetPSQLTestConn())
t.Cleanup(func() {
connection.Close()
})

for _, testCase := range testCases {
t.Run(testCase.Name, func(t *testing.T) {
connection := postgresql.NewStorage(testhelper.GetPSQLTestConn())
t.Cleanup(func() {
connection.Close()
})

k := key.NewKey()
if err := k.Generate(); err != nil {
t.Error(err)
Expand Down Expand Up @@ -303,13 +303,12 @@ func TestGetEntry(t *testing.T) {
}
})
}

}

func TestGetEntryJSON(t *testing.T) {
connection := postgresql.NewStorage(testhelper.GetPSQLTestConn())
t.Cleanup(func() {
connection.Close()
defer connection.Close()
})
testCase := struct {
Name string
Expand All @@ -319,7 +318,7 @@ func TestGetEntryJSON(t *testing.T) {

"first",
"foo",
"3f356f6c-c8b1-4b48-8243-aa04d07b8873",
uuid.NewUUIDString(),
}

k := key.NewKey()
Expand All @@ -335,14 +334,21 @@ func TestGetEntryJSON(t *testing.T) {
}

ctx := context.Background()
connection.Write(ctx, testCase.UUID, encryptedData, time.Second*10, 1)
if err := connection.Write(ctx, testCase.UUID, encryptedData, time.Second*10, 1); err != nil {
t.Error(err)
}
fmt.Println("Wrote", testCase.UUID)

req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/%s/%s", testCase.UUID, hex.EncodeToString(rsakey)), nil)
req.Header.Add("Accept", "application/json")
w := httptest.NewRecorder()
NewSecretHandler(NewHandlerConfig(connection)).ServeHTTP(w, req)

resp := w.Result()
fmt.Println(resp.Header)
if resp.StatusCode != 200 {
t.Errorf("non 200 http statuscode: %d", resp.StatusCode)
}
var encode entries.SecretResponse
err = json.NewDecoder(resp.Body).Decode(&encode)

Expand Down Expand Up @@ -540,7 +546,7 @@ func FuzzSetAndGetEntry(f *testing.F) {
}
connection := postgresql.NewStorage(testhelper.GetPSQLTestConn())
f.Cleanup(func() {
connection.Close()
defer connection.Close()
})

f.Fuzz(func(t *testing.T, testCase string) {
Expand Down
21 changes: 21 additions & 0 deletions entries/entry_meta_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package entries

import (
"testing"
"time"
)

func Test_EntryMeta(t *testing.T) {
expire := time.Now()

meta := EntryMeta{Expire: expire.Add(time.Second)}

if meta.IsExpired() {
t.Error("entry meta should not be expired")
}

meta = EntryMeta{Expire: expire.Add(-time.Second)}
if !meta.IsExpired() {
t.Error("entry meta should be expired")
}
}
16 changes: 7 additions & 9 deletions storage/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,16 @@ import (
)

func TestStorages(t *testing.T) {
connection := postgresql.NewStorage(testhelper.GetPSQLTestConn())
psqlStorage := postgresql.NewStorage(testhelper.GetPSQLTestConn())
t.Cleanup(func() {
connection.Close()
psqlStorage.Close()
})
psqlStorage := postgresql.NewPostgresCleanableStorage(connection)

storages := map[string]storage.Cleanable{
storages := map[string]storage.Storage{
"Postgres": psqlStorage,
"Secret": secret.NewCleanableSecretStorage(
secret.NewSecretStorage(
psqlStorage,
dummy.NewEncrypter(),
),
"Secret": secret.NewSecretStorage(
psqlStorage,
dummy.NewEncrypter(),
),
}

Expand All @@ -54,6 +50,7 @@ func TestStorages(t *testing.T) {
t.Errorf("Expected expire error but got %v", err)
}
})

t.Run("Read", func(t *testing.T) {
UUID := uuid.NewUUIDString()
err := storage.Write(ctx, UUID, []byte("foo"), time.Second*-10, 1)
Expand All @@ -72,6 +69,7 @@ func TestStorages(t *testing.T) {
t.Errorf("Expected expire error but got %v", err)
}
})

t.Run("Delete", func(t *testing.T) {
UUID := uuid.NewUUIDString()
err := storage.Write(ctx, UUID, []byte("foo"), time.Second*-10, 1)
Expand Down
Loading

0 comments on commit e217df9

Please sign in to comment.