Skip to content

Commit 27f1410

Browse files
authored
New weight loader without np copy (#52)
1 parent 4858f3b commit 27f1410

File tree

12 files changed

+289
-357
lines changed

12 files changed

+289
-357
lines changed

benchmark/benchmark_latency.py

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,53 +6,15 @@
66
import numpy as np
77
import torch
88

9-
from cacheflow.master.simple_frontend import SimpleFrontend
10-
from cacheflow.master.server import (Server, add_server_arguments,
11-
process_server_arguments,
12-
initialize_cluster)
9+
from cacheflow.master.server import (
10+
add_server_arguments, process_server_arguments,
11+
init_local_server_and_frontend_with_arguments)
1312
from cacheflow.sampling_params import SamplingParams
14-
from cacheflow.utils import get_gpu_memory, get_cpu_memory
1513

1614

1715
def main(args: argparse.Namespace):
18-
# TODO(zhuohan): Support pipeline parallelism.
19-
assert args.pipeline_parallel_size == 1, (
20-
'Pipeline parallelism is not supported yet.')
16+
server, frontend = init_local_server_and_frontend_with_arguments(args)
2117

22-
(num_nodes, num_devices_per_node, distributed_init_method,
23-
all_stage_devices) = (
24-
initialize_cluster(
25-
use_ray=args.use_ray,
26-
pipeline_parallel_size=args.pipeline_parallel_size,
27-
tensor_parallel_size=args.tensor_parallel_size))
28-
29-
# Create a server.
30-
server = Server(
31-
model=args.model,
32-
model_path=args.model_path,
33-
use_dummy_weights=args.use_dummy_weights,
34-
pipeline_parallel_size=args.pipeline_parallel_size,
35-
tensor_parallel_size=args.tensor_parallel_size,
36-
block_size=args.block_size,
37-
dtype=args.dtype,
38-
seed=args.seed,
39-
swap_space=args.swap_space,
40-
max_num_batched_tokens=args.max_num_batched_tokens,
41-
max_num_sequences=args.max_num_sequences,
42-
num_nodes=num_nodes,
43-
num_devices_per_node=num_devices_per_node,
44-
distributed_init_method=distributed_init_method,
45-
all_stage_devices=all_stage_devices,
46-
gpu_memory=get_gpu_memory(),
47-
cpu_memory=get_cpu_memory(),
48-
use_ray=args.use_ray,
49-
)
50-
51-
# Create a frontend.
52-
frontend = SimpleFrontend(
53-
model_name=args.model,
54-
block_size=args.block_size,
55-
)
5618
sampling_params_dict = {
5719
'n': args.n,
5820
'temperature': 0.0 if args.use_beam_search else 1.0,

benchmark/benchmark_text_completion.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,57 +9,18 @@
99
from transformers import AutoConfig
1010

1111
from benchmark.trace import generate_text_completion_requests
12-
from cacheflow.master.simple_frontend import SimpleFrontend
13-
from cacheflow.master.server import (Server, add_server_arguments,
14-
process_server_arguments,
15-
initialize_cluster)
12+
from cacheflow.master.server import (
13+
add_server_arguments, process_server_arguments,
14+
init_local_server_and_frontend_with_arguments)
1615
from cacheflow.sampling_params import SamplingParams
17-
from cacheflow.utils import get_gpu_memory, get_cpu_memory
1816

1917

2018
logger = logging.getLogger(__name__)
2119

2220

2321
def main(args: argparse.Namespace):
24-
assert args.pipeline_parallel_size == 1, (
25-
'Pipeline parallelism is not supported yet.')
22+
server, frontend = init_local_server_and_frontend_with_arguments(args)
2623

27-
(num_nodes, num_devices_per_node, distributed_init_method,
28-
all_stage_devices) = (
29-
initialize_cluster(
30-
use_ray=args.use_ray,
31-
pipeline_parallel_size=args.pipeline_parallel_size,
32-
tensor_parallel_size=args.tensor_parallel_size))
33-
34-
# Create a server.
35-
server = Server(
36-
model=args.model,
37-
model_path=args.model_path,
38-
use_dummy_weights=args.use_dummy_weights,
39-
pipeline_parallel_size=args.pipeline_parallel_size,
40-
tensor_parallel_size=args.tensor_parallel_size,
41-
block_size=args.block_size,
42-
dtype=args.dtype,
43-
seed=args.seed,
44-
swap_space=args.swap_space,
45-
max_num_batched_tokens=args.max_num_batched_tokens,
46-
max_num_sequences=args.max_num_sequences,
47-
num_nodes=num_nodes,
48-
num_devices_per_node=num_devices_per_node,
49-
distributed_init_method=distributed_init_method,
50-
all_stage_devices=all_stage_devices,
51-
gpu_memory=get_gpu_memory(),
52-
cpu_memory=get_cpu_memory(),
53-
use_ray=args.use_ray,
54-
collect_stats=True,
55-
do_memory_analysis=args.do_memory_analysis,
56-
)
57-
58-
# Create a frontend.
59-
frontend = SimpleFrontend(
60-
model_name=args.model,
61-
block_size=args.block_size,
62-
)
6324
# Generate requests.
6425
requests = generate_text_completion_requests(
6526
args.dataset,

cacheflow/http_frontend/fastapi_frontend.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
import asyncio
33
import time
4-
from typing import List, Dict
4+
from typing import List, Dict, Optional
55
import json
66

77
import ray
@@ -22,11 +22,12 @@
2222
app = FastAPI()
2323

2424

25-
class FastAPIFrontend:
25+
class FastAPIServer:
2626
def __init__(
2727
self,
2828
model: str,
29-
model_path: str,
29+
cache_dir: Optional[str],
30+
use_np_cache: bool,
3031
pipeline_parallel_size: int,
3132
tensor_parallel_size: int,
3233
block_size: int,
@@ -52,8 +53,9 @@ def __init__(
5253
remote_server_class = ray.remote(num_gpus=1)(Server)
5354
self.server = remote_server_class.remote(
5455
model=model,
55-
model_path=model_path,
56+
cache_dir=cache_dir,
5657
use_dummy_weights=False,
58+
use_np_cache=use_np_cache,
5759
pipeline_parallel_size=pipeline_parallel_size,
5860
tensor_parallel_size=tensor_parallel_size,
5961
block_size=block_size,
@@ -148,7 +150,7 @@ async def generate(self, request_dict: Dict):
148150
@app.post("/generate")
149151
async def generate_stream(request: Request):
150152
request_dict = await request.json()
151-
return StreamingResponse(frontend.generate(request_dict))
153+
return StreamingResponse(server.generate(request_dict))
152154

153155

154156
if __name__ == "__main__":
@@ -170,9 +172,10 @@ async def generate_stream(request: Request):
170172
pipeline_parallel_size=args.pipeline_parallel_size,
171173
tensor_parallel_size=args.tensor_parallel_size))
172174

173-
frontend = FastAPIFrontend(
175+
server = FastAPIServer(
174176
model=args.model,
175-
model_path=args.model_path,
177+
cache_dir=args.cache_dir,
178+
use_np_cache=args.use_np_cache,
176179
pipeline_parallel_size=args.pipeline_parallel_size,
177180
tensor_parallel_size=args.tensor_parallel_size,
178181
block_size=args.block_size,

cacheflow/master/server.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,21 @@
99
ray = None
1010

1111
from cacheflow.master.scheduler import Scheduler
12+
from cacheflow.master.simple_frontend import SimpleFrontend
1213
from cacheflow.models import get_memory_analyzer
1314
from cacheflow.worker.controller import Controller, DeviceID
1415
from cacheflow.sequence import SequenceGroup
1516
from cacheflow.sampling_params import SamplingParams
17+
from cacheflow.utils import get_gpu_memory, get_cpu_memory
1618

1719

1820
class Server:
1921
def __init__(
2022
self,
2123
model: str,
22-
model_path: str,
24+
cache_dir: Optional[str],
2325
use_dummy_weights: bool,
26+
use_np_cache: bool,
2427
pipeline_parallel_size: int,
2528
tensor_parallel_size: int,
2629
block_size: int,
@@ -78,8 +81,9 @@ def __init__(
7881
num_cpu_blocks=self.num_cpu_blocks,
7982
dtype=dtype,
8083
seed=seed,
81-
model_path=model_path,
84+
cache_dir=cache_dir,
8285
use_dummy_weights=use_dummy_weights,
86+
use_np_cache=use_np_cache,
8387
max_num_batched_tokens=max_num_batched_tokens,
8488
use_ray=use_ray,
8589
)
@@ -203,25 +207,72 @@ def initialize_cluster(
203207
def add_server_arguments(parser: argparse.ArgumentParser):
204208
# Model arguments
205209
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
206-
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
207-
help='model path to download and load the weights')
210+
parser.add_argument('--cache-dir', type=str, default=None,
211+
help='cache dir to download and load the weights, '
212+
'default to the default cache dir of huggingface')
213+
parser.add_argument('--use-np-cache', action='store_true',
214+
help='save a numpy copy of model weights for faster loading')
215+
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
216+
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
217+
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
208218
# Parallel arguments
209219
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
210220
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
211221
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
212222
# KV cache arguments
213223
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
214-
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
215-
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
216224
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
217225
parser.add_argument('--seed', type=int, default=0, help='random seed')
218226
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
219227
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
220228
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
221-
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
222229
return parser
223230

231+
224232
def process_server_arguments(args: argparse.Namespace):
225233
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
226234
args.use_ray = True
227235
return args
236+
237+
238+
def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
239+
# TODO(zhuohan): Support pipeline parallelism.
240+
assert args.pipeline_parallel_size == 1, (
241+
'Pipeline parallelism is not supported yet.')
242+
243+
(num_nodes, num_devices_per_node, distributed_init_method,
244+
all_stage_devices) = (
245+
initialize_cluster(
246+
use_ray=args.use_ray,
247+
pipeline_parallel_size=args.pipeline_parallel_size,
248+
tensor_parallel_size=args.tensor_parallel_size))
249+
250+
# Create a server.
251+
server = Server(
252+
model=args.model,
253+
cache_dir=args.cache_dir,
254+
use_dummy_weights=args.use_dummy_weights,
255+
use_np_cache=args.use_np_cache,
256+
pipeline_parallel_size=args.pipeline_parallel_size,
257+
tensor_parallel_size=args.tensor_parallel_size,
258+
block_size=args.block_size,
259+
dtype=args.dtype,
260+
seed=args.seed,
261+
swap_space=args.swap_space,
262+
max_num_batched_tokens=args.max_num_batched_tokens,
263+
max_num_sequences=args.max_num_sequences,
264+
num_nodes=num_nodes,
265+
num_devices_per_node=num_devices_per_node,
266+
distributed_init_method=distributed_init_method,
267+
all_stage_devices=all_stage_devices,
268+
gpu_memory=get_gpu_memory(),
269+
cpu_memory=get_cpu_memory(),
270+
use_ray=args.use_ray,
271+
)
272+
273+
# Create a frontend.
274+
frontend = SimpleFrontend(
275+
model_name=args.model,
276+
block_size=args.block_size,
277+
)
278+
return server, frontend

cacheflow/models/gpt_neox.py

Lines changed: 15 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
"""1D GPT-NeoX model compatible with HuggingFace weights."""
2-
import os
3-
import glob
4-
import filelock
5-
from tqdm import tqdm
62
from typing import Dict, List, Optional, Tuple
73

8-
import numpy as np
94
import torch
105
from torch import nn
11-
from huggingface_hub import snapshot_download
126

137
from cacheflow.models import InputMetadata
148
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
159
from cacheflow.models.sample import Sampler
10+
from cacheflow.models.utils import (hf_model_weights_iterator,
11+
load_tensor_parallel_weights)
1612
from cacheflow.parallel_utils.parallel_state import (
1713
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
1814
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
@@ -196,17 +192,22 @@ def forward(
196192
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
197193
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
198194

199-
def load_weights(self, weights_path: str):
195+
def load_weights(self, model_name_or_path: str,
196+
cache_dir: Optional[str] = None,
197+
use_np_cache: bool = False):
200198
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
201199
state_dict = self.state_dict()
202-
for name, param in state_dict.items():
200+
for name, loaded_weight in hf_model_weights_iterator(
201+
model_name_or_path, cache_dir, use_np_cache):
202+
if ("attention.bias" in name or "attention.masked_bias" in name
203+
or "rotary_emb.inv_freq" in name):
204+
continue
205+
param = state_dict[name]
203206
if "query_key_value" in name:
204207
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
205208
# [num_heads * 3 * head_size, num_heads * head_size], while the
206209
# required shape is [3 * num_heads * head_size, num_heads * head_size].
207210
# Thus, we need weight conversion.
208-
loaded_weight = torch.from_numpy(
209-
np.load(os.path.join(weights_path, name)))
210211
shard_size = param.shape[0]
211212
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
212213
:shard_size * (tensor_model_parallel_rank + 1)]
@@ -223,55 +224,10 @@ def load_weights(self, weights_path: str):
223224
loaded_weight = loaded_weight.transpose(0, 1)
224225
loaded_weight = loaded_weight.reshape(-1).contiguous()
225226
else:
226-
assert False
227-
else:
228-
loaded_weight = torch.from_numpy(
229-
np.load(os.path.join(weights_path, name)))
230-
for p in self._column_parallel_weights:
231-
if p in name:
232-
shard_size = param.shape[0]
233-
loaded_weight = loaded_weight[
234-
shard_size * tensor_model_parallel_rank
235-
:shard_size * (tensor_model_parallel_rank + 1)]
236-
break
237-
for p in self._row_parallel_weights:
238-
if p in name:
239-
shard_size = param.shape[1]
240-
loaded_weight = loaded_weight[
241-
:,
242-
shard_size * tensor_model_parallel_rank
243-
:shard_size * (tensor_model_parallel_rank + 1)]
244-
break
245-
246-
assert param.shape == loaded_weight.shape
247-
param.data.copy_(loaded_weight)
248-
249-
@staticmethod
250-
def get_weights(model_name: str, path: str):
251-
path = os.path.join(path, f"{model_name}-np")
252-
path = os.path.abspath(os.path.expanduser(path))
253-
os.makedirs(path, exist_ok=True)
254-
lock_path = os.path.join(path, "file_lock")
255-
lock = filelock.FileLock(lock_path)
256-
257-
with lock:
258-
test_weight_path = os.path.join(
259-
path, "gpt_neox.embed_in.weight")
260-
if os.path.exists(test_weight_path):
261-
return path
262-
263-
folder = snapshot_download(model_name, allow_patterns="*.bin",
264-
cache_dir=os.path.join(path, "cache"))
265-
bin_files = glob.glob(os.path.join(folder, "*.bin"))
266-
267-
for bin_file in tqdm(bin_files, desc="Convert format"):
268-
state = torch.load(bin_file, map_location="cpu")
269-
for name, param in tqdm(state.items(), leave=False):
270-
param_path = os.path.join(path, name)
271-
with open(param_path, "wb") as f:
272-
np.save(f, param.cpu().detach().numpy())
273-
274-
return path
227+
raise ValueError(f"Unexpected weight name: {name}")
228+
load_tensor_parallel_weights(param, loaded_weight, name,
229+
self._column_parallel_weights,
230+
self._row_parallel_weights)
275231

276232
def initialize_dummy_weights(self) -> None:
277233
for param in self.state_dict().values():

0 commit comments

Comments
 (0)