Skip to content

Commit

Permalink
Create a new worker implementation for Apple MLX (lm-sys#2937)
Browse files Browse the repository at this point in the history
  • Loading branch information
aliasaria authored and zhanghao.smooth committed Jan 26, 2024
1 parent 58805b3 commit 141dabe
Show file tree
Hide file tree
Showing 2 changed files with 309 additions and 0 deletions.
23 changes: 23 additions & 0 deletions docs/mlx_integration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Apple MLX Integration

You can use [Apple MLX](https://github.com/ml-explore/mlx) as an optimized worker implementation in FastChat.

It runs models efficiently on Apple Silicon

See the supported models [here](https://github.com/ml-explore/mlx-examples/tree/main/llms#supported-models).

Note that for Apple Silicon Macs with less memory, smaller models (or quantized models) are recommended.

## Instructions

1. Install MLX.

```
pip install mlx-lm
```

2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the MLX worker (`fastchat.serve.mlx_worker`).

```
python3 -m fastchat.serve.mlx_worker --model-path microsoft/phi-2
```
286 changes: 286 additions & 0 deletions fastchat/serve/mlx_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
"""
A model worker using Apple MLX
docs/mlx_integration.md
https://github.com/ml-explore/mlx-examples/tree/main/llms
Code based on vllm_worker https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py
You must install MLX python:
pip install mlx-lm
"""

import argparse
import asyncio
import atexit
import json
from typing import List
import uuid

from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn

from fastchat.serve.base_model_worker import BaseModelWorker
from fastchat.serve.model_worker import (
logger,
worker_id,
)
from fastchat.utils import get_context_length, is_partial_stop

import mlx.core as mx
from mlx_lm import load, generate
from mlx_lm.utils import generate_step

app = FastAPI()


class MLXWorker(BaseModelWorker):
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
no_register: bool,
llm_engine: "MLX",
conv_template: str,
):
super().__init__(
controller_addr,
worker_addr,
worker_id,
model_path,
model_names,
limit_worker_concurrency,
conv_template,
)

logger.info(
f"Loading the model {self.model_names} on worker {worker_id}, worker type: MLX worker..."
)

self.model_name = model_path
self.mlx_model, self.mlx_tokenizer = load(model_path)

self.tokenizer = self.mlx_tokenizer
# self.context_len = get_context_length(
# llm_engine.engine.model_config.hf_config)
self.context_len = 2048 # hard code for now -- not sure how to get in MLX

if not no_register:
self.init_heart_beat()

async def generate_stream(self, params):
self.call_ct += 1

context = params.pop("prompt")
request_id = params.pop("request_id")
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = params.get("top_k", -1.0)
presence_penalty = float(params.get("presence_penalty", 0.0))
frequency_penalty = float(params.get("frequency_penalty", 0.0))
max_new_tokens = params.get("max_new_tokens", 256)
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
if self.tokenizer.eos_token_id is not None:
stop_token_ids.append(self.tokenizer.eos_token_id)
echo = params.get("echo", True)
use_beam_search = params.get("use_beam_search", False)
best_of = params.get("best_of", None)

# Handle stop_str
stop = set()
if isinstance(stop_str, str) and stop_str != "":
stop.add(stop_str)
elif isinstance(stop_str, list) and stop_str != []:
stop.update(stop_str)

for tid in stop_token_ids:
if tid is not None:
s = self.tokenizer.decode(tid)
if s != "":
stop.add(s)

print("Stop patterns: ", stop)

top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0

tokens = []
skip = 0

context_mlx = mx.array(self.tokenizer.encode(context))

finish_reason = "length"

for token, _ in zip(
generate_step(context_mlx, self.mlx_model, temperature),
range(max_new_tokens),
):
if token == self.mlx_tokenizer.eos_token_id:
finish_reason = "stop"
break
tokens.append(token.item())
tokens_decoded = self.mlx_tokenizer.decode(tokens)
last_token_decoded = self.mlx_tokenizer.decode([token.item()])
skip = len(tokens_decoded)

partial_stop = any(is_partial_stop(tokens_decoded, i) for i in stop)

if partial_stop:
finish_reason = "stop"
break

ret = {
"text": tokens_decoded,
"error_code": 0,
"usage": {
"prompt_tokens": len(context),
"completion_tokens": len(tokens),
"total_tokens": len(context) + len(tokens),
},
"cumulative_logprob": [],
"finish_reason": None, # hard code for now
}
# print(ret)
yield (json.dumps(ret) + "\0").encode()
ret = {
"text": self.mlx_tokenizer.decode(tokens),
"error_code": 0,
"usage": {},
"cumulative_logprob": [],
"finish_reason": finish_reason,
}
yield (json.dumps(obj={**ret, **{"finish_reason": None}}) + "\0").encode()
yield (json.dumps(ret) + "\0").encode()

async def generate(self, params):
async for x in self.generate_stream(params):
pass
return json.loads(x[:-1].decode())


def release_worker_semaphore():
worker.semaphore.release()


def acquire_worker_semaphore():
if worker.semaphore is None:
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
return worker.semaphore.acquire()


def create_background_tasks(request_id):
async def abort_request() -> None:
print("trying to abort but not implemented")

background_tasks = BackgroundTasks()
background_tasks.add_task(release_worker_semaphore)
background_tasks.add_task(abort_request)
return background_tasks


@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
params = await request.json()
await acquire_worker_semaphore()
request_id = uuid.uuid4()
params["request_id"] = str(request_id)
generator = worker.generate_stream(params)
background_tasks = create_background_tasks(request_id)
return StreamingResponse(generator, background=background_tasks)


@app.post("/worker_generate")
async def api_generate(request: Request):
params = await request.json()
await acquire_worker_semaphore()
request_id = uuid.uuid4()
params["request_id"] = str(request_id)
output = await worker.generate(params)
release_worker_semaphore()
# await engine.abort(request_id)
print("Trying to abort but not implemented")
return JSONResponse(output)


@app.post("/worker_get_status")
async def api_get_status(request: Request):
return worker.get_status()


@app.post("/count_token")
async def api_count_token(request: Request):
params = await request.json()
return worker.count_token(params)


@app.post("/worker_get_conv_template")
async def api_get_conv(request: Request):
return worker.get_conv_template()


@app.post("/model_details")
async def api_model_details(request: Request):
return {"context_length": worker.context_len}


worker = None


def cleanup_at_exit():
global worker
print("Cleaning up...")
del worker


atexit.register(cleanup_at_exit)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
parser.add_argument("--model-path", type=str, default="microsoft/phi-2")
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
help="Optional display comma separated names",
)
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument(
"--trust_remote_code",
action="store_false",
default=True,
help="Trust remote code (e.g., from HuggingFace) when"
"downloading the model and tokenizer.",
)

args, unknown = parser.parse_known_args()

if args.model_path:
args.model = args.model_path

worker = MLXWorker(
args.controller_address,
args.worker_address,
worker_id,
args.model_path,
args.model_names,
1024,
False,
"MLX",
args.conv_template,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

0 comments on commit 141dabe

Please sign in to comment.