Skip to content

Commit

Permalink
[Core] add an option to log every function call to for debugging hang…
Browse files Browse the repository at this point in the history
…/crash in distributed inference (vllm-project#4079)

Co-authored-by: Simon Mo <simon.mo@hey.com>
  • Loading branch information
youkaichao and simon-mo authored Apr 18, 2024
1 parent 8f9c28f commit 8a7a3e4
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py

- label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py

- label: Entrypoints Test
commands:
Expand Down
2 changes: 2 additions & 0 deletions .github/ISSUE_TEMPLATE/400-bug report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ body:
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
placeholder: |
A clear and concise description of what the bug is.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import sys
import tempfile

from vllm.logger import enable_trace_function_call


def f1(x):
return f2(x)


def f2(x):
return x


def test_trace_function_call():
fd, path = tempfile.mkstemp()
cur_dir = os.path.dirname(__file__)
enable_trace_function_call(path, cur_dir)
f1(1)
with open(path, 'r') as f:
content = f.read()

assert "f1" in content
assert "f2" in content
sys.settrace(None)
os.remove(path)
12 changes: 9 additions & 3 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
get_vllm_instance_id, make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -133,12 +133,18 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)

# Set CUDA_VISIBLE_DEVICES for the driver and workers.
VLLM_INSTANCE_ID = get_vllm_instance_id()

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = []
for (node_id, _) in worker_node_and_gpu_ids:
all_args_to_update_environment_variables.append([{
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id]))
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"),
}])
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
Expand Down
52 changes: 52 additions & 0 deletions vllm/logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import datetime
import logging
import os
import sys
from functools import partial
from typing import Optional

VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
Expand Down Expand Up @@ -65,3 +67,53 @@ def init_logger(name: str):
logger.addHandler(_default_handler)
logger.propagate = False
return logger


logger = init_logger(__name__)


def _trace_calls(log_path, root_dir, frame, event, arg=None):
if event in ['call', 'return']:
# Extract the filename, line number, function name, and the code object
filename = frame.f_code.co_filename
lineno = frame.f_lineno
func_name = frame.f_code.co_name
if not filename.startswith(root_dir):
# only log the functions in the vllm root_dir
return
# Log every function call or return
try:
with open(log_path, 'a') as f:
if event == 'call':
f.write(f"{datetime.datetime.now()} Call to"
f" {func_name} in {filename}:{lineno}\n")
else:
f.write(f"{datetime.datetime.now()} Return from"
f" {func_name} in {filename}:{lineno}\n")
except NameError:
# modules are deleted during shutdown
pass
return partial(_trace_calls, log_path, root_dir)


def enable_trace_function_call(log_file_path: str,
root_dir: Optional[str] = None):
"""
Enable tracing of every function call in code under `root_dir`.
This is useful for debugging hangs or crashes.
`log_file_path` is the path to the log file.
`root_dir` is the root directory of the code to trace. If None, it is the
vllm root directory.
Note that this call is thread-level, any threads calling this function
will have the trace enabled. Other threads will not be affected.
"""
logger.warning(
"VLLM_TRACE_FUNCTION is enabled. It will record every"
" function executed by Python. This will slow down the code. It "
"is suggested to be used for debugging hang or crashes only.")
logger.info(f"Trace frame log is saved to {log_file_path}")
if root_dir is None:
# by default, this is the vllm root directory
root_dir = os.path.dirname(os.path.dirname(__file__))
sys.settrace(partial(_trace_calls, log_file_path, root_dir))
13 changes: 12 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,17 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)


@lru_cache(maxsize=None)
def get_vllm_instance_id():
"""
If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID.
Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id.
"""
return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}")


@lru_cache(maxsize=None)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
Expand Down Expand Up @@ -274,7 +285,7 @@ def get_open_port() -> int:

def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items():
if k in os.environ:
if k in os.environ and os.environ[k] != v:
logger.warning(f"Overwriting environment variable {k} "
f"from '{os.environ[k]}' to '{v}'")
os.environ[k] = v
Expand Down
20 changes: 17 additions & 3 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import datetime
import importlib
import os
import tempfile
import threading
from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple

from vllm.logger import init_logger
from vllm.logger import enable_trace_function_call, init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import update_environment_variables
from vllm.utils import get_vllm_instance_id, update_environment_variables

logger = init_logger(__name__)

Expand Down Expand Up @@ -115,9 +118,20 @@ def update_environment_variables(self, envs: Dict[str, str]) -> None:

def init_worker(self, *args, **kwargs):
"""
Actual initialization of the worker class.
Actual initialization of the worker class, and set up
function tracing if required.
Arguments are passed to the worker class constructor.
"""
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)

mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
Expand Down

0 comments on commit 8a7a3e4

Please sign in to comment.