diff --git a/tt-metal-yolov4/server/fast_api_yolov4.py b/tt-metal-yolov4/server/fast_api_yolov4.py index 1e759ae8..6f26ab0e 100644 --- a/tt-metal-yolov4/server/fast_api_yolov4.py +++ b/tt-metal-yolov4/server/fast_api_yolov4.py @@ -4,6 +4,7 @@ import os import logging from fastapi import FastAPI, File, HTTPException, Request, status, UploadFile +from fastapi.responses import JSONResponse from functools import wraps from io import BytesIO import jwt @@ -63,6 +64,8 @@ def load_class_names(namesfile): class_names = load_class_names(file_path) global model + global ready + ready = False if ("WH_ARCH_YAML" in os.environ) and os.environ[ "WH_ARCH_YAML" ] == "wormhole_b0_80_arch_eth_dispatch.yaml": @@ -92,6 +95,17 @@ def load_class_names(namesfile): ttnn.enable_program_cache(device) model = Yolov4Trace2CQ() model.initialize_yolov4_trace_2cqs_inference(device) + ready = True + + +@app.get("/health") +async def health_check(): + if not ready: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Server is not ready yet", + ) + return JSONResponse(content={"message": "OK\n"}, status_code=status.HTTP_200_OK) @app.on_event("shutdown") @@ -259,6 +273,7 @@ async def objdetection_v2(request: Request, file: UploadFile = File(...)): contents = await file.read() # Load and convert the image to RGB image = Image.open(BytesIO(contents)).convert("RGB") + image = image.resize((320, 320)) # Resize to target dimensions image = np.array(image) if isinstance(image, np.ndarray) and len(image.shape) == 3: # cv2 image image = torch.from_numpy(image).float().div(255.0).unsqueeze(0)