diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 892fe1845..f11c7d490 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -55,6 +55,7 @@ add_library(marian STATIC training/graph_group_sync.cpp training/graph_group_singleton.cpp training/graph_group_multinode.cpp + training/graph_group_multinode_sync.cpp training/validator.cpp $ diff --git a/src/command/marian.cpp b/src/command/marian.cpp index 92b7be7a4..e90ec92e2 100644 --- a/src/command/marian.cpp +++ b/src/command/marian.cpp @@ -2,6 +2,7 @@ #include "training/graph_group_async.h" #include "training/graph_group_multinode.h" +#include "training/graph_group_multinode_sync.h" #include "training/graph_group_singleton.h" #include "training/graph_group_sync.h" #include "training/training.h" @@ -22,7 +23,11 @@ int main(int argc, char** argv) { ABORT_IF(!configureMPI(argc, argv), "MPI not found."); LOG(warn, "[experimental] Running multi-node training"); - New>(options)->run(); + if (!options->get("sync-sgd")) { + New>(options)->run(); + } else { + New>(options)->run(); + } } else { if(devices.size() == 1) { New>(options)->run(); diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 6f44a4747..465808123 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -569,10 +569,6 @@ void ConfigParser::addOptionsTraining(po::options_description& desc) { ("multi-node-overlap", po::value() ->default_value(true), "Overlap model computations with MPI communication") - ("multi-node-local-optimizers", po::value() - ->zero_tokens() - ->default_value(false), - "Enable local optimizers with multi-node. Requires optimizer delay to be turned on.") ; // clang-format on desc.add(training); @@ -937,7 +933,6 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { SET_OPTION("multi-node", bool); SET_OPTION("multi-node-overlap", bool); - SET_OPTION("multi-node-local-optimizers", bool); } if(mode_ == ConfigMode::rescoring) { diff --git a/src/training/graph_group_multinode_sync.cpp b/src/training/graph_group_multinode_sync.cpp new file mode 100644 index 000000000..89f0a0842 --- /dev/null +++ b/src/training/graph_group_multinode_sync.cpp @@ -0,0 +1,306 @@ +#include "training/graph_group_multinode_sync.h" +#include "functional/functional.h" +#include "tensors/tensor_operators.h" + +namespace marian { + +/** + * Set given scheduler to register training observers on the shard optimizers. + */ +void MultiNodeGraphGroupSync::setScheduler(Ptr scheduler) { + scheduler_ = scheduler; + // optimizer has to be registered last to see a change of learning rate + scheduler_->registerTrainingObserver(scheduler_); + + scheduler_->registerTrainingObserver(syncOptimizer_); + +} + +/** + * Allocate new tensor on given GPU and store allocator. + */ +Tensor MultiNodeGraphGroupSync::newTensor(int size, Ptr backend) { + Tensor t; + Ptr allocator = New(backend); + allocator->reserveExact(size * sizeof(float)); + allocator->allocate(t, {1, size}); + allocators_.push_back(allocator); + return t; +} + +/** + * Setup training environment and launch server thread and (if enabled) client + * communication overlap threads. + * Includes setting up MPI, node and shard sizes, clients, server shards and + * communication overlap stuff. + */ +void MultiNodeGraphGroupSync::init(Ptr batch) { + // Setup clients and shards + setupClients(batch); + + // setup sync sgd storage, We keep the summed gradient on Node 0 + accGradientsSync = newTensor(clientGraphs_[0]->params()->vals()->size()*sizeof(float), clientGraphs_[0]->getBackend()); + accGradientsSync->set(0); +} + +/** + * Initialize the CPU arrays, with pinned memory for faster CudaMemCpy operations. + * Requires the graph to be initialized first so we know its size + */ +void MultiNodeGraphGroupSync::initCPUArrays() { + CUDA_CHECK(cudaMallocHost(&accGradientsSync_cpu, clientGraphs_[0]->params()->vals()->size()*sizeof(float))); + CUDA_CHECK(cudaMallocHost(&receiveBuffer_cpu, clientGraphs_[0]->params()->vals()->size()*sizeof(float))); + std::memset(accGradientsSync_cpu, 0, clientGraphs_[0]->params()->vals()->size()*sizeof(float)); + std::memset(receiveBuffer_cpu, 0, clientGraphs_[0]->params()->vals()->size()*sizeof(float)); +} + +/** + * Setup MPI world size and rank of this node. + */ +void MultiNodeGraphGroupSync::setupMPI() { +#if MPI_FOUND + MPI_Comm_size(MPI_COMM_WORLD, &mpi_comm_world_size_); + MPI_Comm_rank(MPI_COMM_WORLD, &mpi_my_rank_); +#endif +} + +/** + * Setup clients that will compute gradients and communicate them with the + * server shards. + * There is one client per GPU. + */ +void MultiNodeGraphGroupSync::setupClients(Ptr batch) { + runBatchThroughClientGraphs(batch); + initCPUArrays(); + + clientThreadPool_ = new marian::ThreadPool(devices_.size(), devices_.size()); +} + +/** + * Initialize the graphs (models) of all clients on this node with the given + * batch. + */ +void MultiNodeGraphGroupSync::runBatchThroughClientGraphs(Ptr batch) { + for(int i = 0; i < devices_.size(); i++) { + THREAD_GUARD(clientBuilders_[i]->build(clientGraphs_[i], batch); + clientGraphs_[i]->forward(); + clientGraphs_[i]->getBackend()->synchronize();); + } +} + +/** + * Initialize variables required for overlapping client computations and + * communication. + * Includes summed and committed word counts, buffer flags, mutexes and + * condition variables. + */ +void MultiNodeGraphGroupSync::sumGRAD(Tensor gradient) { + std::lock_guard guard(sumGradientMutex_); + using namespace functional; //@TODO makes more sense to do that on the CPU i think + Element(_1 += _2, accGradientsSync, gradient); +} + +/** + * If it's rank 0, it's a local update, if it's rank one it's remote + * send and receive. Make sure you only call from device 0. + */ + +void MultiNodeGraphGroupSync::sendReceiveUpdateSync() { + #if MPI_FOUND + // Copy the data to the CPU + CUDA_CHECK(cudaMemcpy(accGradientsSync_cpu, + accGradientsSync->data(), + accGradientsSync->size() * sizeof(float), + cudaMemcpyDeviceToHost)); + + int reduce_result = MPI_Reduce(accGradientsSync_cpu, //CPU buffers + receiveBuffer_cpu, + accGradientsSync->size(), + MPI_FLOAT, + MPI_SUM, + 0, //Rank of the process with the data. In this case Node 0 + MPI_COMM_WORLD); + + if (reduce_result != MPI_SUCCESS) { + LOG(critical, "Error: MPI_REDUCE failed with error {}.", reduce_result); + std::abort(); + } + + // Copy the data back to the GPU and do optimizer update + CUDA_CHECK(cudaMemcpy(accGradientsSync->data(), + accGradientsSync_cpu, + accGradientsSync->size() * sizeof(float), + cudaMemcpyHostToDevice)); + + // Perform optimizer step + syncOptimizer_->update(clientGraphs_[0]->params()->vals(), + accGradientsSync); + + // Copy the data back to the host. + if (mpi_my_rank_ == 0) { + CUDA_CHECK(cudaMemcpy(accGradientsSync_cpu, //This is now the updated params + clientGraphs_[0]->params()->vals()->data(), + accGradientsSync->size() * sizeof(float), + cudaMemcpyDeviceToHost)); + } + + int bcast_result = MPI_Bcast(accGradientsSync_cpu, //This is now the updated params. + accGradientsSync->size(), + MPI_FLOAT, + 0, //Root process + MPI_COMM_WORLD); + + if (bcast_result != MPI_SUCCESS) { + LOG(critical, "Error: MPI_REDUCE failed with error {}.", bcast_result); + std::abort(); + } + + if (mpi_my_rank_ != 0) { + //Copy the data to the GPU + CUDA_CHECK(cudaMemcpy(clientGraphs_[0]->params()->vals()->data(), + accGradientsSync_cpu, + accGradientsSync->size() * sizeof(float), + cudaMemcpyHostToDevice)); + } + //Distribute the graph to the rest of the devices + std::vector threads; + for(int idx = 1; idx < devices_.size(); idx++) { + threads.emplace_back(std::thread( + [=](int idx) { + //If NVLINK is not available it's faster to do this from the CPU + //Because we don't have to go Device->Host->device + CUDA_CHECK(cudaMemcpy(clientGraphs_[idx]->params()->vals()->data(), + accGradientsSync_cpu, + accGradientsSync->size() * sizeof(float), + cudaMemcpyHostToDevice)); + }, + idx)); + } + for(auto&& t : threads) { + t.join(); + } + //set the accumulating buffers to zero; + accGradientsSync->set(0); + std::memset(accGradientsSync_cpu, 0, clientGraphs_[0]->params()->vals()->size()*sizeof(float)); + std::memset(receiveBuffer_cpu, 0, clientGraphs_[0]->params()->vals()->size()*sizeof(float)); + #endif +} + + +/** + * Execute given batch on this node, pushing/pulling the resulting + * gradients/parameters to/from the server shards + * or -- if comm. overlap enabled -- to/from the communication buffers, summing + * gradients locally if the communication thread is busy + * + * @param batch Batch on which to perform forward and backward passes. + */ +void MultiNodeGraphGroupSync::execute(Ptr batch) { + if(!initialized_) { + init(batch); + initialized_ = true; + } + + auto task = [this](Ptr batch) { + static size_t i = 0; + thread_local Ptr graph; + thread_local Ptr builder; + thread_local size_t my_id = 0; + thread_local size_t t = 0; + // only for scheduler statistic + thread_local float cost = 0; + thread_local size_t num_seen_words = 0; + thread_local size_t num_seen_sentences = 0; + + if(!graph) { + std::lock_guard lock(mutexClientInit_); + my_id = i; + graph = clientGraphs_[i]; + builder = clientBuilders_[i++]; + } + + auto costNode = builder->build(graph, batch); + + if (t == 0) { + if (my_id != 0) + graph->params()->vals()->copyFrom(clientGraphs_[0]->params()->vals()); + } + + graph->forward(); + cost += costNode->scalar(); + num_seen_words += batch->words(); + num_seen_sentences += batch->size(); + graph->backward(); + + t++; + + graph->getBackend()->synchronize(); //@Alham do you know why we need this here? + + sumGRAD(graph->params()->vals()); + //Lock here and send receive gradients. @TODO I AM REALLY NOT SURE THIS IS CORRECT FOR MORE THAN ONE THERADS + { + std::unique_lock lock(updateParamsMutex_); + clientThreadPool_->wait_for_one(lock); //Only one thread will do the next, correct @TODO + if (!synchronization_happened) { + sendReceiveUpdateSync(); + synchronization_happened = true; + } + clientThreadPool_->wait_for_others(lock); + synchronization_happened = false; + clientThreadPool_->notify_others(); + } + + // Run scheduler (if enabled) + if(t % tau_ == 0 && scheduler_) { + std::unique_lock lock(schedulerMutex_); + + // Wait until the thread that wants to do validation is finished. + clientThreadPool_->wait_for_one(lock); + + if (options_->get("cost-type") != "ce-sum") + cost /= tau_; + + if (tau_ > 1) { + std::vector fakeLength = {1, 1}; + auto fb = data::CorpusBatch::fakeBatch(fakeLength, + num_seen_sentences, + NULL); + fb->front()->setWords(num_seen_words); + scheduler_->update(cost, fb); + } else { + scheduler_->update(cost, batch); + } + + num_seen_words = 0; + num_seen_sentences = 0; + cost = 0; + + if((scheduler_->saving() || scheduler_->validating())) { + // Wait with validation or saving until all other threads are done with + // update. + // We want to reuse the graphs for validation, so they need to be in + // a safe state. + clientThreadPool_->wait_for_others(lock); + #if MPI_FOUND + //wait until other nodes are ready + MPI_Barrier(MPI_COMM_WORLD); + + // TODO: Saving is broken + //if(mpi_my_rank_ == 0 && scheduler_->saving()) + // this->save(graph); + + if(mpi_my_rank_ == 0 && scheduler_->validating()) + scheduler_->validate(clientGraphs_); + + // inform other nodes to continue + MPI_Barrier(MPI_COMM_WORLD); + #endif + // Validation or saving is done, tell other threads to continue work. + clientThreadPool_->notify_others(); + } + } + }; + + clientThreadPool_->enqueue(task, batch); +} +} diff --git a/src/training/graph_group_multinode_sync.h b/src/training/graph_group_multinode_sync.h new file mode 100644 index 000000000..6e0e68ba4 --- /dev/null +++ b/src/training/graph_group_multinode_sync.h @@ -0,0 +1,313 @@ +#pragma once + +#if MPI_FOUND +#include "mpi.h" +#include "cuda_runtime.h" + +#define CUDA_CHECK(ans) \ + { gpuAssert2((ans), __FILE__, __LINE__); } + +#endif + +#include +#include +#include + +#include +#include +#include + +#include "3rd_party/threadpool.h" +#include "training/graph_group.h" + +#if CUDA_FOUND +inline void gpuAssert2(cudaError_t code, + const char* file, + int line, + bool abort = true) { + if(code != cudaSuccess) { + LOG(critical, "Error: {} - {}:{}", cudaGetErrorString(code), file, line); + std::abort(); + } +} +#endif + +namespace marian { + +/** + * Multi-node graph group for asynchronous training over multiple + * machines each with one or multiple GPUs + */ +class MultiNodeGraphGroupSync : public GraphGroup { +public: + virtual void setScheduler(Ptr scheduler); + +protected: + //////////////////////////////////////////////////////////////////////////// + // General variables. + + /** Number of clients on nodes in MPI world (cluster). */ + std::vector numberClientsOfNodes_; //@TODO not used for now, but might be useful maybe? + + /** Whether graph group has been properly initialized with a first batch. */ + bool initialized_{false}; + + /** Memory allocators for tensors (GPUs). */ + std::vector> allocators_; + + //////////////////////////////////////////////////////////////////////////// + // Client variables. + + /** Thread pool to enable clients to run concurrently. */ + ThreadPool* clientThreadPool_; + + /** Graph builders for clients (which run forward and backward passes). */ + std::vector> clientBuilders_; + + /** Graphs of clients. */ + std::vector> clientGraphs_; + + /** Devices (GPUs) on this node. */ + std::vector devices_; + + /** Mutex to ensure clients are uniquely assigned to graphs and builders. */ + std::mutex mutexClientInit_; + + /** Mutex to avoid race conditions in scheduler. */ + std::mutex schedulerMutex_; + + /** + * Batch number counter used for evenly distributing mini-batches across + * nodes. + */ + size_t batchIter_ = 0; + + //////////////////////////////////////////////////////////////////////////// + // Communication variables. + + /** MPI rank of this node. */ + int mpi_my_rank_{0}; + + /** Number of nodes in MPI world (cluster). */ + int mpi_comm_world_size_{1}; + + /** + * Variables for optimizer delay and synchronous SGD + */ + size_t tau_{1}; + std::mutex sumGradientMutex_; + std::mutex updateParamsMutex_; + Tensor accGradientsSync; + float * accGradientsSync_cpu; + float * receiveBuffer_cpu; + bool synchronization_happened{false}; + + Ptr syncOptimizer_; + + std::vector optDelayMutex_; + std::vector delay_count; + std::vector totalBatchWords; + std::vector accGradients, accGradientBuffer; + + /** + * Allocate new tensor on given GPU and store allocator. + */ + Tensor newTensor(int size, Ptr backend); + + /** + * Setup training environment and launch server thread and (if enabled) client + * communication overlap threads.. + * Includes setting up MPI, node and shard sizes, clients, server shards and + * communication overlap stuff. + */ + virtual void init(Ptr batch); + + /** + * Setup MPI world size and rank of this node. + */ + void setupMPI(); + + /** + * Setup clients that will compute gradients and communicate them with the + * server shards. + * There is one client per GPU. + */ + void setupClients(Ptr batch); + + /** + * Initialize the graphs (models) of all clients on this node with the given + * batch. + */ + void runBatchThroughClientGraphs(Ptr batch); + + /** + * Initialize the CPU arrays, with pinned memory for faster CudaMemCpy operations. + */ + void initCPUArrays(); + + /** + * Sums the gradients from a node, taking care of locking + * @param gradient - the gradient + */ + + void sumGRAD(Tensor gradient); + + /** + * Does the MPI Communication, parameter update and copying back parameters. + * @TODO ALHAM. God function too godly? + */ + void sendReceiveUpdateSync(); + + void execute(Ptr batch); + + /** + * Load the GPU configuration of this node (i.e. which GPUs to use) and the + * number of GPUs on the other nodes. + */ + void loadDeviceConfig(std::vector deviceConfig) { + size_t index = 0, node = 0, nClientsSeen = 0; + numberClientsOfNodes_ = std::vector(mpi_comm_world_size_, 0); + while(index < deviceConfig.size()) { + if(numberClientsOfNodes_[node] == 0) { + numberClientsOfNodes_[node] = deviceConfig[index]; + nClientsSeen = 0; + } else if(nClientsSeen < numberClientsOfNodes_[node]) { + if(node == mpi_my_rank_) { + devices_.push_back(deviceConfig[index]); + } + nClientsSeen++; + } else { + node++; + index--; + } + index++; + } + } + +public: + /** + * (Constructor) Call super class and initialize client graphs and builders. + */ + MultiNodeGraphGroupSync(Ptr options) + : GraphGroup(options), + tau_{options_->get("optimizer-delay")}, + syncOptimizer_{Optimizer(options_)} { + // Set up devices for this node + setupMPI(); //Setup MPI before creating device vectors + std::vector devices; + for(auto& d : options_->getDevices()) + devices.push_back(d.no); + loadDeviceConfig(devices); + + // Create builders and graphs for clients. + for(size_t i = 0; i < devices_.size(); i++) { + clientGraphs_.push_back(New()); + clientGraphs_[i]->setDevice({devices_[i], DeviceType::gpu}); + clientGraphs_[i]->reserveWorkspaceMB(options_->get("workspace")); + clientBuilders_.push_back( + models::from_config(options_, models::usage::training)); + } + } + + /** + * (Destructor) Shut down server shard thread and (if comm. overlap enabled) + * communication overlap threads. + */ + virtual ~MultiNodeGraphGroupSync() { + //@TODO merge with finalize method + delete clientThreadPool_; + CUDA_CHECK(cudaFree(accGradientsSync_cpu)); + CUDA_CHECK(cudaFree(receiveBuffer_cpu)); + } + + /** + * Update any client model with given batch if batch is assigned to this node. + */ + void update(Ptr batch) { + ABORT_IF(finalized_, "Training has already finished."); + if(batchIter_ % mpi_comm_world_size_ + == mpi_my_rank_) { // Only take batch assigned to this node + execute(batch); + } + batchIter_++; + } + + /** + * Load models from disk if file exists and setting is not disabled + */ + void load() { + if(!options_->get("no-reload")) { + std::string name = options_->get("model"); + + if(boost::filesystem::exists(name)) { + if(scheduler_) + scheduler_->load(name); + size_t i = 0; + for(auto graph : clientGraphs_) + clientBuilders_[i++]->load(graph, name); + } else if(options_->has("pretrained-model")) { + std::string init = options_->get("pretrained-model"); + LOG(info, + "Initialize model weights with the pre-trained model {}", + init); + size_t i = 0; + for(auto graph : clientGraphs_) + clientBuilders_[i++]->load(graph, init, false); + } + } + } + + /** + * Save model of first client's graph to disk + */ + void save(bool final = false) { save(clientGraphs_[0], final); } + + /** + * Save model of given graph to disk. + */ + void save(Ptr graph, bool final = false) { + int idx = 0; + for(int i = 0; i < clientGraphs_.size(); ++i) { + if(graph == clientGraphs_[i]) { + idx = i; + break; + } + } + + if(options_->get("overwrite")) { + std::string name = options_->get("model"); + + clientBuilders_[idx]->save(clientGraphs_[idx], name, true); + if(scheduler_) + scheduler_->save(name); + } else { + std::string name = options_->get("model"); + + if(!final) { + std::string numberOfBatches + = scheduler_ ? std::to_string(scheduler_->numberOfBatches()) + : "unknown"; + std::string nameOverwrite = name; + nameOverwrite.replace( + name.size() - 4, 4, ".iter" + numberOfBatches + ".npz"); + clientBuilders_[idx]->save(clientGraphs_[idx], nameOverwrite); + } + + clientBuilders_[idx]->save(clientGraphs_[idx], name, true); + if(scheduler_) + scheduler_->save(name); + } + } + + /** + * Collect statistics from first client's graph. + */ + Ptr collectStats() { + return GraphGroup::collectStats(clientGraphs_[0], clientBuilders_[0]); + } + + virtual void finalize() { + finalized_ = true; + } +}; +}