Skip to content

[Feature] Simple API token authentication #1106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 3 additions & 29 deletions docs/source/getting_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,38 +63,10 @@ Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM

The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.


API Server
----------

vLLM can be deployed as an LLM service. We provide an example `FastAPI <https://fastapi.tiangolo.com/>`_ server. Check `vllm/entrypoints/api_server.py <https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py>`_ for the server implementation. The server uses ``AsyncLLMEngine`` class to support asynchronous processing of incoming requests.

Start the server:

.. code-block:: console

$ python -m vllm.entrypoints.api_server

By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model.

Query the model in shell:

.. code-block:: console

$ curl http://localhost:8000/generate \
$ -d '{
$ "prompt": "San Francisco is a",
$ "use_beam_search": true,
$ "n": 4,
$ "temperature": 0
$ }'

See `examples/api_client.py <https://github.com/vllm-project/vllm/blob/main/examples/api_client.py>`_ for a more detailed client example.

OpenAI-Compatible Server
------------------------

vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
vLLM can be deployed as a server that implements the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the command below) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.

Start the server:
Expand All @@ -118,6 +90,8 @@ This server can be queried in the same format as OpenAI API. For example, list t

$ curl http://localhost:8000/v1/models

You can pass in the argument ``--api-key`` or environment variable ``VLLM_API_KEY`` to enable the server to check for API key in the header.

Using OpenAI Completions API with vLLM
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
45 changes: 45 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import asyncio
import json
from contextlib import asynccontextmanager
import os
import importlib
import inspect

from aioprometheus import MetricsMiddleware
from aioprometheus.asgi.starlette import metrics
import fastapi
Expand Down Expand Up @@ -64,6 +68,13 @@ def parse_args():
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument(
"--api-key",
type=str,
default=None,
help=
"If provided, the server will require this key to be presented in the header."
)
parser.add_argument("--served-model-name",
type=str,
default=None,
Expand Down Expand Up @@ -94,6 +105,17 @@ def parse_args():
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server using app.add_middleware(). "
)

parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args()
Expand Down Expand Up @@ -161,6 +183,29 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
allow_headers=args.allowed_headers,
)

if token := os.environ.get("VLLM_API_KEY") or args.api_key:

@app.middleware("http")
async def authentication(request: Request, call_next):
if not request.url.path.startswith("/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)

for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(
f"Invalid middleware {middleware}. Must be a function or a class."
)

logger.info(f"args: {args}")

if args.served_model_name is not None:
Expand Down