Skip to content

Commit

Permalink
[Runtime] Fix duplicate references to const tensor (DeepRec-AI#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
yitongh authored Dec 7, 2021
1 parent 4e7dbf3 commit a8adcc9
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 16 deletions.
4 changes: 2 additions & 2 deletions tensorflow/core/common_runtime/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ void ExecutorState<PropagatorStateType>::ProcessConstTensor(
nodestats::SetOpEnd(stats);
Entry& output = (*outputs)[0];
output.state = Entry::State::HAS_CONST_TENSOR;
output.const_tensor = item.const_tensor;
output.const_tensor = item.kernel->const_tensor();
output.alloc_attr = item.output_attrs()[0];
}

Expand Down Expand Up @@ -800,7 +800,7 @@ void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node,
if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs);
} else if (TF_PREDICT_FALSE(item.is_noop)) {
ProcessNoop(stats);
} else if (item.const_tensor != nullptr && !params.track_allocations) {
} else if (item.kernel->const_tensor() != nullptr && !params.track_allocations) {
ProcessConstTensor(item, &outputs, stats);
} else {
// Prepares inputs.
Expand Down
3 changes: 0 additions & 3 deletions tensorflow/core/common_runtime/graph_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ struct NodeItem {
// The kernel for this node.
OpKernel* kernel = nullptr;

// If the kernel is a Const op, this containts points to the constant tensor.
const Tensor* const_tensor = nullptr;

// Cached values of node->num_inputs() and node->num_outputs(), to
// avoid levels of indirection.
int num_inputs;
Expand Down
8 changes: 0 additions & 8 deletions tensorflow/core/common_runtime/immutable_executor_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,6 @@ Status ImmutableExecutorState::Initialize() {
break;
}
}
const Tensor* const_tensor = item->kernel->const_tensor();
if (const_tensor) {
// Hold onto a shallow copy of the constant tensor in `*this` so that the
// reference count does not drop to 1. This prevents the constant tensor
// from being forwarded, and its buffer reused.
const_tensors_.emplace_back(*const_tensor);
}
item->const_tensor = const_tensor;
item->is_noop = (item->kernel->type_string() == "NoOp");
item->is_enter = IsEnter(n);
if (item->is_enter) {
Expand Down
3 changes: 0 additions & 3 deletions tensorflow/core/common_runtime/immutable_executor_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,6 @@ class ImmutableExecutorState {
// pending counts for the nodes in the graph, indexed by node ID.
std::unique_ptr<std::atomic<int32>[]> atomic_pending_counts_;

// Shallow copies of the constant tensors used in the graph.
std::vector<Tensor> const_tensors_;

TF_DISALLOW_COPY_AND_ASSIGN(ImmutableExecutorState);
};

Expand Down

0 comments on commit a8adcc9

Please sign in to comment.