Skip to content

Commit

Permalink
[PT FE] Support aten::unsafe_chunk op
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Oct 7, 2024
1 parent cf870cd commit d5abde0
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/frontends/pytorch/src/op/getitem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ OutputVector translate_getitem(const NodeContext& context) {
PYTORCH_OP_CONVERSION_CHECK(!idx_type.is<type::Str>(),
"String index in aten::__getitem__ means dict input, this is not supported.");
if (ov::as_type_ptr<ov::op::util::FrameworkNode>(input.get_node_shared_ptr())) {
PYTORCH_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::split"),
"special case for aten::__getitem__");
PYTORCH_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::chunk"),
"special case for aten::__getitem__");
PYTORCH_OP_CONVERSION_CHECK(
!cast_fw_node(input.get_node_shared_ptr(),
std::vector<std::string>{"aten::split", "aten::chunk", "aten::unsafe_chunk"}),
"special case for aten::__getitem__");
const auto&& list_elems = get_list_as_outputs(input);
auto getitem_idx = context.const_input<int64_t>(1);
if (getitem_idx < 0) {
Expand Down
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
// aten::unbind - Supported in limited set of patterns
{"aten::unflatten", op::translate_unflatten},
{"aten::unfold", op::translate_unfold},
// aten::unsafe_chunk - Supported in limited set of patterns
{"aten::unsqueeze", op::quantizable_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>},
{"aten::upsample_bicubic2d", op::translate_upsample_bicubic2d},
{"aten::upsample_bilinear2d", op::translate_upsample_bilinear2d},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
auto gather = rg.make<v8::Gather>(input_concat, getitem_idx, zero);
replace_node(getitem, gather);
}
} else if (auto chunk = cast_fw_node(input_node, "aten::chunk")) {
} else if (auto chunk =
cast_fw_node(input_node, std::vector<std::string>{"aten::chunk", "aten::unsafe_chunk"})) {
auto input_tensor = chunk->get_input_source_output(0);
auto chunks_i32 = chunk->get_input_source_output(1);
auto dim_i32 = chunk->get_input_source_output(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ IndexLoopGetitemReplacer::IndexLoopGetitemReplacer() {
size_t chunk_idx = 0;
auto loop_inputs = loop_op->input_values();
for (size_t i = 1; i < loop_inputs.size(); i++) {
if (cast_fw_node(loop_inputs.at(i).get_node_shared_ptr(), "aten::chunk")) {
if (cast_fw_node(loop_inputs.at(i).get_node_shared_ptr(),
std::vector<std::string>{"aten::chunk", "aten::unsafe_chunk"})) {
chunk_op = loop_inputs.at(i).get_node_shared_ptr();
chunk_idx = i;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
replace_node(list_unpack, split);

return true;
} else if (auto chunk = cast_fw_node(input_node, "aten::chunk")) {
} else if (auto chunk =
cast_fw_node(input_node, std::vector<std::string>{"aten::chunk", "aten::unsafe_chunk"})) {
if (list_unpack->get_output_size() == 1) {
list_unpack->output(0).replace(input_node->input_value(0));
return true;
Expand Down
15 changes: 15 additions & 0 deletions src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,21 @@ std::shared_ptr<ov::op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node>
return fw_node;
}

std::shared_ptr<ov::op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node,
const std::vector<std::string>& types) {
auto fw_node = std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(node);
if (!fw_node) {
return nullptr;
}
const auto& attrs = fw_node->get_attrs();
for (auto type : types) {
if (attrs.find(PtFrameworkNode::op_type_key) != attrs.end() && attrs.at(PtFrameworkNode::op_type_key) == type) {
return fw_node;
}
}
return nullptr;
}

std::shared_ptr<ov::Node> make_list_construct(const ov::OutputVector& inputs) {
auto list_construct = std::make_shared<ov::op::util::FrameworkNode>(inputs, inputs.size());
ov::op::util::FrameworkNodeAttrs attrs;
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ OutputVector make_framework_node_ignore_bodies(const NodeContext& context, const
OutputVector make_framework_node(const NodeContext& context, const std::string& exception);

std::shared_ptr<op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node, const std::string& type);
std::shared_ptr<op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node,
const std::vector<std::string>& types);

std::shared_ptr<Node> make_list_construct(const ov::OutputVector& inputs);

Expand Down

0 comments on commit d5abde0

Please sign in to comment.