Skip to content

Commit ec53a55

Browse files
Pre stable 1 (PaddlePaddle#53)
* pull sparse-ptr asyn * fix curand bug * fix mem oom Co-authored-by: liaoxiaochao <liaoxiaochao@baidu.com>
1 parent 8d486a6 commit ec53a55

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

paddle/fluid/framework/data_feed.h

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ class SlotObjPool {
265265
}
266266
disable_pool_ = false;
267267
count_ = 0;
268+
pending_ins_.store(0);
268269
}
269270
~SlotObjPool() {
270271
ins_chan_->Close();
@@ -280,22 +281,34 @@ class SlotObjPool {
280281
}
281282
void get(SlotRecord* output, int n) {
282283
int size = 0;
283-
mutex_.lock();
284-
int left = static_cast<int>(alloc_.capacity());
285-
if (left > 0) {
286-
size = (left >= n) ? n : left;
287-
for (int i = 0; i < size; ++i) {
288-
output[i] = alloc_.acquire();
284+
do {
285+
mutex_.lock();
286+
int left = static_cast<int>(alloc_.capacity());
287+
if (left > 0) {
288+
int tmp_size = (left >= n - size) ? n - size : left;
289+
290+
for (int i = size; i < size + tmp_size; ++i) {
291+
output[i] = alloc_.acquire();
292+
}
293+
size += tmp_size;
289294
}
290-
}
291-
mutex_.unlock();
295+
mutex_.unlock();
296+
if (pending_ins_.load() >= 200000) {
297+
usleep(1000);
298+
continue;
299+
} else {
300+
break;
301+
}
302+
} while (true);
303+
304+
292305
count_ += n;
293-
if (size == n) {
294-
return;
295-
}
296306
for (int i = size; i < n; ++i) {
297307
output[i] = make_slotrecord();
298308
}
309+
for (int i = 0; i < n; ++i) {
310+
output[i]->clear(true);
311+
}
299312
}
300313
void put(std::vector<SlotRecord>* input) {
301314
size_t size = input->size();
@@ -306,6 +319,7 @@ class SlotObjPool {
306319
input->clear();
307320
}
308321
void put(SlotRecord* input, size_t size) {
322+
pending_ins_.fetch_add(size);
309323
CHECK(ins_chan_->WriteMove(size, input) == size);
310324
}
311325
void run(void) {
@@ -314,10 +328,12 @@ class SlotObjPool {
314328
if (input.empty()) {
315329
continue;
316330
}
331+
pending_ins_.fetch_sub(input.size());
317332
// over max capacity
318333
size_t n = input.size();
319334
count_ -= n;
320-
if (disable_pool_ || n + capacity() > max_capacity_) {
335+
// if (disable_pool_ || n + capacity() > max_capacity_) {
336+
if (disable_pool_) {
321337
for (auto& t : input) {
322338
free_slotrecord(t);
323339
}
@@ -365,6 +381,7 @@ class SlotObjPool {
365381
SlotObjAllocator<SlotRecordObject> alloc_;
366382
bool disable_pool_;
367383
std::atomic<long> count_; // NOLINT
384+
std::atomic<uint64_t> pending_ins_;
368385
};
369386

370387
inline SlotObjPool& SlotRecordPool() {

0 commit comments

Comments
 (0)