Skip to content

Commit 8a89a92

Browse files
[ExecScope] Add cur scope name to ScopeIdDef (apache#7)
1 parent d54b8fd commit 8a89a92

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

include/tvm/tir/exec_scope.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,32 +49,36 @@ class ScopeIdDefNode : public Object {
4949
Array<ScopeId> def_ids;
5050
/*! \brief The extents of the ScopeId */
5151
Array<PrimExpr> extents;
52-
/*! \brief Parent ExecScope name*/
52+
/*! \brief Parent ExecScope name */
5353
String parent;
54+
/*! \brief Current ExecScope name */
55+
String cur;
5456

5557
void VisitAttrs(AttrVisitor* v) {
5658
v->Visit("def_ids", &def_ids);
5759
v->Visit("extents", &extents);
5860
v->Visit("parent", &parent);
61+
v->Visit("cur", &cur);
5962
}
6063

6164
bool SEqualReduce(const ScopeIdDefNode* other, SEqualReducer equal) const {
6265
return equal(def_ids, other->def_ids) && equal(extents, other->extents) &&
63-
equal(parent, other->parent);
66+
equal(parent, other->parent) && equal(cur, other->cur);
6467
}
6568

6669
void SHashReduce(SHashReducer hash_reduce) const {
6770
hash_reduce(def_ids);
6871
hash_reduce(extents);
6972
hash_reduce(parent);
73+
hash_reduce(cur);
7074
}
7175
static constexpr const char* _type_key = "tir.ScopeIdDef";
7276
TVM_DECLARE_FINAL_OBJECT_INFO(ScopeIdDefNode, Object);
7377
};
7478

7579
class ScopeIdDef : public ObjectRef {
7680
public:
77-
TVM_DLL ScopeIdDef(Array<ScopeId> def_ids, Array<PrimExpr> extents, String parent);
81+
TVM_DLL ScopeIdDef(Array<ScopeId> def_ids, Array<PrimExpr> extents, String parent, String cur);
7882

7983
TVM_DEFINE_OBJECT_REF_METHODS(ScopeIdDef, ObjectRef, ScopeIdDefNode);
8084
};

python/tvm/tir/exec_scope.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ class ScopeIdDef(Object):
3636
def_ids: List[ScopeId]
3737
extents: List[PrimExpr]
3838
parent: str
39+
cur: str
3940

40-
def __init__(self, def_ids: List[ScopeId], extents: List[PrimExpr], parent: str):
41-
self.__init_handle_by_constructor__(_ffi_api.ScopeIdDef, def_ids, extents, parent)
41+
def __init__(self, def_ids: List[ScopeId], extents: List[PrimExpr], parent: str, cur: str):
42+
self.__init_handle_by_constructor__(_ffi_api.ScopeIdDef, def_ids, extents, parent, cur)
4243

4344

4445
@register_object("tir.ExecScope")

src/tir/ir/exec_scope.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,21 @@ TVM_REGISTER_NODE_TYPE(ScopeIdNode);
3737
TVM_REGISTER_GLOBAL("tir.ScopeId").set_body_typed([](String name) { return ScopeId(name); });
3838

3939
// ScopeIdDef
40-
ScopeIdDef::ScopeIdDef(Array<ScopeId> ids, Array<PrimExpr> extents, String parent) {
40+
ScopeIdDef::ScopeIdDef(Array<ScopeId> ids, Array<PrimExpr> extents, String parent, String cur) {
4141
auto n = make_object<ScopeIdDefNode>();
4242
ICHECK_EQ(ids.size(), extents.size()) << "Number of dimensions must match";
4343
n->def_ids = std::move(ids);
4444
n->extents = std::move(extents);
4545
n->parent = std::move(parent);
46+
n->cur = std::move(cur);
4647
data_ = std::move(n);
4748
}
4849

4950
TVM_REGISTER_NODE_TYPE(ScopeIdDefNode);
5051

5152
TVM_REGISTER_GLOBAL("tir.ScopeIdDef")
52-
.set_body_typed([](Array<ScopeId> vars, Array<PrimExpr> extents, String parent) {
53-
return ScopeIdDef(vars, extents, parent);
53+
.set_body_typed([](Array<ScopeId> vars, Array<PrimExpr> extents, String parent, String cur) {
54+
return ScopeIdDef(vars, extents, parent, cur);
5455
});
5556

5657
// ExecScope

0 commit comments

Comments
 (0)