Skip to content

Commit e6ff395

Browse files
committed
feat: add health endpoint and auto-generate pod_id
- Add /health endpoint to FastAPI server for load balancer health checks - Auto-generate pod_id when not set: WAVERLESS_POD_ID > DEVICE_ID > {endpoint}-{uuid}
1 parent 84c3560 commit e6ff395

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

src/wavespeed/config.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import sys
5+
import uuid
56
from typing import Optional
67

78
from ._config_module import install_config_module
@@ -85,16 +86,37 @@ def _detect_serverless_env() -> Optional[str]:
8586
"""
8687

8788
# Check for native Waverless environment
88-
if os.environ.get("WAVERLESS_POD_ID"):
89+
if os.environ.get("WAVERLESS_ENDPOINT_ID"):
8990
return "waverless"
9091

9192
# Check for RunPod environment
92-
if os.environ.get("RUNPOD_POD_ID"):
93+
if os.environ.get("RUNPOD_ENDPOINT_ID"):
9394
return "runpod"
9495

9596
return None
9697

9798

99+
def _generate_pod_id(endpoint_id: Optional[str], raw_pod_id: Optional[str]) -> str:
100+
"""Generate or resolve pod_id.
101+
102+
Priority: raw_pod_id > DEVICE_ID > auto-generate
103+
104+
Args:
105+
endpoint_id: The endpoint identifier.
106+
raw_pod_id: The raw pod_id from environment variable.
107+
108+
Returns:
109+
The resolved pod_id.
110+
"""
111+
if raw_pod_id:
112+
return raw_pod_id
113+
device_id = os.environ.get("DEVICE_ID")
114+
if device_id:
115+
return device_id
116+
prefix = endpoint_id or "worker"
117+
return f"{prefix}-{uuid.uuid4().hex}"
118+
119+
98120
def _resolve_runpod_url(url_template: Optional[str], pod_id: str) -> Optional[str]:
99121
"""Replace pod ID placeholder in RunPod URL template.
100122
@@ -127,8 +149,14 @@ def _resolve_waverless_url(url_template: Optional[str], pod_id: str) -> Optional
127149

128150
def _load_runpod_serverless_config() -> None:
129151
"""Load RunPod environment variables into serverless config."""
152+
# Endpoint identification (load first for pod_id generation)
153+
serverless.endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID")
154+
serverless.project_id = os.environ.get("RUNPOD_PROJECT_ID")
155+
130156
# Worker identification
131-
serverless.pod_id = os.environ.get("RUNPOD_POD_ID") or ""
157+
raw_pod_id = os.environ.get("RUNPOD_POD_ID")
158+
serverless.pod_id = _generate_pod_id(serverless.endpoint_id, raw_pod_id)
159+
serverless.pod_hostname = os.environ.get("RUNPOD_POD_HOSTNAME", serverless.pod_id)
132160

133161
# API endpoint templates
134162
serverless.webhook_get_job = os.environ.get("RUNPOD_WEBHOOK_GET_JOB")
@@ -163,11 +191,6 @@ def _load_runpod_serverless_config() -> None:
163191
log_level = os.environ.get("RUNPOD_DEBUG_LEVEL")
164192
serverless.log_level = log_level or "INFO"
165193

166-
# Endpoint identification
167-
serverless.endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID")
168-
serverless.project_id = os.environ.get("RUNPOD_PROJECT_ID")
169-
serverless.pod_hostname = os.environ.get("RUNPOD_POD_HOSTNAME")
170-
171194
# Timing and concurrency
172195
ping_interval = os.environ.get("RUNPOD_PING_INTERVAL")
173196
if ping_interval:
@@ -184,8 +207,17 @@ def _load_runpod_serverless_config() -> None:
184207

185208
def _load_waverless_serverless_config() -> None:
186209
"""Load Waverless environment variables into serverless config."""
210+
# Endpoint identification (load first for pod_id generation)
211+
serverless.endpoint_id = os.environ.get("WAVERLESS_ENDPOINT_ID")
212+
# Endpoint identification (endpoint_id already set above)
213+
serverless.project_id = os.environ.get("WAVERLESS_PROJECT_ID")
214+
187215
# Worker identification
188-
serverless.pod_id = os.environ.get("WAVERLESS_POD_ID") or ""
216+
raw_pod_id = os.environ.get("WAVERLESS_POD_ID")
217+
serverless.pod_id = _generate_pod_id(serverless.endpoint_id, raw_pod_id)
218+
serverless.pod_hostname = os.environ.get(
219+
"WAVERLESS_POD_HOSTNAME", serverless.pod_id
220+
)
189221

190222
# API endpoint templates
191223
serverless.webhook_get_job = os.environ.get("WAVERLESS_WEBHOOK_GET_JOB")
@@ -217,11 +249,6 @@ def _load_waverless_serverless_config() -> None:
217249
# Logging
218250
serverless.log_level = os.environ.get("WAVERLESS_LOG_LEVEL", "INFO")
219251

220-
# Endpoint identification
221-
serverless.endpoint_id = os.environ.get("WAVERLESS_ENDPOINT_ID")
222-
serverless.project_id = os.environ.get("WAVERLESS_PROJECT_ID")
223-
serverless.pod_hostname = os.environ.get("WAVERLESS_POD_HOSTNAME")
224-
225252
# Timing and concurrency
226253
ping_interval = os.environ.get("WAVERLESS_PING_INTERVAL")
227254
if ping_interval:

src/wavespeed/serverless/modules/fastapi.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,15 @@ def __init__(self, config: Dict[str, Any]):
220220
tags=["Status"],
221221
)
222222

223+
# Health check endpoint
224+
router.add_api_route(
225+
"/health",
226+
lambda: {"status": "ok"},
227+
methods=["GET"],
228+
summary="Health check",
229+
tags=["Status"],
230+
)
231+
223232
self.app.include_router(router)
224233

225234
def start(

0 commit comments

Comments
 (0)