Skip to content

More changes to docs #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3bd7f1b
Update launcher.py
pmcurtin Oct 19, 2024
2ce3576
Merge branch 'main' into docs-2
apoorvkh Oct 19, 2024
2731d31
Merge branch 'main' into docs-2
apoorvkh Oct 20, 2024
0b9c1df
moved log_handlers into .run()
apoorvkh Oct 20, 2024
af8c829
update contributing
apoorvkh Oct 20, 2024
4ac384e
add tyro, remove setuptools from extras
apoorvkh Oct 20, 2024
cbf40b9
enabled linting for docs; clarified public/private functions
apoorvkh Oct 20, 2024
76aa20f
docs for utils.py
apoorvkh Oct 20, 2024
de93aaf
docs for logging_utils
apoorvkh Oct 20, 2024
e4977fd
Merge branch 'docs-2' of github.com:apoorvkh/torchrunx into worker-ex…
apoorvkh Oct 20, 2024
e697257
advanced docs
apoorvkh Oct 20, 2024
748c2b7
adding napoleon for google docs
apoorvkh Oct 21, 2024
24f4a98
linkcode
apoorvkh Oct 21, 2024
cb6620c
update linkcode
apoorvkh Oct 21, 2024
3eb297c
try again
apoorvkh Oct 21, 2024
e609f54
fix?
apoorvkh Oct 21, 2024
e88e320
now linkcode works
apoorvkh Oct 21, 2024
bef8b28
updates
apoorvkh Oct 21, 2024
86bb67b
automethod run for launcher
apoorvkh Oct 21, 2024
d80d822
maximum_signature_line_length
apoorvkh Oct 21, 2024
9950e96
switch to members?
apoorvkh Oct 21, 2024
8276abc
Merge branch 'main' of github.com:apoorvkh/torchrunx into docs-2
apoorvkh Oct 29, 2024
f335140
created utils/
apoorvkh Oct 29, 2024
0b5e316
moved functions to worker.py
apoorvkh Oct 29, 2024
084061f
renamed to worker_entrypoint
apoorvkh Oct 29, 2024
6cc9311
completed docs for utils
apoorvkh Oct 29, 2024
490f2a8
more launcher docs
apoorvkh Oct 29, 2024
e54a533
more updates to docs
apoorvkh Oct 29, 2024
455c3f3
switched LaunchResult to get
apoorvkh Oct 29, 2024
f967218
bump hash in pixi lock
apoorvkh Oct 29, 2024
3a68eb6
removed overloading from LaunchResult
apoorvkh Oct 29, 2024
9e2d5f4
update all docs
apoorvkh Oct 30, 2024
a29212e
fix
apoorvkh Oct 30, 2024
7bf9222
small edits
apoorvkh Oct 30, 2024
122febc
how it works
apoorvkh Oct 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
enabled linting for docs; clarified public/private functions
  • Loading branch information
apoorvkh committed Oct 20, 2024
commit cbf40b9f0c4c547a5a56e1e5e2fe0121dce1afbd
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ src = ["src", "tests"]
[tool.ruff.lint]
select = ["ALL"]
ignore = [
"D", # documentation
"ANN101", "ANN102", "ANN401", # self / cls / Any annotations
"BLE001", # blind exceptions
"TD", # todo syntax
Expand All @@ -54,9 +53,12 @@ ignore = [
]
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = [
"D",
"S101", # allow asserts
"T201" # allow prints
]
[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.pyright]
include = ["src", "tests"]
Expand Down
158 changes: 80 additions & 78 deletions src/torchrunx/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

__all__ = ["main"]

import datetime
import logging
import os
Expand All @@ -25,83 +27,6 @@
)


@dataclass
class WorkerArgs:
function: Callable
logger_hostname: str
logger_port: int
main_agent_hostname: str
main_agent_port: int
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None
rank: int
local_rank: int
local_world_size: int
world_size: int
hostname: str
timeout: int

def serialize(self) -> SerializedWorkerArgs:
return SerializedWorkerArgs(worker_args=self)


class SerializedWorkerArgs:
def __init__(self, worker_args: WorkerArgs) -> None:
self.bytes = cloudpickle.dumps(worker_args)

def deserialize(self) -> WorkerArgs:
return cloudpickle.loads(self.bytes)


def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException:
worker_args: WorkerArgs = serialized_worker_args.deserialize()

logger = logging.getLogger()

