Skip to content

More changes to docs #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3bd7f1b
Update launcher.py
pmcurtin Oct 19, 2024
2ce3576
Merge branch 'main' into docs-2
apoorvkh Oct 19, 2024
2731d31
Merge branch 'main' into docs-2
apoorvkh Oct 20, 2024
0b9c1df
moved log_handlers into .run()
apoorvkh Oct 20, 2024
af8c829
update contributing
apoorvkh Oct 20, 2024
4ac384e
add tyro, remove setuptools from extras
apoorvkh Oct 20, 2024
cbf40b9
enabled linting for docs; clarified public/private functions
apoorvkh Oct 20, 2024
76aa20f
docs for utils.py
apoorvkh Oct 20, 2024
de93aaf
docs for logging_utils
apoorvkh Oct 20, 2024
e4977fd
Merge branch 'docs-2' of github.com:apoorvkh/torchrunx into worker-ex…
apoorvkh Oct 20, 2024
e697257
advanced docs
apoorvkh Oct 20, 2024
748c2b7
adding napoleon for google docs
apoorvkh Oct 21, 2024
24f4a98
linkcode
apoorvkh Oct 21, 2024
cb6620c
update linkcode
apoorvkh Oct 21, 2024
3eb297c
try again
apoorvkh Oct 21, 2024
e609f54
fix?
apoorvkh Oct 21, 2024
e88e320
now linkcode works
apoorvkh Oct 21, 2024
bef8b28
updates
apoorvkh Oct 21, 2024
86bb67b
automethod run for launcher
apoorvkh Oct 21, 2024
d80d822
maximum_signature_line_length
apoorvkh Oct 21, 2024
9950e96
switch to members?
apoorvkh Oct 21, 2024
8276abc
Merge branch 'main' of github.com:apoorvkh/torchrunx into docs-2
apoorvkh Oct 29, 2024
f335140
created utils/
apoorvkh Oct 29, 2024
0b5e316
moved functions to worker.py
apoorvkh Oct 29, 2024
084061f
renamed to worker_entrypoint
apoorvkh Oct 29, 2024
6cc9311
completed docs for utils
apoorvkh Oct 29, 2024
490f2a8
more launcher docs
apoorvkh Oct 29, 2024
e54a533
more updates to docs
apoorvkh Oct 29, 2024
455c3f3
switched LaunchResult to get
apoorvkh Oct 29, 2024
f967218
bump hash in pixi lock
apoorvkh Oct 29, 2024
3a68eb6
removed overloading from LaunchResult
apoorvkh Oct 29, 2024
9e2d5f4
update all docs
apoorvkh Oct 30, 2024
a29212e
fix
apoorvkh Oct 30, 2024
7bf9222
small edits
apoorvkh Oct 30, 2024
122febc
how it works
apoorvkh Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more updates to docs
  • Loading branch information
apoorvkh committed Oct 29, 2024
commit e54a5338450192e40ef385b08dd882e59315d1ad
22 changes: 21 additions & 1 deletion src/torchrunx/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Primary logic for agent processes."""

from __future__ import annotations

__all__ = ["main"]
Expand All @@ -22,8 +24,21 @@


def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None:
"""Main function for agent processes (started on each node).

This function spawns local worker processes (which run the target function). All agents monitor
their worker statuses (including returned objects and raised exceptions) and communicate these
with each other (and launcher). All agents terminate if failure occurs in any agent.

