Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions nonebot_plugin_value/api/api_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ async def batch_del_balance(
data_list: list[UserAccountData] = []
if currency_id is None:
currency_id = (await _get_default()).id
await _batch_del(
updates, currency_id, source, fail_then_rollback=True, return_all_on_fail=True
)
await _batch_del(updates, currency_id, source, return_all_on_fail=True)
for user_id, _ in updates:
data_list.append(await get_or_create_account(user_id, currency_id))
return data_list
Expand All @@ -139,9 +137,7 @@ async def batch_add_balance(
data_list: list[UserAccountData] = []
if currency_id is None:
currency_id = (await _get_default()).id
await _batch_add(
updates, currency_id, source, fail_then_rollback=True, return_all_on_fail=True
)
await _batch_add(updates, currency_id, source, return_all_on_fail=True)
for user_id, _ in updates:
data_list.append(await get_or_create_account(user_id, currency_id))
return data_list
Expand Down
5 changes: 4 additions & 1 deletion nonebot_plugin_value/hook/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# 事件钩子上下文
from pydantic import BaseModel, Field

from .exception import CancelAction
from .exception import CancelAction, DataUpdate


class TransactionContext(BaseModel):
Expand All @@ -19,6 +19,9 @@ class TransactionContext(BaseModel):
def cancel(self, reason: str = ""):
raise CancelAction(reason)

def commit_update(self):
raise DataUpdate(amount=self.amount)


class TransactionComplete(BaseModel):
"""Transaction complete
Expand Down
14 changes: 13 additions & 1 deletion nonebot_plugin_value/hook/exception.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
from typing import Any


class BaseException(Exception):
"""
Base exception class for this module.
"""

def __init__(self, message: str = ""):
def __init__(self, message: str = "", data: Any | None = None):
self.message = message
self.data = data


class CancelAction(BaseException):
"""
Exception raised when the user cancels an action.
"""

class DataUpdate(Exception):
"""
Exception raised when the data updated
"""

def __init__(self, amount: float) -> None:
self.amount = amount
166 changes: 86 additions & 80 deletions nonebot_plugin_value/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,26 @@ async def createcurrency(self, currency_data: CurrencyData):
async def update_currency(self, currency_data: CurrencyData) -> CurrencyMeta:
"""更新货币信息"""
async with self.session as session:
stmt = (
update(CurrencyMeta)
.where(CurrencyMeta.id == currency_data.id)
.values(**dict(currency_data))
)
await session.execute(stmt)
await session.commit()
stmt = (
select(CurrencyMeta)
.where(CurrencyMeta.id == currency_data.id)
.with_for_update()
)
result = await session.execute(stmt)
currency_meta = result.scalar_one()
session.add(currency_meta)
return currency_meta
try:
stmt = (
update(CurrencyMeta)
.where(CurrencyMeta.id == currency_data.id)
.values(**dict(currency_data))
)
await session.execute(stmt)
await session.commit()
stmt = (
select(CurrencyMeta)
.where(CurrencyMeta.id == currency_data.id)
.with_for_update()
)
result = await session.execute(stmt)
currency_meta = result.scalar_one()
session.add(currency_meta)
return currency_meta
except Exception:
await session.rollback()
raise

async def get_currency(self, currency_id: str) -> CurrencyMeta | None:
"""获取货币信息"""
Expand Down Expand Up @@ -127,44 +131,41 @@ async def get_or_create_account(
) -> UserAccount:
async with self.session as session:
"""获取或创建用户账户"""
# 获取货币配置
stmt = select(CurrencyMeta).where(CurrencyMeta.id == currency_id)
result = await session.execute(stmt)
currency = result.scalar_one_or_none()
if currency is None:
raise CurrencyNotFound(f"Currency {currency_id} not found")

# 检查账户是否存在
stmt = (
select(UserAccount)
.where(UserAccount.uni_id == get_uni_id(user_id, currency_id))
.with_for_update()
)
result = await session.execute(stmt)
account = result.scalar_one_or_none()

if account is not None:
try:
# 获取货币配置
stmt = select(CurrencyMeta).where(CurrencyMeta.id == currency_id)
result = await session.execute(stmt)
currency = result.scalar_one_or_none()
if currency is None:
raise CurrencyNotFound(f"Currency {currency_id} not found")

# 检查账户是否存在
stmt = (
select(UserAccount)
.where(UserAccount.uni_id == get_uni_id(user_id, currency_id))
.with_for_update()
)
result = await session.execute(stmt)
account = result.scalar_one_or_none()

if account is not None:
session.add(account)
return account

session.add(currency)
account = UserAccount(
uni_id=get_uni_id(user_id, currency_id),
id=user_id,
currency_id=currency_id,
balance=currency.default_balance,
last_updated=datetime.now(timezone.utc),
)
session.add(account)
await session.commit()
return account

session.add(currency)
account = UserAccount(
uni_id=get_uni_id(user_id, currency_id),
id=user_id,
currency_id=currency_id,
balance=currency.default_balance,
last_updated=datetime.now(timezone.utc),
)
session.add(account)
await session.commit()

stmt = select(UserAccount).where(
UserAccount.uni_id == get_uni_id(user_id, currency_id)
)
result = await session.execute(stmt)
account = result.scalar_one()
session.add(account)
return account
except Exception:
await session.rollback()
raise

async def set_account_frozen(
self,
Expand Down Expand Up @@ -220,41 +221,46 @@ async def update_balance(
) -> tuple[float, float]:
async with self.session as session:
"""更新余额"""
try:
# 获取账户
account = (
await session.execute(
select(UserAccount)
.where(
UserAccount.uni_id == get_uni_id(account_id, currency_id)
)
.with_for_update()
)
).scalar_one_or_none()

# 获取账户
account = (
await session.execute(
select(UserAccount)
.where(UserAccount.uni_id == get_uni_id(account_id, currency_id))
.with_for_update()
)
).scalar_one_or_none()

if account is None:
raise AccountNotFound("Account not found")
session.add(account)
if account is None:
raise AccountNotFound("Account not found")
session.add(account)

if account.frozen:
raise AccountFrozen(
f"Account {account_id} on currency {currency_id} is frozen"
)
if account.frozen:
raise AccountFrozen(
f"Account {account_id} on currency {currency_id} is frozen"
)

# 获取货币规则
currency = await session.get(CurrencyMeta, account.currency_id)
session.add(currency)
# 获取货币规则
currency = await session.get(CurrencyMeta, account.currency_id)
session.add(currency)

# 负余额检查
if amount < 0 and not getattr(currency, "allow_negative", False):
raise TransactionException("Insufficient funds")
# 负余额检查
if amount < 0 and not getattr(currency, "allow_negative", False):
raise TransactionException("Insufficient funds")

# 记录原始余额
old_balance = account.balance
# 记录原始余额
old_balance = account.balance

# 更新余额
account.balance = amount
await session.commit()
# 更新余额
account.balance = amount
await session.commit()

return old_balance, amount
return old_balance, amount
except Exception:
await session.rollback()
raise

async def list_accounts(
self, currency_id: str | None = None
Expand Down
Loading
Loading