Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 105 additions & 3 deletions api/services/billing_service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
from collections.abc import Sequence
Expand Down Expand Up @@ -31,6 +32,11 @@ class BillingService:

compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)

# Redis key prefix for tenant plan cache
_PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
# Cache TTL: 10 minutes
_PLAN_CACHE_TTL = 600

@classmethod
def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
Expand Down Expand Up @@ -272,14 +278,110 @@ def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]
data = resp.get("data", {})

for tenant_id, plan in data.items():
subscription_plan = subscription_adapter.validate_python(plan)
results[tenant_id] = subscription_plan
try:
subscription_plan = subscription_adapter.validate_python(plan)
results[tenant_id] = subscription_plan
except Exception:
logger.exception(
"get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id
)
continue
except Exception:
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk)
continue

return results

@classmethod
def _make_plan_cache_key(cls, tenant_id: str) -> str:
return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}"

@classmethod
def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
"""
Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios.

NOTE: if you want to high data consistency, use get_plan_bulk instead.

Returns:
Mapping of tenant_id -> {plan: str, expiration_date: int}
"""
tenant_plans: dict[str, SubscriptionPlan] = {}

if not tenant_ids:
return tenant_plans

subscription_adapter = TypeAdapter(SubscriptionPlan)

# Step 1: Batch fetch from Redis cache using mget
redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids]
try:
cached_values = redis_client.mget(redis_keys)

if len(cached_values) != len(tenant_ids):
raise Exception(
"get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch"
)

# Map cached values back to tenant_ids
cache_misses: list[str] = []

for tenant_id, cached_value in zip(tenant_ids, cached_values):
if cached_value:
try:
# Redis returns bytes, decode to string and parse JSON
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
plan_dict = json.loads(json_str)
subscription_plan = subscription_adapter.validate_python(plan_dict)
tenant_plans[tenant_id] = subscription_plan
except Exception:
logger.exception(
"get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id
)
cache_misses.append(tenant_id)
else:
cache_misses.append(tenant_id)

logger.info(
"get_plan_bulk_with_cache: cache hits=%s, cache misses=%s",
len(tenant_plans),
len(cache_misses),
)
except Exception:
logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API")
cache_misses = list(tenant_ids)

# Step 2: Fetch missing plans from billing API
if cache_misses:
bulk_plans = BillingService.get_plan_bulk(cache_misses)

if bulk_plans:
plans_to_cache: dict[str, SubscriptionPlan] = {}

for tenant_id, subscription_plan in bulk_plans.items():
tenant_plans[tenant_id] = subscription_plan
plans_to_cache[tenant_id] = subscription_plan

# Step 3: Batch update Redis cache using pipeline
if plans_to_cache:
try:
pipe = redis_client.pipeline()
for tenant_id, subscription_plan in plans_to_cache.items():
redis_key = cls._make_plan_cache_key(tenant_id)
# Serialize dict to JSON string
json_str = json.dumps(subscription_plan)
pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str)
pipe.execute()

logger.info(
"get_plan_bulk_with_cache: cached %s new tenant plans to Redis",
len(plans_to_cache),
)
except Exception:
logger.exception("get_plan_bulk_with_cache: redis pipeline failed")

return tenant_plans

@classmethod
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
Expand Down
Loading
Loading