|
51 | 51 | from .layers.io import ( |
52 | 52 | monkey_patch_reader_methods, |
53 | 53 | _copy_reader_var_, |
| 54 | + __create_unshared_decorated_reader__, |
54 | 55 | ) |
55 | 56 | from .unique_name import UniqueNameGenerator |
56 | 57 | from .framework import _get_paddle_place, _get_paddle_place_list |
@@ -1351,11 +1352,6 @@ def __init__( |
1351 | 1352 | self._use_double_buffer = use_double_buffer |
1352 | 1353 | self._capacity = capacity |
1353 | 1354 | if not self._iterable: |
1354 | | - # Because layers.io.double_buffer is not supported anymore and that iterable is False and use_double_buffer |
1355 | | - # is True is not spported, here if itrable is False, use_double_buffer will be |
1356 | | - # forcely set False to avoid unexpected error. |
1357 | | - # TODO: keep use_double_buffer |
1358 | | - self._use_double_buffer = False |
1359 | 1355 | self._init_non_iterable() |
1360 | 1356 |
|
1361 | 1357 | def _wait_thread_ends(self): |
@@ -1410,6 +1406,7 @@ def _init_non_iterable(self): |
1410 | 1406 | 'lod_tensor_blocking_queue' |
1411 | 1407 | ) |
1412 | 1408 | reader_name = data_loader_unique_name_generator('create_py_reader') |
| 1409 | + double_buffer_name = data_loader_unique_name_generator('double_buffer') |
1413 | 1410 |
|
1414 | 1411 | var = global_scope().var(queue_name) |
1415 | 1412 | self._queue = core.init_lod_tensor_blocking_queue( |
@@ -1455,6 +1452,18 @@ def _init_non_iterable(self): |
1455 | 1452 |
|
1456 | 1453 | reader = monkey_patch_reader_methods(main_prog_var) |
1457 | 1454 |
|
| 1455 | + if self._use_double_buffer: |
| 1456 | + double_buffer_reader = __create_unshared_decorated_reader__( |
| 1457 | + 'create_double_buffer_reader', |
| 1458 | + reader, |
| 1459 | + {}, |
| 1460 | + name=double_buffer_name, |
| 1461 | + ) |
| 1462 | + # we return a double buffer reader. However, the reset method comes from |
| 1463 | + # py_reader. |
| 1464 | + double_buffer_reader.reset = reader.reset |
| 1465 | + reader = double_buffer_reader |
| 1466 | + |
1458 | 1467 | self._reader = reader |
1459 | 1468 |
|
1460 | 1469 | default_main_program().current_block().append_op( |
|
0 commit comments