Skip to content

Commit cbe65a2

Browse files
rickyyxrkooo567
andauthored
[integration] Make llama3 8B work with integration (vllm-project#535)
Co-authored-by: sang <rkooo567@gmail.com>
1 parent 15a712d commit cbe65a2

File tree

7 files changed

+90
-80
lines changed

7 files changed

+90
-80
lines changed

.buildkite/ci/build_scratch.sh

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ rm -rf ${SCRATCH_DIR}
1717
git clone git@github.com:anyscale/scratchllm.git ${SCRATCH_DIR}
1818
pushd ${SCRATCH_DIR}
1919

20-
# TEMPORARY.
21-
git checkout -b ricky/pr-pybind origin/ricky/pr-pybind
20+
git checkout a10-deployment
2221

2322
echo "Build glog"
2423
git clone https://github.com/google/glog.git
@@ -28,30 +27,14 @@ cmake --build build
2827
sudo cmake --build build --target install
2928
popd
3029

31-
echo "Build sentencepiece"
32-
git clone https://github.com/google/sentencepiece.git
33-
pushd sentencepiece
34-
mkdir build
35-
cd build
36-
cmake ..
37-
make -j $(nproc)
38-
sudo make install
39-
sudo ldconfig -v
40-
popd
41-
42-
echo "Build tiktokencpp"
43-
git clone git@github.com:anyscale/tiktokencpp.git
44-
pushd tiktokencpp
45-
mkdir build
46-
cd build
47-
cmake ..
48-
make
49-
sudo make install
50-
popd
51-
5230
echo "Build scratchllm"
5331
# used for pybind.
54-
git submodule update --init --recursive
32+
chmod 700 setup_pybind.sh
33+
bash setup_pybind.sh
34+
5535
# TODO(sang): Support custom flags.
56-
make h=cuda t=f16 b=fullopt scratch_runner
36+
# SANG-TODO H100
37+
# make m=ll38b h=cuda t=f16 b=fullopt s=4 scratch_runner
38+
# SANG-TODO A10
39+
make m=ll38b h=cuda t=f16 b=fullopt s=1 scratch_runner
5740
popd

.buildkite/ci/build_wheel.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ sudo apt install -y cmake
2929

3030

3131
echo "~~~ :python: Building wheel for ${VLLM_PROJECT}@${GIT_COMMIT}"
32-
BUILD_BAZEL=1 python setup.py bdist_wheel
32+
# Build scratch together.
33+
ANYSCALE_USE_SCRATCH_LLM=1 BUILD_BAZEL=1 python setup.py bdist_wheel
3334

3435
VLLM_WHEEL=$(basename $(ls dist/*.whl))
3536
COMMIT_PATH="${S3_WHEEL_CACHE}/${VLLM_PROJECT}/${GIT_COMMIT}/${VLLM_WHEEL}"

setup.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,18 +213,25 @@ def build_extensions(self) -> None:
213213

214214
temp_dir_path = os.path.join(ROOT_DIR, self.build_temp)
215215
print("Build and install ScratchLLM.")
216-
print("Make sure to run ")
217216
subprocess.check_call(["chmod", "700", ".buildkite/ci/build_scratch.sh",])
218217
subprocess.check_call(["bash", ".buildkite/ci/build_scratch.sh", temp_dir_path])
219218
print("Copy .so file to vllm folder.")
220219
# TODO(sang): Support flexible .so file names.
221220
subprocess.check_call(["ls", f"{temp_dir_path}/scratchllm"])
222-
subprocess.check_call([
223-
"cp",
224-
"-f",
225-
f"{temp_dir_path}/scratchllm/scratch.cpython-39-x86_64-linux-gnu.so",
226-
os.path.join(ROOT_DIR, "vllm"),
227-
])
221+
# SANG-TODO: Support flexible models and shard size.
222+
scratch_so_files = [
223+
# SANG-TODO H100
224+
# "scratch-ll38b-s4-cuda-f16-fullopt.cpython-39-x86_64-linux-gnu.so",
225+
# SANG-TODO A10
226+
"scratch-ll38b-s1-cuda-f16-fullopt.cpython-39-x86_64-linux-gnu.so",
227+
]
228+
for shared_object_file in scratch_so_files:
229+
subprocess.check_call([
230+
"cp",
231+
"-f",
232+
f"{temp_dir_path}/scratchllm/{shared_object_file}",
233+
os.path.join(ROOT_DIR, "vllm"),
234+
])
228235
# Anyscale end
229236

230237

tests/basic_correctness/test_scratch_correctness.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
MODELS = [
1010
# "facebook/opt-125m",
1111
"meta-llama/Llama-2-7b-hf",
12+
# "meta-llama/Meta-Llama-3-8B",
1213
]
1314

14-
# assert USE_SCRATCH, ("ScratchLLM should be enabled to run a test. "
15-
# "Use ANYSCALE_USE_SCRATCH_LLM=1 pytest -vs "
16-
# "tests/basic_correctness/test_scratch_correctness.py")
15+
assert USE_SCRATCH, ("ScratchLLM should be enabled to run a test. "
16+
"Use ANYSCALE_USE_SCRATCH_LLM=1 pytest -vs "
17+
"tests/basic_correctness/test_scratch_correctness.py")
1718

1819

1920
@pytest.mark.parametrize("model", MODELS)

vllm/scratch_env.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@
22

33
SCRATCH_ENV_VAR = "ANYSCALE_USE_SCRATCH_LLM"
44
USE_SCRATCH = bool(int(os.getenv(SCRATCH_ENV_VAR, False)))
5+
SCRATCH_EXECUTABLE_PATH_ENV_VAR = "SCRATCH_EXECUTABLE_PATH"
6+
# SANG-TODO H100
7+
# SCRATCH_BUILD_PREFIX = "ll38b-s4-cuda-f16" # CHANGE THIS FOR DIFFERNT MODELS
8+
# SANG-TODO A10
9+
SCRATCH_BUILD_PREFIX = "ll38b-s1-cuda-f16" # CHANGE THIS FOR DIFFERNT MODELS
10+
SCRATCH_BUILD_TYPE = "fullopt" # We should remove this, this is needed because weights are the same for all builds types.
11+
SCRATCH_EXECUTABLE_PATH =os.getenv(SCRATCH_EXECUTABLE_PATH_ENV_VAR, f"./vllm/scratch-{SCRATCH_BUILD_PREFIX}-{SCRATCH_BUILD_TYPE}.cpython-39-x86_64-linux-gnu.so")
512
SCRATCH_WEIGHTS_BUCKET_NAME = "scratch-working-dirs"
6-
SCRATCH_WEIGHTS_PREFIX = "weights/llama-7b/ll27b-cuda-f16/"
13+
SCRATCH_WEIGHTS_PREFIX = f"staging_weights/{SCRATCH_BUILD_PREFIX}/"
714
SCRATCH_WEIGHTS_URI = f"s3://{SCRATCH_WEIGHTS_BUCKET_NAME}/{SCRATCH_WEIGHTS_PREFIX}"
815
SCRATCH_TMP_DIR = "/tmp/scratch/"
916
SCRATCH_WEIGHTS_PATH = "/tmp/scratch/"
10-
11-
if USE_SCRATCH:
12-
try:
13-
from vllm.scratch import ScratchAPI
14-
except ImportError:
15-
raise AssertionError(
16-
"Scratch API hasn't been built with vLLM properly. "
17-
"See https://docs.google.com/document/d/1O9VIfnhYai-gJ1TLlP-3SQ4wH5LqxafxYeEHmEIPD7Q/edit#heading=h.1j3ik15fr6mh"
18-
) # noqa

vllm/worker/scratch_model_runner.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import List, Optional, Set, Hashable
22
import time
3+
import importlib.util
4+
import sys
35

46
import torch
57
import torch.nn as nn
@@ -38,14 +40,23 @@
3840

3941
LLAMA_7B_VOCAB_SIZE = 32000
4042

41-
from vllm.scratch import ScratchAPI
42-
from vllm.scratch_env import (SCRATCH_TMP_DIR, SCRATCH_WEIGHTS_PREFIX,
43+
from vllm.scratch_env import (SCRATCH_EXECUTABLE_PATH, SCRATCH_TMP_DIR, SCRATCH_WEIGHTS_PREFIX,
4344
SCRATCH_WEIGHTS_BUCKET_NAME)
4445

4546
# SANG-TODO WORKS?
4647
MODEL_PARAMS_PATH = "/home/ray/default/weights"
4748

4849

50+
def import_scratch(path: Path):
51+
SCRATCH_MODULE_NAME = "scratch"
52+
logger.info(f"Importing scratch module from {path}")
53+
spec = importlib.util.spec_from_file_location(SCRATCH_MODULE_NAME, path.resolve())
54+
scratch = importlib.util.module_from_spec(spec)
55+
sys.modules[SCRATCH_MODULE_NAME] = scratch
56+
spec.loader.exec_module(scratch)
57+
return scratch
58+
59+
4960
class ScratchSession:
5061

5162
def __init__(self, scratch_session_id: int):
@@ -54,7 +65,7 @@ def __init__(self, scratch_session_id: int):
5465

5566
class ScratchLRUCache(LRUCache[ScratchSession]):
5667

57-
def __init__(self, capacity: int, scratch_api: ScratchAPI):
68+
def __init__(self, capacity: int, scratch_api):
5869
self._scratch_api = scratch_api
5970
super().__init__(capacity)
6071

@@ -77,7 +88,7 @@ class ScratchSessionManager:
7788
information to model runner in a few weeks.
7889
"""
7990

80-
def __init__(self, scratch_api: ScratchAPI, max_num_seqs: int):
91+
def __init__(self, scratch_api, max_num_seqs: int):
8192
# ScratchAPI used to create/delete sessions.
8293
self._scratch_api = scratch_api
8394
# Set capacity to max_num_seqs * 2 so that old sequences are
@@ -134,7 +145,7 @@ def __init__(
134145
self.pin_memory = is_pin_memory_available()
135146

136147
# Lazily initialized.
137-
self.scratch: ScratchAPI
148+
self.scratch: "ScratchAPI" # type: ignore
138149
# Scratch only returns embedding. We need to multiply it to lm_head
139150
# to get the final logits, and that happens in vLLM. In order to
140151
# do that, we create a torch module with lm_head weights loaded.
@@ -155,8 +166,10 @@ def _verify_scratch_config(self):
155166
"Vision model not supported")
156167
assert self.kv_cache_dtype == "auto", (
157168
"Currently, Scratch doesn't use kv cache.")
158-
assert "llama-2" in self.model_config.model.lower(), (
159-
"Only Llama 7B is supported.")
169+
# SANG-TODO Support only llama 2 and 3.
170+
assert ("llama-2" in self.model_config.model.lower()
171+
or "llama-3" in self.model_config.model.lower()), (
172+
"Only Llama 2 7B or llama 3 8B is supported.")
160173
assert self.lora_manager is None, ("lora is not supported.")
161174
assert self.model_config.enforce_eager is True, (
162175
"cuda graph is not needed for Scratch.")
@@ -171,7 +184,12 @@ def load_model(self) -> None:
171184
weights_dir = tmp_dir / "parameters"
172185
weights_dir.mkdir(exist_ok=True)
173186
# TODO(sang): Need to obtain this programmatically.
174-
download_dir = weights_dir / "ll27b-s1-cuda-f16-fullopt"
187+
# download_dir = weights_dir / "ll27b-s1-cuda-f16-fullopt"
188+
scratch_mod = import_scratch(Path(SCRATCH_EXECUTABLE_PATH))
189+
base_dir = str(weights_dir.resolve())
190+
self.scratch = scratch_mod.ScratchAPI(base_dir)
191+
scratch_subdir = self.scratch.get_param_subdir()
192+
download_dir = weights_dir / scratch_subdir
175193
download_dir.mkdir(exist_ok=True)
176194
download_dir_path = str(download_dir.absolute())
177195
self.load_config.download_dir = str(weights_dir.absolute())
@@ -190,7 +208,6 @@ def load_model(self) -> None:
190208
scheduler_config=self.scheduler_config,
191209
cache_config=self.cache_config,
192210
)
193-
self.scratch = ScratchAPI(str(weights_dir.absolute()))
194211
self.scratch.start()
195212
self._scratch_session_manager = ScratchSessionManager(
196213
self.scratch, self.scheduler_config.max_num_seqs)
@@ -223,7 +240,8 @@ def _download_scratch_weights(self, prefix: str, target_dir: str,
223240
dirs.append(k)
224241
next_token = results.get('NextContinuationToken')
225242
# Assume there's no subdirectories.
226-
assert len(dirs) == 1
243+
dirs = {p.rsplit("/", 1)[0] for p in files}
244+
assert len(dirs) == 1, dirs
227245

228246
# NOTE(sang): Versioning is not supported now. We assume the
229247
# weights are always the same.
@@ -285,8 +303,8 @@ def execute_model(
285303
self.device,
286304
self.pin_memory)
287305
return self._execute_and_vllm_sample(prefill_groups, decode_groups,
288-
input_tokens, session_ids,
289-
parent_ids, sampling_metadata)
306+
input_tokens, session_ids,
307+
parent_ids, sampling_metadata)
290308
# return self._execute_and_scratch_sample(
291309
# prefill_groups, decode_groups, input_tokens, session_ids, parent_ids)
292310

@@ -327,7 +345,7 @@ def _execute_and_vllm_sample(
327345
input_tokens_tensor = torch.tensor(input_tokens[i],
328346
device="cuda",
329347
dtype=torch.int)
330-
print(f"SANG-TODO {input_tokens_tensor=}")
348+
# print(f"SANG-TODO {input_tokens_tensor=}")
331349
assert input_tokens_tensor.is_contiguous()
332350
# print(f"SANG-TODO {input_tokens_tensor.shape=}")
333351

@@ -338,7 +356,7 @@ def _execute_and_vllm_sample(
338356
hidden_states_end_index = (len_prefix_before_this + len(input_tokens[i])) * self.model_config.get_hidden_size()
339357
# print(f"SANG-TODO {hidden_states_start_index=} {hidden_states_end_index=}")
340358
# print(f"SANG-TODO {hidden_states.shape=}")
341-
print(f"SANG-TODO {hidden_states[hidden_states_start_index: hidden_states_end_index].shape=}")
359+
# print(f"SANG-TODO {hidden_states[hidden_states_start_index: hidden_states_end_index].shape=}")
342360
assert hidden_states[hidden_states_start_index: hidden_states_end_index].is_contiguous()
343361
self.scratch.prefill(
344362
session_id,
@@ -363,9 +381,9 @@ def _execute_and_vllm_sample(
363381
hidden_states.data_ptr(),
364382
)
365383

366-
print(
367-
f"SANG-TODO forward takes {(time.time() - s)* 1000} ms. Batch size: {len(session_ids)=} is_prefill: {len(prefill_groups) > 0}"
368-
)
384+
# print(
385+
# f"SANG-TODO forward takes {(time.time() - s)* 1000} ms. Batch size: {len(session_ids)=} is_prefill: {len(prefill_groups) > 0}"
386+
# )
369387
# print(hidden_states)
370388
# print(f"SANG-TODO {hidden_states.shape=}")
371389
# Post process Scratch embeddings.
@@ -375,16 +393,16 @@ def _execute_and_vllm_sample(
375393
# is this expected?
376394
hidden_states = hidden_states.view(-1,
377395
self.model_config.get_hidden_size())
378-
if len(prefill_groups) > 0:
379-
print(f"SANG-TODO before norm {hidden_states=}")
380-
print(f"SANG-TODO {hidden_states.shape=}")
396+
# if len(prefill_groups) > 0:
397+
# print(f"SANG-TODO before norm {hidden_states=}")
398+
# print(f"SANG-TODO {hidden_states.shape=}")
381399
# Scratch doesn't apply rms norm in its output, so we should do it ourselves.
382400
# Residual is set to None because it is already added from Scratch output.
383401
hidden_states = self.model.norm(hidden_states, None)
384-
if len(prefill_groups) > 0:
385-
print(f"SANG-TODO norm weights: {self.model.norm.weight=}")
386-
print(f"SANG-TODO {hidden_states.shape=}")
387-
print(f"SANG-TODO after norm {hidden_states=}")
402+
# if len(prefill_groups) > 0:
403+
# print(f"SANG-TODO norm weights: {self.model.norm.weight=}")
404+
# print(f"SANG-TODO {hidden_states.shape=}")
405+
# print(f"SANG-TODO after norm {hidden_states=}")
388406
# print(f"{hidden_states.shape=}")
389407

390408
# SANG-TODO remove it. Hack. It will work once scrath returns embedding of all tokens correctly.
@@ -401,14 +419,14 @@ def _execute_and_vllm_sample(
401419
logits=logits,
402420
sampling_metadata=sampling_metadata,
403421
)
404-
if len(prefill_groups) > 0:
405-
print(
406-
f"SANG-TODO prefill takes {(time.time() - s)* 1000} ms. Batch size: {len(session_ids)=}"
407-
)
408-
else:
409-
print(
410-
f"SANG-TODO decode takes {(time.time() - s)* 1000} ms. Batch size: {len(session_ids)=}"
411-
)
422+
# if len(prefill_groups) > 0:
423+
# print(
424+
# f"SANG-TODO prefill takes {(time.time() - s)* 1000} ms. Batch size: {len(session_ids)=}"
425+
# )
426+
# else:
427+
# print(
428+
# f"SANG-TODO decode takes {(time.time() - s)* 1000} ms. Batch size: {len(session_ids)=}"
429+
# )
412430
# print(output)
413431
return output
414432

@@ -443,7 +461,7 @@ def _execute_and_scratch_sample(
443461
batch_size,
444462
tokens_out.data_ptr(),
445463
)
446-
print(f"SANG-TODO token: {tokens_out}")
464+
# print(f"SANG-TODO token: {tokens_out}")
447465

448466
result_tokens = tokens_out.tolist()
449467
outputs = []
@@ -462,7 +480,7 @@ def _execute_and_scratch_sample(
462480
)
463481
)
464482
output = SamplerOutput(outputs=outputs)
465-
print(output)
483+
# print(output)
466484
return output
467485

468486
@torch.inference_mode()

vllm/worker/worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
371371

372372
def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
373373
max_model_len) -> None:
374+
if USE_SCRATCH:
375+
return
374376
if num_gpu_blocks <= 0:
375377
raise ValueError("No available memory for the cache blocks. "
376378
"Try increasing `gpu_memory_utilization` when "

0 commit comments

Comments
 (0)