Skip to content

Commit

Permalink
[TVMScript][Parser] B1: Dataflow block (apache#252)
Browse files Browse the repository at this point in the history
This PR features the parser support for Relax dataflow blocks.
  • Loading branch information
MasterJH5574 authored Sep 22, 2022
1 parent 574178b commit f320672
Show file tree
Hide file tree
Showing 12 changed files with 655 additions and 123 deletions.
68 changes: 37 additions & 31 deletions include/tvm/script/ir_builder/relax/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,6 @@ class RelaxFrame : public IRBuilderFrame {
RelaxFrame() = default;
};

/*! \brief The ir_builder frame for relax binding blocks. */
class BlockFrameNode : public RelaxFrameNode {
public:
/*! \brief The flag that indicates whether the block is a dataflow block. */
bool is_dataflow;

void VisitAttrs(tvm::AttrVisitor* v) {
RelaxFrameNode::VisitAttrs(v);
v->Visit("is_dataflow", &is_dataflow);
}

static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, RelaxFrameNode);

public:
void EnterWithScope() final;
void ExitWithScope() final;
};

class BlockFrame : public RelaxFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame, BlockFrameNode);
};

/*! \brief The ir_builder frame for the relax function. */
class FunctionFrameNode : public RelaxFrameNode {
public:
Expand All @@ -94,12 +70,10 @@ class FunctionFrameNode : public RelaxFrameNode {
Map<String, ObjectRef> attrs;
/*! \brief The binding blocks inside the function. */
Array<tvm::relax::BindingBlock> binding_blocks;
/*! \brief The function output expr. */
Array<tvm::relax::Expr> outputs;
/*! \brief The function output expr. `NullOpt` when undefined. */
Optional<tvm::relax::Expr> output;
/*! \brief The block builder to create Relax function. */
tvm::relax::BlockBuilder block_builder;
/*! \brief The default binding block frame of the function. */
BlockFrame default_binding_block_frame{nullptr};

void VisitAttrs(tvm::AttrVisitor* v) {
RelaxFrameNode::VisitAttrs(v);
Expand All @@ -108,16 +82,14 @@ class FunctionFrameNode : public RelaxFrameNode {
v->Visit("ret_type", &ret_type);
v->Visit("attrs", &attrs);
v->Visit("binding_blocks", &binding_blocks);
v->Visit("outputs", &outputs);
v->Visit("output", &output);
// `block_builder` is not visited.
// `default_binding_block_frame` is not visited.
}

static constexpr const char* _type_key = "script.ir_builder.relax.FunctionFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, RelaxFrameNode);

public:
void EnterWithScope() final;
void ExitWithScope() final;
};

Expand All @@ -126,6 +98,40 @@ class FunctionFrame : public RelaxFrame {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionFrame, RelaxFrame, FunctionFrameNode);
};

/*! \brief The ir_builder frame for relax binding blocks. */
class BlockFrameNode : public RelaxFrameNode {
public:
/*! \brief The flag that indicates whether the block is a dataflow block. */
bool is_dataflow;
/*! \brief The variables emitted in this block. */
Array<tvm::relax::Var> emitted_vars;
/*!
* \brief (Only used for a dataflow block.) A boolean indicating if the dataflow block is ended of
* construction. If it is true, any new binding trying to be emitted into this block will cause an
* error.
*/
bool block_ended;

void VisitAttrs(tvm::AttrVisitor* v) {
RelaxFrameNode::VisitAttrs(v);
v->Visit("is_dataflow", &is_dataflow);
v->Visit("emitted_vars", &emitted_vars);
v->Visit("block_ended", &block_ended);
}

static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, RelaxFrameNode);

public:
void EnterWithScope() final;
void ExitWithScope() final;
};

class BlockFrame : public RelaxFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame, BlockFrameNode);
};

} // namespace relax
} // namespace ir_builder
} // namespace script
Expand Down
23 changes: 17 additions & 6 deletions include/tvm/script/ir_builder/relax/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,25 +117,36 @@ TVM_DLL BlockFrame BindingBlock();
*/
TVM_DLL BlockFrame Dataflow();

/*!
* \brief Expose the dataflow block output variables as global ones
* \param vars The output variables of a dataflow block
*/
TVM_DLL void DataflowBlockOutput(const Array<tvm::relax::Var>& vars);

////////////////////////////// Bindings ////////////////////////////////

/*!
* \brief Emit a binding to the last binding block frame.
* \param value The right side value of the bindings to be emitted.
* \param is_dataflow_var A boolean indicating if the emitted binding variable is a dataflow
* variable.
* \return The left side var of the emitted binding.
*/
TVM_DLL tvm::relax::Var Emit(const tvm::relax::Expr& value);
TVM_DLL tvm::relax::Var Emit(const tvm::relax::Expr& value, bool is_dataflow_var);

