Skip to content

Commit

Permalink
fix compilation windows
Browse files Browse the repository at this point in the history
  • Loading branch information
minhthuc2502 committed Feb 28, 2024
1 parent f46c6b3 commit 5e9dc28
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
22 changes: 7 additions & 15 deletions include/ctranslate2/devices.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <string>
#include <vector>
#ifdef CT2_WITH_TENSOR_PARALLEL
# include <unistd.h>
# include <nccl.h>
#endif

Expand Down Expand Up @@ -50,22 +49,18 @@ namespace ctranslate2 {
int _new_index;
};

extern int my_rank;
extern int local_rank;
extern int n_ranks;

class ScopedMPISetter {
public:
ScopedMPISetter();
~ScopedMPISetter();

static int getNRanks() {
return _n_ranks;
}

static int getCurRank() {
return _my_rank;
}

static int getLocalRank() {
return _local_rank;
}
static int getNRanks();
static int getCurRank();
static int getLocalRank();

#ifdef CT2_WITH_TENSOR_PARALLEL
static ncclComm_t getNcclComm();
Expand All @@ -79,8 +74,5 @@ namespace ctranslate2 {
static void getHostName(char *hostname, int maxlen);
static std::vector<ncclComm_t*> _nccl_comms;
#endif
static int _my_rank;
static int _local_rank;
static int _n_ranks;
};
}
35 changes: 23 additions & 12 deletions src/devices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#endif

#include "device_dispatch.h"
#include <iostream>

namespace ctranslate2 {

Expand Down Expand Up @@ -123,29 +122,29 @@ namespace ctranslate2 {
#ifdef CT2_WITH_TENSOR_PARALLEL
std::vector<ncclComm_t*> ScopedMPISetter::_nccl_comms;
#endif
int ScopedMPISetter::_my_rank = 0;
int ScopedMPISetter::_local_rank = 0 ;
int ScopedMPISetter::_n_ranks = 1;
int my_rank = 0;
int local_rank = 0;
int n_ranks = 1;

ScopedMPISetter::ScopedMPISetter() {
#ifdef CT2_WITH_TENSOR_PARALLEL
// initializing MPI
MPI_CHECK(MPI_Init(nullptr, nullptr));
MPI_CHECK(MPI_Comm_rank(STUB_MPI_COMM_WORLD, &_my_rank));
MPI_CHECK(MPI_Comm_size(STUB_MPI_COMM_WORLD, &_n_ranks));
MPI_CHECK(MPI_Comm_rank(STUB_MPI_COMM_WORLD, &my_rank));
MPI_CHECK(MPI_Comm_size(STUB_MPI_COMM_WORLD, &n_ranks));

uint64_t hostHashs[_n_ranks];
uint64_t hostHashs[n_ranks];
char hostname[1024];
getHostName(hostname, 1024);
hostHashs[_my_rank] = getHostHash(hostname);
hostHashs[my_rank] = getHostHash(hostname);
MPI_CHECK(MPI_Allgather(MPI_IN_PLACE, 0, STUB_MPI_DATATYPE_NULL,
hostHashs, sizeof(uint64_t), STUB_MPI_BYTE, STUB_MPI_COMM_WORLD));
for (int p = 0; p < _n_ranks; p++) {
if (p == _my_rank) {
for (int p = 0; p < n_ranks; p++) {
if (p == my_rank) {
break;
}
if (hostHashs[p] == hostHashs[_my_rank]) {
_local_rank++;
if (hostHashs[p] == hostHashs[my_rank]) {
local_rank++;
}
}
atexit(finalize);
Expand Down Expand Up @@ -204,4 +203,16 @@ namespace ctranslate2 {
MPI_CHECK(MPI_Finalize());
#endif
}

int ScopedMPISetter::getNRanks() {
return n_ranks;
}

int ScopedMPISetter::getCurRank() {
return my_rank;
}

int ScopedMPISetter::getLocalRank() {
return local_rank;
}
}

0 comments on commit 5e9dc28

Please sign in to comment.