Skip to content

Commit

Permalink
Remove monkey_patch_vllm_dummy_weight_loader (#2064)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Nov 17, 2024
1 parent c1f401f commit 38625e2
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 70 deletions.
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
logits_output, next_token_ids, bid = result

if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
Expand Down Expand Up @@ -970,7 +970,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
self.num_generated_tokens += len(batch.reqs)

if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs
else:
# Move next_token_ids and logprobs to cpu
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward_thread_func_(self):
self.launch_event.set()
self.output_queue.put((copy_event, logits_output, next_token_ids))

def resulve_batch_result(self, bid: int):
def resolve_batch_result(self, bid: int):
copy_event, logits_output, next_token_ids = self.output_queue.get()
while not copy_event.query():
time.sleep(1e-5)
Expand Down
4 changes: 1 addition & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check,
)

Expand Down Expand Up @@ -242,7 +241,6 @@ def load_model(self):
raise RuntimeError("SGLang only supports sm75 and above.")

# Prepare the vllm model config
monkey_patch_vllm_dummy_weight_loader()
self.load_config = LoadConfig(
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
Expand All @@ -261,7 +259,6 @@ def load_model(self):
self.vllm_model_config.hf_config.update(
self.model_config.model_override_args
)
self.dtype = self.vllm_model_config.dtype

# Load the model
self.model = get_model(
Expand All @@ -278,6 +275,7 @@ def load_model(self):
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.dtype = self.vllm_model_config.dtype

logger.info(
f"Load weight end. "
Expand Down
51 changes: 0 additions & 51 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,57 +405,6 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)


def monkey_patch_vllm_dummy_weight_loader():
"""
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
"""

from vllm.model_executor.model_loader.loader import (
CacheConfig,
DeviceConfig,
DummyModelLoader,
LoRAConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
_initialize_model,
initialize_dummy_weights,
nn,
set_default_torch_dtype,
)

def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(
model_config,
self.load_config,
lora_config,
cache_config,
)

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)

# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
return model.eval()

setattr(DummyModelLoader, "load_model", load_model)


vllm_all_gather_backup = None


Expand Down
4 changes: 2 additions & 2 deletions test/srt/test_bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def test_default(self):
output_throughput = run_bench_latency(DEFAULT_MODEL_NAME_FOR_TEST, [])

if is_in_ci():
assert output_throughput > 130, f"{output_throughput=}"
self.assertGreater(output_throughput, 135)

def test_moe_default(self):
output_throughput = run_bench_latency(
DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"]
)

if is_in_ci():
assert output_throughput > 125, f"{output_throughput=}"
self.assertGreater(output_throughput, 125)


if __name__ == "__main__":
Expand Down
22 changes: 11 additions & 11 deletions test/srt/test_bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_offline_throughput_default(self):
)

if is_in_ci():
assert res["output_throughput"] > 2830
self.assertGreater(res["output_throughput"], 2850)

def test_offline_throughput_non_stream_small_batch_size(self):
res = run_bench_serving(
Expand All @@ -35,7 +35,7 @@ def test_offline_throughput_non_stream_small_batch_size(self):
)

if is_in_ci():
assert res["output_throughput"] > 1000
self.assertGreater(res["output_throughput"], 950)

def test_offline_throughput_without_radix_cache(self):
res = run_bench_serving(
Expand All @@ -46,7 +46,7 @@ def test_offline_throughput_without_radix_cache(self):
)

if is_in_ci():
assert res["output_throughput"] > 2880
self.assertGreater(res["output_throughput"], 2900)

def test_offline_throughput_without_chunked_prefill(self):
res = run_bench_serving(
Expand All @@ -57,7 +57,7 @@ def test_offline_throughput_without_chunked_prefill(self):
)

if is_in_ci():
assert res["output_throughput"] > 2600
self.assertGreater(res["output_throughput"], 2600)

def test_offline_throughput_with_triton_attention_backend(self):
res = run_bench_serving(
Expand All @@ -73,7 +73,7 @@ def test_offline_throughput_with_triton_attention_backend(self):
)

if is_in_ci():
assert res["output_throughput"] > 2930
self.assertGreater(res["output_throughput"], 2950)

def test_offline_throughput_default_fp8(self):
res = run_bench_serving(
Expand All @@ -84,7 +84,7 @@ def test_offline_throughput_default_fp8(self):
)

if is_in_ci():
assert res["output_throughput"] > 3100
self.assertGreater(res["output_throughput"], 3200)

def test_online_latency_default(self):
res = run_bench_serving(
Expand All @@ -95,9 +95,9 @@ def test_online_latency_default(self):
)

if is_in_ci():
assert res["median_e2e_latency_ms"] < 12000
assert res["median_ttft_ms"] < 80
assert res["median_itl_ms"] < 12
self.assertLess(res["median_e2e_latency_ms"], 12000)
self.assertLess(res["median_ttft_ms"], 80)
self.assertLess(res["median_itl_ms"], 11)

def test_moe_offline_throughput_default(self):
res = run_bench_serving(
Expand All @@ -108,7 +108,7 @@ def test_moe_offline_throughput_default(self):
)

if is_in_ci():
assert res["output_throughput"] > 1850
self.assertGreater(res["output_throughput"], 1900)

def test_moe_offline_throughput_without_radix_cache(self):
res = run_bench_serving(
Expand All @@ -119,7 +119,7 @@ def test_moe_offline_throughput_without_radix_cache(self):
)

if is_in_ci():
assert res["output_throughput"] > 1950
self.assertGreater(res["output_throughput"], 1950)


if __name__ == "__main__":
Expand Down

0 comments on commit 38625e2

Please sign in to comment.