/*!
* \brief Emit a match_shape binding to the last binding block frame.
* \param value The value of the MatchShape to be emitted.
* \param pattern The pattern of the MatchShape to be emitted.
* \param emit_var The flag that indicate if the match_shape contains the emitted var.
* \return The emitted var if `emit_var` is true, otherwise, `NullOpt`.
* \param emit_var A boolean indicating if the MatchShape contains the emitted variable.
* \param is_dataflow_var A boolean indicating if the emitted variable is a dataflow variable when
* `emit_var` is true. When `emit_var` is false, the value of this flag will be ignored.
* \return The emitted var if `emit_var` is true. Otherwise, return `NullOpt`.
*/
TVM_DLL Optional<tvm::relax::Var> EmitMatchShape(const tvm::relax::Expr& value,
const Array<PrimExpr>& pattern,
bool emit_var = true);
TVM_DLL Optional<tvm::relax::Var> EmitMatchShape(const tvm::relax::Expr& value, //
const Array<PrimExpr>& pattern, //
bool emit_var, //
bool is_dataflow_var);

} // namespace relax
} // namespace ir_builder
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var:
"""
return _ffi_api.BlockBuilderEmitMatchShape(self, value, pattern)

def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> Var:
"""Emit output for the current dataflow block or function.
Parameters
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/script/ir_builder/relax/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class RelaxFrame(IRBuilderFrame):
"""The base ir_builder frame for the relax dialect."""


@_register_object("script.ir_builder.relax.BlockFrame")
class BlockFrame(RelaxFrame):
"""The ir_builder frame for relax binding blocks."""


@_register_object("script.ir_builder.relax.FunctionFrame")
class FunctionFrame(RelaxFrame):
"""The ir_builder frame for the relax function."""


@_register_object("script.ir_builder.relax.BlockFrame")
class BlockFrame(RelaxFrame):
"""The ir_builder frame for relax binding blocks."""
58 changes: 44 additions & 14 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=redefined-builtin, wrong-import-order
"""IRBuilder for Relax dialect"""

from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

from tvm._ffi import register_object as _register_object
from tvm.ir import Type
Expand Down Expand Up @@ -158,50 +158,80 @@ def func_ret_value(value: Expr) -> None:
############################# BindingBlock ##############################


def binding_block() -> frame.BlockFrame:
"""Start a binding block frame.
def dataflow() -> frame.BlockFrame:
"""Start a dataflow binding block frame.
Returns
-------
frame: frame.BlockFrame
The created ir_builder Block frame.
"""
return _ffi_api.BindingBlock() # pylint: disable=no-member # type: ignore
return _ffi_api.Dataflow() # pylint: disable=no-member # type: ignore


def dataflow() -> frame.BlockFrame:
"""Start a dataflow binding block frame.
def output(*vars: Tuple[Var]) -> Tuple[Var]:
"""Expose the dataflow block output variables as global ones.
Parameters
----------
vars: Tuple[Var]
The output variables of a dataflow block.
Returns
-------
frame: frame.BlockFrame
The created ir_builder Block frame.
vars: Tuple[Var]
The output variables of a dataflow block. Return the input variables to parser side for
followup process
"""
return _ffi_api.Dataflow() # pylint: disable=no-member # type: ignore
_ffi_api.DataflowBlockOutput(vars) # pylint: disable=no-member # type: ignore
return vars


############################### Bindings ###############################


def emit(value: Expr) -> Var:
def emit(value: Expr, is_dataflow_var: bool) -> Var:
"""Emit a binding to the last binding block frame.
Parameters
----------
value: Expr
The right side value of the bindings to be emitted.
is_dataflow_var: bool
A boolean indicating if the emitted binding variable is a dataflow variable.
Returns
-------
var: Var
The left side var of the emitted binding.
"""
return _ffi_api.Emit(value) # type: ignore
return _ffi_api.Emit(value, is_dataflow_var) # pylint: disable=no-member # type: ignore


def emit_match_shape(
value: Expr, pattern: List[PrimExpr], emit_var: bool, is_dataflow_var: bool
) -> Optional[Var]:
"""Emit a match_shape binding to the last binding block frame.
def emit_match_shape(value: Expr, pattern: List[PrimExpr], emit_var: bool = True) -> Var:
return _ffi_api.EmitMatchShape(value, pattern, emit_var) # type: ignore
Parameters
----------
value: Expr
The value of the MatchShape to be emitted.
pattern: List[PrimExpr]
The pattern of the MatchShape to be emitted.
emit_var: bool
A boolean indicating if the MatchShape contains the emitted variable.
is_dataflow_var: bool
A boolean indicating if the emitted variable is a dataflow variable when `emit_var` is True.
When `emit_var` is False, the value of this flag will be ignored.
Returns
-------
var: Optional[Var]
The emitted var if `emit_var` is True. Otherwise, return `None`.
"""
return _ffi_api.EmitMatchShape(value, pattern, emit_var, is_dataflow_var) # type: ignore


############################### Importer ###############################
Expand All @@ -210,7 +240,6 @@ def emit_match_shape(value: Expr, pattern: List[PrimExpr], emit_var: bool = True
"TensorType",
"add",
"arg",
"binding_block",
"builtin",
"call_tir",
"dataflow",
Expand All @@ -224,6 +253,7 @@ def emit_match_shape(value: Expr, pattern: List[PrimExpr], emit_var: bool = True
"invoke_closure",
"make_closure",
"multiply",
"output",
"unique",
"shape_of",
"tensor",
Expand Down
Loading

0 comments on commit f320672

Please sign in to comment.