Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions deepmd/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
from typing import (
Tuple,
)

import numpy as np

Expand All @@ -26,3 +30,82 @@
"low. Please set precision with environmental variable "
"DP_INTERFACE_PREC." % dp_float_prec
)


def set_env_if_empty(key: str, value: str, verbose: bool = True):
"""Set environment variable only if it is empty.

Parameters
----------
key : str
env variable name
value : str
env variable value
verbose : bool, optional
if True action will be logged, by default True
"""
if os.environ.get(key) is None:
os.environ[key] = value
if verbose:
logging.warning(
f"Environment variable {key} is empty. Use the default value {value}"
)


def set_default_nthreads():
"""Set internal number of threads to default=automatic selection.

Notes
-----
`DP_INTRA_OP_PARALLELISM_THREADS` and `DP_INTER_OP_PARALLELISM_THREADS`
control configuration of multithreading.
"""
if (
"OMP_NUM_THREADS" not in os.environ
# for backward compatibility
or (
"DP_INTRA_OP_PARALLELISM_THREADS" not in os.environ
and "TF_INTRA_OP_PARALLELISM_THREADS" not in os.environ
)
or (
"DP_INTER_OP_PARALLELISM_THREADS" not in os.environ
and "TF_INTER_OP_PARALLELISM_THREADS" not in os.environ
)
):
logging.warning(
"To get the best performance, it is recommended to adjust "
"the number of threads by setting the environment variables "
"OMP_NUM_THREADS, DP_INTRA_OP_PARALLELISM_THREADS, and "
"DP_INTER_OP_PARALLELISM_THREADS. See "
"https://deepmd.rtfd.io/parallelism/ for more information."
)
if "TF_INTRA_OP_PARALLELISM_THREADS" not in os.environ:
set_env_if_empty("DP_INTRA_OP_PARALLELISM_THREADS", "0", verbose=False)
if "TF_INTER_OP_PARALLELISM_THREADS" not in os.environ:
set_env_if_empty("DP_INTER_OP_PARALLELISM_THREADS", "0", verbose=False)


def get_default_nthreads() -> Tuple[int, int]:
"""Get paralellism settings.

The method will first read the environment variables with the prefix `DP_`.
If not found, it will read the environment variables with the prefix `TF_`
for backward compatibility.

Returns
-------
Tuple[int, int]
number of `DP_INTRA_OP_PARALLELISM_THREADS` and
`DP_INTER_OP_PARALLELISM_THREADS`
"""
return int(
os.environ.get(
"DP_INTRA_OP_PARALLELISM_THREADS",
os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"),
)
), int(
os.environ.get(
"DP_INTER_OP_PARALLELISM_THREADS",
os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"),
)
)
13 changes: 13 additions & 0 deletions deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import numpy as np
import torch

from deepmd.env import (
get_default_nthreads,
set_default_nthreads,
)

PRECISION = os.environ.get("PRECISION", "float64")
GLOBAL_NP_FLOAT_PRECISION = getattr(np, PRECISION)
GLOBAL_PT_FLOAT_PRECISION = getattr(torch, PRECISION)
Expand Down Expand Up @@ -42,3 +47,11 @@
"double": torch.float64,
}
DEFAULT_PRECISION = "float64"

# throw warnings if threads not set
set_default_nthreads()
inter_nthreads, intra_nthreads = get_default_nthreads()
if inter_nthreads > 0: # the behavior of 0 is not documented
torch.set_num_interop_threads(inter_nthreads)
if intra_nthreads > 0:
torch.set_num_threads(intra_nthreads)
67 changes: 7 additions & 60 deletions deepmd/tf/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Module that sets tensorflow working environment and exports inportant constants."""

import ctypes
import logging
import os
import platform
from configparser import (
Expand All @@ -19,7 +18,6 @@
TYPE_CHECKING,
Any,
Dict,
Tuple,
)

import numpy as np
Expand All @@ -31,8 +29,15 @@
from deepmd.env import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.env import get_default_nthreads as get_tf_default_nthreads
from deepmd.env import (
global_float_prec,
)
from deepmd.env import set_default_nthreads as set_tf_default_nthreads
from deepmd.env import (
set_env_if_empty,
)

if TYPE_CHECKING:
from types import (
Expand Down Expand Up @@ -216,26 +221,6 @@ def dlopen_library(module: str, filename: str):
}


def set_env_if_empty(key: str, value: str, verbose: bool = True):
"""Set environment variable only if it is empty.

