Skip to content

Add tracing support for mgrid and advanced tensor indexing #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions KLR/Trace/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def Term.isTrue : Term -> Err Bool
| .tuple []
| .list [] => return false
| .module _
| .mgrid
| .builtin ..
| .source _
| .string _
Expand All @@ -67,6 +68,11 @@ def Term.isTrue : Term -> Err Bool
| .store ..
| .pointer .. => return true
| .expr (.value v) _ => return v.isTrue
| .tensor _ =>
-- Using ndarray as bool in Python raises:
-- "ValueError: The truth value of an array with more than one element
-- is ambiguous. Use a.any() or a.all()"
throw "tensor cannot be evaluated as bool"
| .expr _ _ => throw "non-constant expression"

def Term.isFalse (t : Term) : Err Bool :=
Expand Down
4 changes: 4 additions & 0 deletions KLR/Trace/FromNKI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ instance : FromNKI Expr where
| .module _ => err "module"
| .builtin n .. => return .value (.var n.toString)
| .source _ => err "function"
-- tensor and mgrid must not survive after tracing, thus there is no
-- corresponding expression in KLR.Core.Expr.
| .tensor _ => err "tensor"
| .mgrid => err "mgrid"
| .none => err "none"
| .string _ => err "string"
| .tuple _ => err "tuple"
Expand Down
1 change: 1 addition & 0 deletions KLR/Trace/NKI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,6 @@ def NKIEnv : List (Name × Term) :=
, const_var (nl "hbm")
, const_var (nl "sbuf")
, const_var (nl "psum")
, (nl "mgrid", .mgrid)
]
++ NKIBuiltins.map fun (x,_) => (x, .builtin x (.obj x) none)
284 changes: 280 additions & 4 deletions KLR/Trace/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import KLR.Python
import KLR.Trace.Types
import KLR.Trace.Builtin
import KLR.Trace.Basic
import TensorLib

namespace KLR.Trace
open KLR.Python
Expand Down Expand Up @@ -98,6 +99,221 @@ def listAccess (name : String) (l : List Term) : Term -> Err Term
else throw "index out of bounds"
|_ => throw s!"{name} indicies must be integers of slices"


/-
Given an int tensor t of N dimensions, find 'steps' which is a list of N
s.t. for every i1,i2,..,
t[i1,i2,...] = t[0,0,...] + (i1,i2,..) ⬝ steps
where + is an elementwise addition and ⬝ is a dot product.

For example, if t is [[10, 15], [30, 35]], valAtZero is 10, and
steps = [5, 20].
t[1,1] = 35 = 10 + (1,1) ⬝ (5,20)

Returns: ⟨ t[0,0,...], steps ⟩
For the above example, the return value is ⟨ 10, [5, 20] ⟩
-/
def decomposeLinearIntTensor (t:TensorLib.Tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.
It would be good to add some basic #guard checks for testing.
Also, I wonder if this should be in TensorLib. @seanmcl ?

Copy link
Author

@aqjune-aws aqjune-aws May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added #guard . The guards will pass once leanprover/TensorLib#74 is merged :)

: Err (Int × List (Core.APPair)) := do
match t.dtype with
| TensorLib.Dtype.uint64 | TensorLib.Dtype.int64 =>
let dimsize := t.shape.val.length
let zerocoord := List.replicate dimsize 0
-- Get the value at t[0,..,0]
let valAtZero <- t.intAtDimIndex zerocoord
-- And get 't[0,..,1,..,0] - t[0,..,0,..,0]' to get the size of step
let steps <- List.mapM (fun i => do
if List.getD t.shape.val i 0 ≤ 1
then return 0 -- t is too small in this axis to calculate the step
else
let nextcoord := List.set zerocoord i 1
let valAtOne ← t.intAtDimIndex nextcoord
return valAtOne - valAtZero
) (List.range dimsize)
-- Check whether t[idx] = valAtZero + idx * steps for all indices
let _ <- check [] dimsize steps valAtZero
return ⟨
valAtZero,
List.map (fun (sz, step) => { step := step, num := sz : Core.APPair })
(List.zip t.shape.val steps)
| _ => throw s!"supports uint64 or int64 only, but got {repr t.dtype}"
where
check (idx:List Nat) (cnt:Nat) (steps:List Int) (valAtZero:Int)
: Err Bool := do
let l := idx.length
if cnt > 0 then
let r := List.range (List.getD t.shape.val l 0)
List.allM (fun k => check (idx ++ [k]) (cnt - 1) steps valAtZero) r
else
let mult: List Int := List.map
(fun (a,b) => (Int.ofNat a) * b) (List.zip idx steps)
let dot: Int := mult.foldl (fun a b => a + b) 0
let lhs: Int <- t.intAtDimIndex idx
return decide (lhs = valAtZero + dot)


#guard match (decomposeLinearIntTensor
(TensorLib.Tensor.zeros TensorLib.Dtype.int64
{ val := [10, 10] : TensorLib.Shape})) with
| .error _ => false
| .ok (valAtZero, steps) =>
valAtZero = 0 ∧
steps == [
{ step := 0, num := 10 : Core.APPair },
{ step := 0, num := 10 : Core.APPair }]

