Skip to content

Commit 8b99271

Browse files
init
drop changes drop changes drop changes drop changes drop changes
1 parent 776fd6b commit 8b99271

File tree

4 files changed

+107
-15
lines changed

4 files changed

+107
-15
lines changed

python/tvm/relay/op/_tensor_grad.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,18 @@ def dense_grad(orig, grad):
320320
return [collapse_sum_like(transpose(grad) * weight, data),
321321
collapse_sum_like(data * transpose(grad), weight)]
322322

323+
# UNTESTED
324+
@register_gradient("take")
325+
def take_grad(orig, grad):
326+
x, y = orig.args
327+
return [zeros_like(x), zeros_like(y)]
328+
return [Call(op_get("take_grad"), [x, y, grad], orig.attrs), zeros_like(y)]
329+
330+
331+
@register_gradient("shape_of")
332+
def shape_of_grad(orig, grad):
333+
return [zeros_like(orig.args[0])]
334+
323335

324336
@register_gradient("reshape")
325337
def reshape_grad(orig, grad):
@@ -365,12 +377,3 @@ def sum_grad(orig, grad):
365377
"""Returns grad broadcasted to data dims"""
366378
data = orig.args[0]
367379
return [broadcast_to_like(grad, data)]
368-
369-
370-
@register_gradient("nn.cross_entropy")
371-
def cross_entropy_grad(orig, grad):
372-
x, y = orig.args
373-
shape = shape_of(x)
374-
batch_size = take(shape, const(0, dtype='int32'), axis=0)
375-
grad = grad / batch_size.astype('float32')
376-
return [-grad * y / x, -grad * log(x)]

python/tvm/relay/op/transform.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,15 @@ def cast(data, dtype):
4242

4343
def cast_like(data, dtype_like):
4444
"""Cast input tensor to data type of another tensor.
45+
4546
Parameters
4647
----------
4748
data : relay.Expr
4849
The input data to the operator.
50+
4951
dtype_like: relay.Expr
5052
The tensor to cast to.
53+
5154
Returns
5255
-------
5356
result : relay.Expr
@@ -249,8 +252,8 @@ def reshape_like(data, shape_like):
249252
data : relay.Expr
250253
The input data to the operator.
251254
252-
shape_like : tuple of int
253-
The new shape. Should be compatible with the original shape.
255+
shape_like : relay.Expr
256+
The tensor to reshape to. Should be compatible with the original shape.
254257
255258
Returns
256259
-------
@@ -260,6 +263,37 @@ def reshape_like(data, shape_like):
260263
return _make.reshape_like(data, shape_like)
261264

262265

266+
def embed_like(data, indices, type_like, axis=None, mode="clip"):
267+
"""Take elements from an array along an axis.
268+
269+
Parameters
270+
----------
271+
data : relay.Expr
272+
The source array.
273+
274+
indices : rely.Expr
275+
The indices of the values to extract.
276+
277+
type_like : relay.Expr
278+
The tensor that provide the type to embed into.
279+
280+
axis : int, optional
281+
The axis over which to select values. By default,
282+
the flattened input array is used.
283+
284+
mode : str, optional
285+
Specifies how out-of-bound indices will behave [clip, wrap, fast].
286+
clip: clip to the range (default).
287+
wrap: wrap around the indices.
288+
fast: no clip or wrap around (user must make sure indices are in-bound).
289+
290+
Returns
291+
-------
292+
ret : relay.Expr
293+
The computed result.
294+
"""
295+
return _make.embed_like(data, indices, type_like, axis, mode)
296+
263297
def take(data, indices, axis=None, mode="clip"):
264298
"""Take elements from an array along an axis.
265299

src/relay/op/tensor/transform.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ RELAY_REGISTER_OP("reinterpret")
187187
.set_attr<TOpPattern>("TOpPattern", kElemWise)
188188
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
189189

190+
190191
// relay.expand_dims
191192
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
192193

