Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/cmd/server/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/ent/schema/usage_log.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,6 @@ func (UsageLog) Indexes() []ent.Index {
// 复合索引用于时间范围查询
index.Fields("user_id", "created_at"),
index.Fields("api_key_id", "created_at"),
index.Fields("group_id", "created_at"),
}
}
93 changes: 93 additions & 0 deletions backend/internal/handler/admin/balance_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package admin

import (
"log/slog"
"net/http"
"strconv"
"strings"
"time"

"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)

// BalanceHandler handles admin balance management
type BalanceHandler struct {
usageService *service.UsageService
}

// NewBalanceHandler creates a new admin balance handler
func NewBalanceHandler(usageService *service.UsageService) *BalanceHandler {
return &BalanceHandler{
usageService: usageService,
}
}

// GetStats handles GET /api/v1/admin/balance/stats
func (h *BalanceHandler) GetStats(c *gin.Context) {
groupIDStr := c.Query("group_id")
if groupIDStr == "" {
response.Error(c, http.StatusBadRequest, "group_id is required")
return
}
groupID, err := strconv.ParseInt(groupIDStr, 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "invalid group_id")
return
}

startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
if startDateStr == "" || endDateStr == "" {
response.Error(c, http.StatusBadRequest, "start_date and end_date are required")
return
}

startDate, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
response.Error(c, http.StatusBadRequest, "invalid start_date format, expected YYYY-MM-DD")
return
}
endDate, err := time.Parse("2006-01-02", endDateStr)
if err != nil {
response.Error(c, http.StatusBadRequest, "invalid end_date format, expected YYYY-MM-DD")
return
}
// end_date 需要加一天,以包含当天的数据
endDate = endDate.AddDate(0, 0, 1)

page, pageSize := response.ParsePagination(c)
sortBy := c.DefaultQuery("sort_by", "total_cost")
sortOrder := c.DefaultQuery("sort_order", "desc")
search := c.Query("search")

params := &usagestats.BalanceGroupUserStatsParams{
GroupID: groupID,
StartDate: &startDate,
EndDate: &endDate,
Page: page,
PageSize: pageSize,
SortBy: sortBy,
SortOrder: sortOrder,
Search: search,
}

result, err := h.usageService.GetBalanceGroupUserStats(c.Request.Context(), params)
if err != nil {
// Service validation errors contain safe messages; internal errors should not be exposed
errMsg := err.Error()
if strings.Contains(errMsg, "is required") ||
strings.Contains(errMsg, "must be after") ||
strings.Contains(errMsg, "must not exceed") {
response.Error(c, http.StatusBadRequest, errMsg)
} else {
slog.Error("failed to get balance group user stats", "error", err)
response.Error(c, http.StatusInternalServerError, "failed to get balance stats")
}
return
}

response.Success(c, result)
}
1 change: 1 addition & 0 deletions backend/internal/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type AdminHandlers struct {
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
Balance *admin.BalanceHandler
}

// Handlers contains all HTTP handlers
Expand Down
3 changes: 3 additions & 0 deletions backend/internal/handler/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func ProvideAdminHandlers(
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
balanceHandler *admin.BalanceHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
Expand All @@ -47,6 +48,7 @@ func ProvideAdminHandlers(
Subscription: subscriptionHandler,
Usage: usageHandler,
UserAttribute: userAttributeHandler,
Balance: balanceHandler,
}
}

Expand Down Expand Up @@ -125,6 +127,7 @@ var ProviderSet = wire.NewSet(
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
admin.NewBalanceHandler,

// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
Expand Down
32 changes: 32 additions & 0 deletions backend/internal/pkg/usagestats/usage_log_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,35 @@ type AccountUsageStatsResponse struct {
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
}

// BalanceGroupUserStats represents aggregated usage statistics for a single user in a balance group
type BalanceGroupUserStats struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
Balance float64 `json:"balance"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
TotalRequests int64 `json:"total_requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheReadTokens int64 `json:"cache_read_tokens"`
}

// BalanceGroupUserStatsResponse represents the paginated response for balance group user stats
type BalanceGroupUserStatsResponse struct {
Users []BalanceGroupUserStats `json:"users"`
Total int64 `json:"total"`
}

