Skip to content

Commit

Permalink
Resolves coredump caused by tf.data.experimental.save with prefetch
Browse files Browse the repository at this point in the history
Repeat and prefetch in combination cause the snapshot reader Initialize function to be invoked multiple times.
However, there is nothing to prefetch on the very last iteration. This results in Prefetch issuing a CancelThreads call while the snapshot thread is trying to initialize. See https://github.com/tensorflow/tensorflow/blob/6446dda92eaadf11d22377e2354307642d739d73/tensorflow/core/kernels/data/prefetch_dataset_op.cc#L151

Currently the dataset reference counting is done asymmetrically. The reference increment happens at the end of initialization, where as the reference decrement
happens in a destructor. When prefetch cancels the snapshot thread, it errors out of the initialization function. And stops calling the reference increment. However, the reference decrement happens regardless, as it is in the destructor which always is invoked during cleanup. This results in an attempt to decrement the null dataset pointer, and therefore a segmentation fault.
This is different from all other dataset ops, where the dataset reference increment happens in the constructor and the decrement happens in the destructor, which are symmetric.

The solution to this is to ensure that the dataset reference is always initialized to nullptr, and to check for null when decrementing the dataset reference.
  • Loading branch information
ashahab committed May 20, 2021
1 parent 15d5b93 commit ac1fcf2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tensorflow/core/kernels/data/experimental/io_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,11 @@ class LoadDatasetOp::Dataset : public DatasetBase {
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}

~Iterator() override { input_->Unref(); }
~Iterator() override {
if (input_) {
input_->Unref();
}
}

Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
Expand Down Expand Up @@ -331,7 +335,7 @@ class LoadDatasetOp::Dataset : public DatasetBase {
}

mutex mu_;
DatasetBase* input_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
};
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/data/experimental/kernel_tests/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function

import numpy as np
import os
import shutil

Expand Down Expand Up @@ -111,6 +112,20 @@ def testOptionalElementSpec(self):
dataset_loaded = io.load(self._test_dir)
self.assertDatasetsEqual(dataset, dataset_loaded)

@combinations.generate(test_base.eager_only_combinations())
def testRepeatAndPrefetch(self):
"""This test reproduces github.com/tensorflow/tensorflow/issues/49165"""
dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
io.save(dataset1, self._test_dir)
dataset = io.load(self._test_dir)
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
next_element = self.getNext(dataset)
for _ in range(30):
self.evaluate(next_element())


if __name__ == "__main__":
test.main()

0 comments on commit ac1fcf2

Please sign in to comment.