Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Frontend] NNModule tensor_ir_op support #16278

Merged
merged 1 commit into from
Dec 26, 2023

Conversation

Hzfengsy
Copy link
Member

@Hzfengsy Hzfengsy commented Dec 26, 2023

This PR adds support for tensor_ir_op in NNModule, which enables us to call TensorIR function in NNModule.

Also this PR adds a test case for extern op.

cc @junrushao @MasterJH5574

This PR adds support for `tensor_ir_op` in NNModule, which enables us to
call TensorIR function in NNModule.

Also this PR adds a test case for extern op.
Copy link
Contributor

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @Hzfengsy!

@@ -1461,13 +1461,87 @@ def _convert(arg):
OutType = TypeVar("OutType", bound=Union[Tensor, Sequence[Tensor]])


def tensor_ir_op(
func: _tir.PrimFunc,
name_hint: str,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There’s a bit of complication here: if the PrimFunc provided is a public function (has “global_symbol” field in its attrs), Relax is not allowed to rename it, and in this case, it’s not a name hint but a name instead. Therefore, we will have to check symbol duplication and potentially throw an error if it happens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably leave this logic to future work, but let’s rename name_hint to name to better reflect this point

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree and thanks for pointing it out. However, the current Python interface AddFunction also treats it as name_hint, which may be renamed if conflicts exist.

It would be an independent problem out of the scope of this PR.

@@ -508,5 +509,134 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten
tvm.ir.assert_structural_equal(irmodule, Expected)


def test_tensor_ir_op():
num_q_heads, num_kv_heads, head_dim = 8, 8, 16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unittest is a bit more complicated than I expected :)) in the simplest case, we could probably just supply a “B = A + 1”-style TIR

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging this in for now, and please follow up with my comments in subsequent PRs

@junrushao junrushao merged commit 889d2f6 into apache:unity Dec 26, 2023
17 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants