Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token seq id #5964

Merged
merged 33 commits into from
Aug 25, 2021
Merged
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7b92fd4
bugfix: data_transport_token_per_placement
lixinqi Aug 13, 2021
9bcda09
refactor TransportToken::NewDataTransportToken
lixinqi Aug 15, 2021
472bae4
Merge branch 'master' into bugfix_data_transport_token_per_placement
lixinqi Aug 15, 2021
cd32787
NewDataTransportToken(parallel_desc)
lixinqi Aug 16, 2021
e7b86a2
Merge branch 'master' into bugfix_data_transport_token_per_placement
lixinqi Aug 16, 2021
3bb7877
fix bugs
lixinqi Aug 17, 2021
757d023
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into bugfix_d…
lixinqi Aug 17, 2021
5b52ec5
Merge branch 'master' into bugfix_data_transport_token_per_placement
lixinqi Aug 17, 2021
765b62a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into bugfix_d…
lixinqi Aug 17, 2021
500e384
rename UniqueConsistentIdStorage to ConsistentIdStorage
lixinqi Aug 17, 2021
2749d2c
refactor TransportToken
lixinqi Aug 19, 2021
65fb518
add TransportToken::src_rank and TransportToken::dst_rank
lixinqi Aug 19, 2021
970f379
merge data_token_per_thread
lixinqi Aug 19, 2021
39a447e
harded coded thread_consistent_id
lixinqi Aug 19, 2021
650fb0c
StreamType::SupportingTransportInstructions()
lixinqi Aug 20, 2021
3776ab5
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into token_se…
lixinqi Aug 20, 2021
d73fb5e
merge master
lixinqi Aug 20, 2021
125c150
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into token_se…
lixinqi Aug 21, 2021
04405e8
not thread_consistent_id in single-client mode
lixinqi Aug 21, 2021
42db5fb
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into token_se…
lixinqi Aug 21, 2021
efa6ac8
explicit initialize all fields of TransportToken
lixinqi Aug 21, 2021
47eb7c3
Merge branch 'master' into token_seq_id
lixinqi Aug 24, 2021
50a1e09
remove unused field TransportToken::rank_group_level_
lixinqi Aug 25, 2021
4600365
more bits for TransportToken::seq_id_
lixinqi Aug 25, 2021
881b09f
merge master
lixinqi Aug 25, 2021
57e337c
Merge branch 'master' into token_seq_id
lixinqi Aug 25, 2021
d39c8c2
Merge branch 'master' into token_seq_id
oneflow-ci-bot Aug 25, 2021
fd9961b
fix oneflow_testexe error
clackhan Aug 25, 2021
40313b3
Merge branch 'master' into token_seq_id
clackhan Aug 25, 2021
6771337
Merge branch 'master' into token_seq_id
oneflow-ci-bot Aug 25, 2021
d41f37a
Merge branch 'master' into token_seq_id
oneflow-ci-bot Aug 25, 2021
c96859f
Merge branch 'master' into token_seq_id
oneflow-ci-bot Aug 25, 2021
2197112
Merge branch 'master' into token_seq_id
oneflow-ci-bot Aug 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
NewDataTransportToken(parallel_desc)
lixinqi committed Aug 16, 2021
commit cd32787c19a452f6019b74762525a13a62acd305
2 changes: 1 addition & 1 deletion oneflow/core/ccl/ccl.cpp
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ Maybe<void> Broadcast<DeviceType::kCPU>(const void* in, void* out, size_t elem_c
CHECK_EQ_OR_RETURN(parallel_desc->device_type(), DeviceType::kCPU);
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
TransportToken transport_token = TransportToken::NewDataTransportToken();
TransportToken transport_token = JUST(TransportToken::NewDataTransportToken(parallel_desc));
return CpuBroadcast(in, out, buffer_size, root, parallel_desc, transport_token);
}

26 changes: 18 additions & 8 deletions oneflow/core/framework/transport_token.cpp
Original file line number Diff line number Diff line change
@@ -31,13 +31,20 @@ class DataTransportTokenView final {
return reinterpret_cast<DataTransportTokenView*>(transport_token);
}

Maybe<void> set_thread_consistent_id(int32_t val) {
CHECK_LT_OR_RETURN(val, (1 << kDataTransportTokenThreadConsistentUIdBit));
thread_consistent_id_ = val;
return Maybe<void>::Ok();
}

void set_data_seq_id(int64_t seq_id) { data_seq_id_ = seq_id; }

private:
uint16_t src_rank_;
uint16_t dst_rank_;
uint32_t type_ : 2; // TransportTokenType
uint32_t data_seq_id_ : 30;
uint8_t type_ : kTransportTokenTypeBit; // TransportTokenType
uint8_t thread_consistent_id_;
uint16_t data_seq_id_;
};
static_assert(sizeof(DataTransportTokenView) == sizeof(uint64_t), "");

@@ -58,7 +65,7 @@ class MetaTransportTokenView final {

Maybe<void> set_thread_consistent_unique_id(int8_t val) {
CHECK_GE_OR_RETURN(val, 0);
CHECK_LT_OR_RETURN(val, 1 << kTransportTokenThreadConsistentUIdBit);
CHECK_LT_OR_RETURN(val, 1 << kCtrlTransportTokenThreadConsistentUIdBit);
thread_consistent_unique_id_ = val;
return Maybe<void>::Ok();
}
@@ -80,7 +87,7 @@ class MetaTransportTokenView final {
uint16_t src_rank_;
uint16_t dst_rank_;
uint8_t type_ : 2; // TransportTokenType
uint8_t thread_consistent_unique_id_ : kTransportTokenThreadConsistentUIdBit;
uint8_t thread_consistent_unique_id_ : kCtrlTransportTokenThreadConsistentUIdBit;
uint8_t rank_group_level_ : kTransportTokenRankGroupLevelBit;
uint8_t high_meta_seq_id_;
uint16_t low_meta_seq_id_;
@@ -104,7 +111,7 @@ class CtrlTransportTokenView final {

Maybe<void> set_thread_consistent_unique_id(int8_t val) {
CHECK_GE_OR_RETURN(val, 0);
CHECK_LT_OR_RETURN(val, 1 << kTransportTokenThreadConsistentUIdBit);
CHECK_LT_OR_RETURN(val, 1 << kCtrlTransportTokenThreadConsistentUIdBit);
thread_consistent_unique_id_ = val;
return Maybe<void>::Ok();
}
@@ -128,7 +135,7 @@ class CtrlTransportTokenView final {
uint16_t src_rank_;
uint16_t dst_rank_;
uint8_t type_ : 2; // TransportTokenType
uint8_t thread_consistent_unique_id_ : kTransportTokenThreadConsistentUIdBit;
uint8_t thread_consistent_unique_id_ : kCtrlTransportTokenThreadConsistentUIdBit;
uint8_t rank_group_level_ : kTransportTokenRankGroupLevelBit;
uint8_t cmd_;
uint16_t ctrl_seq_id_;
@@ -143,11 +150,14 @@ TransportToken::TransportToken(TransportTokenType type) {
type_ = type;
}

/*static*/ TransportToken TransportToken::NewDataTransportToken(Symbol<ParallelDesc> parallel_desc) {
/*static*/ Maybe<TransportToken> TransportToken::NewDataTransportToken(Symbol<ParallelDesc> parallel_desc) {
int32_t thread_consistent_unique_id = JUST(GetThisThreadConsistentUniqueId());
static thread_local HashMap<Symbol<ParallelDesc>, int64_t> parallel_desc2seq_id;
auto* seq_id = parallel_desc2seq_id[parallel_desc];
TransportToken transport_token(kDataTransportTokenType);
CHECK_JUST(DataTransportTokenView::MutCast(&transport_token))->set_data_seq_id(++*seq_id);
auto* data_token_view = JUST(DataTransportTokenView::MutCast(&transport_token));
JUST(data_token_view->set_thread_consistent_id(thread_consistent_unique_id));
data_token_view->set_data_seq_id(++*seq_id);
return transport_token;
}

8 changes: 5 additions & 3 deletions oneflow/core/framework/transport_token.h
Original file line number Diff line number Diff line change
@@ -25,9 +25,11 @@ namespace oneflow {
class ParallelDesc;

const static int kTransportTokenTypeBit = 2;
const static int kTransportTokenThreadConsistentUIdBit = 3;
const static int kCtrlTransportTokenThreadConsistentUIdBit = 3;
const static int kTransportTokenRankGroupLevelBit = 3;

const static int kDataTransportTokenThreadConsistentUIdBit = 8;

enum TransportTokenType {
// Begin
kInvalidTransportTokenType = 0,
@@ -67,14 +69,14 @@ class TransportToken final {
TransportToken(TransportToken&) = default;
~TransportToken() = default;

static TransportToken NewDataTransportToken(Symbol<ParallelDesc> parallel_desc);
static Maybe<TransportToken> NewDataTransportToken(Symbol<ParallelDesc> parallel_desc);
static Maybe<TransportToken> NewMetaTransportToken();
static Maybe<TransportToken> AcquireCtrlTransportToken(RankGroupCtrlCmd cmd);
Maybe<void> TryAcquireCtrlTransportTokenLock() const;
Maybe<void> TryReleaseCtrlTransportTokenLock() const;

static constexpr size_t MaxNumberOfThreadConsistentUId() {
return (1 << kTransportTokenThreadConsistentUIdBit);
return (1 << kCtrlTransportTokenThreadConsistentUIdBit);
}

// Getters
56 changes: 53 additions & 3 deletions oneflow/core/vm/oneflow_vm.cpp
Original file line number Diff line number Diff line change
@@ -13,24 +13,73 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <typeinfo>
#include "oneflow/core/vm/oneflow_vm.h"
#include "oneflow/core/vm/instruction.msg.h"
#include "oneflow/core/vm/no_arg_cb_phy_instr_operand.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/common/blocking_counter.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/thread/consistent_unqiue_id.h"

namespace oneflow {

namespace {

static constexpr int kSchedulerThreadConsistentUniqueId = (1 << kCtrlTransportTokenThreadConsistentUIdBit);

void GetSchedulerThreadInitializer(std::function<void()>* Initializer) {
*Initializer = [](){
SetThisThreadConsistentUniqueId(kSchedulerThreadConsistentUniqueId, "scheduler");
};
}

std::type_index GetStreamTypeIndex(const vm::ThreadCtx* thread_ctx) {
const auto& stream_rt_desc = thread_ctx->stream_rt_desc();
const auto& stream_type_id = stream_rt_desc.stream_type_id();
const auto& stream_type = stream_type_id.stream_type();
return typeid(stream_type);
}

// Threads with the same stream_type share a thread_consistent_id.
// e.g.
// Given there are 8 gpu thread in a single process.
// thread #0 is active in process #0, while others are not.
// thread #1 is active in process #1, while others are not.
// ...
// thread #7 is active in process #7, while others are not.
// to make them communicate with each other, we can allocate thread_consistent_id 1 to all those gpu threads in all processes.
void GetWorkerThreadInitializer(ObjectMsgPtr<vm::VirtualMachine> vm,
std::function<void(vm::ThreadCtx*)>* Initializer) {
int64_t thread_consistent_id = kSchedulerThreadConsistentUniqueId + 1;
HashMap<std::type_index, int64_t> stream_type_index2consistent_id;
OBJECT_MSG_LIST_UNSAFE_FOR_EACH_PTR(vm->mut_thread_ctx_list(), thread_ctx) {
const auto& stream_type_index = GetStreamTypeIndex(thread_ctx);
if (stream_type_index2consistent_id.count(stream_type_index) > 0) { continue; }
stream_type_index2consistent_id[stream_type_index] = thread_consistent_id++;
}
*Initializer = [stream_type_index2consistent_id](vm::ThreadCtx* thread_ctx){
const auto& stream_type_index = GetStreamTypeIndex(thread_ctx);
int64_t thread_consistent_id = stream_type_index2consistent_id.at(stream_type_index);
SetThisThreadConsistentUniqueId(thread_consistent_id, stream_type_index.name());
};
}

}

OneflowVM::OneflowVM(const Resource& resource, int64_t this_machine_id)
: vm_(ObjectMsgPtr<vm::VirtualMachine>::New(vm::MakeVmDesc(resource, this_machine_id).Get())) {
std::function<void(vm::ThreadCtx*)> WorkerInitializer;
GetWorkerThreadInitializer(vm_, &WorkerInitializer);
OBJECT_MSG_LIST_UNSAFE_FOR_EACH_PTR(vm_->mut_thread_ctx_list(), thread_ctx) {
auto thread = std::make_unique<std::thread>(&vm::ThreadCtx::LoopRun, thread_ctx);
auto thread = std::make_unique<std::thread>(&vm::ThreadCtx::LoopRun, thread_ctx, WorkerInitializer);
worker_threads_.push_back(std::move(thread));
}
exiting_ = false;
schedule_thread_ = std::thread(&OneflowVM::Loop, this);
std::function<void()> SchedulerInitializer;
GetSchedulerThreadInitializer(&SchedulerInitializer);
schedule_thread_ = std::thread(&OneflowVM::Loop, this, SchedulerInitializer);
}

namespace {
@@ -61,7 +110,8 @@ OneflowVM::~OneflowVM() {
CHECK(!vm_);
}

void OneflowVM::Loop() {
void OneflowVM::Loop(const std::function<void()>& Initializer) {
Initializer();
auto* vm = mut_vm();
while (!exiting_) { vm->Schedule(); }
while (!mut_vm()->Empty()) { vm->Schedule(); }
2 changes: 1 addition & 1 deletion oneflow/core/vm/oneflow_vm.h
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ class OneflowVM final {
const vm::VirtualMachine& vm() const { return *vm_; }

private:
void Loop();
void Loop(const std::function<void()>& Initializer);

ObjectMsgPtr<vm::VirtualMachine> vm_;
// for asynchronized execution
3 changes: 2 additions & 1 deletion oneflow/core/vm/thread_ctx.cpp
Original file line number Diff line number Diff line change
@@ -19,7 +19,8 @@ limitations under the License.
namespace oneflow {
namespace vm {

void ThreadCtx::LoopRun() {
void ThreadCtx::LoopRun(const std::function<void(ThreadCtx*)>& Initializer) {
Initializer(this);
while (ReceiveAndRun() == kObjectMsgConditionListStatusSuccess) {}
}

3 changes: 2 additions & 1 deletion oneflow/core/vm/thread_ctx.msg.h
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_CORE_VM_THREAD_MSG_H_
#define ONEFLOW_CORE_VM_THREAD_MSG_H_

#include <functional>
#include "oneflow/core/vm/stream.msg.h"
#include "oneflow/core/vm/stream_runtime_desc.msg.h"

@@ -28,7 +29,7 @@ OBJECT_MSG_BEGIN(ThreadCtx);
OF_PUBLIC void __Init__(const StreamRtDesc& stream_rt_desc) {
set_stream_rt_desc(&stream_rt_desc);
}
OF_PUBLIC void LoopRun();
OF_PUBLIC void LoopRun(const std::function<void(ThreadCtx*)>& Initializer);
// fields
OBJECT_MSG_DEFINE_PTR(const StreamRtDesc, stream_rt_desc);