Skip to content

Commit 9cc7874

Browse files
authored
[Relay][Params] Add APIs for storing and retrieving parameters from individual functions. (#4194)
* Add support for attaching params * Fix types * Fix test
1 parent 93d610a commit 9cc7874

File tree

4 files changed

+76
-3
lines changed

4 files changed

+76
-3
lines changed

include/tvm/relay/expr.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,19 @@ class FunctionNode : public ExprNode {
274274
tvm::Array<TypeVar> ty_params,
275275
tvm::Attrs attrs = Attrs());
276276

277+
/*!
278+
* \brief Attach the function's parameters to its attributes for use in analysis.
279+
* \return The function with its parameters attached.
280+
*/
281+
Function SetParams(const tvm::Map<Var, Constant>& parameters) const;
282+
283+
/*!
284+
* \brief Retrieve the function's parameters.
285+
*
286+
* \return The function's parameter.
287+
*/
288+
tvm::Map<Var, Constant> GetParams() const;
289+
277290
static constexpr const char* _type_key = "relay.Function";
278291
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
279292
};
@@ -284,7 +297,6 @@ RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);
284297
TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
285298
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);
286299

287-
288300
/*!
289301
* \brief Call corresponds to operator invocation.
290302
* Corresponds to the operator in computational graph terminology.

python/tvm/relay/expr.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .._ffi import base as _base
2828
from .. import nd as _nd
2929
from .. import convert
30+
from ..ndarray import NDArray
3031

3132
# will be registered afterwards
3233
_op_make = None
@@ -305,6 +306,17 @@ def __call__(self, *args):
305306
"""
306307
return Call(self, args, None, None)
307308

309+
def get_params(self):
310+
return _expr.FunctionGetParams(self)
311+
312+
def set_params(self, params):
313+
for key in params:
314+
value = params[key]
315+
if isinstance(value, NDArray):
316+
params[key] = Constant(value)
317+
318+
return _expr.FunctionSetParams(self, params)
319+
308320

309321
@register_relay_node
310322
class Call(Expr):

src/relay/ir/expr.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,26 @@ bool FunctionNode::IsPrimitive() const {
159159
return pval && pval->value != 0;
160160
}
161161

162+
Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
163+
return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
164+
}
165+
166+
TVM_REGISTER_API("relay._expr.FunctionSetParams")
167+
.set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
168+
[](const Function& func, const tvm::Map<Var, Constant>& parameters) {
169+
return func->SetParams(parameters);
170+
});
171+
172+
tvm::Map<Var, Constant> FunctionNode::GetParams() const {
173+
auto node_ref = FunctionGetAttr(GetRef<Function>(this), "__params__");
174+
return Downcast<tvm::Map<Var, Constant>>(node_ref);
175+
}
176+
177+
TVM_REGISTER_API("relay._expr.FunctionGetParams")
178+
.set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
179+
return func->GetParams();
180+
});
181+
162182
NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
163183
if (!func->attrs.defined()) { return NodeRef(); }
164184

tests/python/relay/test_ir_nodes.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tvm.expr import *
2121
from tvm.relay import op
2222
from tvm.relay.analysis import graph_equal
23-
23+
import numpy as np
2424

2525
def check_json_roundtrip(node):
2626
json_str = tvm.save_json(node)
@@ -160,7 +160,6 @@ def test_global_var():
160160
str(gv)
161161
check_json_roundtrip(gv)
162162

163-
164163
def test_function():
165164
param_names = ['a', 'b', 'c', 'd']
166165
params = tvm.convert([relay.Var(n) for n in param_names])
@@ -175,6 +174,34 @@ def test_function():
175174
str(fn)
176175
check_json_roundtrip(fn)
177176

177+
def test_function_attrs():
178+
param_names = ['a', 'b', 'c', 'd']
179+
params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names])
180+
ret_type = relay.TupleType(tvm.convert([]))
181+
body = relay.Tuple(tvm.convert([]))
182+
type_params = tvm.convert([])
183+
fn = relay.Function(params, body, ret_type, type_params)
184+
model_params = {}
185+
for param in params[:1]:
186+
cty = param.type_annotation
187+
tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype)
188+
model_params[param] = tvm.nd.array(tensor)
189+
fn = fn.set_params(model_params)
190+
assert fn.params == params
191+
assert fn.body == body
192+
assert fn.type_params == type_params
193+
assert fn.span == None
194+
str(fn)
195+
check_json_roundtrip(fn)
196+
json_str = tvm.save_json(fn)
197+
fn_after = tvm.load_json(json_str)
198+
model_params_after = fn_after.get_params()
199+
after_keys = [item[0] for item in model_params_after.items()]
200+
for key1, key2 in zip(model_params, after_keys):
201+
assert key1.name_hint == key2.name_hint
202+
p1 = model_params[key1]
203+
p2 = model_params_after[key2]
204+
np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy())
178205

179206
def test_call():
180207
op = relay.Var('f')
@@ -257,9 +284,11 @@ def test_conv2d_attrs():
257284
test_local_var()
258285
test_global_var()
259286
test_function()
287+
test_function_attrs()
260288
test_call()
261289
test_let()
262290
test_if()
263291
test_tuple_get_item()
264292
test_op()
265293
test_conv2d_attrs()
294+

0 commit comments

Comments
 (0)