Skip to content

Add tracing support for mgrid and advanced tensor indexing #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

aqjune-aws
Copy link

This patch adds support for mgrid as well as tensor indexing that uses the output of mgrid to Python tracing.

https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing

Given tensors ind_1, ind_2, ..., x[ind_1, ind_2, .., ind_N] has advanced indexing over the elements of x.
The result of the access is:

result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
                          ..., ind_N[i_1, ..., i_M]]

In NumPy, mixing advanced indexing and basic indexing is allowed. However, in NKI, only one of the two forms is allowed.
Refer to:
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/programming_model.html "Note that currently NKI does not support mixing Basic and Advanced Tensor
Indexing in the same Index tuple."

This patch adds support for `mgrid` as well as tensor indexing that uses the output of `mgrid` to Python tracing.

https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing

Given tensors ind_1, ind_2, ..., x[ind_1, ind_2, .., ind_N] has advanced
indexing over the elements of x.
The result of the access is:

```
result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
                          ..., ind_N[i_1, ..., i_M]]
```

In NumPy, mixing advanced indexing and basic indexing is allowed. However,
in NKI, only one of the two forms is allowed.
Refer to:
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/programming_model.html
"Note that currently NKI does not support mixing Basic and Advanced Tensor
  Indexing in the same Index tuple."
@aqjune-aws
Copy link
Author

CI is failing, but this will be fixed after BEq is added to structure Tensor in TensorLib (leanprover/TensorLib#67).

My pull request does not have a unit test for mgrid because I have no idea how to write it. I will appreciate any instruction :)

@seanmcl
Copy link
Collaborator

seanmcl commented May 1, 2025

You can use #guard and plausible. See, for example, some simple Base64 tests.

Start with #guard and #guard_msgs. You can use #eval to get the test right before you switch to #guard. Happy to show you this irl as well :)

Copy link
Collaborator

@govereau govereau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job!
There are some details we need to think through a bit.
Perhaps you and Sean and I can sync up on this?

Returns: ⟨ t[0,0,...], steps ⟩
For the above example, the return value is ⟨ 10, [5, 20] ⟩
-/
def decomposeLinearIntTensor (t:TensorLib.Tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.
It would be good to add some basic #guard checks for testing.
Also, I wonder if this should be in TensorLib. @seanmcl ?

Copy link
Author

@aqjune-aws aqjune-aws May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added #guard . The guards will pass once leanprover/TensorLib#74 is merged :)

@@ -294,16 +468,42 @@ def assignExpr (e : Core.Expr) (t : Term) : Trace Unit := do
-- Unpack an RValue, must be a list or tuple
def unpack : Term -> Trace (List Term)
| .tuple l | .list l => return l
| .tensor t =>
-- Unpack tensor to a list of subtensors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too complex? You should be able to use TensorLib like numpy.
TensorLib can compute, e.g. t[i,...], which you seem to be doing in a fairly low-level way.
@seanmcl ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in fact I could not find the right function. I could have missed one.

aqjune pushed a commit to aqjune-aws/KLR that referenced this pull request May 7, 2025
This patch allows Tensor.store API (which is connected to `nki.language.store`) to accept a more
generic `Core.Value` type.

The motivation is tracing of interop/test/examples/matmul.py, specifically the `nki_matmul_basic_` function.

After apply leanprover#111, tracing the Python function was raising the following error message:

```
error:
line 44:
  nl.store(result[i_out_p, i_out_f], value=result_sbuf)
  ^-- expecting tensor access
```

It is because its `value` keyword argument was having the following expression:
```
KLR.Trace.Term.expr (KLR.Core.Expr.value (KLR.Core.Value.var "5")) (KLR.Trace.TermType.obj `object)
```
which could not be converted to Access through the FromNKI typeclass.

The "5" temporary variable was emerging from the right hand side of the definition of `result_sbuf`:

```
result_sbuf = nl.copy(result_psum, dtype=result.dtype)
```

To convert the value of "5", it seems we need to get the generated trace and find assignment to "5"
because:

```
def RValue : Term -> Trace Term
...
  | .expr e@(.call ..) ty => do
       let v := (<- genName).toString
       add_stmt (.assign v e)
       return .expr (.value $ .var v) ty
```

the `add_stmt` is just adding a Core statement to `State.body`.

Skimming through `State.body` and finding this assignment to "5" didn't seem something we wanted to do inside Tensor.store,
so instead I slightly chose a conservative approach and simply removed the shape checker.

But any other reasonable option is still fine with me.
Copy link
Collaborator

@govereau govereau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants