Skip to content

Commit

Permalink
Modify LSFEnvironment to use more reliable environment variable (#10825)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Andrew Tritt <andrew.j.tritt@gmail.com>
  • Loading branch information
5 people authored and lexierule committed Jan 19, 2022
1 parent 8cffc0f commit e95d8b1
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 116 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Skip testing with PyTorch 1.7 and Python 3.9 on Ubuntu ([#11217](https://github.com/PyTorchLightning/pytorch-lightning/pull/11217))
- Fixed type promotion when tensors of higher category than float are logged ([#11401](https://github.com/PyTorchLightning/pytorch-lightning/pull/11401))

### Changed

- Changed `LSFEnvironment` to use `LSB_DJOB_RANKFILE` environment variable instead of `LSB_HOSTS` for determining node rank and main address ([#10825](https://github.com/PyTorchLightning/pytorch-lightning/pull/10825))


## [1.5.8] - 2022-01-05

Expand Down
188 changes: 112 additions & 76 deletions pytorch_lightning/plugins/environments/lsf_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

import os
import socket
from typing import Dict, List

from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.cloud_io import get_filesystem


class LSFEnvironment(ClusterEnvironment):
Expand All @@ -25,128 +28,161 @@ class LSFEnvironment(ClusterEnvironment):
It is expected that any execution using this ClusterEnvironment was executed
using the Job Step Manager i.e. ``jsrun``.
This plugin expects the following environment variables.
This plugin expects the following environment variables:
LSB_JOBID:
The LSF assigned job ID
``LSB_JOBID``
The LSF assigned job ID
LSB_HOSTS:
The hosts used in the job. This string is expected to have the format "batch <rank_0_host> ...."
``LSB_DJOB_RANKFILE``
The OpenMPI compatibile rank file for the LSF job
JSM_NAMESPACE_LOCAL_RANK:
The node local rank for the task. This environment variable is set by jsrun
``JSM_NAMESPACE_LOCAL_RANK``
The node local rank for the task. This environment variable is set by ``jsrun``
JSM_NAMESPACE_SIZE:
The world size for the task. This environment variable is set by jsrun
"""
``JSM_NAMESPACE_SIZE``
The world size for the task. This environment variable is set by ``jsrun``
def __init__(self):
self._master_address = self._get_master_address()
self._master_port = self._get_master_port()
log.debug(f"MASTER_ADDR: {self._master_address}")
log.debug(f"MASTER_PORT: {self._master_port}")
``JSM_NAMESPACE_RANK``
The global rank for the task. This environment variable is set by ``jsrun``
"""

@staticmethod
def is_using_lsf() -> bool:
"""Returns ``True`` if the current process was launched using the jsrun command."""
required_env_vars = ("LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE")
return all(v in os.environ for v in required_env_vars)
def __init__(self) -> None:
super().__init__()
# TODO: remove in 1.7
if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf):
rank_zero_deprecation(
f"`{self.__class__.__name__}.is_using_lsf` has been deprecated in v1.6 and will be removed in v1.7."
" Implement the static method `detect()` instead (do not forget to add the `@staticmethod` decorator)."
)
self._main_address = self._get_main_address()
self._main_port = self._get_main_port()
self._node_rank = self._get_node_rank()
self._set_init_progress_group_env_vars()

def _set_init_progress_group_env_vars(self) -> None:
# set environment variables needed for initializing torch distributed process group
os.environ["MASTER_ADDR"] = str(self._main_address)
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
os.environ["MASTER_PORT"] = str(self._main_port)
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

@property
def creates_processes_externally(self) -> bool:
"""LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them."""
return True

def master_address(self):
"""The master address is read from a list of hosts contained in the environment variable `LSB_HOSTS`."""
return self._master_address
def master_address(self) -> str:
"""The main address is read from an OpenMPI host rank file in the environment variable
``LSB_DJOB_RANKFILE``."""
return self._main_address

def master_port(self) -> int:
"""The main port is calculated from the LSF job ID."""
return self._main_port

def master_port(self):
"""THe master port gets calculated from the LSF job ID."""
return self._master_port
@staticmethod
def is_using_lsf() -> bool:
"""Returns ``True`` if the current process was launched using the ``jsrun`` command."""
required_env_vars = {"LSB_JOBID", "LSB_DJOB_RANKFILE", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
return required_env_vars.issubset(os.environ.keys())

def world_size(self):
"""The world size is read from the environment variable `JSM_NAMESPACE_SIZE`."""
var = "JSM_NAMESPACE_SIZE"
world_size = os.environ.get(var)
def world_size(self) -> int:
"""The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``."""
world_size = os.environ.get("JSM_NAMESPACE_SIZE")
if world_size is None:
raise ValueError(
f"Cannot determine world size from environment variable {var}."
" Make sure you run your executable with `jsrun`"
"Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
"Make sure you run your executable with `jsrun`."
)
return int(world_size)

def set_world_size(self, size: int) -> None:
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self):
"""The world size is read from the environment variable `JSM_NAMESPACE_RANK`."""
var = "JSM_NAMESPACE_RANK"
global_rank = os.environ.get(var)
def global_rank(self) -> int:
"""The world size is read from the environment variable ``JSM_NAMESPACE_RANK``."""
global_rank = os.environ.get("JSM_NAMESPACE_RANK")
if global_rank is None:
raise ValueError(
f"Cannot determine global rank from environment variable {var}."
" Make sure you run your executable with `jsrun`"
"Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
"Make sure you run your executable with `jsrun`."
)
return int(global_rank)

def set_global_rank(self, rank: int) -> None:
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self):
def local_rank(self) -> int:
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
var = "JSM_NAMESPACE_LOCAL_RANK"
local_rank = os.environ.get(var)
local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK")
if local_rank is None:
raise ValueError(
f"Cannot determine local rank from environment variable {var}."
" Make sure you run your executable with `jsrun`"
"Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
"Make sure you run your executable with `jsrun`."
)
return int(local_rank)

def node_rank(self):
"""The node rank is determined by the position of the current hostname in the list of hosts stored in the
environment variable `LSB_HOSTS`."""
def node_rank(self) -> int:
"""The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored
in ``LSB_DJOB_RANKFILE``."""
return self._node_rank

def _get_node_rank(self) -> int:
"""A helper method for getting the node rank.
The node rank is determined by the position of the current node in the list of hosts used in the job. This is
calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list.
"""
hosts = self._read_hosts()
count = {}
count: Dict[str, int] = {}
for host in hosts:
if "batch" in host or "login" in host:
continue
if host not in count:
count[host] = len(count)
return count[socket.gethostname()]

@staticmethod
def _read_hosts():
hosts = os.environ.get("LSB_HOSTS")
if not hosts:
raise ValueError("Could not find hosts in environment variable LSB_HOSTS")
hosts = hosts.split()
if len(hosts) < 2:
raise ValueError(
'Cannot parse hosts from LSB_HOSTS environment variable. Expected format: "batch <rank_0_host> ..."'
)
return hosts
def _read_hosts() -> List[str]:
"""Read compute hosts that are a part of the compute job.
def _get_master_address(self):
LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes.
Each job is assigned a launch node. This launch node will be the first node in the list contained in
``LSB_DJOB_RANKFILE``.
"""
var = "LSB_DJOB_RANKFILE"
rankfile = os.environ.get(var)
if rankfile is None:
raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`")
if not rankfile:
raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty")

fs = get_filesystem(rankfile)
with fs.open(rankfile, "r") as f:
ret = [line.strip() for line in f]
# remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list
return ret[1:]

def _get_main_address(self) -> str:
"""A helper for getting the main address.
The main address is assigned to the first node in the list of nodes used for the job.
"""
hosts = self._read_hosts()
return hosts[1]
return hosts[0]

@staticmethod
def _get_master_port():
"""A helper function for accessing the master port.
def _get_main_port() -> int:
"""A helper function for accessing the main port.
Uses the LSF job ID so all ranks can compute the master port.
Uses the LSF job ID so all ranks can compute the main port.
"""
# check for user-specified master port
port = os.environ.get("MASTER_PORT")
if not port:
jobid = os.environ.get("LSB_JOBID")
if not jobid:
raise ValueError("Could not find job id in environment variable LSB_JOBID")
port = int(jobid)
# check for user-specified main port
if "MASTER_PORT" in os.environ:
log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
return int(os.environ["MASTER_PORT"])
if "LSB_JOBID" in os.environ:
port = int(os.environ["LSB_JOBID"])
# all ports should be in the 10k+ range
port = int(port) % 1000 + 10000
log.debug(f"calculated LSF master port: {port}")
else:
log.debug(f"using externally specified master port: {port}")
return int(port)
port = port % 1000 + 10000
log.debug(f"calculated LSF main port: {port}")
return port
raise ValueError("Could not find job id in environment variable LSB_JOBID")
Loading

0 comments on commit e95d8b1

Please sign in to comment.