Arguments:
launcher_agent_group: The communication group between launcher and all agents.
logger_hostname: The hostname of the launcher (for logging).
logger_port: The port of the launcher (for logging).
"""
agent_rank = launcher_agent_group.rank - 1

# Communicate initial payloads between launcher/agents

payload = AgentPayload(
hostname=socket.getfqdn(),
port=get_open_port(),
Expand All @@ -38,6 +53,8 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
worker_global_ranks = launcher_payload.worker_global_ranks[agent_rank]
num_workers = len(worker_global_ranks)

# Stream logs to logging server

logger = logging.getLogger()

log_records_to_socket(
Expand All @@ -50,7 +67,7 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_

redirect_stdio_to_logger(logger)

# spawn workers
# Spawn worker processes

ctx = dist_mp.start_processes(
name=f"{hostname}_",
Expand Down Expand Up @@ -84,6 +101,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
), # pyright: ignore [reportArgumentType]
)

# Monitor and communicate agent statuses
# Terminate gracefully upon failure

try:
status = None
while True:
Expand Down
54 changes: 32 additions & 22 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def run( # noqa: C901, PLR0912
agent_payloads = None

try:
# start logging server
# Start logging server (recieves LogRecords from agents/workers)

log_receiver = _build_logging_server(
log_handlers=log_handlers,
Expand All @@ -105,7 +105,7 @@ def run( # noqa: C901, PLR0912

log_process.start()

# start agents on each node
# Start agents on each node

for i, hostname in enumerate(hostnames):
_execute_command(
Expand All @@ -122,7 +122,7 @@ def run( # noqa: C901, PLR0912
ssh_config_file=self.ssh_config_file,
)

# initialize launcher-agent process group
# Initialize launcher-agent process group
# ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1])

launcher_agent_group = LauncherAgentGroup(
Expand All @@ -132,7 +132,7 @@ def run( # noqa: C901, PLR0912
rank=0,
)

# build and sync payloads between launcher and agents
# Sync initial payloads between launcher and agents

_cumulative_workers = [0, *itertools.accumulate(workers_per_host)]

Expand All @@ -152,7 +152,7 @@ def run( # noqa: C901, PLR0912

launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload)

# loop to monitor agent statuses (until failed or done)
# Monitor agent statuses (until failed or done)

while True:
# could raise AgentFailedError
Expand Down Expand Up @@ -187,6 +187,7 @@ def run( # noqa: C901, PLR0912
ssh_config_file=self.ssh_config_file,
)

# if launch is successful: return objects from workers
return_values = [s.return_values for s in agent_statuses]
return LaunchResult(hostnames=hostnames, return_values=return_values)

Expand Down Expand Up @@ -216,23 +217,32 @@ def launch(
) -> LaunchResult:
"""Launch a distributed PyTorch function on the specified nodes.

