Skip to content

Commit 72e7f49

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into auto_code_gen_pr_2
2 parents 8992801 + 72241a6 commit 72e7f49

File tree

279 files changed

+2749
-668
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

279 files changed

+2749
-668
lines changed

cmake/external/xxhash.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@ ENDIF()
3131
3232
if (WIN32)
3333
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/xxhash.lib")
34+
set(XXHASH_CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4710 /wd4711")
35+
set(XXHASH_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4710 /wd4711")
3436
else()
3537
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a")
38+
set(XXHASH_CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
39+
set(XXHASH_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
3640
endif ()
3741
3842
if(WIN32)
@@ -55,6 +59,12 @@ if(WIN32)
5559
-DCMAKE_GENERATOR=${CMAKE_GENERATOR}
5660
-DCMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM}
5761
-DBUILD_SHARED_LIBS=OFF
62+
-DCMAKE_CXX_FLAGS=${XXHASH_CMAKE_CXX_FLAGS}
63+
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
64+
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
65+
-DCMAKE_C_FLAGS=${XXHASH_CMAKE_C_FLAGS}
66+
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
67+
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
5868
${OPTIONAL_CACHE_ARGS}
5969
TEST_COMMAND ""
6070
BUILD_BYPRODUCTS ${XXHASH_LIBRARIES}

paddle/fluid/distributed/fleet.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,6 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {
570570
ret.wait();
571571
if (ret.get() != 0) {
572572
LOG(ERROR) << "load model from path:" << path << " failed";
573-
sleep(sleep_seconds_before_fail_exit_);
574-
exit(-1);
575573
}
576574
}
577575

@@ -596,8 +594,6 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
596594
int32_t feasign_cnt = ret.get();
597595
if (feasign_cnt == -1) {
598596
LOG(ERROR) << "save model failed";
599-
sleep(sleep_seconds_before_fail_exit_);
600-
exit(-1);
601597
}
602598
}
603599

paddle/fluid/distributed/fleet_executor/carrier.cc

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,16 @@ bool Carrier::EnqueueInterceptorMessage(
4848
// handle control message
4949
return true;
5050
} else {
51-
if (creating_interceptors_) {
52-
// Cannot handle the message to interceptor since interceptors
53-
// are still under creating. Will enqueue into a tmp stack.
54-
VLOG(3) << "Receiving message while creating interceptors.";
55-
message_tmp_.emplace_back(interceptor_message);
56-
return true;
51+
{
52+
std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_);
53+
if (creating_interceptors_) {
54+
std::unique_lock<std::mutex> lock_message(tmp_message_mutex_);
55+
// Cannot handle the message to interceptor since interceptors
56+
// are still under creating. Will enqueue into a tmp stack.
57+
VLOG(3) << "Receiving message while creating interceptors.";
58+
message_tmp_.emplace_back(interceptor_message);
59+
return true;
60+
}
5761
}
5862
int64_t dst_id = interceptor_message.dst_id();
5963
Interceptor* dst_interceptor = GetInterceptor(dst_id);
@@ -112,16 +116,24 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
112116

113117
void Carrier::SetCreatingFlag(bool flag) {
114118
// set the creating flag
119+
creating_flag_mutex_.lock();
115120
VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_
116121
<< " to " << flag << ".";
117122
creating_interceptors_ = flag;
123+
creating_flag_mutex_.unlock();
118124
if (!flag) {
119125
// finish create interceptors outside, handle tmp messsages
120126
HandleTmpMessages();
121127
}
122128
}
123129

124130
void Carrier::HandleTmpMessages() {
131+
// NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this
132+
// `HandleTmpMessages` method, the creating_interceptors_ flag
133+
// must be false, therefore, there won't have conflict with the
134+
// lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage`
135+
// on the same thread.
136+
std::unique_lock<std::mutex> lock(tmp_message_mutex_);
125137
VLOG(3) << "Carrier has received " << message_tmp_.size()
126138
<< " messages during creating interceptors.";
127139
for (const auto& msg : message_tmp_) {
@@ -147,7 +159,9 @@ void Carrier::CreateInterceptors() {
147159
}
148160
// The carrier will be always waiting for outside initializer
149161
// since there is no interceptor has been created during auto init
162+
creating_flag_mutex_.lock();
150163
creating_interceptors_ = false;
164+
creating_flag_mutex_.unlock();
151165
HandleTmpMessages();
152166
}
153167
}

paddle/fluid/distributed/fleet_executor/carrier.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <memory>
18+
#include <mutex>
1819
#include <string>
1920
#include <unordered_map>
2021
#include <vector>
@@ -78,7 +79,9 @@ class Carrier final {
7879
interceptor_idx_to_interceptor_;
7980

8081
std::vector<InterceptorMessage> message_tmp_{};
82+
std::mutex tmp_message_mutex_;
8183
bool creating_interceptors_{true};
84+
std::mutex creating_flag_mutex_;
8285
bool is_init_{false};
8386
};
8487

paddle/fluid/distributed/fleet_executor/message_bus.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,11 @@ void MessageBus::Init(
5151
#endif
5252

5353
ListenPort();
54-
55-
std::call_once(once_flag_, []() {
56-
std::atexit([]() { MessageBus::Instance().Release(); });
57-
});
5854
}
5955

6056
bool MessageBus::IsInit() const { return is_init_; }
6157

62-
void MessageBus::Release() {
58+
MessageBus::~MessageBus() {
6359
VLOG(3) << "Message bus releases resource.";
6460
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
6561
!defined(PADDLE_WITH_ASCEND_CL)

paddle/fluid/distributed/fleet_executor/message_bus.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ class MessageBus final {
5050

5151
bool IsInit() const;
5252

53-
void Release();
54-
5553
// called by Interceptor, send InterceptorMessage to dst
5654
bool Send(const InterceptorMessage& interceptor_message);
5755

56+
~MessageBus();
57+
5858
DISABLE_COPY_AND_ASSIGN(MessageBus);
5959

6060
private:

paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,7 @@ set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLA
22
set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
33
cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS})
44
cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})
5+
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
6+
set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
7+
cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS})
8+
endif()
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <sys/socket.h>
16+
#include <time.h>
17+
#include <iostream>
18+
#include <unordered_map>
19+
20+
#include "gtest/gtest.h"
21+
22+
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
23+
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
24+
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
25+
26+
namespace paddle {
27+
namespace distributed {
28+
29+
class PingPongInterceptor : public Interceptor {
30+
public:
31+
PingPongInterceptor(int64_t interceptor_id, TaskNode* node)
32+
: Interceptor(interceptor_id, node) {
33+
RegisterMsgHandle([this](const InterceptorMessage& msg) { PingPong(msg); });
34+
}
35+
36+
void PingPong(const InterceptorMessage& msg) {
37+
std::cout << GetInterceptorId() << " recv msg, count=" << count_
38+
<< std::endl;
39+
++count_;
40+
if (count_ == 20 && GetInterceptorId() == 0) {
41+
InterceptorMessage stop;
42+
stop.set_message_type(STOP);
43+
Send(0, stop);
44+
Send(1, stop);
45+
return;
46+
}
47+
48+
InterceptorMessage resp;
49+
int64_t dst = GetInterceptorId() == 0 ? 1 : 0;
50+
Send(dst, resp);
51+
}
52+
53+
private:
54+
int count_{0};
55+
};
56+
57+
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
58+
59+
TEST(InterceptorTest, PingPong) {
60+
std::cout << "Ping pong test through brpc" << std::endl;
61+
unsigned int seed = time(0);
62+
// random generated two ports in from 6000 to 9000
63+
int port0 = 6000 + rand_r(&seed) % 3000;
64+
int port1 = port0 + 1;
65+
66+
// using socket to check the availability of the port
67+
int server_fd = -1;
68+
server_fd = socket(AF_INET, SOCK_STREAM, 0);
69+
int opt = 1;
70+
linger ling;
71+
ling.l_onoff = 1;
72+
ling.l_linger = 0;
73+
setsockopt(server_fd, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling));
74+
setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
75+
struct sockaddr_in address;
76+
address.sin_family = AF_INET;
77+
address.sin_addr.s_addr = INADDR_ANY;
78+
address.sin_port = htons(port0);
79+
while (bind(server_fd, (struct sockaddr*)&address, sizeof(address)) == -1) {
80+
port0++;
81+
address.sin_port = htons(port0);
82+
}
83+
close(server_fd);
84+
85+
// use another socket to check another port
86+
server_fd = socket(AF_INET, SOCK_STREAM, 0);
87+
setsockopt(server_fd, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling));
88+
setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
89+
port1 = port0 + 1;
90+
address.sin_port = htons(port1);
91+
while (bind(server_fd, (struct sockaddr*)&address, sizeof(address)) == -1) {
92+
port1++;
93+
address.sin_port = htons(port1);
94+
}
95+
close(server_fd);
96+
97+
std::string ip0 = "127.0.0.1:" + std::to_string(port0);
98+
std::string ip1 = "127.0.0.1:" + std::to_string(port1);
99+
std::cout << "ip0: " << ip0 << std::endl;
100+
std::cout << "ip1: " << ip1 << std::endl;
101+
102+
int pid = fork();
103+
if (pid == 0) {
104+
MessageBus& msg_bus = MessageBus::Instance();
105+
msg_bus.Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip0);
106+
107+
Carrier& carrier = Carrier::Instance();
108+
109+
Interceptor* a = carrier.SetInterceptor(
110+
0, InterceptorFactory::Create("PingPong", 0, nullptr));
111+
carrier.SetCreatingFlag(false);
112+
113+
InterceptorMessage msg;
114+
a->Send(1, msg);
115+
} else {
116+
MessageBus& msg_bus = MessageBus::Instance();
117+
msg_bus.Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip1);
118+
119+
Carrier& carrier = Carrier::Instance();
120+
121+
carrier.SetInterceptor(1,
122+
InterceptorFactory::Create("PingPong", 1, nullptr));
123+
carrier.SetCreatingFlag(false);
124+
}
125+
}
126+
127+
} // namespace distributed
128+
} // namespace paddle
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
proto_library(index_dataset_proto SRCS index_dataset.proto)
22
cc_library(index_wrapper SRCS index_wrapper.cc DEPS index_dataset_proto fs)
3-
cc_library(index_sampler SRCS index_sampler.cc DEPS index_wrapper)
4-
3+
if(WITH_MKLDNN)
4+
cc_library(index_sampler SRCS index_sampler.cc DEPS xxhash index_wrapper mkldnn)
5+
else()
6+
cc_library(index_sampler SRCS index_sampler.cc DEPS xxhash index_wrapper)
7+
endif()
58
if(WITH_PYTHON)
69
py_proto_compile(index_dataset_py_proto SRCS index_dataset.proto)
710
endif()

paddle/fluid/distributed/index_dataset/index_dataset.proto

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ message IndexNode {
1919
required uint64 id = 1;
2020
required bool is_leaf = 2;
2121
required float probability = 3;
22+
optional string item_name = 4;
2223
}
2324

2425
message TreeMeta {
@@ -29,4 +30,4 @@ message TreeMeta {
2930
message KVItem {
3031
required bytes key = 1;
3132
required bytes value = 2;
32-
}
33+
}

0 commit comments

Comments
 (0)