Skip to content

Commit d832dd6

Browse files
committed
address comments
1 parent 42434bb commit d832dd6

File tree

13 files changed

+20
-87
lines changed

13 files changed

+20
-87
lines changed

include/tvm/relay/op_attr_types.h

Lines changed: 6 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,6 @@ class OpImplementNode : public Object {
249249
*/
250250
class OpImplement : public ObjectRef {
251251
public:
252-
/*! \brief default constructor */
253-
OpImplement() {}
254-
/*! \brief constructor from node pointer */
255-
explicit OpImplement(ObjectPtr<Object> n) : ObjectRef(n) {}
256-
/*!
257-
* \brief access the internal node container
258-
* \return the pointer to the internal node container
259-
*/
260-
inline const OpImplementNode* operator->() const;
261252
/*!
262253
* \brief Invoke the operator compute function.
263254
* \param attrs The attribute of the primitive
@@ -278,6 +269,8 @@ class OpImplement : public ObjectRef {
278269
te::Schedule Schedule(const Attrs& attrs,
279270
const Array<te::Tensor>& outs,
280271
const Target& target);
272+
273+
TVM_DEFINE_OBJECT_REF_METHODS(OpImplement, ObjectRef, OpImplementNode);
281274
};
282275

283276
/*!
@@ -305,18 +298,6 @@ class OpSpecializationNode : public Object {
305298
*/
306299
class OpSpecialization : public ObjectRef {
307300
public:
308-
OpSpecialization() {}
309-
explicit OpSpecialization(ObjectPtr<Object> n) : ObjectRef(n) {}
310-
/*!
311-
* \brief access the internal node container
312-
* \return the pointer to the internal node container
313-
*/
314-
inline const OpSpecializationNode* operator->() const;
315-
/*!
316-
* \brief access the internal node container
317-
* \return the pointer to the internal node container
318-
*/
319-
inline OpSpecializationNode* operator->();
320301
/*!
321302
* \brief Add an implementation.
322303
* \param compute Compute function
@@ -325,6 +306,8 @@ class OpSpecialization : public ObjectRef {
325306
*/
326307
void AddImplement(FTVMCompute fcompute, FTVMSchedule fschedule,
327308
int plevel);
309+
310+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
328311
};
329312

330313
/*!
@@ -348,49 +331,16 @@ class OpStrategyNode : public Object {
348331
*/
349332
class OpStrategy : public ObjectRef {
350333
public:
351-
/*! \brief default constructor */
352-
OpStrategy() {}
353-
/*! \brief constructor from node pointer */
354-
explicit OpStrategy(ObjectPtr<Object> n) : ObjectRef(n) {}
355-
/*!
356-
* \brief access the internal node container
357-
* \return the pointer to the internal node container
358-
*/
359-
inline const OpStrategyNode* operator->() const;
360-
/*!
361-
* \brief access the internal node container
362-
* \return the pointer to the internal node container
363-
*/
364-
inline OpStrategyNode* operator->();
365334
/*!
366335
* \brief Add an implementation.
367336
* \param compute Compute function
368337
* \param schedule Schedule function
369338
* \param plevel Priority level of this implementation.
370339
*/
371340
void AddImplement(FTVMCompute fcompute, FTVMSchedule fschedule, int plevel);
372-
};
373-
374-
// implementations
375-
inline const OpImplementNode* OpImplement::operator->() const {
376-
return static_cast<const OpImplementNode*>(get());
377-
}
378341

379-
inline const OpSpecializationNode* OpSpecialization::operator->() const {
380-
return static_cast<const OpSpecializationNode*>(get());
381-
}
382-
383-
inline OpSpecializationNode* OpSpecialization::operator->() {
384-
return static_cast<OpSpecializationNode*>(get_mutable());
385-
}
386-
387-
inline const OpStrategyNode* OpStrategy::operator->() const {
388-
return static_cast<const OpStrategyNode*>(get());
389-
}
390-
391-
inline OpStrategyNode* OpStrategy::operator->() {
392-
return static_cast<OpStrategyNode*>(get_mutable());
393-
}
342+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
343+
};
394344

395345
} // namespace relay
396346
} // namespace tvm

python/tvm/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def __init__(self, var, value, body):
864864
_make.Let, var, value, body)
865865

866866

867-
@register_object
867+
@tvm._ffi.register_object
868868
class Any(PrimExpr):
869869
"""Any node.
870870
"""

python/tvm/relay/backend/compile_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from ..base import register_relay_node, Object
2626
from ... import _api_internal
2727
from ... import target as _target
28-
from ..._ffi.function import register_func
2928
from ... import autotvm
3029
from .. import expr as _expr
3130
from .. import op as _op
@@ -389,7 +388,7 @@ def visit_tuple_getitem(self, t):
389388
return [tup[t.index]]
390389