Parameters
----------
key : str
env variable name
value : str
env variable value
verbose : bool, optional
if True action will be logged, by default True
"""
if os.environ.get(key) is None:
os.environ[key] = value
if verbose:
logging.warning(
f"Environment variable {key} is empty. Use the default value {value}"
)


def set_mkl():
"""Tuning MKL for the best performance.

Expand Down Expand Up @@ -270,44 +255,6 @@ def set_mkl():
reload(np)


def set_tf_default_nthreads():
"""Set TF internal number of threads to default=automatic selection.

Notes
-----
`TF_INTRA_OP_PARALLELISM_THREADS` and `TF_INTER_OP_PARALLELISM_THREADS`
control TF configuration of multithreading.
"""
if (
"OMP_NUM_THREADS" not in os.environ
or "TF_INTRA_OP_PARALLELISM_THREADS" not in os.environ
or "TF_INTER_OP_PARALLELISM_THREADS" not in os.environ
):
logging.warning(
"To get the best performance, it is recommended to adjust "
"the number of threads by setting the environment variables "
"OMP_NUM_THREADS, TF_INTRA_OP_PARALLELISM_THREADS, and "
"TF_INTER_OP_PARALLELISM_THREADS. See "
"https://deepmd.rtfd.io/parallelism/ for more information."
)
set_env_if_empty("TF_INTRA_OP_PARALLELISM_THREADS", "0", verbose=False)
set_env_if_empty("TF_INTER_OP_PARALLELISM_THREADS", "0", verbose=False)


def get_tf_default_nthreads() -> Tuple[int, int]:
"""Get TF paralellism settings.

Returns
-------
Tuple[int, int]
number of `TF_INTRA_OP_PARALLELISM_THREADS` and
`TF_INTER_OP_PARALLELISM_THREADS`
"""
return int(os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0")), int(
os.environ.get("TF_INTER_OP_PARALLELISM_THREADS", "0")
)


