-
Notifications
You must be signed in to change notification settings - Fork 4
Add tracing support for mgrid
and advanced tensor indexing
#111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
073d371
96a32dc
7dbce21
27e2eb5
79de34a
665117a
fce6438
678fa29
40a17d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ import KLR.Python | |
import KLR.Trace.Types | ||
import KLR.Trace.Builtin | ||
import KLR.Trace.Basic | ||
import TensorLib | ||
|
||
namespace KLR.Trace | ||
open KLR.Python | ||
|
@@ -98,6 +99,221 @@ def listAccess (name : String) (l : List Term) : Term -> Err Term | |
else throw "index out of bounds" | ||
|_ => throw s!"{name} indicies must be integers of slices" | ||
|
||
|
||
/- | ||
Given an int tensor t of N dimensions, find 'steps' which is a list of N | ||
s.t. for every i1,i2,.., | ||
t[i1,i2,...] = t[0,0,...] + (i1,i2,..) ⬝ steps | ||
where + is an elementwise addition and ⬝ is a dot product. | ||
|
||
For example, if t is [[10, 15], [30, 35]], valAtZero is 10, and | ||
steps = [5, 20]. | ||
t[1,1] = 35 = 10 + (1,1) ⬝ (5,20) | ||
|
||
Returns: ⟨ t[0,0,...], steps ⟩ | ||
For the above example, the return value is ⟨ 10, [5, 20] ⟩ | ||
-/ | ||
def decomposeLinearIntTensor (t:TensorLib.Tensor) | ||
: Err (Int × List (Core.APPair)) := do | ||
match t.dtype with | ||
| TensorLib.Dtype.uint64 | TensorLib.Dtype.int64 => | ||
let dimsize := t.shape.val.length | ||
let zerocoord := List.replicate dimsize 0 | ||
-- Get the value at t[0,..,0] | ||
let valAtZero <- t.intAtDimIndex zerocoord | ||
-- And get 't[0,..,1,..,0] - t[0,..,0,..,0]' to get the size of step | ||
let steps <- List.mapM (fun i => do | ||
if List.getD t.shape.val i 0 ≤ 1 | ||
then return 0 -- t is too small in this axis to calculate the step | ||
else | ||
let nextcoord := List.set zerocoord i 1 | ||
let valAtOne ← t.intAtDimIndex nextcoord | ||
return valAtOne - valAtZero | ||
) (List.range dimsize) | ||
-- Check whether t[idx] = valAtZero + idx * steps for all indices | ||
let _ <- check [] dimsize steps valAtZero | ||
return ⟨ | ||
valAtZero, | ||
List.map (fun (sz, step) => { step := step, num := sz : Core.APPair }) | ||
(List.zip t.shape.val steps) | ||
⟩ | ||
| _ => throw s!"supports uint64 or int64 only, but got {repr t.dtype}" | ||
where | ||
check (idx:List Nat) (cnt:Nat) (steps:List Int) (valAtZero:Int) | ||
: Err Bool := do | ||
let l := idx.length | ||
if cnt > 0 then | ||
let r := List.range (List.getD t.shape.val l 0) | ||
List.allM (fun k => check (idx ++ [k]) (cnt - 1) steps valAtZero) r | ||
else | ||
let mult: List Int := List.map | ||
(fun (a,b) => (Int.ofNat a) * b) (List.zip idx steps) | ||
let dot: Int := mult.foldl (fun a b => a + b) 0 | ||
let lhs: Int <- t.intAtDimIndex idx | ||
return decide (lhs = valAtZero + dot) | ||
|
||
|
||
#guard match (decomposeLinearIntTensor | ||
(TensorLib.Tensor.zeros TensorLib.Dtype.int64 | ||
{ val := [10, 10] : TensorLib.Shape})) with | ||
| .error _ => false | ||
| .ok (valAtZero, steps) => | ||
valAtZero = 0 ∧ | ||
steps == [ | ||
{ step := 0, num := 10 : Core.APPair }, | ||
{ step := 0, num := 10 : Core.APPair }] | ||
|
||
#guard match (do | ||
let tensor <- TensorLib.Tensor.ofIntList TensorLib.Dtype.int64 | ||
[10, 20, 30, 40, 50, 60] | ||
let tensor2d <- tensor.reshape (TensorLib.Shape.mk [2, 3]) | ||
decomposeLinearIntTensor tensor2d) with | ||
| .error _ => false | ||
| .ok (valAtZero, steps) => | ||
valAtZero = 10 ∧ steps == [ | ||
{ step := 30, num := 2 : Core.APPair }, | ||
{ step := 10, num := 3 : Core.APPair } | ||
] | ||
|
||
/- | ||
# Implement Advanced tensor indexing. | ||
|
||
https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing | ||
|
||
Given tensors ind_1, ind_2, ..., x[ind_1, ind_2, .., ind_N] has advanced | ||
indexing over the elements of x. | ||
The result of the access is: | ||
|
||
result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M], | ||
..., ind_N[i_1, ..., i_M]] | ||
|
||
In NumPy, mixing advanced indexing and basic indexing is allowed. However, | ||
in NKI, only one of the two forms is allowed. | ||
Refer to: | ||
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/programming_model.html | ||
"Note that currently NKI does not support mixing Basic and Advanced Tensor | ||
Indexing in the same Index tuple." | ||
|
||
## Q: Is the result of advanced indexing a view of the tensor or a copy of it? | ||
In the case of numpy, the numpy doc above says that advanced indexing always | ||
returns a copy of the data (contrast with basic slicing that returns a view). | ||
However, we cannot assume the same thing for NKI, because NKI typically does | ||
the following stuff (excerpted from a matmul example): | ||
|
||
i_lhsT_p, i_lhsT_f = nl.mgrid[0:128, 0:64] | ||
... | ||
nl.store(result[i_out_p, i_out_f], value=result_sbuf) | ||
^^^^^^^^^^^^^^^^^^^^^^^^ | ||
this is doing advanced indexing. | ||
|
||
If advanced indexing returns a copy of the view, the store statement does not | ||
make sense. Therefore, advanced indexing in NKI must have view semantics. | ||
-/ | ||
def advancedAccessPattern (tensor : Core.TensorName) : Term -> Err Core.AccessPattern | ||
| .tuple l | .list l => mkAccessPattern tensor l | ||
| t => mkAccessPattern tensor [t] | ||
where | ||
mkAccessPattern (tensor : Core.TensorName) (inds: List Term) : Err Core.AccessPattern | ||
:= do | ||
let sizes := tensor.shape.toList | ||
if sizes.length ≠ inds.length | ||
then throw "unimplemented: number of dimensions mismatch" | ||
else if sizes.isEmpty | ||
then throw "empty indices" | ||
else | ||
-- numElems[i] = sizes[i] * sizes[i+1] * ... * sizes[-1] | ||
-- numElems[sizes.length] = 1 | ||
let numElems := sizes.foldr (fun sz l => match l with | ||
| [] => [sz] | ||
| h::l' => (sz*h)::h::l') [1] | ||
-- The goal is to build AccessPattern that does the following: | ||
-- result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M], | ||
-- ..., ind_N[i_1, ..., i_M]] | ||
let mut accessPatterns : List Core.AccessPattern := [] | ||
for inds_i in inds do | ||
match inds_i with | ||
| .tensor t => | ||
-- Create AccessPattern for each ind_j. | ||
let (valAtZero, steps) <- decomposeLinearIntTensor t | ||
-- To create AccessPattern, freePattern's steps must be | ||
-- multiplied by the number of elements in the lower dimensions. | ||
-- For example, if tensor t has 3x4, t[i,j] is t[i * 4 + j] | ||
let steps: List Core.APPair := List.map | ||
(fun (ap,i) => | ||
{ step := ap.step * Int.ofNat (numElems.getD (i+1) 1), | ||
num := ap.num }) | ||
(steps.zipIdx) | ||
accessPatterns := accessPatterns ++ [{ | ||
tensor := tensor, | ||
parNum := 1, -- This will be filled later. | ||
freePattern := steps, | ||
offset := Int.toNat valAtZero | ||
}] | ||
| _ => throw "NKI doesn't allow mixing tensor index + basic index" | ||
-- Accumulate AccessPattern of ind_js and create one large AccessPattern | ||
match accessPatterns with | ||
| pat1::pat' => | ||
let mut res : Core.AccessPattern := pat1 | ||
for ap in pat' do | ||
let mut fp: List Core.APPair := [] | ||
for (p1,p2) in (List.zip res.freePattern ap.freePattern) do | ||
if p1.num ≠ p2.num then | ||
throw "APPair num mismatch" | ||
else | ||
fp := fp ++ [{ | ||
step := p1.step + p2.step, | ||
num := p1.num | ||
}] | ||
res := { | ||
tensor := res.tensor, | ||
parNum := res.parNum, | ||
freePattern := fp, | ||
offset := res.offset } | ||
-- Check the partition index & fill in the partition number | ||
match res.freePattern with | ||
| fp0::fp' => | ||
let numPartitions := fp0.num | ||
if not (fp0.step = numElems.getD 1 0) | ||
then throw "nontrivial step for partition index" | ||
else return { | ||
tensor := res.tensor, | ||
parNum := numPartitions, | ||
freePattern := fp', | ||
offset := res.offset | ||
} | ||
| _ => throw "insufficient indices" | ||
| _ => throw "insufficient indices" | ||
|
||
|
||
#guard | ||
match do | ||
let shape <- Core.Shape.fromList [/-parnum-/2,3,4] | ||
Core.TensorName.make "x" Core.Dtype.int8 shape none with | ||
| .ok (tensor:Core.TensorName) => | ||
let mk (ls:List Int): TensorLib.Tensor := | ||
let t := TensorLib.Tensor.ofIntList! TensorLib.Dtype.int64 ls | ||
let t3d := t.reshape! (TensorLib.Shape.mk [2,3,4]) | ||
t3d | ||
-- a,b,c = numpy.mgrid[0:2,0:3,0:4] | ||
let a : Term := .tensor (mk [ | ||
0,0,0,0, 0,0,0,0, 0,0,0,0, | ||
1,1,1,1, 1,1,1,1, 1,1,1,1]) | ||
let b : Term := .tensor (mk [ | ||
0,0,0,0, 1,1,1,1, 2,2,2,2, | ||
0,0,0,0, 1,1,1,1, 2,2,2,2 | ||
]) | ||
let c : Term := .tensor (mk [ | ||
0,1,2,3, 0,1,2,3, 0,1,2,3, | ||
0,1,2,3, 0,1,2,3, 0,1,2,3, | ||
]) | ||
let res := advancedAccessPattern tensor (.tuple [a,b,c]) | ||
res == .ok | ||
(Core.AccessPattern.mk tensor 2 [ | ||
{ step := 4, num := 3}, | ||
{ step := 1, num := 4}, | ||
] 0) | ||
| .error _ => false | ||
|
||
/- | ||
Access to pointer types (a.k.a. Address) | ||
NKI users can define memory regions by using slices on other memory regions. | ||
|
@@ -194,15 +410,44 @@ def access (t : Term) (i : Term) : Err Term := do | |
| .slice .. | ||
| .store .. => throw "subscript not supported" | ||
| .string _ => throw "string subscript not implemented" | ||
| .tensor _ => throw "subscript of a constant tensor unimplemented" | ||
| .tuple l => listAccess "list" l i | ||
| .list l => listAccess "tuple" l i | ||
| .pointer addr => pointerAccess addr i | ||
| .mgrid => do | ||
let slice_convert (s:Term): Err TensorLib.Slice := | ||
match s with | ||
| .slice b e st => TensorLib.Slice.make b e st | ||
| _ => throw "not .slice" | ||
let slices : List TensorLib.Slice <- | ||
List.mapM slice_convert (match i with | ||
| .tuple l => l | t => [t]) | ||
-- Use TensorLib's mgrid semantics. Thus this naturally picks NumPy's | ||
-- mgrid semantics, whose return type (ndarray) is slightly different | ||
-- from NKI''s mgrid return type. The usages of NKI API are designed to be | ||
-- analogous to that of NumPy API anyway. | ||
let res <- TensorLib.mgrid slices | ||
-- Note: this does not support '.p' and '.x' in NKI because a generic | ||
-- tensor does not have such fields. | ||
return .tensor res | ||
| .expr .. => do | ||
let tensor : Core.TensorName <- fromNKI? t | ||
let access <- Core.Access.mkBasic tensor (<- termToIndex tensor.shape.toList i) | ||
let shape <- Tensor.inferShape access | ||
return .expr (.value (.access access)) (.tensor tensor.dtype shape) | ||
|
||
-- Try basic indexing first | ||
tryCatch | ||
(do | ||
let indices <- termToIndex tensor.shape.toList i | ||
let access <- Core.Access.mkBasic tensor indices | ||
let shape <- Tensor.inferShape access | ||
return .expr (.value (.access access)) (.tensor tensor.dtype shape)) | ||
(fun _ => do | ||
-- Try advanced indexing | ||
let ap <- advancedAccessPattern tensor i | ||
let access := Core.Access.pattern ap | ||
let shape <- Tensor.inferShape access | ||
return .expr (.value (.access access)) (.tensor tensor.dtype shape)) | ||
|
||
|
||
-- | ||
/- | ||
# Handling of assignment statements | ||
|
||
|
@@ -270,6 +515,7 @@ def RValue : Term -> Trace Term | |
| .source f => return .source f | ||
| .none => return .none | ||
| .string s => return .string s | ||
| .tensor t => return .tensor t | ||
| .tuple es => return .tuple (<- es.attach.mapM fun ⟨ e, _ ⟩ => RValue e) | ||
| .list es => return .tuple (<- es.attach.mapM fun ⟨ e, _ ⟩ => RValue e) | ||
| .ellipsis => return .ellipsis | ||
|
@@ -284,6 +530,10 @@ def RValue : Term -> Trace Term | |
add_stmt (.assign v e) | ||
return .expr (.value $ .var v) ty | ||
| .expr e ty => return .expr e ty | ||
| .mgrid => | ||
-- Assume that people do not write a code that has mgrid appearing solely | ||
-- without a subscript on the RHS of assignment... | ||
throw "unimplemented" | ||
|
||
-- Create an assignment to a Core Expr, this must be a variable | ||
def assignExpr (e : Core.Expr) (t : Term) : Trace Unit := do | ||
|
@@ -294,16 +544,42 @@ def assignExpr (e : Core.Expr) (t : Term) : Trace Unit := do | |
-- Unpack an RValue, must be a list or tuple | ||
def unpack : Term -> Trace (List Term) | ||
| .tuple l | .list l => return l | ||
| .tensor t => | ||
-- Unpack tensor to a list of subtensors | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems too complex? You should be able to use TensorLib like numpy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, in fact I could not find the right function. I could have missed one. |
||
match t.shape.val with | ||
| [] => return [] | ||
| nTensors::shapes' => | ||
-- Make a tuple whose number of element is n_tensors | ||
let newShape := TensorLib.Shape.mk shapes' | ||
-- `extract i` returns the subtensor t[i]. | ||
let extract (i:Nat): Trace Term := do | ||
if t.startIndex ≠ 0 ∨ t.unitStrides ≠ t.shape.unitStrides then | ||
throw "Don't know how to extract i'th subarray of t" | ||
else | ||
let subtensorSz := t.size / nTensors | ||
let extractedData := t.data.extract (subtensorSz * i * t.itemsize) | ||
(subtensorSz * t.itemsize) | ||
return .tensor ({ | ||
dtype := t.dtype, | ||
shape := newShape, | ||
data := extractedData | ||
: TensorLib.Tensor | ||
}) | ||
|
||
(List.range nTensors).mapM extract | ||
|
||
| t => throw s!"cannot unpack non-iterable object {repr t}" | ||
|
||
-- Assign to a term, handling unpacking for tuples and lists | ||
def assignTerm (x : Term) (e : Term) : Trace Unit := do | ||
match x with | ||
| .module name => throw s!"cannot assign to {name}" | ||
| .mgrid => throw s!"cannot assign to mgrid" | ||
| .builtin name .. => throw s!"cannot assign to {name}" | ||
| .source _ => throw "cannot assign to function" | ||
| .none => throw "cannot assign to None" | ||
| .string _ => throw "cannot assign to a string literal" | ||
| .tensor _ => throw "cannot assign to a constant tensor" | ||
| .tuple l | ||
| .list l => assignList l (<- unpack e) | ||
| .ellipsis => throw "cannot assign to ellipsis" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice.
It would be good to add some basic
#guard
checks for testing.Also, I wonder if this should be in
TensorLib
. @seanmcl ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
#guard
. The guards will pass once leanprover/TensorLib#74 is merged :)