Skip to content

Commit

Permalink
POC of using mostly a SQL API instead
Browse files Browse the repository at this point in the history
  • Loading branch information
Psycojoker committed Oct 29, 2024
1 parent ad3bd65 commit ce7a16c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
31 changes: 21 additions & 10 deletions src/aleph/db/accessors/balances.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,34 @@ def get_total_balance(


def get_total_detailed_balance(
session: DbSession, address: str, include_dapps: bool = False
) -> Dict[str, Decimal]:
where_clause = AlephBalanceDb.address == address
if not include_dapps:
where_clause = where_clause & AlephBalanceDb.dapp.is_(None)
session: DbSession, address: str, chain: Optional[str] = None
) -> tuple[Decimal, Dict[str, Decimal]]:
if chain is not None:
query = (
select(func.sum(AlephBalanceDb.balance))
.where((AlephBalanceDb.address == address) & (AlephBalanceDb.chain == chain))
.group_by(AlephBalanceDb.address)
)

select_stmt = (
result = session.execute(query).first()
return result[0] if result is not None else Decimal(0), {}

query = (
select(AlephBalanceDb.chain, func.sum(AlephBalanceDb.balance).label("balance"))
.where(where_clause)
.where(AlephBalanceDb.address == address)
.group_by(AlephBalanceDb.chain)
)

result = session.execute(select_stmt).fetchall()
balances_by_chain = {row.chain: row.balance or Decimal(0) for row in session.execute(query).fetchall()}

balances_by_chain = {row.chain: row.balance or Decimal(0) for row in result}
query = (
select(func.sum(AlephBalanceDb.balance))
.where(AlephBalanceDb.address == address)
.group_by(AlephBalanceDb.address)
)

return balances_by_chain
result = session.execute(query).first()
return result[0] if result is not None else Decimal(0), balances_by_chain


def update_balances(
Expand Down
8 changes: 1 addition & 7 deletions src/aleph/web/controllers/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,7 @@ async def get_account_balance(request: web.Request):

session_factory: DbSessionFactory = get_session_factory_from_request(request)
with session_factory() as session:
balance_detail = get_total_detailed_balance(session=session, address=address)
if query_params.chain is None:
balance = Decimal(sum(balance_detail.values()))
details = balance_detail
else:
balance = balance_detail.get(query_params.chain, Decimal(0))
details = {}
balance, details = get_total_detailed_balance(session=session, address=address, chain=query_params.chain)
total_cost = get_total_cost_for_address(session=session, address=address)
return web.json_response(
text=GetAccountBalanceResponse(
Expand Down

0 comments on commit ce7a16c

Please sign in to comment.