diff --git a/db/migrations/202407012100_transactions.go b/db/migrations/202407012100_transactions.go index 62d6f71d..986475f7 100644 --- a/db/migrations/202407012100_transactions.go +++ b/db/migrations/202407012100_transactions.go @@ -7,14 +7,15 @@ import ( "gorm.io/gorm" ) +// This migration +// - Replaces the old payments table with a new transactions table +// - Adds new properties to app_permissions +// - balance_type string - isolated | full +// - visibility string - isolated | full var _202407012100_transactions = &gormigrate.Migration{ ID: "202407012100_transactions", Migrate: func(tx *gorm.DB) error { - // request_event_id and app_id are not FKs, as apps and request events can be deleted - // TODO: create indexes - // type + payment hash - // if err := tx.Exec(` CREATE TABLE transactions( id integer, @@ -38,6 +39,13 @@ CREATE TABLE transactions( ); DROP TABLE payments; + +ALTER TABLE app_permissions ADD balance_type string; +ALTER TABLE app_permissions ADD visibility string; + +UPDATE app_permissions set balance_type = "full"; +UPDATE app_permissions set visibility = "full"; + `).Error; err != nil { return err } diff --git a/db/models.go b/db/models.go index 504fba21..722b8c61 100644 --- a/db/models.go +++ b/db/models.go @@ -30,6 +30,9 @@ type AppPermission struct { ExpiresAt *time.Time CreatedAt time.Time UpdatedAt time.Time + + BalanceType string + Visibility string } type RequestEvent struct { diff --git a/nip47/controllers/get_balance_controller.go b/nip47/controllers/get_balance_controller.go index 66ca9b94..c1763c0d 100644 --- a/nip47/controllers/get_balance_controller.go +++ b/nip47/controllers/get_balance_controller.go @@ -3,8 +3,10 @@ package controllers import ( "context" + "github.com/getAlby/hub/db" "github.com/getAlby/hub/logger" "github.com/getAlby/hub/nip47/models" + "github.com/getAlby/hub/transactions" "github.com/nbd-wtf/go-nostr" "github.com/sirupsen/logrus" ) @@ -14,13 +16,13 @@ const ( ) type getBalanceResponse struct { - Balance int64 `json:"balance"` + Balance uint64 `json:"balance"` // MaxAmount int `json:"max_amount"` // BudgetRenewal string `json:"budget_renewal"` } // TODO: remove checkPermission - can it be a middleware? -func (controller *nip47Controller) HandleGetBalanceEvent(ctx context.Context, nip47Request *models.Request, requestEventId uint, checkPermission checkPermissionFunc, publishResponse publishFunc) { +func (controller *nip47Controller) HandleGetBalanceEvent(ctx context.Context, nip47Request *models.Request, requestEventId uint, appId uint, checkPermission checkPermissionFunc, publishResponse publishFunc) { // basic permissions check resp := checkPermission(0) if resp != nil { @@ -32,19 +34,47 @@ func (controller *nip47Controller) HandleGetBalanceEvent(ctx context.Context, ni "request_event_id": requestEventId, }).Info("Getting balance") - balance, err := controller.lnClient.GetBalance(ctx) - if err != nil { - logger.Logger.WithFields(logrus.Fields{ - "request_event_id": requestEventId, - }).WithError(err).Error("Failed to fetch balance") - publishResponse(&models.Response{ - ResultType: nip47Request.Method, - Error: &models.Error{ - Code: models.ERROR_INTERNAL, - Message: err.Error(), - }, - }, nostr.Tags{}) - return + // TODO: optimize + var appPermission db.AppPermission + controller.db.Find(&appPermission, &db.AppPermission{ + AppId: appId, + }) + balance := uint64(0) + if appPermission.BalanceType == "isolated" { + // TODO: remove duplication in transactions service + var received struct { + Sum uint64 + } + controller.db. + Table("transactions"). + Select("SUM(amount) as sum"). + Where("app_id = ? AND type = ? AND state = ?", appPermission.AppId, transactions.TRANSACTION_TYPE_INCOMING, transactions.TRANSACTION_STATE_SETTLED).Scan(&received) + + var spent struct { + Sum uint64 + } + controller.db. + Table("transactions"). + Select("SUM(amount + fee) as sum"). + Where("app_id = ? AND type = ? AND (state = ? OR state = ?)", appPermission.AppId, transactions.TRANSACTION_TYPE_OUTGOING, transactions.TRANSACTION_STATE_SETTLED, transactions.TRANSACTION_STATE_PENDING).Scan(&spent) + + balance = received.Sum - spent.Sum + } else { + balance_signed, err := controller.lnClient.GetBalance(ctx) + balance = uint64(balance_signed) + if err != nil { + logger.Logger.WithFields(logrus.Fields{ + "request_event_id": requestEventId, + }).WithError(err).Error("Failed to fetch balance") + publishResponse(&models.Response{ + ResultType: nip47Request.Method, + Error: &models.Error{ + Code: models.ERROR_INTERNAL, + Message: err.Error(), + }, + }, nostr.Tags{}) + return + } } responsePayload := &getBalanceResponse{ diff --git a/nip47/controllers/get_balance_controller_test.go b/nip47/controllers/get_balance_controller_test.go index 259cc67c..9d43e94f 100644 --- a/nip47/controllers/get_balance_controller_test.go +++ b/nip47/controllers/get_balance_controller_test.go @@ -53,7 +53,7 @@ func TestHandleGetBalanceEvent_NoPermission(t *testing.T) { permissionsSvc := permissions.NewPermissionsService(svc.DB, svc.EventPublisher) transactionsSvc := transactions.NewTransactionsService(svc.DB) NewNip47Controller(svc.LNClient, svc.DB, svc.EventPublisher, permissionsSvc, transactionsSvc). - HandleGetBalanceEvent(ctx, nip47Request, dbRequestEvent.ID, checkPermission, publishResponse) + HandleGetBalanceEvent(ctx, nip47Request, dbRequestEvent.ID, *dbRequestEvent.AppId, checkPermission, publishResponse) assert.Nil(t, publishedResponse.Result) assert.Equal(t, models.ERROR_RESTRICTED, publishedResponse.Error.Code) @@ -86,7 +86,7 @@ func TestHandleGetBalanceEvent_WithPermission(t *testing.T) { permissionsSvc := permissions.NewPermissionsService(svc.DB, svc.EventPublisher) transactionsSvc := transactions.NewTransactionsService(svc.DB) NewNip47Controller(svc.LNClient, svc.DB, svc.EventPublisher, permissionsSvc, transactionsSvc). - HandleGetBalanceEvent(ctx, nip47Request, dbRequestEvent.ID, checkPermission, publishResponse) + HandleGetBalanceEvent(ctx, nip47Request, dbRequestEvent.ID, *dbRequestEvent.AppId, checkPermission, publishResponse) assert.Equal(t, int64(21000), publishedResponse.Result.(*getBalanceResponse).Balance) assert.Nil(t, publishedResponse.Error) diff --git a/nip47/controllers/list_transactions_controller_test.go b/nip47/controllers/list_transactions_controller_test.go index 71cbea3f..62a972cb 100644 --- a/nip47/controllers/list_transactions_controller_test.go +++ b/nip47/controllers/list_transactions_controller_test.go @@ -61,7 +61,7 @@ func TestHandleListTransactionsEvent_NoPermission(t *testing.T) { permissionsSvc := permissions.NewPermissionsService(svc.DB, svc.EventPublisher) transactionsSvc := transactions.NewTransactionsService(svc.DB) NewNip47Controller(svc.LNClient, svc.DB, svc.EventPublisher, permissionsSvc, transactionsSvc). - HandleListTransactionsEvent(ctx, nip47Request, dbRequestEvent.ID, checkPermission, publishResponse) + HandleListTransactionsEvent(ctx, nip47Request, dbRequestEvent.ID, *dbRequestEvent.AppId, checkPermission, publishResponse) assert.Nil(t, publishedResponse.Result) assert.Equal(t, models.ERROR_RESTRICTED, publishedResponse.Error.Code) @@ -117,7 +117,7 @@ func TestHandleListTransactionsEvent_WithPermission(t *testing.T) { permissionsSvc := permissions.NewPermissionsService(svc.DB, svc.EventPublisher) transactionsSvc := transactions.NewTransactionsService(svc.DB) NewNip47Controller(svc.LNClient, svc.DB, svc.EventPublisher, permissionsSvc, transactionsSvc). - HandleListTransactionsEvent(ctx, nip47Request, dbRequestEvent.ID, checkPermission, publishResponse) + HandleListTransactionsEvent(ctx, nip47Request, dbRequestEvent.ID, *dbRequestEvent.AppId, checkPermission, publishResponse) assert.Nil(t, publishedResponse.Error) diff --git a/nip47/event_handler.go b/nip47/event_handler.go index fe883a88..fde62132 100644 --- a/nip47/event_handler.go +++ b/nip47/event_handler.go @@ -307,7 +307,7 @@ func (svc *nip47Service) HandleEvent(ctx context.Context, sub *nostr.Subscriptio HandlePayKeysendEvent(ctx, nip47Request, requestEvent.ID, &app, checkPermission, publishResponse, nostr.Tags{}) case models.GET_BALANCE_METHOD: controller. - HandleGetBalanceEvent(ctx, nip47Request, requestEvent.ID, checkPermission, publishResponse) + HandleGetBalanceEvent(ctx, nip47Request, requestEvent.ID, app.ID, checkPermission, publishResponse) case models.MAKE_INVOICE_METHOD: controller. HandleMakeInvoiceEvent(ctx, nip47Request, requestEvent.ID, app.ID, checkPermission, publishResponse) diff --git a/transactions/transactions_service.go b/transactions/transactions_service.go index 7ce9806c..044fd3eb 100644 --- a/transactions/transactions_service.go +++ b/transactions/transactions_service.go @@ -119,26 +119,66 @@ func (svc *transactionsService) SendPaymentSync(ctx context.Context, payReq stri return nil, err } - // TODO: in transaction, ensure budget - var expiresAt *time.Time - if paymentRequest.Expiry > 0 { - expiresAtValue := time.Now().Add(time.Duration(paymentRequest.Expiry) * time.Second) - expiresAt = &expiresAtValue - } - dbTransaction := &db.Transaction{ - AppId: appId, - RequestEventId: requestEventId, - Type: TRANSACTION_TYPE_OUTGOING, - State: TRANSACTION_STATE_PENDING, - Amount: uint64(paymentRequest.MSatoshi), - PaymentRequest: payReq, - PaymentHash: paymentRequest.PaymentHash, - Description: paymentRequest.Description, - DescriptionHash: paymentRequest.DescriptionHash, - ExpiresAt: expiresAt, - // Metadata: metadata, - } - err = svc.db.Create(dbTransaction).Error + var dbTransaction *db.Transaction + + err = svc.db.Transaction(func(tx *gorm.DB) error { + // ensure balance for isolated apps + if appId != nil { + var appPermission db.AppPermission + tx.Find(&appPermission, &db.AppPermission{ + AppId: *appId, + }) + + if appPermission.BalanceType == "isolated" { + var received struct { + Sum uint64 + } + tx. + Table("transactions"). + Select("SUM(amount) as sum"). + Where("app_id = ? AND type = ? AND state = ?", appPermission.AppId, TRANSACTION_TYPE_INCOMING, TRANSACTION_STATE_SETTLED).Scan(&received) + + var spent struct { + Sum uint64 + } + tx. + Table("transactions"). + Select("SUM(amount + fee) as sum"). + Where("app_id = ? AND type = ? AND (state = ? OR state = ?)", appPermission.AppId, TRANSACTION_TYPE_OUTGOING, TRANSACTION_STATE_SETTLED, TRANSACTION_STATE_PENDING).Scan(&spent) + + // TODO: ensure fee reserve for external payment + balance := received.Sum - spent.Sum + if balance < uint64(paymentRequest.MSatoshi) { + // TODO: add a proper error type so INSUFFICIENT_BALANCE is returned + return errors.New("Insufficient balance") + } + } + } + + // TODO: ensure budget is not exceeded + + var expiresAt *time.Time + if paymentRequest.Expiry > 0 { + expiresAtValue := time.Now().Add(time.Duration(paymentRequest.Expiry) * time.Second) + expiresAt = &expiresAtValue + } + dbTransaction = &db.Transaction{ + AppId: appId, + RequestEventId: requestEventId, + Type: TRANSACTION_TYPE_OUTGOING, + State: TRANSACTION_STATE_PENDING, + Amount: uint64(paymentRequest.MSatoshi), + PaymentRequest: payReq, + PaymentHash: paymentRequest.PaymentHash, + Description: paymentRequest.Description, + DescriptionHash: paymentRequest.DescriptionHash, + ExpiresAt: expiresAt, + // Metadata: metadata, + } + err = tx.Create(dbTransaction).Error + return err + }) + if err != nil { logger.Logger.WithFields(logrus.Fields{ "bolt11": payReq, @@ -148,26 +188,7 @@ func (svc *transactionsService) SendPaymentSync(ctx context.Context, payReq stri var response *lnclient.PayInvoiceResponse if paymentRequest.Payee != "" && paymentRequest.Payee == lnClient.GetPubkey() { - transaction := db.Transaction{} - result := svc.db.Find(&transaction, &db.Transaction{ - Type: TRANSACTION_TYPE_INCOMING, - PaymentHash: dbTransaction.PaymentHash, - AppId: appId, - }) - err = result.Error - if err == nil && result.RowsAffected == 0 { - err = NewNotFoundError() - } - if transaction.Preimage == nil { - err = errors.New("preimage is not set on transaction. Self payments not supported.") - } - if err == nil { - fee := uint64(0) - response = &lnclient.PayInvoiceResponse{ - Preimage: *transaction.Preimage, - Fee: &fee, - } - } + response, err = svc.interceptSelfPayment(paymentRequest.PaymentHash) } else { response, err = lnClient.SendPaymentSync(ctx, payReq) } @@ -218,7 +239,7 @@ func (svc *transactionsService) SendPaymentSync(ctx context.Context, payReq stri } func (svc *transactionsService) SendKeysend(ctx context.Context, amount uint64, destination string, customRecords []lnclient.TLVRecord, lnClient lnclient.LNClient, appId *uint, requestEventId *uint) (*Transaction, error) { - // TODO: in transaction, ensure budget + // TODO: add same transaction as SendPayment to ensure balance and budget are not exceeded metadata := map[string]interface{}{} @@ -312,11 +333,23 @@ func (svc *transactionsService) SendKeysend(ctx context.Context, amount uint64, func (svc *transactionsService) LookupTransaction(ctx context.Context, paymentHash string, lnClient lnclient.LNClient, appId *uint) (*Transaction, error) { transaction := db.Transaction{} + tx := svc.db + + if appId != nil { + // TODO: optimize + var appPermission db.AppPermission + svc.db.Find(&appPermission, &db.AppPermission{ + AppId: *appId, + }) + if appPermission.Visibility == "isolated" { + tx = tx.Where("app_id == ?", *appId) + } + } + // FIXME: this is currently not unique - result := svc.db.Find(&transaction, &db.Transaction{ + result := tx.Find(&transaction, &db.Transaction{ //Type: transactionType, PaymentHash: paymentHash, - AppId: appId, }) if result.Error != nil { @@ -349,7 +382,14 @@ func (svc *transactionsService) ListTransactions(ctx context.Context, from, unti } if appId != nil { - tx = tx.Where("app_id == ?", *appId) + // TODO: optimize + var appPermission db.AppPermission + svc.db.Find(&appPermission, &db.AppPermission{ + AppId: *appId, + }) + if appPermission.Visibility == "isolated" { + tx = tx.Where("app_id == ?", *appId) + } } tx = tx.Order("created_at desc") @@ -561,3 +601,42 @@ func (svc *transactionsService) ConsumeEvent(ctx context.Context, event *events. return nil } + +func (svc *transactionsService) interceptSelfPayment(paymentHash string) (*lnclient.PayInvoiceResponse, error) { + // TODO: extract into separate function + incomingTransaction := db.Transaction{} + result := svc.db.Find(&incomingTransaction, &db.Transaction{ + Type: TRANSACTION_TYPE_INCOMING, + State: TRANSACTION_STATE_PENDING, + PaymentHash: paymentHash, + }) + if result.Error != nil { + return nil, result.Error + } + + if result.RowsAffected == 0 { + return nil, NewNotFoundError() + } + if incomingTransaction.Preimage == nil { + return nil, errors.New("preimage is not set on transaction. Self payments not supported.") + } + + // update the incoming transaction + now := time.Now() + fee := uint64(0) + err := svc.db.Model(incomingTransaction).Updates(&db.Transaction{ + State: TRANSACTION_STATE_SETTLED, + Fee: &fee, + SettledAt: &now, + }).Error + if err != nil { + return nil, err + } + + // TODO: publish event for self payment + + return &lnclient.PayInvoiceResponse{ + Preimage: *incomingTransaction.Preimage, + Fee: &fee, + }, nil +}