Skip to content

Commit a962c96

Browse files
init
drop changes drop changes
1 parent 0f4c151 commit a962c96

File tree

7 files changed

+251
-22
lines changed

7 files changed

+251
-22
lines changed

python/tvm/relay/op/_tensor_grad.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -269,18 +269,6 @@ def conv2d_grad(orig, grad):
269269
return [backward_data, backward_weight]
270270

271271

272-
@register_gradient("max")
273-
def max_grad(orig, grad):
274-
"""Returns the gradient of max"""
275-
# Only support axis=0, since broadcasting orig to x behaves incorrectly
276-
x, axis = orig.args[0], orig.attrs.axis
277-
assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
278-
orig = broadcast_to_like(orig, x)
279-
grad = broadcast_to_like(grad, x)
280-
indicators = cast_like(equal(orig, x), grad)
281-
return [indicators * grad]
282-
283-
284272
@register_gradient("nn.softmax")
285273
def softmax_grad(orig, grad):
286274
"""Gradient of softmax"""
@@ -302,6 +290,24 @@ def dense_grad(orig, grad):
302290
return [collapse_sum_like(transpose(grad) * weight, data),
303291
collapse_sum_like(data * transpose(grad), weight)]
304292

293+
# UNTESTED
294+
@register_gradient("reshape")
295+
def reshape_grad(orig, grad):
296+
return [reshape_like(grad, orig.args[0])]
297+
298+
299+
# UNTESTED
300+
@register_gradient("take")
301+
def take_grad(orig, grad):
302+
x, y = orig.args
303+
return [zeros_like(x), zeros_like(y)]
304+
return [Call(op_get("take_grad"), [x, y, grad], orig.attrs), zeros_like(y)]
305+
306+
307+
@register_gradient("shape_of")
308+
def shape_of_grad(orig, grad):
309+
return [zeros_like(orig.args[0])]
310+
305311

306312
@register_gradient("reshape")
307313
def reshape_grad(orig, grad):
@@ -347,3 +353,34 @@ def sum_grad(orig, grad):
347353
"""Returns grad broadcasted to data dims"""
348354
data = orig.args[0]
349355
return [broadcast_to_like(grad, data)]
356+
357+
358+
@register_gradient("nn.global_avg_pool2d")
359+
def global_avg_pool2d_grad(orig, grad):
360+
"""repeat the h w dimension"""
361+
# focuse on conv2d rn.
362+
return [orig.args[0]]
363+
364+
365+
@register_gradient("nn.batch_norm")
366+
def batch_norm_grad(orig, grad):
367+
"""multiply some stuff"""
368+
# batchnorm has a wrong api so we will not waste time implementing it.
369+
a, b, c, d, e = orig.args
370+
return [a, b, c, d, e]
371+
372+
373+
@register_gradient("split")
374+
def split_grad(orig, grad):
375+
# return zero
376+
return [orig.args[0]]
377+
378+
379+
@register_gradient("nn.cross_entropy")
380+
def cross_entropy_grad(orig, grad):
381+
x, y = orig.args
382+
sm = softmax(x)
383+
shape = shape_of(x)
384+
batch_size = take(shape, const(0, dtype='int32'), axis=0)
385+
grad = grad / batch_size.astype('float32')
386+
return [reduce_sum(y, axis=1) * grad * (sm - y), -grad * log(sm)]

python/tvm/relay/op/nn/_nn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,3 +717,13 @@ def schedule_bitserial_dense(attrs, outputs, target):
717717

718718

719719
reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
720+
721+
reg.register_schedule("nn.cross_entropy", schedule_injective)
722+
723+
reg.register_pattern("nn.cross_entropy",
724+
OpPattern.OPAQUE)
725+
726+
@reg.register_compute("nn.cross_entropy")
727+
def compute_cross_entropy(attrs, inputs, out_dtype, target):
728+
x, y = inputs
729+
return [-topi.sum(topi.nn.log_softmax(x) * y / x.shape[0])]