#guard match (do
let tensor <- TensorLib.Tensor.ofIntList TensorLib.Dtype.int64
[10, 20, 30, 40, 50, 60]
let tensor2d <- tensor.reshape (TensorLib.Shape.mk [2, 3])
decomposeLinearIntTensor tensor2d) with
| .error _ => false
| .ok (valAtZero, steps) =>
valAtZero = 10 ∧ steps == [
{ step := 30, num := 2 : Core.APPair },
{ step := 10, num := 3 : Core.APPair }
]

/-
# Implement Advanced tensor indexing.

https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing

Given tensors ind_1, ind_2, ..., x[ind_1, ind_2, .., ind_N] has advanced
indexing over the elements of x.
The result of the access is:

result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
..., ind_N[i_1, ..., i_M]]

In NumPy, mixing advanced indexing and basic indexing is allowed. However,
in NKI, only one of the two forms is allowed.
Refer to:
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/programming_model.html
"Note that currently NKI does not support mixing Basic and Advanced Tensor
Indexing in the same Index tuple."

## Q: Is the result of advanced indexing a view of the tensor or a copy of it?
In the case of numpy, the numpy doc above says that advanced indexing always
returns a copy of the data (contrast with basic slicing that returns a view).
However, we cannot assume the same thing for NKI, because NKI typically does
the following stuff (excerpted from a matmul example):

i_lhsT_p, i_lhsT_f = nl.mgrid[0:128, 0:64]
...
nl.store(result[i_out_p, i_out_f], value=result_sbuf)
^^^^^^^^^^^^^^^^^^^^^^^^
this is doing advanced indexing.

