Skip to content

Commit 49e7da6

Browse files
committed
fix: potential invalid balance during upgrade
1 parent c39e4d4 commit 49e7da6

File tree

2 files changed

+144
-9
lines changed

2 files changed

+144
-9
lines changed

internal/storage/ledger/balances.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ func (store *Store) GetBalances(ctx context.Context, query ledgercontroller.Bala
2121
}
2222

2323
if isUpToDate {
24-
return store.getBalancesAfterUpgrade(ctx, query)
24+
return store.GetBalancesAfterUpgrade(ctx, query)
2525
} else {
26-
return store.getBalancesWhenUpgrading(ctx, query)
26+
return store.GetBalancesWhenUpgrading(ctx, query)
2727
}
2828
}
2929

30-
func (store *Store) getBalancesWhenUpgrading(ctx context.Context, query ledgercontroller.BalanceQuery) (ledgercontroller.Balances, error) {
30+
func (store *Store) GetBalancesWhenUpgrading(ctx context.Context, query ledgercontroller.BalanceQuery) (ledgercontroller.Balances, error) {
3131
return tracing.TraceWithMetric(
3232
ctx,
3333
"GetBalances",
@@ -46,12 +46,13 @@ func (store *Store) getBalancesWhenUpgrading(ctx context.Context, query ledgerco
4646
type AccountsVolumesWithLedger struct {
4747
Ledger string `bun:"ledger,type:varchar"`
4848
ledger.AccountsVolumes `bun:",extend"`
49+
Priority int `bun:"priority"` // for ordering (keep at 0)
4950
}
5051

51-
accountsVolumes := make([]AccountsVolumesWithLedger, 0)
52+
defaultAccountsVolumes := make([]AccountsVolumesWithLedger, 0)
5253
for account, assets := range query {
5354
for _, asset := range assets {
54-
accountsVolumes = append(accountsVolumes, AccountsVolumesWithLedger{
55+
defaultAccountsVolumes = append(defaultAccountsVolumes, AccountsVolumesWithLedger{
5556
Ledger: store.ledger.Name,
5657
AccountsVolumes: ledger.AccountsVolumes{
5758
Account: account,
@@ -64,7 +65,7 @@ func (store *Store) getBalancesWhenUpgrading(ctx context.Context, query ledgerco
6465
}
6566

6667
// prevent deadlocks by sorting the accountsVolumes slice
67-
slices.SortStableFunc(accountsVolumes, func(i, j AccountsVolumesWithLedger) int {
68+
slices.SortStableFunc(defaultAccountsVolumes, func(i, j AccountsVolumesWithLedger) int {
6869
if i.Account < j.Account {
6970
return -1
7071
} else if i.Account > j.Account {
@@ -94,18 +95,22 @@ func (store *Store) getBalancesWhenUpgrading(ctx context.Context, query ledgerco
9495
Column("ledger", "accounts_address", "asset").
9596
ColumnExpr("(post_commit_volumes).inputs as input").
9697
ColumnExpr("(post_commit_volumes).outputs as output").
98+
ColumnExpr("1 as priority").
9799
UnionAll(
98100
store.db.NewSelect().
99101
TableExpr(
100102
"(?) data",
101-
store.db.NewSelect().NewValues(&accountsVolumes),
103+
store.db.NewSelect().
104+
NewValues(&defaultAccountsVolumes),
102105
).
103106
Column("*"),
104107
)
105108

106109
zeroValueOrMoves := store.db.NewSelect().
107110
TableExpr("(?) data", zeroValuesAndMoves).
108-
Column("ledger", "accounts_address", "asset", "input", "output").
111+
Column("ledger", "accounts_address", "asset").
112+
ColumnExpr("first_value(input) over (partition by ledger, accounts_address, asset order by priority desc) as input").
113+
ColumnExpr("first_value(output) over (partition by ledger, accounts_address, asset order by priority desc) as output").
109114
DistinctOn("ledger, accounts_address, asset")
110115

111116
insertDefaultValue := store.db.NewInsert().
@@ -122,6 +127,7 @@ func (store *Store) getBalancesWhenUpgrading(ctx context.Context, query ledgerco
122127
// notes(gfyrag): Keep order, it ensures consistent locking order and limit deadlocks
123128
Order("accounts_address", "asset")
124129

130+
accountsVolumes := make([]ledger.AccountsVolumes, 0)
125131
finalQuery := store.db.NewSelect().
126132
With("inserted", insertDefaultValue).
127133
With("existing", selectExistingValues).
@@ -163,7 +169,7 @@ func (store *Store) getBalancesWhenUpgrading(ctx context.Context, query ledgerco
163169
)
164170
}
165171

166-
func (store *Store) getBalancesAfterUpgrade(ctx context.Context, query ledgercontroller.BalanceQuery) (ledgercontroller.Balances, error) {
172+
func (store *Store) GetBalancesAfterUpgrade(ctx context.Context, query ledgercontroller.BalanceQuery) (ledgercontroller.Balances, error) {
167173
return tracing.TraceWithMetric(
168174
ctx,
169175
"GetBalances",

internal/storage/ledger/balances_test.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package ledger_test
44

55
import (
66
"database/sql"
7+
"github.com/formancehq/go-libs/v2/bun/bunpaginate"
78
"math/big"
89
"testing"
910

@@ -129,6 +130,134 @@ func TestBalancesGet(t *testing.T) {
129130
require.NoError(t, err)
130131
require.Equal(t, 2, count)
131132
})
133+
134+
t.Run("with balance from move", func(t *testing.T) {
135+
t.Parallel()
136+
137+
tx := ledger.NewTransaction().WithPostings(
138+
ledger.NewPosting("world", "bank", "USD", big.NewInt(100)),
139+
ledger.NewPosting("world", "bank", "EUR", big.NewInt(200)),
140+
)
141+
err := store.InsertTransaction(ctx, &tx)
142+
require.NoError(t, err)
143+
144+
err = store.UpsertAccounts(ctx,
145+
&ledger.Account{
146+
Address: "world",
147+
},
148+
&ledger.Account{
149+
Address: "bank",
150+
},
151+
&ledger.Account{
152+
Address: "not-existing",
153+
},
154+
)
155+
require.NoError(t, err)
156+
157+
err = store.InsertMoves(ctx, &ledger.Move{
158+
TransactionID: *tx.ID,
159+
IsSource: true,
160+
Account: "world",
161+
Amount: (*bunpaginate.BigInt)(big.NewInt(100)),
162+
Asset: "USD",
163+
InsertionDate: tx.InsertedAt,
164+
EffectiveDate: tx.InsertedAt,
165+
PostCommitVolumes: pointer.For(ledger.NewVolumesInt64(0, 100)),
166+
})
167+
require.NoError(t, err)
168+
169+
err = store.InsertMoves(ctx, &ledger.Move{
170+
TransactionID: *tx.ID,
171+
IsSource: false,
172+
Account: "bank",
173+
Amount: (*bunpaginate.BigInt)(big.NewInt(100)),
174+
Asset: "USD",
175+
InsertionDate: tx.InsertedAt,
176+
EffectiveDate: tx.InsertedAt,
177+
PostCommitVolumes: pointer.For(ledger.NewVolumesInt64(100, 0)),
178+
})
179+
require.NoError(t, err)
180+
181+
err = store.InsertMoves(ctx, &ledger.Move{
182+
TransactionID: *tx.ID,
183+
IsSource: true,
184+
Account: "world",
185+
Amount: (*bunpaginate.BigInt)(big.NewInt(200)),
186+
Asset: "EUR",
187+
InsertionDate: tx.InsertedAt,
188+
EffectiveDate: tx.InsertedAt,
189+
PostCommitVolumes: pointer.For(ledger.NewVolumesInt64(0, 200)),
190+
})
191+
require.NoError(t, err)
192+
193+
err = store.InsertMoves(ctx, &ledger.Move{
194+
TransactionID: *tx.ID,
195+
IsSource: false,
196+
Account: "bank",
197+
Amount: (*bunpaginate.BigInt)(big.NewInt(200)),
198+
Asset: "EUR",
199+
InsertionDate: tx.InsertedAt,
200+
EffectiveDate: tx.InsertedAt,
201+
PostCommitVolumes: pointer.For(ledger.NewVolumesInt64(200, 0)),
202+
})
203+
require.NoError(t, err)
204+
205+
balances, err := store.GetBalancesWhenUpgrading(ctx, ledgercontroller.BalanceQuery{
206+
"bank": {"USD"},
207+
"world": {"USD"},
208+
"not-existing": {"USD"},
209+
})
210+
require.NoError(t, err)
211+
212+
require.NotNil(t, balances["bank"])
213+
RequireEqual(t, big.NewInt(100), balances["bank"]["USD"])
214+
RequireEqual(t, big.NewInt(-100), balances["world"]["USD"])
215+
RequireEqual(t, big.NewInt(0), balances["not-existing"]["USD"])
216+
217+
// Check a new line has been inserted into accounts_volumes table
218+
volumes := &ledger.AccountsVolumes{}
219+
err = store.GetDB().NewSelect().
220+
ModelTableExpr(store.GetPrefixedRelationName("accounts_volumes")).
221+
Where("accounts_address = ? and ledger = ? and asset = 'USD'", "bank", store.GetLedger().Name).
222+
Scan(ctx, volumes)
223+
require.NoError(t, err)
224+
225+
RequireEqual(t, big.NewInt(100), volumes.Input)
226+
RequireEqual(t, big.NewInt(0), volumes.Output)
227+
228+
err = store.GetDB().NewSelect().
229+
ModelTableExpr(store.GetPrefixedRelationName("accounts_volumes")).
230+
Where("accounts_address = ? and ledger = ? and asset = 'USD'", "world", store.GetLedger().Name).
231+
Scan(ctx, volumes)
232+
require.NoError(t, err)
233+
234+
RequireEqual(t, big.NewInt(0), volumes.Input)
235+
RequireEqual(t, big.NewInt(100), volumes.Output)
236+
237+
err = store.GetDB().NewSelect().
238+
ModelTableExpr(store.GetPrefixedRelationName("accounts_volumes")).
239+
Where("accounts_address = ? and ledger = ? and asset = 'USD'", "not-existing", store.GetLedger().Name).
240+
Scan(ctx, volumes)
241+
require.NoError(t, err)
242+
243+
RequireEqual(t, big.NewInt(0), volumes.Input)
244+
RequireEqual(t, big.NewInt(0), volumes.Output)
245+
246+
balances, err = store.GetBalancesWhenUpgrading(ctx, ledgercontroller.BalanceQuery{
247+
"bank": {"USD", "EUR"},
248+
"world": {"USD", "EUR"},
249+
"not-existing": {"USD", "EUR"},
250+
})
251+
require.NoError(t, err)
252+
253+
require.NotNil(t, balances["bank"])
254+
RequireEqual(t, big.NewInt(100), balances["bank"]["USD"])
255+
RequireEqual(t, big.NewInt(200), balances["bank"]["EUR"])
256+
RequireEqual(t, big.NewInt(-100), balances["world"]["USD"])
257+
RequireEqual(t, big.NewInt(-200), balances["world"]["EUR"])
258+
RequireEqual(t, big.NewInt(0), balances["not-existing"]["USD"])
259+
RequireEqual(t, big.NewInt(0), balances["not-existing"]["EUR"])
260+
})
132261
}
133262

134263
func TestBalancesAggregates(t *testing.T) {

0 commit comments

Comments
 (0)