python/tvm/relay/op/nn/nn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,3 +1621,7 @@ def bitserial_dense(data,
16211621
"""
16221622
return _make.bitserial_dense(data, weight, units, data_bits, weight_bits,
16231623
pack_dtype, out_dtype, unipolar)
1624+
1625+
1626+
def cross_entropy(predictions, targets):
1627+
return _make.cross_entropy(predictions, targets)

python/tvm/relay/op/transform.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,15 @@ def cast(data, dtype):
4242

4343
def cast_like(data, dtype_like):
4444
"""Cast input tensor to data type of another tensor.
45+
4546
Parameters
4647
----------
4748
data : relay.Expr
4849
The input data to the operator.
50+
4951
dtype_like: relay.Expr
5052
The tensor to cast to.
53+
5154
Returns
5255
-------
5356
result : relay.Expr
@@ -228,8 +231,8 @@ def reshape_like(data, shape_like):
228231
data : relay.Expr
229232
The input data to the operator.
230233
231-
shape_like : tuple of int
232-
The new shape. Should be compatible with the original shape.
234+
shape_like : relay.Expr
235+
The tensor to reshape to. Should be compatible with the original shape.
233236
234237
Returns
235238
-------
@@ -239,6 +242,37 @@ def reshape_like(data, shape_like):
239242
return _make.reshape_like(data, shape_like)
240243

241244

245+
def embed_like(data, indices, type_like, axis=None, mode="clip"):
246+
"""Take elements from an array along an axis.
247+
248+
Parameters
249+
----------
250+
data : relay.Expr
251+
The source array.
252+
253+
indices : rely.Expr
254+
The indices of the values to extract.
255+
256+
type_like : relay.Expr
257+
The tensor that provide the type to embed into.
258+
259+
axis : int, optional
260+
The axis over which to select values. By default,
261+
the flattened input array is used.
262+
263+
mode : str, optional
264+
Specifies how out-of-bound indices will behave [clip, wrap, fast].
265+
clip: clip to the range (default).
266+
wrap: wrap around the indices.
267+
fast: no clip or wrap around (user must make sure indices are in-bound).
268+
269+
Returns
270+
-------
271+
ret : relay.Expr
272+
The computed result.
273+
"""
274+
return _make.embed_like(data, indices, type_like, axis, mode)
275+
242276
def take(data, indices, axis=None, mode="clip"):
243277
"""Take elements from an array along an axis.
244278

src/relay/op/nn/nn.cc

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,50 @@ RELAY_REGISTER_OP("nn.bias_add")
103103
// relay.nn.dense
104104
TVM_REGISTER_NODE_TYPE(DenseAttrs);
105105

106+
107+
bool DenseRel(const Array<Type>& types,
108+
int num_inputs,
109+
const Attrs& attrs,
110+
const TypeReporter& reporter) {
111+
CHECK_EQ(types.size(), 3);
112+
const auto* data = types[0].as<TensorTypeNode>();
113+
const auto* weight = types[1].as<TensorTypeNode>();
114+
if (data == nullptr) return false;
115+
116+
const DenseAttrs* param = attrs.as<DenseAttrs>();
117+
CHECK(param != nullptr);
118+
119+
CHECK(static_cast<int>(data->shape.size()) != 0);
120+
121+
Array<tvm::Expr> oshape = data->shape;
122+
if (param->units.defined()) {
123+
Array<tvm::Expr> dshape = data->shape;
124+
// validate the weight shape is proper if defined
125+
// Assign weight type
126+
Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
127+
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
128+
oshape.Set((oshape.size() - 1), param->units);
129+
} else {
130+
if (weight == nullptr) return false;
131+
Array<tvm::Expr> wshape = weight->shape;
132+
CHECK(static_cast<int>(weight->shape.size()) == 2);
133+
CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
134+
<< "DenseRel: input dimension doesn't match,"
135+
<< " data shape=" << data->shape
136+
<< ", weight shape=" << weight->shape;
137+
oshape.Set((oshape.size() - 1), wshape[0]);
138+
}
139+
140+
DataType out_dtype = param->out_dtype;
141+
if (out_dtype.bits() == 0) {
142+
out_dtype = data->dtype;
143+
}
144+
// assign output type
145+
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
146+
return true;
147+
}
148+
149+
106150
// Positional relay function to create dense operator used by frontend FFI.
107151
Expr MakeDense(Expr data,
108152
Expr weight,
@@ -698,11 +742,11 @@ bool BatchMatmulRel(const Array<Type>& types,
698742
if (x == nullptr || y == nullptr) return false;
699743
CHECK(x->shape.size() == 3 && y->shape.size() == 3);
700744
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
701-
<< "BatchDot: batch dimension doesn't match, "
702-
<< " x shape=" << x->shape
703-
<< ", y shape=" << y->shape;
745+
<< "BatchDot: batch dimension doesn't match,"
746+
<< " x shape=" << x->shape
747+
<< ", y shape=" << y->shape;
704748
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
705-
<< "BatchDot: shapes of x and y is inconsistent, "
749+
<< "BatchDot: shapes of x and y is inconsistent,"
706750
<< " x shape=" << x->shape
707751
<< ", y shape=" << y->shape;
708752

@@ -746,6 +790,51 @@ are data in batch.
746790
.set_support_level(10)
747791
.add_type_rel("BatchMatmul", BatchMatmulRel);
748792

793+
// relay.nn.cross_entropy
794+
bool CrossEntropyRel(const Array<Type>& types,
795+
int num_inputs,
796+
const Attrs& attrs,
797+
const TypeReporter& reporter) {
798+
CHECK_EQ(types.size(), 3);
799+
const auto* x = types[0].as<TensorTypeNode>();
800+
const auto* y = types[1].as<TensorTypeNode>();
801+
if (x == nullptr || y == nullptr) return false;
802+
CHECK(x->shape.size() == 2 && y->shape.size() == 2)
803+
<< "CrossEntropy: shapes of x and y is inconsistent,"
804+
<< " x shape=" << x->shape
805+
<< ", y shape=" << y->shape;
806+
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
807+
<< "CrossEntropy: shapes of x and y is inconsistent,"
808+
<< " x shape=" << x->shape
809+
<< ", y shape=" << y->shape;
810+
CHECK(reporter->AssertEQ(x->shape[1], y->shape[1]))
811+
<< "CrossEntropy: shapes of x and y is inconsistent,"
812+
<< " x shape=" << x->shape
813+
<< ", y shape=" << y->shape;
814+
// assign output type
815+
reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype));
816+
return true;
817+
}
818+
819+
// Positional relay function to create batch_matmul operator used by frontend FFI.
820+
Expr MakeCrossEntropy(Expr predictions, Expr targets) {
821+
static const Op& op = Op::Get("nn.cross_entropy");
822+
return CallNode::make(op, {predictions, targets}, Attrs(), {});
823+
}
824+
825+
826+
TVM_REGISTER_API("relay.op.nn._make.cross_entropy")
827+
.set_body_typed(MakeCrossEntropy);
828+
829+
830+
RELAY_REGISTER_OP("nn.cross_entropy")
831+
.describe(R"code(Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE)
832+
.set_num_inputs(2)
833+
.add_argument("x", "1D Tensor", "Predictions.")
834+
.add_argument("y", "1D Tensor", "Targets.")
835+
.set_support_level(10)
836+
.add_type_rel("CrossEntropy", CrossEntropyRel);
837+
749838

750839
} // namespace relay
751840
} // namespace tvm

src/relay/op/tensor/transform.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ RELAY_REGISTER_OP("reinterpret")
186186
.set_attr<TOpPattern>("TOpPattern", kElemWise)
187187
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
188188

189+
189190
// relay.expand_dims
190191
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
191192

@@ -914,6 +915,43 @@ Examples::
914915
.set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
915916
.set_attr<TOpPattern>("TOpPattern", kInjective);
916917

918+
bool EmbedLikeRel(const Array<Type>& types,
919+
int num_inputs,
920+
const Attrs& attrs,
921+
const TypeReporter& reporter) {
922+
// `types` contains: [data, indices, type_like, result]
923+
CHECK_EQ(types.size(), 4);
924+
reporter->Assign(types[3], types[2]);
925+
return TakeRel({types[2], types[1], types[0]}, 2, attrs, reporter);
926+
}
927+
928+
Expr MakeEmbedLike(Expr data,
929+
Expr indices,
930+
Expr type_like,
931+
Integer axis,
932+
std::string mode) {
933+
auto attrs = make_node<TakeAttrs>();
934+
attrs->axis = std::move(axis);
935+
attrs->mode = std::move(mode);
936+
static const Op& op = Op::Get("embed_like");
937+
return CallNode::make(op, {data, indices, type_like}, Attrs(attrs), {});
938+
}
939+
940+
TVM_REGISTER_API("relay.op._make.embed_like")
941+
.set_body_typed(MakeEmbedLike);
942+
943+
RELAY_REGISTER_OP("embed_like")
944+
.describe(R"code(The inverse of take.)code" TVM_ADD_FILELINE)
945+
.set_attrs_type_key("relay.attrs.TakeAttrs")
946+
.set_num_inputs(3)
947+
.add_argument("data", "Tensor", "The input tensor.")
948+
.add_argument("indices", "Tensor", "The indices tensor.")
949+
.add_argument("type_like", "Tensor", "The tensor that provide the type and shape to embed into.")
950+
.set_support_level(3)
951+
.add_type_rel("EmbedLike", EmbedLikeRel)
952+
.set_attr<FTVMCompute>("FTVMCompute", TakeCompute) // implement this at python side?
953+
.set_attr<TOpPattern>("TOpPattern", kInjective);
954+
917955

918956
// Init ops
919957
TVM_REGISTER_NODE_TYPE(InitOpAttrs);

tests/python/relay/test_pass_gradient.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,14 @@
2323
from tvm.relay.transform import gradient
2424
from tvm.relay.prelude import Prelude
2525
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand
26+
from tvm.relay.testing import resnet, inception_v3, squeezenet, densenet, lstm
2627
import tvm.relay.op as op
2728

2829

30+
def rand(dtype='float32', *shape):
31+
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
32+
33+
2934
def test_id():
3035
shape = (10, 10)
3136
dtype = 'float32'
@@ -198,10 +203,9 @@ def test_pow():
198203
double = relay.Function([x], x + x)
199204
i = relay.var("i", t)
200205
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
201-
mod["main"] = func
202-
mod["main"] = gradient(mod["main"], mod=mod)
203-
m = transform.InferType()(mod)
204-
back_func = m["main"]
206+
mod["func"] = func
207+
mod["back_func"] = gradient(mod["func"], mod=mod)
208+
back_func = mod["back_func"]
205209
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
206210
i_nd = rand(dtype, *shape)
207211
ex = create_executor(mod=mod)
@@ -295,6 +299,19 @@ def test_concat():
295299
# no value validation as concatenate has dummy gradient right now.
296300

297301

302+
def rand_from_type(t):
303+
assert isinstance(t, relay.ty.TensorType)
304+
return rand(t.dtype, *[int(s) for s in t.shape])
305+
306+
307+
def test_resnet():
308+
x, _ = densenet.get_workload()
309+
x = gradient(x["main"])
310+
args = [rand_from_type(e.checked_type) for e in x.params]
311+
ex = create_executor()
312+
ex.evaluate(x)(*args)
313+
314+
298315
if __name__ == "__main__":
299316
test_id()
300317
test_add()

0 commit comments

Comments
 (0)