Skip to content

Commit 1a32448

Browse files
authored
Keep double-buffer reader for static mode (#49068)
1 parent 18f921e commit 1a32448

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

python/paddle/fluid/reader.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from .layers.io import (
5252
monkey_patch_reader_methods,
5353
_copy_reader_var_,
54+
__create_unshared_decorated_reader__,
5455
)
5556
from .unique_name import UniqueNameGenerator
5657
from .framework import _get_paddle_place, _get_paddle_place_list
@@ -1351,11 +1352,6 @@ def __init__(
13511352
self._use_double_buffer = use_double_buffer
13521353
self._capacity = capacity
13531354
if not self._iterable:
1354-
# Because layers.io.double_buffer is not supported anymore and that iterable is False and use_double_buffer
1355-
# is True is not spported, here if itrable is False, use_double_buffer will be
1356-
# forcely set False to avoid unexpected error.
1357-
# TODO: keep use_double_buffer
1358-
self._use_double_buffer = False
13591355
self._init_non_iterable()
13601356

13611357
def _wait_thread_ends(self):
@@ -1410,6 +1406,7 @@ def _init_non_iterable(self):
14101406
'lod_tensor_blocking_queue'
14111407
)
14121408
reader_name = data_loader_unique_name_generator('create_py_reader')
1409+
double_buffer_name = data_loader_unique_name_generator('double_buffer')
14131410

14141411
var = global_scope().var(queue_name)
14151412
self._queue = core.init_lod_tensor_blocking_queue(
@@ -1455,6 +1452,18 @@ def _init_non_iterable(self):
14551452

14561453
reader = monkey_patch_reader_methods(main_prog_var)
14571454

1455+
if self._use_double_buffer:
1456+
double_buffer_reader = __create_unshared_decorated_reader__(
1457+
'create_double_buffer_reader',
1458+
reader,
1459+
{},
1460+
name=double_buffer_name,
1461+
)
1462+
# we return a double buffer reader. However, the reset method comes from
1463+
# py_reader.
1464+
double_buffer_reader.reset = reader.reset
1465+
reader = double_buffer_reader
1466+
14581467
self._reader = reader
14591468

14601469
default_main_program().current_block().append_op(

0 commit comments

Comments
 (0)