:param func:
:param func_args:
:param func_kwargs:
:param hostnames: Nodes to launch the function on. Default infers from a SLURM environment or runs on localhost.
:param workers_per_host: Number of processes to run per node. Can define per node with :type:`list[int]`.
:param ssh_config_file: An SSH configuration file for connecting to nodes, by default loads ``~/.ssh/config`` or ``/etc/ssh/ssh_config``.
:param backend: `Backend <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend>`_ to initialize worker process group with. Default uses NCCL (if GPUs available) or GLOO. Disabled by ``None``.
:param timeout: Worker process group timeout (seconds).
:param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax.
:param extra_env_vars: Additional, user-specified variables to copy.
:param env_file: A file (like ``.env``) with additional environment variables to copy.
:param log_handlers: A list of handlers to manage agent and worker logs. Default uses an automatic basic logging scheme.
:raises RuntimeError: Due to various misconfigurations.
:raises AgentFailedError: If any agent fails (e.g. due to signal from OS).
:raises WorkerFailedError: If any worker fails (e.g. due to segmentation faults).
:raises Exception: Propagates exceptions raised in worker processes.
""" # noqa: E501
Arguments:
func: Function to run on each worker.
func_args: Positional arguments for ``func``.
func_kwargs: Keyword arguments for ``func``.
hostnames: Nodes on which to launch the function.
Defaults to nodes inferred from a SLURM environment or localhost.
workers_per_host: Number of processes to run per node.
Can specify different counts per node with a list.
ssh_config_file: Path to an SSH configuration file for connecting to nodes.
Defaults to ``~/.ssh/config`` or ``/etc/ssh/ssh_config``.
backend: `Backend <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend>`_
for worker process group. Defaults to NCCL (GPU) or GLOO (CPU). Set `None` to disable.
timeout: Worker process group timeout (seconds).
default_env_vars: Environment variables to copy from the launcher process to workers.
Supports bash pattern matching syntax.
extra_env_vars: Additional user-specified environment variables to copy.
env_file: Path to a file (e.g., `.env`) with additional environment variables to copy.
log_handlers: Handlers to manage agent and worker logs.
Defaults to an automatic basic logging scheme.

Raises:
RuntimeError: If there are configuration issues.
AgentFailedError: If an agent fails, e.g. from an OS signal.
WorkerFailedError: If a worker fails, e.g. from a segmentation fault.
Exception: Any exception raised in a worker process is propagated.
"""
return Launcher(
hostnames=hostnames,
workers_per_host=workers_per_host,
Expand Down
17 changes: 12 additions & 5 deletions src/torchrunx/utils/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __post_init__(self) -> None:
"""Initialize process group.

Raises:
torch.distributed.DistStoreError: if group initialization times out.
torch.distributed.DistStoreError: if group initialization times out.
"""
self.group = dist.init_process_group(
backend="gloo",
Expand All @@ -69,7 +69,11 @@ def _deserialize(self, serialized: bytes) -> Any:
return cloudpickle.loads(serialized)

def _all_gather(self, obj: Any) -> list:
"""Gather object from every rank to list on every rank."""
"""Gather object from every rank to list on every rank.

Raises:
AgentFailedError: if any agent fails (observed by this communication).
"""
try:
object_bytes = self._serialize(obj)
object_list = [b""] * self.world_size
Expand Down Expand Up @@ -125,8 +129,8 @@ class AgentStatus:
"""Status of each agent (to be synchronized in LauncherAgentGroup).

Attributes:
state: Whether the agent is running, failed, or done.
return_values: Objects returned (or exceptions raised) by workers (indexed by local rank).
state: Whether the agent is running, failed, or done.
return_values: Objects returned (or exceptions raised) by workers (indexed by local rank).
"""

state: Literal["running", "failed", "done"]
Expand All @@ -139,10 +143,13 @@ def from_result(cls, result: RunProcsResult | None) -> Self:
"""Convert RunProcsResult (from polling worker process context) to AgentStatus."""
if result is None:
return cls(state="running")

for local_rank, failure in result.failures.items():
result.return_values[local_rank] = WorkerFailedError(failure.message)

return_values = list(result.return_values.values())
failed = any(isinstance(v, ExceptionFromWorker) for v in return_values)

failed = any(isinstance(v, (ExceptionFromWorker, WorkerFailedError)) for v in return_values)
state = "failed" if failed else "done"

return cls(
Expand Down
6 changes: 1 addition & 5 deletions src/torchrunx/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@ def in_slurm_job() -> bool:


def slurm_hosts() -> list[str]:
"""Retrieves hostnames of Slurm-allocated nodes.

:return: Hostnames of nodes in current Slurm allocation
:rtype: list[str]
"""
"""Retrieves hostnames of Slurm-allocated nodes."""
# TODO: sanity check SLURM variables, commands
if not in_slurm_job():
msg = "Not in a SLURM job"
Expand Down
8 changes: 4 additions & 4 deletions src/torchrunx/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def add_filter_to_handler(
"""A filter for ``logging.Handler`` such that only specific agent/worker logs are handled.

Args:
handler: ``logging.Handler`` to be modified.
hostname: Name of specified host.
local_rank: Rank of specified worker (or ``None`` for agent).
log_level: Minimum log level to capture.
handler: ``logging.Handler`` to be modified.
hostname: Name of specified host.
local_rank: Rank of specified worker (or ``None`` for agent).
log_level: Minimum log level to capture.
"""

def _filter(record: WorkerLogRecord) -> bool:
Expand Down
21 changes: 21 additions & 0 deletions src/torchrunx/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Arguments and entrypoint for the worker processes."""

from __future__ import annotations

import datetime
Expand All @@ -20,6 +22,8 @@

@dataclass
class WorkerArgs:
"""Arguments passed from agent to spawned workers."""

function: Callable
logger_hostname: str
logger_port: int
Expand All @@ -34,10 +38,13 @@ class WorkerArgs:
timeout: int

def serialize(self) -> SerializedWorkerArgs:
"""Arguments must be serialized (to bytes) before passed to spawned workers."""
return SerializedWorkerArgs(worker_args=self)


class SerializedWorkerArgs:
"""We use cloudpickle as a serialization backend (as it supports nearly all Python types)."""

def __init__(self, worker_args: WorkerArgs) -> None:
self.bytes = cloudpickle.dumps(worker_args)

Expand All @@ -46,8 +53,16 @@ def deserialize(self) -> WorkerArgs:


def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker:
"""Function called by spawned worker processes.

Workers first prepare a process group (for communicating with all other workers).
They then invoke the user-provided function.
Logs are transmitted to the launcher process.
"""
worker_args: WorkerArgs = serialized_worker_args.deserialize()

# Start logging to the logging server (i.e. the launcher)

logger = logging.getLogger()

log_records_to_socket(
Expand All @@ -60,13 +75,17 @@ def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | Exc

redirect_stdio_to_logger(logger)

# Set rank/world environment variables

os.environ["RANK"] = str(worker_args.rank)
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)

# Prepare the process group (e.g. for communication within the user's function)

if worker_args.backend is not None:
backend = worker_args.backend
if backend == "auto":
Expand All @@ -85,6 +104,8 @@ def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | Exc
timeout=datetime.timedelta(seconds=worker_args.timeout),
)

# Invoke the user's function on this worker

try:
return worker_args.function()
except Exception as e:
Expand Down
Loading