Skip to content

Commit

Permalink
Add virtual device as a first class field to Relay nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
electriclilies committed Dec 2, 2021
1 parent f435a13 commit 0fe8291
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 50 deletions.
27 changes: 27 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ namespace tvm {

using tvm::runtime::String;

// Forward-declare SEScope to avoid circular imports.
class SEScope;

/*!
* \brief Base type of all the expressions.
* \sa Expr
Expand Down Expand Up @@ -165,6 +168,30 @@ class RelayExprNode : public BaseExprNode {
template <typename TTypeNode>
inline const TTypeNode* type_as() const;

/*!
* \brief The virtual device (SEScope) for this node (the result of device planning).
* For first-order expressions (non functions), this describes where the result of evaluating the
* expression should be stored. Note that currently, all composite first-order values (tuples,
* references, ADTs) must be stored on the same virtual device. This means that it is not possible
* to store two tuple fields on different devices, so we only need one virtual device for these
* types.
*
* For expressions that have the function type, the virtual device describes where the result of
* the call to the function or closure is stored (instead of where the function itself is stored).
* The SEScope's Target field describes how the body of the function should be compiled.
*
* \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular
* import. We can forward-declare the SEScope type for the getter function, but not for the field
* itself.
*/
mutable ObjectRef virtual_device_;

/*!
* \return The virtual device (SEScope).
* If the virtual device is not defined, returns SEScope::FullyUnconstrained().
*/
SEScope virtual_device() const;

static constexpr const char* _type_key = "RelayExpr";
static constexpr const uint32_t _type_child_slots = 22;
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
Expand Down
102 changes: 62 additions & 40 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/ir/op.h>
#include <tvm/target/se_scope.h>

#include <functional>
#include <stack>
Expand Down Expand Up @@ -151,10 +152,14 @@ class Tuple : public Expr {
* \param tuple The tuple to copy
* \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields =
* tuple->fields.
* \param opt_span The (optional) span for the copied tuple. If none, ret_tuple->span = tuple->span.
* \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none,
* ret_tuple->virtual_device = tuple->virtual_device.
* \param opt_span The (optional) span for the copied tuple. If none,
* ret_tuple->span = tuple->span.
*/
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = Optional<Array<Expr>>(),
Optional<Span> opt_span = Optional<Span>(nullptr));
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*!
* \brief Local variables used in the let expression.
Expand Down Expand Up @@ -240,14 +245,16 @@ class Var : public Expr {
* \param opt_vid The (optional) vid for the copied var. If none, ret_var->vid = var->vid.
* \param opt_type_annotation The (optional) type_annotation for the copied var. If none,
* ret_var->type_annotation = var->type_annotation.
* \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span.
* \return If all properties are null or the same as the property in the input var
* (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise,
* we return a copy of call with the different fields overwritten. (i.e., if
* opt_vid.value() != var->vid, then ret_var->vid = opt_.value()).
* \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none,
* ret_tuple->virtual_device = tuple->virtual_device. \param opt_span The (optional) span for the
* copied var. If none, ret_var->span = var->span. \return If all properties are null or the same as
* the property in the input var (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then
* we return var. Otherwise, we return a copy of call with the different fields overwritten. (i.e.,
* if opt_vid.value() != var->vid, then ret_var->vid = opt_.value()).
*/
Var WithFields(Var var, Optional<Id> opt_vid = Optional<Id>(),
Optional<Type> opt_type_annotation = Optional<Type>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*!
Expand Down Expand Up @@ -362,16 +369,18 @@ class Call : public Expr {
* call->attrs.
* \param opt_type_args The (optional) type args for the copied call. If none,
* ret_call->type_args = call->type_args.
* \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span.
* \return If all properties are null or the same as the property in the input call
* (i.e., opt_op is null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we
* return a copy of call with the different fields overwritten. (i.e., if opt_op.value() !=
* call->op, then ret_call->op = opt_op.value()).
* \param opt_virtual_device The (optional) virtual_device for the copied call. If none,
* ret_call->virtual_device = call->virtual_device. \param opt_span The (optional) span for the
* copied call. If none, ret_call->span = call->span. \return If all properties are null or the same
* as the property in the input call (i.e., opt_op is null or opt_op.value() == call->op, etc.),
* then we return call. Otherwise, we return a copy of call with the different fields overwritten.
* (i.e., if opt_op.value() != call->op, then ret_call->op = opt_op.value()).
*/
Call WithFields(Call call, Optional<Expr> opt_op = Optional<Expr>(),
Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(),
Optional<Attrs> opt_attrs = Optional<Attrs>(),
Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*!
Expand Down Expand Up @@ -456,15 +465,17 @@ class Let : public Expr {
* \param opt_var The (optional) var for the copied let. If none, ret_let->op = let->op.
* \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args.
* \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs.
* \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span.
* \return If all properties are null or the same as the property in the input let (i.e., opt_var is
* null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of
* let with the different fields overwritten. (i.e., if opt_var.value() != let->var, then
* ret_let->var = opt_var.value()).
* \param opt_virtual_device The (optional) virtual_device for the copied let. If none,
* ret_let->virtual_device = let->virtual_device. \param opt_span The (optional) span for the copied
* let. If none, ret_let->span = let->span. \return If all properties are null or the same as the
* property in the input let (i.e., opt_var is null or opt_var.value() == let->var, etc.), then we
* return let. Otherwise, we return a copy of let with the different fields overwritten. (i.e., if
* opt_var.value() != let->var, then ret_let->var = opt_var.value()).
*/
Let WithFields(Let let, Optional<Var> opt_var = Optional<Var>(),
Optional<Expr> opt_value = Optional<Expr>(),
Optional<Expr> opt_body = Optional<Expr>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*!
Expand Down Expand Up @@ -539,17 +550,19 @@ class If : public Expr {
* ret_if->true_branch = ret_if->false_branch.
* \param opt_false_branch The (optional) false_branch
* for the copied if_expr. If none, ret_if->false_branch = if_expr->false_branch.
* \param opt_span
* The (optional) span for the copied if_expr. If none, ret_if->span = if_expr->span.
* \return If all
* properties are null or the same as the property in the input if_expr (i.e., opt_cond is null or
* opt_cond.value() == if_expr->cond, etc.), then we return if_expr. Otherwise, we return a copy of
* if_expr with the different fields overwritten. (i.e., if opt_cond.value() != if_expr->cond, then
* ret_if->cond = opt_cond.value()).
* \param opt_virtual_device The (optional) virtual_device for the copied if_expr. If none,
* ret_if->virtual_device = if_expr->virtual_device.
* \param opt_span The (optional) span for the copied if_expr. If none,
* ret_if->span = if_expr->span.
* \return If all properties are null or the same as the property in
* the input if_expr (i.e., opt_cond is null or opt_cond.value() == if_expr->cond, etc.), then we
* return if_expr. Otherwise, we return a copy of if_expr with the different fields overwritten.
* (i.e., if opt_cond.value() != if_expr->cond, then ret_if->cond = opt_cond.value()).
*/
If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
Optional<Expr> opt_true_branch = Optional<Expr>(),
Optional<Expr> opt_false_branch = Optional<Expr>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief Get index-th field out of a tuple. */
Expand Down Expand Up @@ -603,8 +616,9 @@ class TupleGetItem : public Expr {
* ret_tuple_get_item->tuple = tuple_get_item->tuple.
* \param opt_index The (optional) index for the copied tuple_get_item. If none,
* ret_tuple_get_item->index = tuple_get_item->index.
* \param
* opt_span The (optional) span for the copied tuple_get_item. If none,
* \param opt_virtual_device The (optional) virtual_device for the copied tuple_get_item.
* If none, ret_tuple_get_item->virtual_device = tuple_get_item->virtual_device.
* \param opt_span The (optional) span for the copied tuple_get_item. If none,
* ret_tuple_get_item->span = tuple_get_item->span.
* \return If all properties are null or the same as the property in the input tuple_get_item
* (i.e., opt_tuple is null or opt_tuple.value() == tuple_get_item->tuple, etc.), then we return
Expand All @@ -614,6 +628,7 @@ class TupleGetItem : public Expr {
*/
TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple = Optional<Expr>(),
Optional<Integer> opt_index = Optional<Integer>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief Create a new Reference out of initial value. */
Expand Down Expand Up @@ -663,6 +678,8 @@ class RefCreate : public Expr {
* \param ref_create The ref_create to copy.
* \param opt_value The (optional) value for the copied ref_create. If none,
* ret_ref_create->value = ref_create->value.
* \param opt_virtual_device The (optional) virtual_device for the copied ref_create. If none,
* ret_ref_create->virtual_device = ref_create->virtual_device.
* \param opt_span The (optional) span for the copied ref_create. If none,
* ret_ref_create->span = ref_create->span.
* \return If all properties are null or the same as the property in the input ref_create
Expand All @@ -672,6 +689,7 @@ class RefCreate : public Expr {
* ret_ref_create->value = opt_value.value()).
*/
RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value = Optional<Expr>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief Get value out of Reference. */
Expand Down Expand Up @@ -720,15 +738,17 @@ class RefRead : public Expr {
* \param ref_read The ref_read to copy.
* \param opt_ref The (optional) ref for the copied ref_read. If none, ret_ref_read->ref =
* ref_read->ref.
* \param opt_span
* The (optional) span for the copied ref_read. If none, ret_ref_read->span = ref_read->span.
* \return If all properties are null or the same as the property in the input ref_read
* (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return ref_read.
* Otherwise, we return a copy of ref_read with the different fields overwritten.
* (i.e., if opt_ref.value() != ref_read->ref, then
* ret_ref_read->ref = opt_ref.value()).
* \param opt_virtual_device
* The (optional) virtual_device for the copied ref_read. If none, ret_ref_read->virtual_device =
* ref_read->virtual_device. \param opt_span The (optional) span for the copied ref_read. If none,
* ret_ref_read->span = ref_read->span. \return If all properties are null or the same as the
* property in the input ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.),
* then we return ref_read. Otherwise, we return a copy of ref_read with the different fields
* overwritten. (i.e., if opt_ref.value() != ref_read->ref, then ret_ref_read->ref =
* opt_ref.value()).
*/
RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref = Optional<Expr>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
Expand Down Expand Up @@ -784,16 +804,18 @@ class RefWrite : public Expr {
* ret_ref_write->ref = ref_write->ref.
* \param opt_value The (optional) value for the copied ref_write. If none,
* ret_ref_write->value = ref_write->value.
* \param opt_span
* The (optional) span for the copied ref_write. If none, ret_ref_write->span = ref_write->span.
* \return If all properties are null or the same as the property in the input ref_write
* (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write.
* Otherwise, we return a copy of ref_write with the different fields overwritten.
* (i.e., if ref_write.value() != ref_write->ref, then
* ret_ref_write->ref = opt_ref.value()).
* \param opt_virtual_device
* The (optional) virtual_device for the copied ref_write. If none, ret_ref_write->virtual_device =
* ref_write->virtual_device. \param opt_span The (optional) span for the copied ref_write. If none,
* ret_ref_write->span = ref_write->span. \return If all properties are null or the same as the
* property in the input ref_write (i.e., opt_ref is null or opt_ref.value() == ref_write->ref,
* etc.), then we return ref_write. Otherwise, we return a copy of ref_write with the different
* fields overwritten. (i.e., if ref_write.value() != ref_write->ref, then ret_ref_write->ref =
* opt_ref.value()).
*/
RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref = Optional<Expr>(),
Optional<Expr> opt_value = Optional<Expr>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*!
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class Function : public BaseFunc {
* \param opt_attrs
* The (optional) attributes for the copied function. If none,
* ret_function->attrs = function->attrs.
* \param opt_virtual_device The (optional) virtual_device for the copied function. If none,
* ret_function->virtual_device = function->virtual_device.
* \param opt_span The (optional) span for the copied function. If none,
* ret_function->span = function->span.
* \return If all properties are null or the same as the property in the input function
Expand All @@ -146,6 +148,7 @@ Function WithFields(Function function, Optional<Array<Var>> opt_params = Optiona
Optional<Type> opt_ret_type = Optional<Type>(),
Optional<Array<TypeVar>> opt_ty_params = Optional<Array<TypeVar>>(),
Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*
Expand Down
Loading

0 comments on commit 0fe8291

Please sign in to comment.