Skip to content

Fix: fix the bug that uninitialized tensor cannot be found #1933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
const auto to_compile_sym = c10::Symbol::attr("to_compile");

for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant) {
if (isConstantOrUninitialized(n)) {
continue;
}

Expand Down Expand Up @@ -107,7 +107,7 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
q.pop();
// for every node that produces this fallback node's NonTensor input, they should fallback too
for (auto input : cur_node->inputs()) {
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
if (!isTensor(input) && !isConstantOrUninitialized(input->node()) &&
ctx->shouldNodeRunInTensorRT(input->node())) {
ctx->setNodeExecutorDecision(input->node(), NodeExecutorDecision::kNON_TENSOR);
q.push(input->node());
Expand All @@ -118,7 +118,7 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
if (!isTensor(output)) {
for (auto use : output->uses()) {
auto node = use.user;
if (node->kind() != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT(node)) {
if (!isConstantOrUninitialized(node) && ctx->shouldNodeRunInTensorRT(node)) {
ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR);
q.push(node);
}
Expand All @@ -128,11 +128,13 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
}
}

std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
std::set<torch::jit::Node*> dependent_nodes;
std::set<torch::jit::Node*> getUserNodes(torch::jit::Node* n) {
std::set<torch::jit::Node*> user_nodes;
for (auto val : n->outputs()) {
for (auto use : val->uses()) {
dependent_nodes.insert(use.user);
if (use.user->owningBlock()->owningNode())
user_nodes.insert(use.user->owningBlock()->owningNode());
user_nodes.insert(use.user);
}
}
if (const auto* schema = n->maybeSchema()) {
Expand All @@ -142,13 +144,13 @@ std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
for (auto use : n->inputs()[i]->uses()) {
torch::jit::Node* use_node = use.user;
if (use_node->isAfter(n)) {
dependent_nodes.insert(use_node);
user_nodes.insert(use_node);
}
}
}
}
}
return dependent_nodes;
return user_nodes;
}

// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
Expand All @@ -158,14 +160,14 @@ std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
std::vector<torch::jit::Node*> min_block_fallback_nodes;
for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant) {
if (isConstantOrUninitialized(n)) {
continue;
}

// check if current node fallback or not
if (!ctx->shouldNodeRunInTorch(n)) {
cur_trt_nodes.push_back(n);
auto dependent_nodes = getDependentNodes(n);
auto dependent_nodes = getUserNodes(n);
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
} else {
if (cur_trt_nodes_uses.count(n)) {
Expand Down Expand Up @@ -250,7 +252,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
auto cur_val = q.front();
q.pop();
auto node = cur_val->node();
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
if (!isConstantOrUninitialized(node) && !visited.count(node)) {
visited.insert(node);
auto modifying_nodes = findModifyingNodes(cur_val, seg_block_nodes);
stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend());
Expand Down Expand Up @@ -454,10 +456,10 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
std::unordered_set<torch::jit::Node*> cur_pyt_nodes_uses;
for (const auto n : nodes) {
// Skip constant nodes as they are resources for both kinds of modules
if (n->kind() == torch::jit::prim::Constant) {
if (isConstantOrUninitialized(n)) {
continue;
}
auto dependent_nodes = getDependentNodes(n);
auto dependent_nodes = getUserNodes(n);
// the outputs of trt subgraph shouldn't be collections
if (ctx->shouldNodeRunInTensorRT(n)) {
in_prog_trt_blk_nodes.push_back(n);
Expand Down
6 changes: 4 additions & 2 deletions core/partitioning/partitioningctx/PartitioningCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ PartitioningCtx::PartitioningCtx(torch::jit::Block* b, PartitioningInfo info)
}

void PartitioningCtx::_load_nodes_into_decision_map(torch::jit::Block* b) {
if (b->owningNode() && b->owningNode()->kind() == torch::jit::prim::Loop)
// won't load nodes if these nodes are in prim::loop or if these nodes are 2-level nested
if (b->owningNode() &&
(b->owningNode()->kind() == torch::jit::prim::Loop || b->owningNode()->owningBlock()->owningNode()))
return;

original_blocks.push_back(b);

for (const auto n : b->nodes()) {
if (n->kind() == torch::jit::prim::Constant) {
if (isConstantOrUninitialized(n)) {
continue;
}
node_executor_decision_map[n] = NodeExecutorDecision::kUNKNOWN;
Expand Down
4 changes: 4 additions & 0 deletions core/partitioning/partitioningctx/PartitioningCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ struct PartitioningCtx {

std::ostream& operator<<(std::ostream& os, const PartitioningCtx& s);

inline bool isConstantOrUninitialized(torch::jit::Node* n) {
return n->kind() == torch::jit::prim::Constant || n->kind() == torch::jit::prim::Uninitialized;
}

} // namespace partitioning
} // namespace core
} // namespace torch_tensorrt
6 changes: 6 additions & 0 deletions core/partitioning/segmentedblock/SegmentedBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ torch::jit::Value* SegmentedBlock::getOrAddInputForValue(torch::jit::Value* old_
old_to_new_[old_value] = new_const->output();
return new_const->output();
}
if (node->kind() == torch::jit::prim::Uninitialized) {
auto new_uninitialized = g_->createUninitialized(old_value->type());
g_->block()->prependNode(new_uninitialized);
old_to_new_[old_value] = new_uninitialized->output();
return new_uninitialized->output();
}
auto new_value = g_->block()->addInput();
// every time when we addInput, we push back the corresponding lowering graph torch::jit::Value to our raw_inputs
inputs_.push_back(old_value);
Expand Down
38 changes: 38 additions & 0 deletions tests/core/partitioning/test_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,44 @@ TEST(Partitioning, SegmentModelWithDependencyAwareness) {
checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 2, 4}, {1, 3, 5}, {6, 7}}));
}

TEST(Partitioning, ContainUninitializedValueCorrectly) {
auto g = std::make_shared<torch::jit::Graph>();
auto x = g->insertInput(0, "x");
auto none_const_val = g->insertConstant(torch::jit::IValue());
auto ivalue_1 = g->insertConstant(torch::jit::IValue(1));
auto ivalue_2 = g->insertConstant(torch::jit::IValue(2));

auto uninitialized_node = g->createUninitialized(torch::jit::BoolType::get());
g->appendNode(uninitialized_node);

auto x_dim = g->create(torch::jit::aten::dim, {x}, 1);
g->appendNode(x_dim);
x_dim->output()->setType(torch::jit::IntType::get());

auto eq1 = g->create(torch::jit::aten::eq, {ivalue_1, x_dim->output()}, 1);
g->appendNode(eq1);
eq1->output()->setType(torch::jit::BoolType::get());

torch::jit::IValue except("EXCEPTION");
auto exception_val = g->insertConstant(except);
auto if_node = g->create(torch::jit::prim::If, {eq1->output()}, 1);
auto if_block_0 = if_node->addBlock();
auto exception_node = g->create(torch::jit::prim::RaiseException, {exception_val, none_const_val}, 0);
if_block_0->appendNode(exception_node);
if_block_0->registerOutput(uninitialized_node->output());

auto if_block_1 = if_node->addBlock();
if_block_1->registerOutput(eq1->output());

g->insertNode(if_node);

PartitioningInfo partitioning_info;
partitioning_info.enabled = true;
PartitioningCtx ctx(g->block(), partitioning_info);
segmentGraph(&ctx, g->block());
ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2));
}

} // namespace tests
} // namespace partitioning
} // namespace core
Expand Down