|
14 | 14 | from contextlib import asynccontextmanager
|
15 | 15 | from functools import partial
|
16 | 16 | from http import HTTPStatus
|
17 |
| -from typing import AsyncIterator, Optional, Set, Tuple |
| 17 | +from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union |
18 | 18 |
|
19 | 19 | import uvloop
|
20 | 20 | from fastapi import APIRouter, FastAPI, HTTPException, Request
|
@@ -420,6 +420,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
420 | 420 | "use the Pooling API (`/pooling`) instead.")
|
421 | 421 |
|
422 | 422 | res = await fallback_handler.create_pooling(request, raw_request)
|
| 423 | + |
| 424 | + generator: Union[ErrorResponse, EmbeddingResponse] |
423 | 425 | if isinstance(res, PoolingResponse):
|
424 | 426 | generator = EmbeddingResponse(
|
425 | 427 | id=res.id,
|
@@ -494,7 +496,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
494 | 496 | return await create_score(request, raw_request)
|
495 | 497 |
|
496 | 498 |
|
497 |
| -TASK_HANDLERS = { |
| 499 | +TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { |
498 | 500 | "generate": {
|
499 | 501 | "messages": (ChatCompletionRequest, create_chat_completion),
|
500 | 502 | "default": (CompletionRequest, create_completion),
|
@@ -652,7 +654,7 @@ async def add_request_id(request: Request, call_next):
|
652 | 654 | module_path, object_name = middleware.rsplit(".", 1)
|
653 | 655 | imported = getattr(importlib.import_module(module_path), object_name)
|
654 | 656 | if inspect.isclass(imported):
|
655 |
| - app.add_middleware(imported) |
| 657 | + app.add_middleware(imported) # type: ignore[arg-type] |
656 | 658 | elif inspect.iscoroutinefunction(imported):
|
657 | 659 | app.middleware("http")(imported)
|
658 | 660 | else:
|
|
0 commit comments