If advanced indexing returns a copy of the view, the store statement does not
make sense. Therefore, advanced indexing in NKI must have view semantics.
-/
def advancedAccessPattern (tensor : Core.TensorName) : Term -> Err Core.AccessPattern
| .tuple l | .list l => mkAccessPattern tensor l
| t => mkAccessPattern tensor [t]
where
mkAccessPattern (tensor : Core.TensorName) (inds: List Term) : Err Core.AccessPattern
:= do
let sizes := tensor.shape.toList
if sizes.length ≠ inds.length
then throw "unimplemented: number of dimensions mismatch"
else if sizes.isEmpty
then throw "empty indices"
else
-- numElems[i] = sizes[i] * sizes[i+1] * ... * sizes[-1]
-- numElems[sizes.length] = 1
let numElems := sizes.foldr (fun sz l => match l with
| [] => [sz]
| h::l' => (sz*h)::h::l') [1]
-- The goal is to build AccessPattern that does the following:
-- result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
-- ..., ind_N[i_1, ..., i_M]]
let mut accessPatterns : List Core.AccessPattern := []
for inds_i in inds do
match inds_i with
| .tensor t =>
-- Create AccessPattern for each ind_j.
let (valAtZero, steps) <- decomposeLinearIntTensor t
-- To create AccessPattern, freePattern's steps must be
-- multiplied by the number of elements in the lower dimensions.
-- For example, if tensor t has 3x4, t[i,j] is t[i * 4 + j]
let steps: List Core.APPair := List.map
(fun (ap,i) =>
{ step := ap.step * Int.ofNat (numElems.getD (i+1) 1),
num := ap.num })
(steps.zipIdx)
accessPatterns := accessPatterns ++ [{
tensor := tensor,
parNum := 1, -- This will be filled later.
freePattern := steps,
offset := Int.toNat valAtZero
}]
| _ => throw "NKI doesn't allow mixing tensor index + basic index"
-- Accumulate AccessPattern of ind_js and create one large AccessPattern
match accessPatterns with
| pat1::pat' =>
let mut res : Core.AccessPattern := pat1
for ap in pat' do
let mut fp: List Core.APPair := []
for (p1,p2) in (List.zip res.freePattern ap.freePattern) do
if p1.num ≠ p2.num then
throw "APPair num mismatch"
else
fp := fp ++ [{
step := p1.step + p2.step,
num := p1.num
}]
res := {
tensor := res.tensor,
parNum := res.parNum,
freePattern := fp,
offset := res.offset }
-- Check the partition index & fill in the partition number
match res.freePattern with
| fp0::fp' =>
let numPartitions := fp0.num
if not (fp0.step = numElems.getD 1 0)
then throw "nontrivial step for partition index"
else return {
tensor := res.tensor,
parNum := numPartitions,
freePattern := fp',
offset := res.offset
}
| _ => throw "insufficient indices"
| _ => throw "insufficient indices"


#guard
match do
let shape <- Core.Shape.fromList [/-parnum-/2,3,4]
Core.TensorName.make "x" Core.Dtype.int8 shape none with
| .ok (tensor:Core.TensorName) =>
let mk (ls:List Int): TensorLib.Tensor :=
let t := TensorLib.Tensor.ofIntList! TensorLib.Dtype.int64 ls
let t3d := t.reshape! (TensorLib.Shape.mk [2,3,4])
t3d
-- a,b,c = numpy.mgrid[0:2,0:3,0:4]
let a : Term := .tensor (mk [
0,0,0,0, 0,0,0,0, 0,0,0,0,
1,1,1,1, 1,1,1,1, 1,1,1,1])
let b : Term := .tensor (mk [
0,0,0,0, 1,1,1,1, 2,2,2,2,
0,0,0,0, 1,1,1,1, 2,2,2,2
])
let c : Term := .tensor (mk [
0,1,2,3, 0,1,2,3, 0,1,2,3,
0,1,2,3, 0,1,2,3, 0,1,2,3,
])
let res := advancedAccessPattern tensor (.tuple [a,b,c])
res == .ok
(Core.AccessPattern.mk tensor 2 [
{ step := 4, num := 3},
{ step := 1, num := 4},
] 0)
| .error _ => false

/-
Access to pointer types (a.k.a. Address)
NKI users can define memory regions by using slices on other memory regions.
Expand Down Expand Up @@ -194,15 +410,44 @@ def access (t : Term) (i : Term) : Err Term := do
| .slice ..
| .store .. => throw "subscript not supported"
| .string _ => throw "string subscript not implemented"
| .tensor _ => throw "subscript of a constant tensor unimplemented"
| .tuple l => listAccess "list" l i
| .list l => listAccess "tuple" l i
| .pointer addr => pointerAccess addr i
| .mgrid => do
let slice_convert (s:Term): Err TensorLib.Slice :=
match s with
| .slice b e st => TensorLib.Slice.make b e st
| _ => throw "not .slice"
let slices : List TensorLib.Slice <-
List.mapM slice_convert (match i with
| .tuple l => l | t => [t])
-- Use TensorLib's mgrid semantics. Thus this naturally picks NumPy's
-- mgrid semantics, whose return type (ndarray) is slightly different
-- from NKI''s mgrid return type. The usages of NKI API are designed to be
-- analogous to that of NumPy API anyway.
let res <- TensorLib.mgrid slices
-- Note: this does not support '.p' and '.x' in NKI because a generic
-- tensor does not have such fields.
return .tensor res
| .expr .. => do
let tensor : Core.TensorName <- fromNKI? t
let access <- Core.Access.mkBasic tensor (<- termToIndex tensor.shape.toList i)
let shape <- Tensor.inferShape access
return .expr (.value (.access access)) (.tensor tensor.dtype shape)