391390

392-
@register_func("relay.backend.create_schedule")
391+
@tvm._ffi.register_func("relay.backend.create_schedule")
393392
def create_schedule(src_func, target):
394393
return ScheduleGetter(target).create(src_func)
395394

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target):
172172
assert groups == 1, "only support groups == 1 for now"
173173
strategy = _op.OpStrategy()
174174
strategy.add_implement(
175-
wrap_comptue_conv2d_transpose(topi.arm_cpu.conv2d_transpose_nchw),
175+
wrap_compute_conv2d_transpose(topi.arm_cpu.conv2d_transpose_nchw),
176176
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_transpose_nchw))
177177
return strategy
178178

python/tvm/relay/op/strategy/cuda.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,6 @@ def schedule_lrn_cuda(attrs, outs, target):
7171
with target:
7272
return topi.cuda.schedule_lrn(outs)
7373

74-
@schedule_l2_normalize.register(["cuda", "gpu"])
75-
def schedule_l2_normalize_cuda(attrs, outs, target):
76-
"""schedule L2 normalize for cuda"""
77-
with target:
78-
return topi.cuda.schedule_l2_normalize(outs)
79-
8074
@conv2d_strategy.register(["cuda", "gpu"])
8175
def conv2d_strategy_cuda(attrs, inputs, out_type, target):
8276
"""conv2d cuda strategy"""
@@ -197,7 +191,7 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
197191
assert groups == 1, "only support groups == 1 for now"
198192
strategy = _op.OpStrategy()
199193
strategy.add_implement(
200-
wrap_comptue_conv2d_transpose(topi.cuda.conv2d_transpose_nchw),
194+
wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_nchw),
201195
wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw))
202196
return strategy
203197

python/tvm/relay/op/strategy/generic.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,6 @@ def schedule_lrn(attrs, outs, target):
120120
with target:
121121
return topi.generic.schedule_lrn(outs)
122122

123-
# l2_normalize
124-
@generic_func
125-
def schedule_l2_normalize(attrs, outs, target):
126-
"""Schedule L2 normalize op"""
127-
with target:
128-
return topi.generic.schedule_l2_normalize(outs)
129-
130123
# bitpack
131124
@generic_func
132125
def schedule_bitpack(attrs, outs, target):
@@ -283,7 +276,7 @@ def deformable_conv2d_strategy(attrs, inputs, out_type, target):
283276
return strategy
284277

285278
# conv2d_transpose
286-
def wrap_comptue_conv2d_transpose(topi_compute):
279+
def wrap_compute_conv2d_transpose(topi_compute):
287280
"""wrap conv2d_transpose topi compute"""
288281
def compute_conv2d_transpose(attrs, inputs, out_dtype):
289282
"""Compute definition of conv2d_transpose"""
@@ -311,7 +304,7 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target):
311304
assert groups == 1, "only support groups == 1 for now"
312305
strategy = _op.OpStrategy()
313306
strategy.add_implement(
314-
wrap_comptue_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
307+
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
315308
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw))
316309
return strategy
317310

python/tvm/relay/op/strategy/hls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def conv2d_transpose_strategy_hls(attrs, inputs, out_type, target):
121121
assert groups == 1, "only support groups == 1 for now"
122122
strategy = _op.OpStrategy()
123123
strategy.add_implement(
124-
wrap_comptue_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
124+
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
125125
wrap_topi_schedule(topi.hls.schedule_conv2d_transpose_nchw))
126126
return strategy
127127

python/tvm/relay/op/strategy/rocm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,6 @@ def schedule_lrn_rocm(attrs, outs, target):
2828
with target:
2929
return topi.rocm.schedule_lrn(outs)
3030

31-
@schedule_l2_normalize.register("rocm")
32-
def schedule_l2_normalize_rocm(attrs, outs, target):
33-
"""schedule L2 normalize for rocm"""
34-
with target:
35-
return topi.rocm.schedule_l2_normalize(outs)
36-
3731
@conv2d_strategy.register("rocm")
3832
def conv2d_strategy_cuda(attrs, inputs, out_type, target):
3933
"""conv2d cuda strategy"""

python/tvm/relay/op/strategy/x86.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
169169
assert groups == 1, "only support groups == 1 for now"
170170
strategy = _op.OpStrategy()
171171
strategy.add_implement(
172-
wrap_comptue_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
172+
wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
173173
wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw))
174174
return strategy
175175

python/tvm/schedule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ def opengl(self):
650650
"""
651651
_api_internal._StageOpenGL(self)
652652

653-
@register_object
653+
@tvm._ffi.register_object
654654
class SpecializedCondition(Object):
655655
"""Specialized condition to enable op specialization."""
656656
def __init__(self, conditions):

0 commit comments

Comments
 (0)