@@ -958,6 +959,43 @@ Examples::
958959
.set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
959960
.set_attr<TOpPattern>("TOpPattern", kInjective);
960961

962+
bool EmbedLikeRel(const Array<Type>& types,
963+
int num_inputs,
964+
const Attrs& attrs,
965+
const TypeReporter& reporter) {
966+
// `types` contains: [data, indices, type_like, result]
967+
CHECK_EQ(types.size(), 4);
968+
reporter->Assign(types[3], types[2]);
969+
return TakeRel({types[2], types[1], types[0]}, 2, attrs, reporter);
970+
}
971+
972+
Expr MakeEmbedLike(Expr data,
973+
Expr indices,
974+
Expr type_like,
975+
Integer axis,
976+
std::string mode) {
977+
auto attrs = make_node<TakeAttrs>();
978+
attrs->axis = std::move(axis);
979+
attrs->mode = std::move(mode);
980+
static const Op& op = Op::Get("embed_like");
981+
return CallNode::make(op, {data, indices, type_like}, Attrs(attrs), {});
982+
}
983+
984+
TVM_REGISTER_API("relay.op._make.embed_like")
985+
.set_body_typed(MakeEmbedLike);
986+
987+
RELAY_REGISTER_OP("embed_like")
988+
.describe(R"code(The inverse of take.)code" TVM_ADD_FILELINE)
989+
.set_attrs_type_key("relay.attrs.TakeAttrs")
990+
.set_num_inputs(3)
991+
.add_argument("data", "Tensor", "The input tensor.")
992+
.add_argument("indices", "Tensor", "The indices tensor.")
993+
.add_argument("type_like", "Tensor", "The tensor that provide the type and shape to embed into.")
994+
.set_support_level(3)
995+
.add_type_rel("EmbedLike", EmbedLikeRel)
996+
.set_attr<FTVMCompute>("FTVMCompute", TakeCompute) // implement this at python side?
997+
.set_attr<TOpPattern>("TOpPattern", kInjective);
998+
961999

9621000
// Init ops
9631001
TVM_REGISTER_NODE_TYPE(InitOpAttrs);

tests/python/relay/test_pass_gradient.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,14 @@
2323
from tvm.relay.transform import gradient
2424
from tvm.relay.prelude import Prelude
2525
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand
26+
from tvm.relay.testing import resnet, inception_v3, squeezenet, densenet, lstm
2627
import tvm.relay.op as op
2728

2829

30+
def rand(dtype='float32', *shape):
31+
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
32+
33+
2934
def test_id():
3035
shape = (10, 10)
3136
dtype = 'float32'
@@ -198,10 +203,9 @@ def test_pow():
198203
double = relay.Function([x], x + x)
199204
i = relay.var("i", t)
200205
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
201-
mod["main"] = func
202-
mod["main"] = gradient(mod["main"], mod=mod)
203-
m = transform.InferType()(mod)
204-
back_func = m["main"]
206+
mod["func"] = func
207+
mod["back_func"] = gradient(mod["func"], mod=mod)
208+
back_func = mod["back_func"]
205209
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
206210
i_nd = rand(dtype, *shape)
207211
ex = create_executor(mod=mod)
@@ -295,6 +299,19 @@ def test_concat():
295299
# no value validation as concatenate has dummy gradient right now.
296300

297301

302+
def rand_from_type(t):
303+
assert isinstance(t, relay.ty.TensorType)
304+
return rand(t.dtype, *[int(s) for s in t.shape])
305+
306+
307+
def test_resnet():
308+
x, _ = densenet.get_workload()
309+
x = gradient(x["main"])
310+
args = [rand_from_type(e.checked_type) for e in x.params]
311+
ex = create_executor()
312+
ex.evaluate(x)(*args)
313+
314+
298315
if __name__ == "__main__":
299316
test_id()
300317
test_add()

0 commit comments

Comments
 (0)