1919
2020import aiohttp
2121import fastapi
22- from fastapi import Depends , HTTPException , responses
22+ from fastapi import Depends , File , HTTPException , Request , responses , Form , UploadFile
2323from fastapi .exceptions import RequestValidationError
2424from fastapi .middleware .cors import CORSMiddleware
2525from fastapi .responses import StreamingResponse , JSONResponse , FileResponse
2626from fastapi .security .http import HTTPAuthorizationCredentials , HTTPBearer
2727import httpx
28+ import base64
2829
2930try :
3031 from pydantic .v1 import BaseSettings , validator
@@ -194,18 +195,18 @@ def create_error_response(code: int, message: str) -> JSONResponse:
194195
195196
196197@app .exception_handler (RequestValidationError )
197- async def validation_exception_handler (request , exc ):
198+ async def validation_exception_handler (request : Request , exc : RequestValidationError ):
198199 return create_error_response (ErrorCode .VALIDATION_TYPE_ERROR , str (exc ))
199200
200201
201- def check_model (request ) -> Optional [JSONResponse ]:
202+ def check_model (model : str ) -> Optional [JSONResponse ]:
202203 global model_address_map , models_
203204 ret = None
204205 models = models_
205- if request . model not in models_ :
206+ if model not in models_ :
206207 ret = create_error_response (
207208 ErrorCode .INVALID_MODEL ,
208- f"Only { '&&' .join (models )} allowed now, your model { request . model } " ,
209+ f"Only { '&&' .join (models )} allowed now, your model { model } " ,
209210 )
210211 return ret
211212
@@ -418,7 +419,7 @@ def get_model_address_map():
418419)
419420async def create_chat_completion (request : CustomChatCompletionRequest ):
420421 """Creates a completion for the chat message"""
421- error_check_ret = check_model (request )
422+ error_check_ret = check_model (request . model )
422423 if error_check_ret is not None :
423424 return error_check_ret
424425 worker_addr = get_worker_address (request .model )
@@ -554,7 +555,7 @@ async def chat_completion_stream_generator(
554555 response_class = responses .ORJSONResponse ,
555556)
556557async def create_completion (request : CompletionRequest ):
557- error_check_ret = check_model (request )
558+ error_check_ret = check_model (request . model )
558559 if error_check_ret is not None :
559560 return error_check_ret
560561
@@ -714,7 +715,6 @@ async def generate_completion(payload: Dict[str, Any], worker_addr: str):
714715 SpeechRequest ,
715716 OpenAISpeechRequest ,
716717 ImagesGenRequest ,
717- ImagesEditsRequest ,
718718)
719719
720720
@@ -729,17 +729,27 @@ async def get_images_edits(payload: Dict[str, Any]):
729729
730730
731731@app .post ("/v1/images/edits" , dependencies = [Depends (check_api_key )])
732- async def images_edits (request : ImagesEditsRequest ):
732+ async def images_edits (
733+ model : str = Form (...),
734+ image : UploadFile = File (media_type = "application/octet-stream" ),
735+ prompt : Optional [Union [str , List [str ]]] = Form (None ),
736+ # negative_prompt: Optional[Union[str, List[str]]] = Form(None),
737+ response_format : Optional [str ] = Form ("url" ),
738+ output_format : Optional [str ] = Form ("png" ),
739+ ):
733740 """图片编辑"""
734- error_check_ret = check_model (request )
741+
742+ error_check_ret = check_model (model )
735743 if error_check_ret is not None :
736744 return error_check_ret
737745 payload = {
738- "image" : request .image ,
739- "model" : request .model ,
740- "prompt" : request .prompt ,
741- "output_format" : request .output_format ,
742- "response_format" : request .response_format ,
746+ "image" : base64 .b64encode (await image .read ()).decode (
747+ "utf-8"
748+ ), # bytes → Base64 字符串,
749+ "model" : model ,
750+ "prompt" : prompt ,
751+ "output_format" : output_format ,
752+ "response_format" : response_format ,
743753 }
744754 result = await get_images_edits (payload = payload )
745755 return result
@@ -758,7 +768,7 @@ async def get_images_gen(payload: Dict[str, Any]):
758768@app .post ("/v1/images/generations" , dependencies = [Depends (check_api_key )])
759769async def images_generations (request : ImagesGenRequest ):
760770 """文生图"""
761- error_check_ret = check_model (request )
771+ error_check_ret = check_model (request . model )
762772 if error_check_ret is not None :
763773 return error_check_ret
764774 payload = {
@@ -877,10 +887,6 @@ async def get_transcriptions(payload: Dict[str, Any]):
877887 return json .loads (transcription )
878888
879889
880- from fastapi import UploadFile , Form
881- import base64
882-
883-
884890@app .post (
885891 "/v1/audio/transcriptions" ,
886892 dependencies = [Depends (check_api_key )],
@@ -915,7 +921,7 @@ async def transcriptions(file: UploadFile, model: str = Form()):
915921 response_class = responses .ORJSONResponse ,
916922)
917923async def classify (request : ModerationsRequest ):
918- error_check_ret = check_model (request )
924+ error_check_ret = check_model (request . model )
919925 if error_check_ret is not None :
920926 return error_check_ret
921927 request .input = process_input (request .model , request .input )
@@ -958,7 +964,7 @@ async def classify(request: ModerationsRequest):
958964 response_class = responses .ORJSONResponse ,
959965)
960966async def rerank (request : RerankRequest ):
961- error_check_ret = check_model (request )
967+ error_check_ret = check_model (request . model )
962968 if error_check_ret is not None :
963969 return error_check_ret
964970 request .documents = process_input (request .model , request .documents )
@@ -1009,7 +1015,7 @@ async def create_embeddings(request: CustomEmbeddingsRequest, model_name: str =
10091015 """Creates embeddings for the text"""
10101016 if request .model is None :
10111017 request .model = model_name
1012- error_check_ret = check_model (request )
1018+ error_check_ret = check_model (request . model )
10131019 if error_check_ret is not None :
10141020 return error_check_ret
10151021
@@ -1111,7 +1117,7 @@ async def count_tokens(request: APITokenCheckRequest):
11111117@app .post ("/api/v1/chat/completions" )
11121118async def create_chat_completion (request : APIChatCompletionRequest ):
11131119 """Creates a completion for the chat message"""
1114- error_check_ret = check_model (request )
1120+ error_check_ret = check_model (request . model )
11151121 if error_check_ret is not None :
11161122 return error_check_ret
11171123
0 commit comments