Skip to content

Commit 3136d26

Browse files
committed
Fix status code: split pipeline load from input parsing
pipeline loading -> 500 input parsing -> 400 Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com>
1 parent 45907cd commit 3136d26

File tree

3 files changed

+36
-6
lines changed

3 files changed

+36
-6
lines changed

api_inference_community/routes.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
IMAGE,
1414
IMAGE_INPUTS,
1515
IMAGE_OUTPUTS,
16+
KNOWN_TASKS,
1617
ffmpeg_convert,
1718
normalize_payload,
1819
parse_accept,
@@ -88,6 +89,18 @@ def already_left(request: Request) -> bool:
8889
async def pipeline_route(request: Request) -> Response:
8990
start = time.time()
9091

92+
task = os.environ["TASK"]
93+
94+
# Shortcut: quickly check the task is in enum: no need to go any further otherwise, as we know for sure that
95+
# normalize_payload will fail below: this avoids us to wait for the pipeline to be loaded to return
96+
if task not in KNOWN_TASKS:
97+
msg = f"The task `{task}` is not recognized by api-inference-community"
98+
logger.error(msg)
99+
# Special case: despite the fact that the task comes from environment (which could be considered a service
100+
# config error, thus triggering a 500), this var indirectly comes from the user
101+
# so we choose to have a 400 here
102+
return JSONResponse({"error": msg}, status_code=400)
103+
91104
if os.getenv("DISCARD_LEFT", "0").lower() in [
92105
"1",
93106
"true",
@@ -97,16 +110,30 @@ async def pipeline_route(request: Request) -> Response:
97110
return Response(status_code=204)
98111

99112
payload = await request.body()
100-
task = os.environ["TASK"]
113+
101114
if os.getenv("DEBUG", "0") in {"1", "true"}:
102115
pipe = request.app.get_pipeline()
116+
103117
try:
104118
pipe = request.app.get_pipeline()
105119
try:
106120
sampling_rate = pipe.sampling_rate
107121
except Exception:
108122
sampling_rate = None
123+
if task in AUDIO_INPUTS:
124+
msg = f"Sampling rate is expected for model for audio task {task}"
125+
logger.error(msg)
126+
return JSONResponse({"error": msg}, status_code=500)
127+
except Exception as e:
128+
return JSONResponse({"error": str(e)}, status_code=500)
129+
130+
try:
109131
inputs, params = normalize_payload(payload, task, sampling_rate=sampling_rate)
132+
except EnvironmentError as e:
133+
# Since we catch the environment edge cases earlier above, this should not happen here anymore
134+
# harmless to keep it, just in case
135+
logger.error("Error while parsing input %s", e)
136+
return JSONResponse({"error": str(e)}, status_code=500)
110137
except ValidationError as e:
111138
errors = []
112139
for error in e.errors():
@@ -120,7 +147,9 @@ async def pipeline_route(request: Request) -> Response:
120147
)
121148
return JSONResponse({"error": errors}, status_code=400)
122149
except Exception as e:
123-
return JSONResponse({"error": str(e)}, status_code=500)
150+
# We assume the payload is bad -> 400
151+
logger.warning("Error while parsing input %s", e)
152+
return JSONResponse({"error": str(e)}, status_code=400)
124153

125154
accept = request.headers.get("accept", "")
126155
lora_adapter = request.headers.get("lora")

api_inference_community/validation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def check_inputs(inputs, tag):
218218
"zero-shot-classification",
219219
}
220220

221+
KNOWN_TASKS = AUDIO_INPUTS.union(IMAGE_INPUTS).union(TEXT_INPUTS)
221222

222223
AUDIO = [
223224
"flac",

tests/test_routes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def startup_event():
6262
self.assertEqual(response.headers["x-compute-characters"], "4")
6363
self.assertEqual(response.content, b'{"some":"json serializable"}')
6464

65-
def test_invalid_pipeline(self):
65+
def test_invalid_task(self):
6666
os.environ["TASK"] = "invalid"
6767

6868
class Pipeline:
@@ -99,15 +99,15 @@ async def startup_event():
9999

100100
self.assertEqual(
101101
response.status_code,
102-
500,
102+
400,
103103
)
104104
self.assertEqual(
105105
response.content,
106106
b'{"error":"The task `invalid` is not recognized by api-inference-community"}',
107107
)
108108

109-
def test_invalid_task(self):
110-
os.environ["TASK"] = "invalid"
109+
def test_invalid_pipeline(self):
110+
os.environ["TASK"] = "text-generation"
111111

112112
def get_pipeline():
113113
raise Exception("We cannot load the pipeline")

0 commit comments

Comments
 (0)