diff --git a/api/api_test.go b/api/api_test.go index a020141..8a69128 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -164,7 +164,7 @@ func TestCreateEntryForm(t *testing.T) { value := "Foo" connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) t.Cleanup(func() { - connection.Close() + defer connection.Close() }) data, multi, err := createMultipart(map[string]io.Reader{ @@ -233,13 +233,13 @@ func TestRequestPathsCreateEntry(t *testing.T) { {Name: "/ path", Path: "/", StatusCode: 200}, {Name: "Longer path", Path: "/other", StatusCode: 404}, } + connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) + t.Cleanup(func() { + connection.Close() + }) for _, testCase := range testCases { t.Run(testCase.Name, func(t *testing.T) { - connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) - t.Cleanup(func() { - connection.Close() - }) req := httptest.NewRequest("POST", fmt.Sprintf("http://example.com%s", testCase.Path), bytes.NewReader([]byte("ASDF"))) w := httptest.NewRecorder() NewSecretHandler(NewHandlerConfig(connection)).ServeHTTP(w, req) @@ -264,17 +264,17 @@ func TestGetEntry(t *testing.T) { { "first", "foo", - "3f356f6c-c8b1-4b48-8243-aa04d07b8873", + uuid.NewUUIDString(), }, } + connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) + t.Cleanup(func() { + connection.Close() + }) + for _, testCase := range testCases { t.Run(testCase.Name, func(t *testing.T) { - connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) - t.Cleanup(func() { - connection.Close() - }) - k := key.NewKey() if err := k.Generate(); err != nil { t.Error(err) @@ -303,13 +303,12 @@ func TestGetEntry(t *testing.T) { } }) } - } func TestGetEntryJSON(t *testing.T) { connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) t.Cleanup(func() { - connection.Close() + defer connection.Close() }) testCase := struct { Name string @@ -319,7 +318,7 @@ func TestGetEntryJSON(t *testing.T) { "first", "foo", - "3f356f6c-c8b1-4b48-8243-aa04d07b8873", + uuid.NewUUIDString(), } k := key.NewKey() @@ -335,7 +334,10 @@ func TestGetEntryJSON(t *testing.T) { } ctx := context.Background() - connection.Write(ctx, testCase.UUID, encryptedData, time.Second*10, 1) + if err := connection.Write(ctx, testCase.UUID, encryptedData, time.Second*10, 1); err != nil { + t.Error(err) + } + fmt.Println("Wrote", testCase.UUID) req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/%s/%s", testCase.UUID, hex.EncodeToString(rsakey)), nil) req.Header.Add("Accept", "application/json") @@ -343,6 +345,10 @@ func TestGetEntryJSON(t *testing.T) { NewSecretHandler(NewHandlerConfig(connection)).ServeHTTP(w, req) resp := w.Result() + fmt.Println(resp.Header) + if resp.StatusCode != 200 { + t.Errorf("non 200 http statuscode: %d", resp.StatusCode) + } var encode entries.SecretResponse err = json.NewDecoder(resp.Body).Decode(&encode) @@ -540,7 +546,7 @@ func FuzzSetAndGetEntry(f *testing.F) { } connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) f.Cleanup(func() { - connection.Close() + defer connection.Close() }) f.Fuzz(func(t *testing.T, testCase string) { diff --git a/entries/entry_meta_test.go b/entries/entry_meta_test.go new file mode 100644 index 0000000..c1656cf --- /dev/null +++ b/entries/entry_meta_test.go @@ -0,0 +1,21 @@ +package entries + +import ( + "testing" + "time" +) + +func Test_EntryMeta(t *testing.T) { + expire := time.Now() + + meta := EntryMeta{Expire: expire.Add(time.Second)} + + if meta.IsExpired() { + t.Error("entry meta should not be expired") + } + + meta = EntryMeta{Expire: expire.Add(-time.Second)} + if !meta.IsExpired() { + t.Error("entry meta should be expired") + } +} diff --git a/storage/integration/integration_test.go b/storage/integration/integration_test.go index f32fb35..847568e 100644 --- a/storage/integration/integration_test.go +++ b/storage/integration/integration_test.go @@ -15,20 +15,16 @@ import ( ) func TestStorages(t *testing.T) { - connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) + psqlStorage := postgresql.NewStorage(testhelper.GetPSQLTestConn()) t.Cleanup(func() { - connection.Close() + psqlStorage.Close() }) - psqlStorage := postgresql.NewPostgresCleanableStorage(connection) - storages := map[string]storage.Cleanable{ + storages := map[string]storage.Storage{ "Postgres": psqlStorage, - "Secret": secret.NewCleanableSecretStorage( - secret.NewSecretStorage( - psqlStorage, - dummy.NewEncrypter(), - ), + "Secret": secret.NewSecretStorage( psqlStorage, + dummy.NewEncrypter(), ), } @@ -54,6 +50,7 @@ func TestStorages(t *testing.T) { t.Errorf("Expected expire error but got %v", err) } }) + t.Run("Read", func(t *testing.T) { UUID := uuid.NewUUIDString() err := storage.Write(ctx, UUID, []byte("foo"), time.Second*-10, 1) @@ -72,6 +69,7 @@ func TestStorages(t *testing.T) { t.Errorf("Expected expire error but got %v", err) } }) + t.Run("Delete", func(t *testing.T) { UUID := uuid.NewUUIDString() err := storage.Write(ctx, UUID, []byte("foo"), time.Second*-10, 1) diff --git a/storage/postgresql/postgresql_storage.go b/storage/postgresql/postgresql_storage.go index fd3a4b7..7ac3ee8 100644 --- a/storage/postgresql/postgresql_storage.go +++ b/storage/postgresql/postgresql_storage.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "log" "strings" "time" @@ -27,15 +26,18 @@ func (s Storage) Close() error { // Write stores a new entry in database func (s Storage) Write(ctx context.Context, UUID string, entry []byte, expire time.Duration, remainingReads int) error { - now := time.Now() - k, err := key.NewGeneratedKey() + tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } - deleteKey := k.ToHex() - _, err = s.db.ExecContext(ctx, `INSERT INTO entries (uuid, data, created, expire, remaining_reads, delete_key) VALUES ($1, $2, $3, $4, $5, $6) RETURNING uuid, delete_key;`, UUID, entry, now, now.Add(expire), remainingReads, deleteKey) - return err + if err = s.write(tx, UUID, entry, expire, remainingReads); err != nil { + tx.Rollback() + return err + } + + tx.Commit() + return nil } // ReadMeta to get entry metadata (without the actual secret) @@ -47,7 +49,59 @@ func (s Storage) ReadMeta(ctx context.Context, UUID string) (*entries.EntryMeta, return nil, err } - row := tx.QueryRowContext(ctx, ` + meta, err := s.readMeta(tx, UUID) + if err != nil { + tx.Rollback() + if err == sql.ErrNoRows { + return nil, entries.ErrEntryNotFound + } + + return nil, err + } + + if meta.IsExpired() { + if err := s.setAccessed(tx, UUID); err != nil { + tx.Rollback() + return nil, err + } + + if err = tx.Commit(); err != nil { + return nil, err + } + + return nil, entries.ErrEntryExpired + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return meta, nil +} + +func (s Storage) write(tx *sql.Tx, UUID string, entry []byte, expire time.Duration, remainingReads int) error { + now := time.Now() + k, err := key.NewGeneratedKey() + if err != nil { + return err + } + deleteKey := k.ToHex() + + _, err = tx.Exec(`INSERT INTO entries (uuid, data, created, expire, remaining_reads, delete_key) VALUES ($1, $2, $3, $4, $5, $6) RETURNING uuid, delete_key;`, UUID, entry, now, now.Add(expire), remainingReads, deleteKey) + + return err +} + +func (s Storage) setAccessed(tx *sql.Tx, UUID string) error { + if _, err := tx.Exec("UPDATE entries SET accessed=$1 WHERE uuid=$2", time.Now(), UUID); err != nil { + return err + } + + return nil +} + +func (s Storage) readMeta(tx *sql.Tx, UUID string) (*entries.EntryMeta, error) { + row := tx.QueryRow(` SELECT created, accessed, @@ -58,6 +112,7 @@ func (s Storage) ReadMeta(ctx context.Context, UUID string) (*entries.EntryMeta, entries WHERE uuid=$1 + AND remaining_reads > 0 `, UUID) var created time.Time @@ -65,13 +120,9 @@ func (s Storage) ReadMeta(ctx context.Context, UUID string) (*entries.EntryMeta, var expireNullTime sql.NullTime var remainingReadsNullInt32 sql.NullInt32 var deleteKeyNullString sql.NullString - err = row.Scan(&created, &accessedNullTime, &expireNullTime, &remainingReadsNullInt32, &deleteKeyNullString) + err := row.Scan(&created, &accessedNullTime, &expireNullTime, &remainingReadsNullInt32, &deleteKeyNullString) if err != nil { - tx.Rollback() - if err == sql.ErrNoRows { - return nil, entries.ErrEntryNotFound - } return nil, err } @@ -102,62 +153,36 @@ func (s Storage) ReadMeta(ctx context.Context, UUID string) (*entries.EntryMeta, DeleteKey: deleteKey, } - if meta.IsExpired() { - _, err = tx.ExecContext(ctx, ` - UPDATE entries - SET data=$1, accessed=$2 - WHERE uuid=$3 - `, nil, time.Now(), UUID) - - if err != nil { - tx.Rollback() - return nil, err - } - err := tx.Commit() - if err != nil { - return nil, err - } - - return nil, entries.ErrEntryExpired - } - - err = tx.Commit() - - if err != nil { - return nil, err - } - return meta, nil } -// Get to get entry including the actual secret +func (s Storage) updateReadCount(tx *sql.Tx, UUID string) error { + _, err := tx.Exec("UPDATE entries SET remaining_reads = remaining_reads - 1 WHERE uuid=$1;", UUID) + return err +} + +// read to get entry including the actual secret // returns the data if the secret not expired yet // updates read count -func (s Storage) Get(ctx context.Context, UUID string) (*entries.Entry, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } +func (s Storage) read(tx *sql.Tx, UUID string) (*entries.Entry, error) { - row := tx.QueryRowContext(ctx, "SELECT data, created, accessed, expire FROM entries WHERE uuid=$1", UUID) + row := tx.QueryRow(`SELECT data, created, accessed, expire, remaining_reads FROM entries + WHERE uuid=$1 + AND remaining_reads > 0 + LIMIT 1`, UUID) var data []byte var created time.Time var accessedNullTime sql.NullTime var expireNullTime sql.NullTime - err = row.Scan(&data, &created, &accessedNullTime, &expireNullTime) - - if err != nil { - tx.Rollback() - if err == sql.ErrNoRows { - return nil, entries.ErrEntryNotFound - } - + var remainingReadsNullInt32 sql.NullInt32 + if err := row.Scan(&data, &created, &accessedNullTime, &expireNullTime, &remainingReadsNullInt32); err != nil { return nil, err } var accessed time.Time var expire time.Time + var maxReads int32 if accessedNullTime.Valid { accessed = accessedNullTime.Time @@ -165,33 +190,16 @@ func (s Storage) Get(ctx context.Context, UUID string) (*entries.Entry, error) { if expireNullTime.Valid { expire = expireNullTime.Time } + if remainingReadsNullInt32.Valid { + maxReads = remainingReadsNullInt32.Int32 + } meta := entries.EntryMeta{ UUID: UUID, Created: created, Accessed: accessed, Expire: expire, - } - - if meta.IsExpired() { - _, err = tx.ExecContext(ctx, "UPDATE entries SET data=$1, accessed=$2 WHERE uuid=$3", nil, time.Now(), UUID) - - if err != nil { - tx.Rollback() - return nil, err - } - err := tx.Commit() - if err != nil { - return nil, err - } - - return nil, entries.ErrEntryExpired - } - - err = tx.Commit() - - if err != nil { - return nil, err + MaxReads: maxReads, } return &entries.Entry{ @@ -210,13 +218,7 @@ func (s Storage) Read(ctx context.Context, UUID string) (*entries.Entry, error) return nil, err } - row := tx.QueryRowContext(ctx, "SELECT data, created, accessed, expire FROM entries WHERE uuid=$1", UUID) - - var data []byte - var created time.Time - var accessedNullTime sql.NullTime - var expireNullTime sql.NullTime - err = row.Scan(&data, &created, &accessedNullTime, &expireNullTime) + entry, err := s.read(tx, UUID) if err != nil { tx.Rollback() @@ -226,16 +228,15 @@ func (s Storage) Read(ctx context.Context, UUID string) (*entries.Entry, error) return nil, err } - queries := []string{ - "UPDATE entries SET remaining_reads = remaining_reads - 1 WHERE uuid=$1;", - "DELETE FROM entries WHERE uuid=$1 AND remaining_reads < 1;", + if entry.IsExpired() { + s.setAccessed(tx, UUID) + tx.Commit() + return nil, entries.ErrEntryExpired } - for _, query := range queries { - _, err = tx.ExecContext(ctx, query, UUID) - if err != nil { - tx.Rollback() - return nil, err - } + + if err := s.updateReadCount(tx, UUID); err != nil { + tx.Rollback() + return nil, err } err = tx.Commit() @@ -243,31 +244,7 @@ func (s Storage) Read(ctx context.Context, UUID string) (*entries.Entry, error) return nil, err } - var accessed time.Time - var expire time.Time - - if accessedNullTime.Valid { - accessed = accessedNullTime.Time - } - if expireNullTime.Valid { - expire = expireNullTime.Time - } - - meta := entries.EntryMeta{ - UUID: UUID, - Created: created, - Accessed: accessed, - Expire: expire, - } - - if meta.IsExpired() { - return nil, entries.ErrEntryExpired - } - - return &entries.Entry{ - EntryMeta: meta, - Data: data, - }, nil + return entry, nil } // Delete deletes the entry from the database @@ -346,29 +323,10 @@ func (s Storage) DeleteExpired(ctx context.Context) error { _, err = tx.ExecContext(ctx, "DELETE FROM entries WHERE expire < NOW() OR remaining_reads < 1;") if err != nil { + fmt.Println("DELETE ERRRO", err) tx.Rollback() return err } return tx.Commit() } - -// NewPostgresCleanableStorage Creates a cleanable psql storage instance -func NewPostgresCleanableStorage(s *Storage) *PostgresCleanableStorage { - return &PostgresCleanableStorage{s} -} - -// PostgresCleanableStorage extends the regular PostgresqlStorage with a Clean -// method to remove all entries -type PostgresCleanableStorage struct { - *Storage -} - -// Clean deletes all entries from the database -func (s PostgresCleanableStorage) Clean() { - _, err := s.db.Exec("TRUNCATE entries;") - - if err != nil { - log.Fatal(err) - } -} diff --git a/storage/postgresql/postgresql_storage_test.go b/storage/postgresql/postgresql_storage_test.go index 72d56c0..ad3e7e3 100644 --- a/storage/postgresql/postgresql_storage_test.go +++ b/storage/postgresql/postgresql_storage_test.go @@ -2,7 +2,6 @@ package postgresql import ( "context" - "database/sql" "testing" "time" @@ -10,39 +9,7 @@ import ( "github.com/Ajnasz/sekret.link/uuid" ) -// func TestPostgresqlStorageWriteGet(t *testing.T) { -// psqlConn := testhelper.GetPSQLTestConn() -// storage := NewStorage(psqlConn) -// t.Cleanup(func() { -// defer storage.Close() -// }) -// testCases := []string{ -// "foo", -// } - -// for _, testCase := range testCases { -// t.Run(testCase, func(t *testing.T) { - -// UUID := uuid.NewUUIDString() -// err := storage.Write(UUID, []byte("foo"), time.Second*10, 1) - -// if err != nil { -// t.Fatal(err) -// } -// res, err := storage.Get(UUID) -// if err != nil { -// t.Fatal(err) -// } - -// actual := string(res.Data) -// if actual != testCase { -// t.Errorf("expected: %s, actual: %s", testCase, actual) -// } -// }) -// } -// } - -func TestPostgresqlStorageWrite(t *testing.T) { +func Test_PostgresqlStorageWrite(t *testing.T) { psqlConn := testhelper.GetPSQLTestConn() storage := NewStorage(psqlConn) t.Cleanup(func() { @@ -50,38 +17,34 @@ func TestPostgresqlStorageWrite(t *testing.T) { }) testCases := []struct { - Name string - Secret string - Reads int - Remaining int - ExistanceErr error + Name string + Secret string + Reads int + Remaining int }{ { - Name: "Simple get", - Secret: "foo", - Reads: 1, - Remaining: 0, - ExistanceErr: sql.ErrNoRows, + Name: "Simple get", + Secret: "foo", + Reads: 1, + Remaining: 0, }, { - Name: "Exist get", - Secret: "bar", - Reads: 2, - Remaining: 1, - ExistanceErr: nil, + Name: "Exist get", + Secret: "bar", + Reads: 2, + Remaining: 1, }, { - Name: "Exist get 2", - Secret: "bar", - Reads: 3, - Remaining: 2, - ExistanceErr: nil, + Name: "Exist get 2", + Secret: "bar", + Reads: 3, + Remaining: 2, }, } for _, testCase := range testCases { t.Run(testCase.Name, func(t *testing.T) { - + t.Logf("%+v", testCase) UUID := uuid.NewUUIDString() ctx := context.Background() err := storage.Write(ctx, UUID, []byte(testCase.Secret), time.Second*10, testCase.Reads) @@ -96,16 +59,15 @@ func TestPostgresqlStorageWrite(t *testing.T) { actual := string(res.Data) if actual != testCase.Secret { - t.Errorf("expected: %s, actual: %s", testCase.Secret, actual) + t.Errorf("%s expected: %s, actual: %s", UUID, testCase.Secret, actual) } - var data []byte var remainingReads int - row := storage.db.QueryRow("SELECT data, remaining_reads FROM entries WHERE uuid=$1", UUID) - err = row.Scan(&data, &remainingReads) - if err != testCase.ExistanceErr { - t.Fatal(err) + row := storage.db.QueryRow("SELECT remaining_reads FROM entries WHERE uuid=$1", UUID) + err = row.Scan(&remainingReads) + if err != nil { + t.Fatalf("%s: %v", UUID, err) } if remainingReads != testCase.Remaining { @@ -121,7 +83,6 @@ func TestPostgresqlStorageVerifyDelete(t *testing.T) { t.Cleanup(func() { storage.Close() }) - defer storage.Close() testCases := []struct { UUID string Key string diff --git a/storage/secret/secret_storage.go b/storage/secret/secret_storage.go index 14a9361..1e4017d 100644 --- a/storage/secret/secret_storage.go +++ b/storage/secret/secret_storage.go @@ -94,20 +94,3 @@ func (s SecretStorage) Delete(ctx context.Context, UUID string) error { func (s SecretStorage) DeleteExpired(ctx context.Context) error { return s.internalStorage.DeleteExpired(ctx) } - -// NewCleanableSecretStorage Creates a cleanable secret storage -func NewCleanableSecretStorage(s *SecretStorage, internal storage.Cleanable) CleanableSecretStorage { - return CleanableSecretStorage{s, internal} -} - -// CleanableSecretStorage Storage which implements CleanableStorage interface, -// to allow to clean every entry from the underlying storage -type CleanableSecretStorage struct { - *SecretStorage - internalStorage storage.Cleanable -} - -// Clean Executes the clean call on the storage -func (s CleanableSecretStorage) Clean() { - s.internalStorage.Clean() -} diff --git a/storage/secret/secret_storage_test.go b/storage/secret/secret_storage_test.go index bd5c514..394fda9 100644 --- a/storage/secret/secret_storage_test.go +++ b/storage/secret/secret_storage_test.go @@ -14,20 +14,15 @@ import ( func TestSecretStorage(t *testing.T) { testData := "Lorem ipusm dolor sit amet" - connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) - t.Cleanup(func() { - connection.Close() - }) - psqlStorage := postgresql.PostgresCleanableStorage{connection} - storage := &CleanableSecretStorage{ - NewSecretStorage( - psqlStorage, - dummy.NewEncrypter(), - ), + psqlStorage := postgresql.NewStorage(testhelper.GetPSQLTestConn()) + storage := NewSecretStorage( psqlStorage, - } - // TODO defer storage.Close() + dummy.NewEncrypter(), + ) + t.Cleanup(func() { + storage.Close() + }) UUID := uuid.NewUUIDString() ctx := context.Background() err := storage.Write(ctx, UUID, []byte(testData), time.Second*10, 1) diff --git a/storage/storage.go b/storage/storage.go index c64fb16..bf633bd 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -37,12 +37,6 @@ type Storage interface { Writer } -// Cleanable Interface which enables to remove every entry from a storae -type Cleanable interface { - Storage - Clean() -} - // Verifyable an interface which extends the EntryStorage with a // VerifyDelete method type Verifyable interface {