forked from lm-sys/FastChat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create a new worker implementation for Apple MLX (lm-sys#2937)
- Loading branch information
Showing
2 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Apple MLX Integration | ||
|
||
You can use [Apple MLX](https://github.com/ml-explore/mlx) as an optimized worker implementation in FastChat. | ||
|
||
It runs models efficiently on Apple Silicon | ||
|
||
See the supported models [here](https://github.com/ml-explore/mlx-examples/tree/main/llms#supported-models). | ||
|
||
Note that for Apple Silicon Macs with less memory, smaller models (or quantized models) are recommended. | ||
|
||
## Instructions | ||
|
||
1. Install MLX. | ||
|
||
``` | ||
pip install mlx-lm | ||
``` | ||
|
||
2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the MLX worker (`fastchat.serve.mlx_worker`). | ||
|
||
``` | ||
python3 -m fastchat.serve.mlx_worker --model-path microsoft/phi-2 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
""" | ||
A model worker using Apple MLX | ||
docs/mlx_integration.md | ||
https://github.com/ml-explore/mlx-examples/tree/main/llms | ||
Code based on vllm_worker https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py | ||
You must install MLX python: | ||
pip install mlx-lm | ||
""" | ||
|
||
import argparse | ||
import asyncio | ||
import atexit | ||
import json | ||
from typing import List | ||
import uuid | ||
|
||
from fastapi import FastAPI, Request, BackgroundTasks | ||
from fastapi.responses import StreamingResponse, JSONResponse | ||
import uvicorn | ||
|
||
from fastchat.serve.base_model_worker import BaseModelWorker | ||
from fastchat.serve.model_worker import ( | ||
logger, | ||
worker_id, | ||
) | ||
from fastchat.utils import get_context_length, is_partial_stop | ||
|
||
import mlx.core as mx | ||
from mlx_lm import load, generate | ||
from mlx_lm.utils import generate_step | ||
|
||
app = FastAPI() | ||
|
||
|
||
class MLXWorker(BaseModelWorker): | ||
def __init__( | ||
self, | ||
controller_addr: str, | ||
worker_addr: str, | ||
worker_id: str, | ||
model_path: str, | ||
model_names: List[str], | ||
limit_worker_concurrency: int, | ||
no_register: bool, | ||
llm_engine: "MLX", | ||
conv_template: str, | ||
): | ||
super().__init__( | ||
controller_addr, | ||
worker_addr, | ||
worker_id, | ||
model_path, | ||
model_names, | ||
limit_worker_concurrency, | ||
conv_template, | ||
) | ||
|
||
logger.info( | ||
f"Loading the model {self.model_names} on worker {worker_id}, worker type: MLX worker..." | ||
) | ||
|
||
self.model_name = model_path | ||
self.mlx_model, self.mlx_tokenizer = load(model_path) | ||
|
||
self.tokenizer = self.mlx_tokenizer | ||
# self.context_len = get_context_length( | ||
# llm_engine.engine.model_config.hf_config) | ||
self.context_len = 2048 # hard code for now -- not sure how to get in MLX | ||
|
||
if not no_register: | ||
self.init_heart_beat() | ||
|
||
async def generate_stream(self, params): | ||
self.call_ct += 1 | ||
|
||
context = params.pop("prompt") | ||
request_id = params.pop("request_id") | ||
temperature = float(params.get("temperature", 1.0)) | ||
top_p = float(params.get("top_p", 1.0)) | ||
top_k = params.get("top_k", -1.0) | ||
presence_penalty = float(params.get("presence_penalty", 0.0)) | ||
frequency_penalty = float(params.get("frequency_penalty", 0.0)) | ||
max_new_tokens = params.get("max_new_tokens", 256) | ||
stop_str = params.get("stop", None) | ||
stop_token_ids = params.get("stop_token_ids", None) or [] | ||
if self.tokenizer.eos_token_id is not None: | ||
stop_token_ids.append(self.tokenizer.eos_token_id) | ||
echo = params.get("echo", True) | ||
use_beam_search = params.get("use_beam_search", False) | ||
best_of = params.get("best_of", None) | ||
|
||
# Handle stop_str | ||
stop = set() | ||
if isinstance(stop_str, str) and stop_str != "": | ||
stop.add(stop_str) | ||
elif isinstance(stop_str, list) and stop_str != []: | ||
stop.update(stop_str) | ||
|
||
for tid in stop_token_ids: | ||
if tid is not None: | ||
s = self.tokenizer.decode(tid) | ||
if s != "": | ||
stop.add(s) | ||
|
||
print("Stop patterns: ", stop) | ||
|
||
top_p = max(top_p, 1e-5) | ||
if temperature <= 1e-5: | ||
top_p = 1.0 | ||
|
||
tokens = [] | ||
skip = 0 | ||
|
||
context_mlx = mx.array(self.tokenizer.encode(context)) | ||
|
||
finish_reason = "length" | ||
|
||
for token, _ in zip( | ||
generate_step(context_mlx, self.mlx_model, temperature), | ||
range(max_new_tokens), | ||
): | ||
if token == self.mlx_tokenizer.eos_token_id: | ||
finish_reason = "stop" | ||
break | ||
tokens.append(token.item()) | ||
tokens_decoded = self.mlx_tokenizer.decode(tokens) | ||
last_token_decoded = self.mlx_tokenizer.decode([token.item()]) | ||
skip = len(tokens_decoded) | ||
|
||
partial_stop = any(is_partial_stop(tokens_decoded, i) for i in stop) | ||
|
||
if partial_stop: | ||
finish_reason = "stop" | ||
break | ||
|
||
ret = { | ||
"text": tokens_decoded, | ||
"error_code": 0, | ||
"usage": { | ||
"prompt_tokens": len(context), | ||
"completion_tokens": len(tokens), | ||
"total_tokens": len(context) + len(tokens), | ||
}, | ||
"cumulative_logprob": [], | ||
"finish_reason": None, # hard code for now | ||
} | ||
# print(ret) | ||
yield (json.dumps(ret) + "\0").encode() | ||
ret = { | ||
"text": self.mlx_tokenizer.decode(tokens), | ||
"error_code": 0, | ||
"usage": {}, | ||
"cumulative_logprob": [], | ||
"finish_reason": finish_reason, | ||
} | ||
yield (json.dumps(obj={**ret, **{"finish_reason": None}}) + "\0").encode() | ||
yield (json.dumps(ret) + "\0").encode() | ||
|
||
async def generate(self, params): | ||
async for x in self.generate_stream(params): | ||
pass | ||
return json.loads(x[:-1].decode()) | ||
|
||
|
||
def release_worker_semaphore(): | ||
worker.semaphore.release() | ||
|
||
|
||
def acquire_worker_semaphore(): | ||
if worker.semaphore is None: | ||
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) | ||
return worker.semaphore.acquire() | ||
|
||
|
||
def create_background_tasks(request_id): | ||
async def abort_request() -> None: | ||
print("trying to abort but not implemented") | ||
|
||
background_tasks = BackgroundTasks() | ||
background_tasks.add_task(release_worker_semaphore) | ||
background_tasks.add_task(abort_request) | ||
return background_tasks | ||
|
||
|
||
@app.post("/worker_generate_stream") | ||
async def api_generate_stream(request: Request): | ||
params = await request.json() | ||
await acquire_worker_semaphore() | ||
request_id = uuid.uuid4() | ||
params["request_id"] = str(request_id) | ||
generator = worker.generate_stream(params) | ||
background_tasks = create_background_tasks(request_id) | ||
return StreamingResponse(generator, background=background_tasks) | ||
|
||
|
||
@app.post("/worker_generate") | ||
async def api_generate(request: Request): | ||
params = await request.json() | ||
await acquire_worker_semaphore() | ||
request_id = uuid.uuid4() | ||
params["request_id"] = str(request_id) | ||
output = await worker.generate(params) | ||
release_worker_semaphore() | ||
# await engine.abort(request_id) | ||
print("Trying to abort but not implemented") | ||
return JSONResponse(output) | ||
|
||
|
||
@app.post("/worker_get_status") | ||
async def api_get_status(request: Request): | ||
return worker.get_status() | ||
|
||
|
||
@app.post("/count_token") | ||
async def api_count_token(request: Request): | ||
params = await request.json() | ||
return worker.count_token(params) | ||
|
||
|
||
@app.post("/worker_get_conv_template") | ||
async def api_get_conv(request: Request): | ||
return worker.get_conv_template() | ||
|
||
|
||
@app.post("/model_details") | ||
async def api_model_details(request: Request): | ||
return {"context_length": worker.context_len} | ||
|
||
|
||
worker = None | ||
|
||
|
||
def cleanup_at_exit(): | ||
global worker | ||
print("Cleaning up...") | ||
del worker | ||
|
||
|
||
atexit.register(cleanup_at_exit) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--host", type=str, default="localhost") | ||
parser.add_argument("--port", type=int, default=21002) | ||
parser.add_argument("--worker-address", type=str, default="http://localhost:21002") | ||
parser.add_argument( | ||
"--controller-address", type=str, default="http://localhost:21001" | ||
) | ||
parser.add_argument("--model-path", type=str, default="microsoft/phi-2") | ||
parser.add_argument( | ||
"--model-names", | ||
type=lambda s: s.split(","), | ||
help="Optional display comma separated names", | ||
) | ||
parser.add_argument( | ||
"--conv-template", type=str, default=None, help="Conversation prompt template." | ||
) | ||
parser.add_argument( | ||
"--trust_remote_code", | ||
action="store_false", | ||
default=True, | ||
help="Trust remote code (e.g., from HuggingFace) when" | ||
"downloading the model and tokenizer.", | ||
) | ||
|
||
args, unknown = parser.parse_known_args() | ||
|
||
if args.model_path: | ||
args.model = args.model_path | ||
|
||
worker = MLXWorker( | ||
args.controller_address, | ||
args.worker_address, | ||
worker_id, | ||
args.model_path, | ||
args.model_names, | ||
1024, | ||
False, | ||
"MLX", | ||
args.conv_template, | ||
) | ||
uvicorn.run(app, host=args.host, port=args.port, log_level="info") |