13
13
IMAGE ,
14
14
IMAGE_INPUTS ,
15
15
IMAGE_OUTPUTS ,
16
+ KNOWN_TASKS ,
16
17
ffmpeg_convert ,
17
18
normalize_payload ,
18
19
parse_accept ,
@@ -88,6 +89,18 @@ def already_left(request: Request) -> bool:
88
89
async def pipeline_route (request : Request ) -> Response :
89
90
start = time .time ()
90
91
92
+ task = os .environ ["TASK" ]
93
+
94
+ # Shortcut: quickly check the task is in enum: no need to go any further otherwise, as we know for sure that
95
+ # normalize_payload will fail below: this avoids us to wait for the pipeline to be loaded to return
96
+ if task not in KNOWN_TASKS :
97
+ msg = f"The task `{ task } ` is not recognized by api-inference-community"
98
+ logger .error (msg )
99
+ # Special case: despite the fact that the task comes from environment (which could be considered a service
100
+ # config error, thus triggering a 500), this var indirectly comes from the user
101
+ # so we choose to have a 400 here
102
+ return JSONResponse ({"error" : msg }, status_code = 400 )
103
+
91
104
if os .getenv ("DISCARD_LEFT" , "0" ).lower () in [
92
105
"1" ,
93
106
"true" ,
@@ -97,16 +110,30 @@ async def pipeline_route(request: Request) -> Response:
97
110
return Response (status_code = 204 )
98
111
99
112
payload = await request .body ()
100
- task = os . environ [ "TASK" ]
113
+
101
114
if os .getenv ("DEBUG" , "0" ) in {"1" , "true" }:
102
115
pipe = request .app .get_pipeline ()
116
+
103
117
try :
104
118
pipe = request .app .get_pipeline ()
105
119
try :
106
120
sampling_rate = pipe .sampling_rate
107
121
except Exception :
108
122
sampling_rate = None
123
+ if task in AUDIO_INPUTS :
124
+ msg = f"Sampling rate is expected for model for audio task { task } "
125
+ logger .error (msg )
126
+ return JSONResponse ({"error" : msg }, status_code = 500 )
127
+ except Exception as e :
128
+ return JSONResponse ({"error" : str (e )}, status_code = 500 )
129
+
130
+ try :
109
131
inputs , params = normalize_payload (payload , task , sampling_rate = sampling_rate )
132
+ except EnvironmentError as e :
133
+ # Since we catch the environment edge cases earlier above, this should not happen here anymore
134
+ # harmless to keep it, just in case
135
+ logger .error ("Error while parsing input %s" , e )
136
+ return JSONResponse ({"error" : str (e )}, status_code = 500 )
110
137
except ValidationError as e :
111
138
errors = []
112
139
for error in e .errors ():
@@ -120,7 +147,9 @@ async def pipeline_route(request: Request) -> Response:
120
147
)
121
148
return JSONResponse ({"error" : errors }, status_code = 400 )
122
149
except Exception as e :
123
- return JSONResponse ({"error" : str (e )}, status_code = 500 )
150
+ # We assume the payload is bad -> 400
151
+ logger .warning ("Error while parsing input %s" , e )
152
+ return JSONResponse ({"error" : str (e )}, status_code = 400 )
124
153
125
154
accept = request .headers .get ("accept" , "" )
126
155
lora_adapter = request .headers .get ("lora" )
0 commit comments