// BalanceGroupUserStatsParams represents the query parameters for balance group user stats
type BalanceGroupUserStatsParams struct {
GroupID int64 `json:"group_id"`
StartDate *time.Time `json:"start_date"`
EndDate *time.Time `json:"end_date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
SortBy string `json:"sort_by"`
SortOrder string `json:"sort_order"`
Search string `json:"search"`
}
140 changes: 140 additions & 0 deletions backend/internal/repository/usage_log_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -2386,3 +2386,143 @@ func setToSlice(set map[int64]struct{}) []int64 {
}
return out
}

// BalanceGroupUserStats type alias
type BalanceGroupUserStats = usagestats.BalanceGroupUserStats
type BalanceGroupUserStatsResponse = usagestats.BalanceGroupUserStatsResponse
type BalanceGroupUserStatsParams = usagestats.BalanceGroupUserStatsParams

// GetBalanceGroupUserStats returns aggregated usage statistics per user for a balance group within a time range.
func (r *usageLogRepository) GetBalanceGroupUserStats(ctx context.Context, params *BalanceGroupUserStatsParams) (resp *BalanceGroupUserStatsResponse, err error) {
if params.StartDate == nil || params.EndDate == nil {
return nil, fmt.Errorf("start_date and end_date are required")
}

// Build dynamic WHERE clause for search
args := []any{params.GroupID, *params.StartDate, *params.EndDate}
argPos := 4

searchClause := ""
if params.Search != "" {
searchClause = fmt.Sprintf(" AND (u.email ILIKE $%d OR u.username ILIKE $%d)", argPos, argPos+1)
// Escape ILIKE metacharacters to prevent unintended wildcard matching
escaped := strings.NewReplacer(`\`, `\\`, `%`, `\%`, `_`, `\_`).Replace(params.Search)
searchPattern := "%" + escaped + "%"
args = append(args, searchPattern, searchPattern)
argPos += 2
}

// Validate and set sort_by
sortColumn := "total_cost"
allowedSorts := map[string]string{
"total_cost": "total_cost",
"actual_cost": "actual_cost",
"total_requests": "total_requests",
"input_tokens": "input_tokens",
"output_tokens": "output_tokens",
"cache_read_tokens": "cache_read_tokens",
"balance": "balance",
}
if col, ok := allowedSorts[params.SortBy]; ok {
sortColumn = col
}

sortOrder := "DESC"
if strings.EqualFold(params.SortOrder, "asc") {
sortOrder = "ASC"
}

// Count query for total
countQuery := fmt.Sprintf(`
WITH usage_agg AS (
SELECT user_id
FROM usage_logs
WHERE group_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY user_id
)
SELECT COUNT(DISTINCT ua.user_id)
FROM usage_agg ua
JOIN users u ON u.id = ua.user_id
WHERE 1=1%s
`, searchClause)

var total int64
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
return nil, err
}

// Data query with pagination
limit := params.PageSize
offset := (params.Page - 1) * params.PageSize

dataArgs := make([]any, len(args))
copy(dataArgs, args)
limitPos := argPos
offsetPos := argPos + 1
dataArgs = append(dataArgs, limit, offset)

dataQuery := fmt.Sprintf(`
WITH usage_agg AS (
SELECT
user_id,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as actual_cost,
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens
FROM usage_logs
WHERE group_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY user_id
)
SELECT
ua.user_id,
COALESCE(u.email, '') as email,
COALESCE(u.username, '') as username,
COALESCE(u.balance, 0) as balance,
ua.total_cost,
ua.actual_cost,
ua.total_requests,
ua.input_tokens,
ua.output_tokens,
ua.cache_read_tokens
FROM usage_agg ua
JOIN users u ON u.id = ua.user_id
WHERE 1=1%s
ORDER BY %s %s
LIMIT $%d OFFSET $%d
`, searchClause, sortColumn, sortOrder, limitPos, offsetPos)

rows, err := r.sql.QueryContext(ctx, dataQuery, dataArgs...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
resp = nil
}
}()

users := make([]BalanceGroupUserStats, 0)
for rows.Next() {
var s BalanceGroupUserStats
if err = rows.Scan(
&s.UserID, &s.Email, &s.Username, &s.Balance,
&s.TotalCost, &s.ActualCost, &s.TotalRequests,
&s.InputTokens, &s.OutputTokens, &s.CacheReadTokens,
); err != nil {
return nil, err
}
users = append(users, s)
}
if err = rows.Err(); err != nil {
return nil, err
}

resp = &BalanceGroupUserStatsResponse{
Users: users,
Total: total,
}
return resp, nil
}
4 changes: 4 additions & 0 deletions backend/internal/server/api_contract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,10 @@ func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usag
return nil, errors.New("not implemented")
}

func (r *stubUsageLogRepo) GetBalanceGroupUserStats(ctx context.Context, params *usagestats.BalanceGroupUserStatsParams) (*usagestats.BalanceGroupUserStatsResponse, error) {
return nil, errors.New("not implemented")
}

type stubSettingRepo struct {
all map[string]string
}
Expand Down
10 changes: 10 additions & 0 deletions backend/internal/server/routes/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ func RegisterAdminRoutes(
// 使用记录管理
registerUsageRoutes(admin, h)

// 余额管理
registerBalanceRoutes(admin, h)

// 用户属性管理
registerUserAttributeRoutes(admin, h)
}
Expand Down Expand Up @@ -387,3 +390,10 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}

func registerBalanceRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
balance := admin.Group("/balance")
{
balance.GET("/stats", h.Admin.Balance.GetStats)
}
}
3 changes: 3 additions & 0 deletions backend/internal/service/account_usage_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ type UsageLogRepository interface {
// Account stats
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)

// Balance group stats
GetBalanceGroupUserStats(ctx context.Context, params *usagestats.BalanceGroupUserStatsParams) (*usagestats.BalanceGroupUserStatsResponse, error)

// Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
Expand Down
Loading
Loading