1919from typing import AsyncIterator , Dict , Optional , Set , Tuple , Union
2020
2121import uvloop
22- from fastapi import APIRouter , FastAPI , HTTPException , Request
22+ from fastapi import APIRouter , Depends , FastAPI , HTTPException , Request
2323from fastapi .exceptions import RequestValidationError
2424from fastapi .middleware .cors import CORSMiddleware
2525from fastapi .responses import JSONResponse , Response , StreamingResponse
@@ -252,6 +252,15 @@ def _cleanup_ipc_path():
252252 multiprocess .mark_process_dead (engine_process .pid )
253253
254254
255+ async def validate_json_request (raw_request : Request ):
256+ content_type = raw_request .headers .get ("content-type" , "" ).lower ()
257+ if content_type != "application/json" :
258+ raise HTTPException (
259+ status_code = HTTPStatus .UNSUPPORTED_MEDIA_TYPE ,
260+ detail = "Unsupported Media Type: Only 'application/json' is allowed"
261+ )
262+
263+
255264router = APIRouter ()
256265
257266
@@ -335,7 +344,7 @@ async def ping(raw_request: Request) -> Response:
335344 return await health (raw_request )
336345
337346
338- @router .post ("/tokenize" )
347+ @router .post ("/tokenize" , dependencies = [ Depends ( validate_json_request )] )
339348@with_cancellation
340349async def tokenize (request : TokenizeRequest , raw_request : Request ):
341350 handler = tokenization (raw_request )
@@ -350,7 +359,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
350359 assert_never (generator )
351360
352361
353- @router .post ("/detokenize" )
362+ @router .post ("/detokenize" , dependencies = [ Depends ( validate_json_request )] )
354363@with_cancellation
355364async def detokenize (request : DetokenizeRequest , raw_request : Request ):
356365 handler = tokenization (raw_request )
@@ -379,7 +388,8 @@ async def show_version():
379388 return JSONResponse (content = ver )
380389
381390
382- @router .post ("/v1/chat/completions" )
391+ @router .post ("/v1/chat/completions" ,
392+ dependencies = [Depends (validate_json_request )])
383393@with_cancellation
384394async def create_chat_completion (request : ChatCompletionRequest ,
385395 raw_request : Request ):
@@ -400,7 +410,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
400410 return StreamingResponse (content = generator , media_type = "text/event-stream" )
401411
402412
403- @router .post ("/v1/completions" )
413+ @router .post ("/v1/completions" , dependencies = [ Depends ( validate_json_request )] )
404414@with_cancellation
405415async def create_completion (request : CompletionRequest , raw_request : Request ):
406416 handler = completion (raw_request )
@@ -418,7 +428,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
418428 return StreamingResponse (content = generator , media_type = "text/event-stream" )
419429
420430
421- @router .post ("/v1/embeddings" )
431+ @router .post ("/v1/embeddings" , dependencies = [ Depends ( validate_json_request )] )
422432@with_cancellation
423433async def create_embedding (request : EmbeddingRequest , raw_request : Request ):
424434 handler = embedding (raw_request )
@@ -464,7 +474,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
464474 assert_never (generator )
465475
466476
467- @router .post ("/pooling" )
477+ @router .post ("/pooling" , dependencies = [ Depends ( validate_json_request )] )
468478@with_cancellation
469479async def create_pooling (request : PoolingRequest , raw_request : Request ):
470480 handler = pooling (raw_request )
@@ -482,7 +492,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
482492 assert_never (generator )
483493
484494
485- @router .post ("/score" )
495+ @router .post ("/score" , dependencies = [ Depends ( validate_json_request )] )
486496@with_cancellation
487497async def create_score (request : ScoreRequest , raw_request : Request ):
488498 handler = score (raw_request )
@@ -500,7 +510,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
500510 assert_never (generator )
501511
502512
503- @router .post ("/v1/score" )
513+ @router .post ("/v1/score" , dependencies = [ Depends ( validate_json_request )] )
504514@with_cancellation
505515async def create_score_v1 (request : ScoreRequest , raw_request : Request ):
506516 logger .warning (
@@ -510,7 +520,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
510520 return await create_score (request , raw_request )
511521
512522
513- @router .post ("/rerank" )
523+ @router .post ("/rerank" , dependencies = [ Depends ( validate_json_request )] )
514524@with_cancellation
515525async def do_rerank (request : RerankRequest , raw_request : Request ):
516526 handler = rerank (raw_request )
@@ -527,7 +537,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
527537 assert_never (generator )
528538
529539
530- @router .post ("/v1/rerank" )
540+ @router .post ("/v1/rerank" , dependencies = [ Depends ( validate_json_request )] )
531541@with_cancellation
532542async def do_rerank_v1 (request : RerankRequest , raw_request : Request ):
533543 logger .warning_once (
@@ -538,7 +548,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
538548 return await do_rerank (request , raw_request )
539549
540550
541- @router .post ("/v2/rerank" )
551+ @router .post ("/v2/rerank" , dependencies = [ Depends ( validate_json_request )] )
542552@with_cancellation
543553async def do_rerank_v2 (request : RerankRequest , raw_request : Request ):
544554 return await do_rerank (request , raw_request )
@@ -582,7 +592,7 @@ async def reset_prefix_cache(raw_request: Request):
582592 return Response (status_code = 200 )
583593
584594
585- @router .post ("/invocations" )
595+ @router .post ("/invocations" , dependencies = [ Depends ( validate_json_request )] )
586596async def invocations (raw_request : Request ):
587597 """
588598 For SageMaker, routes requests to other handlers based on model `task`.
@@ -632,7 +642,8 @@ async def stop_profile(raw_request: Request):
632642 "Lora dynamic loading & unloading is enabled in the API server. "
633643 "This should ONLY be used for local development!" )
634644
635- @router .post ("/v1/load_lora_adapter" )
645+ @router .post ("/v1/load_lora_adapter" ,
646+ dependencies = [Depends (validate_json_request )])
636647 async def load_lora_adapter (request : LoadLoraAdapterRequest ,
637648 raw_request : Request ):
638649 handler = models (raw_request )
@@ -643,7 +654,8 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest,
643654
644655 return Response (status_code = 200 , content = response )
645656
646- @router .post ("/v1/unload_lora_adapter" )
657+ @router .post ("/v1/unload_lora_adapter" ,
658+ dependencies = [Depends (validate_json_request )])
647659 async def unload_lora_adapter (request : UnloadLoraAdapterRequest ,
648660 raw_request : Request ):
649661 handler = models (raw_request )
0 commit comments