-- Try basic indexing first
tryCatch
(do
let indices <- termToIndex tensor.shape.toList i
let access <- Core.Access.mkBasic tensor indices
let shape <- Tensor.inferShape access
return .expr (.value (.access access)) (.tensor tensor.dtype shape))
(fun _ => do
-- Try advanced indexing
let ap <- advancedAccessPattern tensor i
let access := Core.Access.pattern ap
let shape <- Tensor.inferShape access
return .expr (.value (.access access)) (.tensor tensor.dtype shape))


--
/-
# Handling of assignment statements

Expand Down Expand Up @@ -270,6 +515,7 @@ def RValue : Term -> Trace Term
| .source f => return .source f
| .none => return .none
| .string s => return .string s
| .tensor t => return .tensor t
| .tuple es => return .tuple (<- es.attach.mapM fun ⟨ e, _ ⟩ => RValue e)
| .list es => return .tuple (<- es.attach.mapM fun ⟨ e, _ ⟩ => RValue e)
| .ellipsis => return .ellipsis
Expand All @@ -284,6 +530,10 @@ def RValue : Term -> Trace Term
add_stmt (.assign v e)
return .expr (.value $ .var v) ty
| .expr e ty => return .expr e ty
| .mgrid =>
-- Assume that people do not write a code that has mgrid appearing solely
-- without a subscript on the RHS of assignment...
throw "unimplemented"

-- Create an assignment to a Core Expr, this must be a variable
def assignExpr (e : Core.Expr) (t : Term) : Trace Unit := do
Expand All @@ -294,16 +544,42 @@ def assignExpr (e : Core.Expr) (t : Term) : Trace Unit := do
-- Unpack an RValue, must be a list or tuple
def unpack : Term -> Trace (List Term)
| .tuple l | .list l => return l
| .tensor t =>
-- Unpack tensor to a list of subtensors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too complex? You should be able to use TensorLib like numpy.
TensorLib can compute, e.g. t[i,...], which you seem to be doing in a fairly low-level way.
@seanmcl ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in fact I could not find the right function. I could have missed one.

match t.shape.val with
| [] => return []
| nTensors::shapes' =>
-- Make a tuple whose number of element is n_tensors
let newShape := TensorLib.Shape.mk shapes'
-- `extract i` returns the subtensor t[i].
let extract (i:Nat): Trace Term := do
if t.startIndex ≠ 0 ∨ t.unitStrides ≠ t.shape.unitStrides then
throw "Don't know how to extract i'th subarray of t"
else
let subtensorSz := t.size / nTensors
let extractedData := t.data.extract (subtensorSz * i * t.itemsize)
(subtensorSz * t.itemsize)
return .tensor ({
dtype := t.dtype,
shape := newShape,
data := extractedData
: TensorLib.Tensor
})

(List.range nTensors).mapM extract

| t => throw s!"cannot unpack non-iterable object {repr t}"

-- Assign to a term, handling unpacking for tuples and lists
def assignTerm (x : Term) (e : Term) : Trace Unit := do
match x with
| .module name => throw s!"cannot assign to {name}"
| .mgrid => throw s!"cannot assign to mgrid"
| .builtin name .. => throw s!"cannot assign to {name}"
| .source _ => throw "cannot assign to function"
| .none => throw "cannot assign to None"
| .string _ => throw "cannot assign to a string literal"
| .tensor _ => throw "cannot assign to a constant tensor"
| .tuple l
| .list l => assignList l (<- unpack e)
| .ellipsis => throw "cannot assign to ellipsis"
Expand Down
Loading