Skip to content

Commit ac1fcf2

Browse files
committed
Resolves coredump caused by tf.data.experimental.save with prefetch
Repeat and prefetch in combination cause the snapshot reader Initialize function to be invoked multiple times. However, there is nothing to prefetch on the very last iteration. This results in Prefetch issuing a CancelThreads call while the snapshot thread is trying to initialize. See https://github.com/tensorflow/tensorflow/blob/6446dda92eaadf11d22377e2354307642d739d73/tensorflow/core/kernels/data/prefetch_dataset_op.cc#L151 Currently the dataset reference counting is done asymmetrically. The reference increment happens at the end of initialization, where as the reference decrement happens in a destructor. When prefetch cancels the snapshot thread, it errors out of the initialization function. And stops calling the reference increment. However, the reference decrement happens regardless, as it is in the destructor which always is invoked during cleanup. This results in an attempt to decrement the null dataset pointer, and therefore a segmentation fault. This is different from all other dataset ops, where the dataset reference increment happens in the constructor and the decrement happens in the destructor, which are symmetric. The solution to this is to ensure that the dataset reference is always initialized to nullptr, and to check for null when decrementing the dataset reference.
1 parent 15d5b93 commit ac1fcf2

File tree

2 files changed

+21
-2
lines changed
  • tensorflow

2 files changed

+21
-2
lines changed

tensorflow/core/kernels/data/experimental/io_ops.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,11 @@ class LoadDatasetOp::Dataset : public DatasetBase {
253253
explicit Iterator(const Params& params)
254254
: DatasetIterator<Dataset>(params) {}
255255

256-
~Iterator() override { input_->Unref(); }
256+
~Iterator() override {
257+
if (input_) {
258+
input_->Unref();
259+
}
260+
}
257261

258262
Status Initialize(IteratorContext* ctx) override {
259263
mutex_lock l(mu_);
@@ -331,7 +335,7 @@ class LoadDatasetOp::Dataset : public DatasetBase {
331335
}
332336

333337
mutex mu_;
334-
DatasetBase* input_ TF_GUARDED_BY(mu_);
338+
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;
335339
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
336340
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
337341
};

tensorflow/python/data/experimental/kernel_tests/io_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import numpy as np
2021
import os
2122
import shutil
2223

@@ -111,6 +112,20 @@ def testOptionalElementSpec(self):
111112
dataset_loaded = io.load(self._test_dir)
112113
self.assertDatasetsEqual(dataset, dataset_loaded)
113114

115+
@combinations.generate(test_base.eager_only_combinations())
116+
def testRepeatAndPrefetch(self):
117+
"""This test reproduces github.com/tensorflow/tensorflow/issues/49165"""
118+
dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
119+
io.save(dataset1, self._test_dir)
120+
dataset = io.load(self._test_dir)
121+
dataset = dataset.shuffle(buffer_size=16)
122+
dataset = dataset.batch(16)
123+
dataset = dataset.repeat()
124+
dataset = dataset.prefetch(1)
125+
next_element = self.getNext(dataset)
126+
for _ in range(30):
127+
self.evaluate(next_element())
128+
114129

115130
if __name__ == "__main__":
116131
test.main()

0 commit comments

Comments
 (0)