def get_tf_session_config() -> Any:
"""Configure tensorflow session.

Expand Down
37 changes: 26 additions & 11 deletions doc/troubleshooting/howtoset_num_nodes.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,46 @@ Sometimes, `$num_nodes` and the nodes information can be directly given by the H

## Parallelism between independent operators

For CPU devices, TensorFlow use multiple streams to run independent operators (OP).
For CPU devices, TensorFlow and PyTorch use multiple streams to run independent operators (OP).

```bash
export TF_INTER_OP_PARALLELISM_THREADS=3
export DP_INTER_OP_PARALLELISM_THREADS=3
```

However, for GPU devices, TensorFlow uses only one compute stream and multiple copy streams.
Note that some of DeePMD-kit OPs do not have GPU support, so it is still encouraged to set environmental variables even if one has a GPU.

## Parallelism within an individual operators

For CPU devices, `TF_INTRA_OP_PARALLELISM_THREADS` controls parallelism within TensorFlow native OPs when TensorFlow is built against Eigen.
For CPU devices, `DP_INTRA_OP_PARALLELISM_THREADS` controls parallelism within TensorFlow (when TensorFlow is built against Eigen) and PyTorch native OPs.

```bash
export TF_INTRA_OP_PARALLELISM_THREADS=2
export DP_INTRA_OP_PARALLELISM_THREADS=2
```

`OMP_NUM_THREADS` is threads for OpenMP parallelism. It controls parallelism within TensorFlow native OPs when TensorFlow is built by Intel OneDNN and DeePMD-kit custom CPU OPs.
It may also control parallelsim for NumPy when NumPy is built against OpenMP, so one who uses GPUs for training should also care this environmental variable.
`OMP_NUM_THREADS` is the number of threads for OpenMP parallelism.
It controls parallelism within TensorFlow (when TensorFlow is built upon Intel OneDNN) and PyTorch (when PyTorch is built upon OpenMP) native OPs and DeePMD-kit custom CPU OPs.
It may also control parallelism for NumPy when NumPy is built against OpenMP, so one who uses GPUs for training should also care this environmental variable.

```bash
export OMP_NUM_THREADS=2
```

There are several other environmental variables for OpenMP, such as `KMP_BLOCKTIME`. See [Intel documentation](https://www.intel.com/content/www/us/en/developer/articles/technical/maximize-tensorflow-performance-on-cpu-considerations-and-recommendations-for-inference.html) for detailed information.
There are several other environmental variables for OpenMP, such as `KMP_BLOCKTIME`.

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}

See [Intel documentation](https://www.intel.com/content/www/us/en/developer/articles/technical/maximize-tensorflow-performance-on-cpu-considerations-and-recommendations-for-inference.html) for detailed information.

:::
:::{tab-item} PyTorch {{ pytorch_icon }}

See [PyTorch documentation](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html) for detailed information.

:::
::::

## Tune the performance

Expand All @@ -56,17 +71,17 @@ Here are some empirical examples.
If you wish to use 3 cores of 2 CPUs on one node, you may set the environmental variables and run DeePMD-kit as follows:
```bash
export OMP_NUM_THREADS=3
export TF_INTRA_OP_PARALLELISM_THREADS=3
export TF_INTER_OP_PARALLELISM_THREADS=2
export DP_INTRA_OP_PARALLELISM_THREADS=3
export DP_INTER_OP_PARALLELISM_THREADS=2
dp train input.json
```

For a node with 128 cores, it is recommended to start with the following variables:

```bash
export OMP_NUM_THREADS=16
export TF_INTRA_OP_PARALLELISM_THREADS=16
export TF_INTER_OP_PARALLELISM_THREADS=8
export DP_INTRA_OP_PARALLELISM_THREADS=16
export DP_INTER_OP_PARALLELISM_THREADS=8
```

Again, in general, one should make sure the product of the parallel numbers is less than or equal to the number of cores available.
Expand Down
4 changes: 2 additions & 2 deletions source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ void select_map_inv(typename std::vector<VT>::iterator out,
* @brief Get the number of threads from the environment variable.
* @details A warning will be thrown if environmental variables are not set.
* @param[out] num_intra_nthreads The number of intra threads. Read from
*TF_INTRA_OP_PARALLELISM_THREADS.
*DP_INTRA_OP_PARALLELISM_THREADS.
* @param[out] num_inter_nthreads The number of inter threads. Read from
*TF_INTER_OP_PARALLELISM_THREADS.
*DP_INTER_OP_PARALLELISM_THREADS.
**/
void get_env_nthreads(int& num_intra_nthreads, int& num_inter_nthreads);

Expand Down
19 changes: 16 additions & 3 deletions source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,23 +330,36 @@ void deepmd::get_env_nthreads(int& num_intra_nthreads,
num_intra_nthreads = 0;
num_inter_nthreads = 0;
const char* env_intra_nthreads =
std::getenv("TF_INTRA_OP_PARALLELISM_THREADS");
std::getenv("DP_INTRA_OP_PARALLELISM_THREADS");
const char* env_inter_nthreads =
std::getenv("DP_INTER_OP_PARALLELISM_THREADS");
// backward compatibility
const char* env_intra_nthreads_tf =
std::getenv("TF_INTRA_OP_PARALLELISM_THREADS");
const char* env_inter_nthreads_tf =
std::getenv("TF_INTER_OP_PARALLELISM_THREADS");
const char* env_omp_nthreads = std::getenv("OMP_NUM_THREADS");
if (env_intra_nthreads &&
std::string(env_intra_nthreads) != std::string("") &&
atoi(env_intra_nthreads) >= 0) {
num_intra_nthreads = atoi(env_intra_nthreads);
} else if (env_intra_nthreads_tf &&
std::string(env_intra_nthreads_tf) != std::string("") &&
atoi(env_intra_nthreads_tf) >= 0) {
num_intra_nthreads = atoi(env_intra_nthreads_tf);
} else {
throw_env_not_set_warning("TF_INTRA_OP_PARALLELISM_THREADS");
throw_env_not_set_warning("DP_INTRA_OP_PARALLELISM_THREADS");
}
if (env_inter_nthreads &&
std::string(env_inter_nthreads) != std::string("") &&
atoi(env_inter_nthreads) >= 0) {
num_inter_nthreads = atoi(env_inter_nthreads);
} else if (env_inter_nthreads_tf &&
std::string(env_inter_nthreads_tf) != std::string("") &&
atoi(env_inter_nthreads_tf) >= 0) {
num_inter_nthreads = atoi(env_inter_nthreads_tf);
} else {
throw_env_not_set_warning("TF_INTER_OP_PARALLELISM_THREADS");
throw_env_not_set_warning("DP_INTER_OP_PARALLELISM_THREADS");
}
if (!(env_omp_nthreads && std::string(env_omp_nthreads) != std::string("") &&
atoi(env_omp_nthreads) >= 0)) {
Expand Down
4 changes: 2 additions & 2 deletions source/tests/tf/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_empty(self):
@mock.patch.dict(
"os.environ",
values={
"TF_INTRA_OP_PARALLELISM_THREADS": "5",
"TF_INTER_OP_PARALLELISM_THREADS": "3",
"DP_INTRA_OP_PARALLELISM_THREADS": "5",
"DP_INTER_OP_PARALLELISM_THREADS": "3",
},
)
def test_given(self):
Expand Down