Skip to content

Commit

Permalink
use num_gpu instead of num_gpus
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyu1994 committed Dec 24, 2024
1 parent 0cf1062 commit ae4cce6
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 59 deletions.
2 changes: 1 addition & 1 deletion R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ test_that("all parameters are stored correctly with save_model_to_string()", {
, "[gpu_platform_id: -1]"
, "[gpu_device_id: -1]"
, "[gpu_use_dp: 0]"
, "[num_gpus: 1]"
, "[num_gpu: 1]"
)
all_param_entries <- c(non_default_param_entries, default_param_entries)

Expand Down
8 changes: 2 additions & 6 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1371,12 +1371,6 @@ GPU Parameters

- **Note**: refer to `GPU Targets <./GPU-Targets.rst#query-opencl-devices-in-your-system>`__ for more details

- ``num_gpus`` :raw-html:`<a id="num_gpus" title="Permalink to this parameter" href="#num_gpus">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int

- Number of GPUs to use for training, used with device_type=cuda

- When <= 0, only 1 GPU will be used

- ``gpu_device_id_list`` :raw-html:`<a id="gpu_device_id_list" title="Permalink to this parameter" href="#gpu_device_id_list">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string

- List of CUDA device IDs used when device_type=cuda
Expand All @@ -1395,6 +1389,8 @@ GPU Parameters

- **Note**: can be used only in CUDA implementation (``device_type="cuda"``)

- When <= 0, only 1 GPU will be used

.. end params list
Others
Expand Down
5 changes: 1 addition & 4 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -1125,10 +1125,6 @@ struct Config {
// desc = **Note**: refer to `GPU Targets <./GPU-Targets.rst#query-opencl-devices-in-your-system>`__ for more details
int gpu_device_id = -1;

// desc = Number of GPUs to use for training, used with device_type=cuda
// desc = When <= 0, only 1 GPU will be used
int num_gpus = 1;

// desc = List of CUDA device IDs used when device_type=cuda
// desc = When empty, the devices with the smallest IDs will be used
std::string gpu_device_id_list = "";
Expand All @@ -1140,6 +1136,7 @@ struct Config {
// check = >0
// desc = number of GPUs
// desc = **Note**: can be used only in CUDA implementation (``device_type="cuda"``)
// desc = When <= 0, only 1 GPU will be used
int num_gpu = 1;

#ifndef __NVCC__
Expand Down
64 changes: 32 additions & 32 deletions include/LightGBM/cuda/cuda_nccl_topology.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ namespace LightGBM {

class NCCLTopology {
public:
NCCLTopology(const int master_gpu_device_id, const int num_gpus, const std::string& gpu_device_id_list, const data_size_t global_num_data) {
num_gpus_ = num_gpus;
NCCLTopology(const int master_gpu_device_id, const int num_gpu, const std::string& gpu_device_id_list, const data_size_t global_num_data) {
num_gpu_ = num_gpu;
master_gpu_device_id_ = master_gpu_device_id;
global_num_data_ = global_num_data;
int max_num_gpu = 0;
Expand All @@ -42,10 +42,10 @@ class NCCLTopology {
gpu_list_.push_back(gpu_id);
}
}
if (!gpu_list_.empty() && num_gpus_ != static_cast<int>(gpu_list_.size())) {
Log::Warning("num_gpus_ = %d is different from the number of valid device IDs in gpu_device_list (%d), using %d GPUs instead.", \
num_gpus_, static_cast<int>(gpu_list_.size()), static_cast<int>(gpu_list_.size()));
num_gpus_ = static_cast<int>(gpu_list_.size());
if (!gpu_list_.empty() && num_gpu_ != static_cast<int>(gpu_list_.size())) {
Log::Warning("num_gpu_ = %d is different from the number of valid device IDs in gpu_device_list (%d), using %d GPUs instead.", \
num_gpu_, static_cast<int>(gpu_list_.size()), static_cast<int>(gpu_list_.size()));
num_gpu_ = static_cast<int>(gpu_list_.size());
}

if (!gpu_list_.empty()) {
Expand All @@ -64,37 +64,37 @@ class NCCLTopology {
master_gpu_index_ = 0;
}
} else {
if (num_gpus_ <= 0) {
num_gpus_ = 1;
} else if (num_gpus_ > max_num_gpu) {
if (num_gpu_ <= 0) {
num_gpu_ = 1;
} else if (num_gpu_ > max_num_gpu) {
Log::Warning("Only %d GPUs available, using num_gpu = %d.", max_num_gpu, max_num_gpu);
num_gpus_ = max_num_gpu;
num_gpu_ = max_num_gpu;
}
if (master_gpu_device_id_ < 0 || master_gpu_device_id_ >= num_gpus_) {
if (master_gpu_device_id_ < 0 || master_gpu_device_id_ >= num_gpu_) {
Log::Warning("Invalid gpu_device_id = %d for master GPU index, using gpu_device_id = 0 instead.", master_gpu_device_id_);
master_gpu_device_id_ = 0;
master_gpu_index_ = 0;
}
for (int i = 0; i < num_gpus_; ++i) {
for (int i = 0; i < num_gpu_; ++i) {
gpu_list_.push_back(i);
}
}

Log::Info("Using GPU devices %s, and local master GPU device %d.", Common::Join<int>(gpu_list_, ",").c_str(), master_gpu_device_id_);

const int num_threads = OMP_NUM_THREADS();
if (num_gpus_ > num_threads) {
Log::Fatal("Number of GPUs %d is greater than the number of threads %d. Please use more threads.", num_gpus_, num_threads);
if (num_gpu_ > num_threads) {
Log::Fatal("Number of GPUs %d is greater than the number of threads %d. Please use more threads.", num_gpu_, num_threads);
}

host_threads_.resize(num_gpus_);
host_threads_.resize(num_gpu_);
}

~NCCLTopology() {}

void InitNCCL() {
nccl_gpu_rank_.resize(num_gpus_, -1);
nccl_communicators_.resize(num_gpus_);
nccl_gpu_rank_.resize(num_gpu_, -1);
nccl_communicators_.resize(num_gpu_);
ncclUniqueId nccl_unique_id;
if (Network::num_machines() == 1 || Network::rank() == 0) {
NCCLCHECK(ncclGetUniqueId(&nccl_unique_id));
Expand All @@ -113,26 +113,26 @@ class NCCLTopology {
if (Network::num_machines() > 1) {
node_rank_offset_.resize(Network::num_machines() + 1, 0);
Network::Allgather(
reinterpret_cast<char*>(&num_gpus_),
reinterpret_cast<char*>(&num_gpu_),
sizeof(int) / sizeof(char),
reinterpret_cast<char*>(node_rank_offset_.data() + 1));
for (int rank = 1; rank < Network::num_machines() + 1; ++rank) {
node_rank_offset_[rank] += node_rank_offset_[rank - 1];
}
CHECK_EQ(node_rank_offset_[Network::rank() + 1] - node_rank_offset_[Network::rank()], num_gpus_);
CHECK_EQ(node_rank_offset_[Network::rank() + 1] - node_rank_offset_[Network::rank()], num_gpu_);
NCCLCHECK(ncclGroupStart());
for (int gpu_index = 0; gpu_index < num_gpus_; ++gpu_index) {
for (int gpu_index = 0; gpu_index < num_gpu_; ++gpu_index) {
SetCUDADevice(gpu_list_[gpu_index], __FILE__, __LINE__);
nccl_gpu_rank_[gpu_index] = gpu_index + node_rank_offset_[Network::rank()];
NCCLCHECK(ncclCommInitRank(&nccl_communicators_[gpu_index], node_rank_offset_.back(), nccl_unique_id, nccl_gpu_rank_[gpu_index]));
}
NCCLCHECK(ncclGroupEnd());
} else {
NCCLCHECK(ncclGroupStart());
for (int gpu_index = 0; gpu_index < num_gpus_; ++gpu_index) {
for (int gpu_index = 0; gpu_index < num_gpu_; ++gpu_index) {
SetCUDADevice(gpu_list_[gpu_index], __FILE__, __LINE__);
nccl_gpu_rank_[gpu_index] = gpu_index;
NCCLCHECK(ncclCommInitRank(&nccl_communicators_[gpu_index], num_gpus_, nccl_unique_id, gpu_index));
NCCLCHECK(ncclCommInitRank(&nccl_communicators_[gpu_index], num_gpu_, nccl_unique_id, gpu_index));
}
NCCLCHECK(ncclGroupEnd());
}
Expand All @@ -143,8 +143,8 @@ class NCCLTopology {

template <typename ARG_T, typename RET_T>
void RunPerDevice(const std::vector<std::unique_ptr<ARG_T>>& objs, const std::function<RET_T(ARG_T*)>& func) {
#pragma omp parallel for schedule(static) num_threads(num_gpus_)
for (int i = 0; i < num_gpus_; ++i) {
#pragma omp parallel for schedule(static) num_threads(num_gpu_)
for (int i = 0; i < num_gpu_; ++i) {
CUDASUCCESS_OR_FATAL(cudaSetDevice(gpu_list_[i]));
func(objs[i].get());
}
Expand All @@ -153,9 +153,9 @@ class NCCLTopology {

template <typename RET_T>
void InitPerDevice(std::vector<std::unique_ptr<RET_T>>* vec) {
vec->resize(num_gpus_);
#pragma omp parallel for schedule(static) num_threads(num_gpus_)
for (int i = 0; i < num_gpus_; ++i) {
vec->resize(num_gpu_);
#pragma omp parallel for schedule(static) num_threads(num_gpu_)
for (int i = 0; i < num_gpu_; ++i) {
CUDASUCCESS_OR_FATAL(cudaSetDevice(gpu_list_[i]));
RET_T* nccl_info = new RET_T();
nccl_info->SetNCCLInfo(nccl_communicators_[i], nccl_gpu_rank_[i], i, gpu_list_[i], global_num_data_);
Expand All @@ -166,13 +166,13 @@ class NCCLTopology {

template <typename ARG_T>
void DispatchPerDevice(std::vector<std::unique_ptr<ARG_T>>* objs, const std::function<void(ARG_T*)>& func) {
for (int i = 0; i < num_gpus_; ++i) {
for (int i = 0; i < num_gpu_; ++i) {
host_threads_[i] = std::thread([this, i, &func, objs] () {
CUDASUCCESS_OR_FATAL(cudaSetDevice(gpu_list_[i]))
func(objs->operator[](i).get());
});
}
for (int i = 0; i < num_gpus_; ++i) {
for (int i = 0; i < num_gpu_; ++i) {
host_threads_[i].join();
}
CUDASUCCESS_OR_FATAL(cudaSetDevice(master_gpu_device_id_));
Expand All @@ -186,7 +186,7 @@ class NCCLTopology {

template <typename ARG_T, typename RET_T>
void RunOnNonMasterDevice(const std::vector<std::unique_ptr<ARG_T>>& objs, const std::function<RET_T(ARG_T*)>& func) {
for (int i = 0; i < num_gpus_; ++i) {
for (int i = 0; i < num_gpu_; ++i) {
if (i != master_gpu_index_) {
CUDASUCCESS_OR_FATAL(cudaSetDevice(gpu_list_[i]));
func(objs[i].get());
Expand All @@ -195,7 +195,7 @@ class NCCLTopology {
CUDASUCCESS_OR_FATAL(cudaSetDevice(master_gpu_device_id_));
}

int num_gpus() const { return num_gpus_; }
int num_gpu() const { return num_gpu_; }

int master_gpu_index() const { return master_gpu_index_; }

Expand All @@ -204,7 +204,7 @@ class NCCLTopology {
const std::vector<int>& gpu_list() const { return gpu_list_; }

private:
int num_gpus_;
int num_gpu_;
int master_gpu_index_;
int master_gpu_device_id_;
std::vector<int> gpu_list_;
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_objective_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <typename HOST_OBJECTIVE>
class CUDAObjectiveInterface: public HOST_OBJECTIVE, public NCCLInfo {
public:
explicit CUDAObjectiveInterface(const Config& config): HOST_OBJECTIVE(config) {
if (config.num_gpus <= 1) {
if (config.num_gpu <= 1) {
const int gpu_device_id = config.gpu_device_id >= 0 ? config.gpu_device_id : 0;
SetCUDADevice(gpu_device_id, __FILE__, __LINE__);
}
Expand Down
6 changes: 3 additions & 3 deletions src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ void Application::InitTrain() {
// create boosting
boosting_.reset(
Boosting::CreateBoosting(config_.boosting,
config_.input_model.c_str(), config_.device_type, config_.num_gpus));
config_.input_model.c_str(), config_.device_type, config_.num_gpu));
// create objective function
objective_fun_.reset(
ObjectiveFunction::CreateObjectiveFunction(config_.objective,
Expand Down Expand Up @@ -274,13 +274,13 @@ void Application::Predict() {

void Application::InitPredict() {
boosting_.reset(
Boosting::CreateBoosting("gbdt", config_.input_model.c_str(), config_.device_type, config_.num_gpus));
Boosting::CreateBoosting("gbdt", config_.input_model.c_str(), config_.device_type, config_.num_gpu));
Log::Info("Finished initializing prediction, total used %d iterations", boosting_->GetCurrentIteration());
}

void Application::ConvertModel() {
boosting_.reset(
Boosting::CreateBoosting(config_.boosting, config_.input_model.c_str(), config_.device_type, config_.num_gpus));
Boosting::CreateBoosting(config_.boosting, config_.input_model.c_str(), config_.device_type, config_.num_gpu));
boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
}

Expand Down
6 changes: 3 additions & 3 deletions src/boosting/boosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename
#endif // USE_CUDA
, const int
#ifdef USE_CUDA
num_gpus
num_gpu
#endif // USE_CUDA
) {
if (filename == nullptr || filename[0] == '\0') {
if (type == std::string("gbdt")) {
#ifdef USE_CUDA
if (device_type == std::string("cuda") && num_gpus > 1) {
if (device_type == std::string("cuda") && num_gpu > 1) {
return new NCCLGBDT<GBDT>();
} else {
#endif // USE_CUDA
Expand All @@ -70,7 +70,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename
if (GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
if (type == std::string("gbdt")) {
#ifdef USE_CUDA
if (device_type == std::string("cuda") && num_gpus > 1) {
if (device_type == std::string("cuda") && num_gpu > 1) {
ret.reset(new NCCLGBDT<GBDT>());
} else {
#endif // USE_CUDA
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/cuda/nccl_gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void NCCLGBDT<GBDT_T>::Init(

this->tree_learner_.reset();

nccl_topology_.reset(new NCCLTopology(this->config_->gpu_device_id, this->config_->num_gpus, this->config_->gpu_device_id_list, train_data->num_data()));
nccl_topology_.reset(new NCCLTopology(this->config_->gpu_device_id, this->config_->num_gpu, this->config_->gpu_device_id_list, train_data->num_data()));

nccl_topology_->InitNCCL();

Expand Down
2 changes: 1 addition & 1 deletion src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class Booster {
"please use continued train with input score");
}

boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr, config_.device_type, config_.num_gpus));
boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr, config_.device_type, config_.num_gpu));

train_data_ = train_data;
CreateObjectiveAndMetrics();
Expand Down
13 changes: 7 additions & 6 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"machines",
"gpu_platform_id",
"gpu_device_id",
"num_gpus",
"gpu_device_id_list",
"gpu_use_dp",
"num_gpu",
});
return params;
}
Expand Down Expand Up @@ -664,11 +664,12 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

GetInt(params, "gpu_device_id", &gpu_device_id);

GetInt(params, "num_gpus", &num_gpus);

GetString(params, "gpu_device_id_list", &gpu_device_id_list);

GetBool(params, "gpu_use_dp", &gpu_use_dp);

GetInt(params, "num_gpu", &num_gpu);
CHECK_GT(num_gpu, 0);
}

std::string Config::SaveMembersToString() const {
Expand Down Expand Up @@ -786,9 +787,9 @@ std::string Config::SaveMembersToString() const {
str_buf << "[machines: " << machines << "]\n";
str_buf << "[gpu_platform_id: " << gpu_platform_id << "]\n";
str_buf << "[gpu_device_id: " << gpu_device_id << "]\n";
str_buf << "[num_gpus: " << num_gpus << "]\n";
str_buf << "[gpu_device_id_list: " << gpu_device_id_list << "]\n";
str_buf << "[gpu_use_dp: " << gpu_use_dp << "]\n";
str_buf << "[num_gpu: " << num_gpu << "]\n";
return str_buf.str();
}

Expand Down Expand Up @@ -932,9 +933,9 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
{"machines", {"workers", "nodes"}},
{"gpu_platform_id", {}},
{"gpu_device_id", {}},
{"num_gpus", {}},
{"gpu_device_id_list", {}},
{"gpu_use_dp", {}},
{"num_gpu", {}},
});
return map;
}
Expand Down Expand Up @@ -1078,9 +1079,9 @@ const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
{"machines", "string"},
{"gpu_platform_id", "int"},
{"gpu_device_id", "int"},
{"num_gpus", "int"},
{"gpu_device_id_list", "string"},
{"gpu_use_dp", "bool"},
{"num_gpu", "int"},
});
return map;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,7 @@ def test_all_expected_params_are_written_out_to_model_text(tmp_path):
"[machines: ]",
"[gpu_platform_id: -1]",
"[gpu_device_id: -1]",
"[num_gpus: 1]",
"[num_gpu: 1]",
]
all_param_entries = non_default_param_entries + default_param_entries

Expand Down

0 comments on commit ae4cce6

Please sign in to comment.