log_records_to_socket(
logger=logger,
hostname=worker_args.hostname,
worker_rank=worker_args.local_rank,
logger_hostname=worker_args.logger_hostname,
logger_port=worker_args.logger_port,
)

redirect_stdio_to_logger(logger)

os.environ["RANK"] = str(worker_args.rank)
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)

if worker_args.backend is not None:
backend = worker_args.backend
if backend == "auto":
backend = "nccl" if torch.cuda.is_available() else "gloo"

dist.init_process_group(
backend=backend,
world_size=worker_args.world_size,
rank=worker_args.rank,
store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
host_name=worker_args.main_agent_hostname,
port=worker_args.main_agent_port,
world_size=worker_args.world_size,
is_master=(worker_args.rank == 0),
),
timeout=datetime.timedelta(seconds=worker_args.timeout),
)

try:
return worker_args.function()
except Exception as e:
traceback.print_exc()
return WorkerException(exception=e)
finally:
sys.stdout.flush()
sys.stderr.flush()


def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None:
agent_rank = launcher_agent_group.rank - 1

Expand Down Expand Up @@ -135,7 +60,7 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_

ctx = dist_mp.start_processes(
name=f"{hostname}_",
entrypoint=entrypoint,
entrypoint=_entrypoint,
args={
i: (
WorkerArgs(
Expand Down Expand Up @@ -179,3 +104,80 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
ctx.close()
sys.stdout.flush()
sys.stderr.flush()


@dataclass
class WorkerArgs:
function: Callable
logger_hostname: str
logger_port: int
main_agent_hostname: str
main_agent_port: int
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None
rank: int
local_rank: int
local_world_size: int
world_size: int
hostname: str
timeout: int

def serialize(self) -> SerializedWorkerArgs:
return SerializedWorkerArgs(worker_args=self)


class SerializedWorkerArgs:
def __init__(self, worker_args: WorkerArgs) -> None:
self.bytes = cloudpickle.dumps(worker_args)

def deserialize(self) -> WorkerArgs:
return cloudpickle.loads(self.bytes)


def _entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException:
worker_args: WorkerArgs = serialized_worker_args.deserialize()

logger = logging.getLogger()

log_records_to_socket(
logger=logger,
hostname=worker_args.hostname,
worker_rank=worker_args.local_rank,
logger_hostname=worker_args.logger_hostname,
logger_port=worker_args.logger_port,
)

redirect_stdio_to_logger(logger)

os.environ["RANK"] = str(worker_args.rank)
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)

if worker_args.backend is not None:
backend = worker_args.backend
if backend == "auto":
backend = "nccl" if torch.cuda.is_available() else "gloo"

dist.init_process_group(
backend=backend,
world_size=worker_args.world_size,
rank=worker_args.rank,
store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
host_name=worker_args.main_agent_hostname,
port=worker_args.main_agent_port,
world_size=worker_args.world_size,
is_master=(worker_args.rank == 0),
),
timeout=datetime.timedelta(seconds=worker_args.timeout),
)

try:
return worker_args.function()
except Exception as e:
traceback.print_exc()
return WorkerException(exception=e)
finally:
sys.stdout.flush()
sys.stderr.flush()
11 changes: 5 additions & 6 deletions src/torchrunx/environment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

__all__ = ["in_slurm_job", "slurm_hosts", "slurm_workers", "auto_hosts", "auto_workers"]

import os
import subprocess

Expand Down Expand Up @@ -29,8 +31,7 @@ def slurm_hosts() -> list[str]:


def slurm_workers() -> int:
"""
| Determines number of workers per node in current Slurm allocation using
"""| Determines number of workers per node in current Slurm allocation using
| the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables.

:return: The implied number of workers per node
Expand All @@ -52,8 +53,7 @@ def slurm_workers() -> int:


def auto_hosts() -> list[str]:
"""
Automatically determine hostname list
"""Automatically determine hostname list

:return: Hostnames in Slurm allocation, or ['localhost']
:rtype: list[str]
Expand All @@ -65,8 +65,7 @@ def auto_hosts() -> list[str]:


def auto_workers() -> int:
"""
Automatically determine number of workers per host
"""Automatically determine number of workers per host

:return: Workers per host
:rtype: int
Expand Down
38 changes: 18 additions & 20 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

__all__ = ["AgentKilledError", "Launcher", "launch", "LaunchResult"]

import fnmatch
import ipaddress
import itertools
Expand Down Expand Up @@ -54,14 +56,14 @@ def run( # noqa: C901, PLR0912
func: Callable,
func_args: tuple[Any] | None = None,
func_kwargs: dict[str, Any] | None = None,
log_handlers: list[Handler] | Literal["auto"] | None = "auto"
log_handlers: list[Handler] | Literal["auto"] | None = "auto",
) -> LaunchResult:
if not dist.is_available():
msg = "The torch.distributed package is not available."
raise RuntimeError(msg)

hostnames = resolve_hostnames(self.hostnames)
workers_per_host = resolve_workers_per_host(self.workers_per_host, len(hostnames))
hostnames = _resolve_hostnames(self.hostnames)
workers_per_host = _resolve_workers_per_host(self.workers_per_host, len(hostnames))

launcher_hostname = socket.getfqdn()
launcher_port = get_open_port()
Expand All @@ -75,7 +77,7 @@ def run( # noqa: C901, PLR0912
try:
# start logging server

log_receiver = build_logging_server(
log_receiver = _build_logging_server(
log_handlers=log_handlers,
launcher_hostname=launcher_hostname,
hostnames=hostnames,
Expand All @@ -94,8 +96,8 @@ def run( # noqa: C901, PLR0912
# start agents on each node

for i, hostname in enumerate(hostnames):
execute_command(
command=build_launch_command(
_execute_command(
command=_build_launch_command(
launcher_hostname=launcher_hostname,
launcher_port=launcher_port,
logger_port=log_receiver.port,
Expand Down Expand Up @@ -168,7 +170,7 @@ def run( # noqa: C901, PLR0912
# cleanup: SIGTERM all agents
if agent_payloads is not None:
for agent_payload, agent_hostname in zip(agent_payloads, hostnames):
execute_command(
_execute_command(
command=f"kill {agent_payload.process_id}",
hostname=agent_hostname,
ssh_config_file=self.ssh_config_file,
Expand Down Expand Up @@ -200,8 +202,7 @@ def launch(
env_file: str | os.PathLike | None = None,
log_handlers: list[Handler] | Literal["auto"] | None = "auto",
) -> LaunchResult:
"""
Launch a distributed PyTorch function on the specified nodes.
"""Launch a distributed PyTorch function on the specified nodes.

:param func:
:param func_args:
Expand Down Expand Up @@ -249,8 +250,7 @@ def all(self, by: Literal["rank"]) -> list[Any]:
pass

def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]:
"""
Get all worker return values by rank or hostname.
"""Get all worker return values by rank or hostname.

:param by: Whether to aggregate all return values by hostname, or just output all of them \
in order of rank, defaults to ``'hostname'``
Expand All @@ -264,17 +264,15 @@ def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[An
raise TypeError(msg)

def values(self, hostname: str) -> list[Any]:
"""
Get worker return values for host ``hostname``.
"""Get worker return values for host ``hostname``.

:param hostname: The host to get return values from
"""
host_idx = self.hostnames.index(hostname)
return self.return_values[host_idx]

def value(self, rank: int) -> Any:
"""
Get worker return value from global rank ``rank``.
"""Get worker return value from global rank ``rank``.

:param rank: Global worker rank to get return value from
"""
Expand All @@ -292,15 +290,15 @@ def value(self, rank: int) -> Any:
raise ValueError(msg)


def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
if hostnames == "auto":
return auto_hosts()
if hostnames == "slurm":
return slurm_hosts()
return hostnames


def resolve_workers_per_host(
def _resolve_workers_per_host(
workers_per_host: int | list[int] | Literal["auto", "slurm"],
num_hosts: int,
) -> list[int]:
Expand All @@ -318,7 +316,7 @@ def resolve_workers_per_host(
return workers_per_host


def build_logging_server(
def _build_logging_server(
log_handlers: list[Handler] | Literal["auto"] | None,
launcher_hostname: str,
hostnames: list[str],
Expand All @@ -343,7 +341,7 @@ def build_logging_server(
)


def build_launch_command(
def _build_launch_command(
launcher_hostname: str,
launcher_port: int,
logger_port: int,
Expand Down Expand Up @@ -385,7 +383,7 @@ def build_launch_command(
return " && ".join(commands)


def execute_command(
def _execute_command(
command: str,
hostname: str,
ssh_config_file: str | os.PathLike | None = None,
Expand Down
Loading