Skip to content

Commit 58f17b7

Browse files
committed
Apply code review comments
1 parent d36a899 commit 58f17b7

File tree

7 files changed

+40
-35
lines changed

7 files changed

+40
-35
lines changed

include/tvm/relay/transform.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ TVM_DLL Pass FoldConstant(bool fold_qnn = false);
120120
/*!
121121
* \brief Split function with huge number of arguments to smaller pieces.
122122
*
123-
* \param max_function_args Maximum number of function arguments. If it is 0 then SplitArgs won't
124-
* split function.
123+
* \param max_function_args Maximum number of function arguments. If it equals 0 then SplitArgs
124+
* shouldn't split the function.
125125
*
126126
* \return The pass.
127127
*/

python/tvm/relay/transform/transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,8 @@ def SplitArgs(max_function_args):
13791379
Parameters
13801380
----------
13811381
max_function_args: int
1382-
Maximum number of function arguments. If it is 0 then SplitArgs won't split function.
1382+
Maximum number of function arguments. If it equals 0 then SplitArgs
1383+
shouldn't split the function.
13831384
13841385
13851386
Returns

src/relay/analysis/graph_partitioner.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,33 +243,33 @@ size_t GraphPartitioner::CountArgs_(IndexedForwardGraph::Node* src,
243243
ICHECK(gnode != nullptr);
244244
auto sum = gnode->args_num;
245245
visited_groups.insert(gnode->FindRoot());
246-
auto calcArgs = [this, src, &graph, &visited_groups,
247-
update_postpone](const relay::Expr& arg) -> size_t {
246+
auto calc_args_number = [this, src, &graph, &visited_groups,
247+
update_postpone](const relay::Expr& arg) -> size_t {
248248
if (arg.as<VarNode>()) return 0;
249249
auto* node = graph.node_map.at(arg.get());
250250
Group* prev_group = groups_[node->index]->FindRoot();
251251
if (visited_groups.count(prev_group) == 0) {
252252
visited_groups.insert(prev_group);
253253
if (prev_group->args_num > 0) {
254-
// Get number of arguments from group
254+
// Get the number of arguments from the group
255255
return prev_group->args_num;
256256
} else if (update_postpone) {
257-
// Update pointer to node which should be postponed for deferred fusing
257+
// Update pointer to the node which should be postponed for deferred fusing
258258
postpone_node_ = src;
259259
} else {
260-
// Calculate number of arguments for the node which wasn't processed before
260+
// Calculate the number of arguments for the node which wasn't processed before
261261
return CountArgs_(node, graph, update_postpone);
262262
}
263263
}
264264
return 0;
265265
};
266266
if (auto call_node = GetRef<ObjectRef>(src->ref).as<CallNode>()) {
267267
for (auto& it : call_node->args) {
268-
sum += calcArgs(it);
268+
sum += calc_args_number(it);
269269
}
270270
} else if (auto tuple_node = GetRef<ObjectRef>(src->ref).as<TupleNode>()) {
271271
for (auto& it : tuple_node->fields) {
272-
sum += calcArgs(it);
272+
sum += calc_args_number(it);
273273
}
274274
}
275275
return sum;
@@ -357,11 +357,11 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
357357
Group* group_node = groups_[nid];
358358
ICHECK(group_node != nullptr);
359359
postpone_node_ = nullptr;
360-
// Check if fusing of some inputs was postponed
360+
// Check if the fusing of some inputs was postponed
361361
if (postponed_fusing_map_.count(graph_node)) {
362362
auto range = postponed_fusing_map_.equal_range(graph_node);
363363
for (auto it = range.first; it != range.second; ++it) {
364-
// If number of arguments is less than limit then the input can be fused
364+
// If the number of arguments is less than the limit then the input can be fused
365365
if (CountArgs_(graph_node, graph, false) <= CountArgsLimit_(graph_node)) {
366366
auto* src = it->second;
367367
auto* snode = post_dom_tree.nodes[src->index]->parent->gnode;
@@ -381,7 +381,7 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
381381
// refuse the fusion if too many ops are going to be fused together
382382
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
383383
continue;
384-
// refuse the fusion if too many arguments are going to be in fused function
384+
// Refuse the fusion if too many arguments are going to be in the fused function
385385
if (max_function_args_ > 0) {
386386
auto limit = CountArgsLimit_(graph_node);
387387
if (limit > 0) {

src/relay/analysis/graph_partitioner.h

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ class GraphPartitioner {
224224
postponed_fusing_map_;
225225
/*!
226226
* \brief Fusing of this node should be postponed till all child nodes will be evaluated.
227-
* It is used to calculate number of arguments which will be passed to this node in
228-
* generated function.
227+
* It is used to calculate the number of arguments which will be passed to this node in
228+
* the generated function.
229229
*/
230230
const IndexedForwardGraph::Node* postpone_node_{nullptr};
231231
// Internal implementation of CheckPath
@@ -266,21 +266,21 @@ class GraphPartitioner {
266266
void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);
267267

268268
size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);
269-
// Count the number of additional arguments. In case of dynamic shape,
270-
// generated function takes several additional arguments, such as size of
271-
// dynamic dimension and strides.
272-
// This function calculates number of such additional arguments.
269+
// Count the number of additional arguments. In the case of dynamic shape,
270+
// generated function takes several additional arguments, such as the sizes of
271+
// the dynamic dimensions and strides.
272+
// This function calculates the number of such additional arguments.
273273
size_t CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides = true);
274274
// Calculate the number of arguments for the node.
275275
size_t CountArgs_(IndexedForwardGraph::Node* src, const IndexedForwardGraph& graph,
276276
bool update_postpone = true);
277-
// Count actual limit of arguments for a generated function.
277+
// Count the actual limit of arguments for a generated function.
278278
// max_function_args_ specifies the number of maximum function arguments. But
279-
// usually, output tensors also passed to the function as arguments.
280-
// Additionally, in case of dynamic shape, it is necessary to take into
281-
// account the number of parameters which specifies the size of dynamic
282-
// dimension.
283-
// This function computes limit of arguments by the following formula:
279+
// usually, output tensors are also passed to the function as arguments.
280+
// Additionally, in the case of dynamic shape, it is necessary to take into
281+
// account the number of parameters which specifies the sizes of the dynamic
282+
// dimensions.
283+
// This function computes the limit of arguments by the following formula:
284284
// limit = max_function_args_ - output_args_count
285285
size_t CountArgsLimit_(const IndexedForwardGraph::Node* child);
286286

@@ -292,7 +292,9 @@ class GraphPartitioner {
292292
// is important for correct calculation.
293293
size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
294294
IndexedForwardGraph::Node* dom_parent);
295-
// Count the number of arguments in a fused subgraph if the output of child is additionally fused.
295+
// Count the number of arguments in a fused subgraph. This function also takes into account the
296+
// number of the child's output node argument. It helps to stop fusing before the node when the
297+
// limit will be exceeded.
296298
size_t CountFusedArgs(const IndexedForwardGraph& graph, IndexedForwardGraph::Node* child);
297299

298300
// Initialize the groups.

src/relay/transforms/split_args.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,17 @@ class ArgumentSplitter : public ExprRewriter {
6161
return lastExpr;
6262
}
6363

64-
// In case of dynamic shape in tensor, size of any_dims and strides are passed as function args
64+
// In the case of dynamic shape in tensor, the sizes of any_dims and strides are passed as
65+
// function args
6566
size_t CalculateNumberOfAdditionalArgs_(const TensorTypeNode* arg, bool isOutput = false) {
6667
size_t num = 0;
6768
for (const auto& dim : arg->shape) {
6869
if (dim.as<AnyNode>()) {
6970
num++;
7071
}
7172
}
72-
// In case of dynamic shape also strides will be passed to function
73-
// as arguments. Number of strides equals to the rank of the tensor.
73+
// In the case of dynamic shape, strides are also passed to a function as arguments. The number
74+
// of strides equals the rank of the tensor.
7475
if (num > 0 && isOutput)
7576
return arg->shape.size();
7677
else if (num > 0)

src/target/target_kind.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,8 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
365365
.add_attr_option<Integer>("max_num_threads", Integer(256))
366366
.add_attr_option<Integer>("thread_warp_size", Integer(1))
367367
.add_attr_option<Integer>("texture_spatial_limit", Integer(16384))
368-
// Faced that Qualcomm OpenCL runtime was crashed without any error message in
369-
// case when the number of kernel arguments was pretty big. OpenCL doesn't
368+
// Faced that Qualcomm OpenCL runtime crashed without any error message in
369+
// the case when the number of kernel arguments was pretty big. OpenCL doesn't
370370
// specify any limitations on the number of kernel arguments. max_function_args
371371
// equals to 128 looks like a reasonable number of kernel arguments.
372372
.add_attr_option<Integer>("max_function_args", Integer(128))

tests/python/unittest/test_target_codegen_opencl.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,11 @@ def ref_impl(inputs):
357357
@tvm.testing.requires_opencl
358358
def test_fuse_concat_max_num_args(executor_type, shape_type):
359359
"""
360-
In this test before concat we have an operation with 3 inputs. In the
361-
SplitArgs we cannot calculate these inputs as inputs to concat, because
362-
they will be added to the concat after fusing operation. So FuseOps pass
363-
should handle this case and stop fusing before concat.
360+
In this test, we have an operation with 3 inputs before concat. In the
361+
SplitArgs we cannot calculate these inputs as inputs to the concat layer,
362+
because they will be added to the concat after the fusing operation. So
363+
FuseOps pass should handle this case and stop fusing before the concat
364+
layer.
364365
365366
The example:
366367
x y z x y z

0 commit comments

Comments
 (0)