Skip to content

UmerHA/triton_util

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

47 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Make Triton easier 🔱 😊

Utility functions for OpenAI Triton

Writing fast GPU kernels is easier with Triton than with CUDA, but there is still a lot of tedious indices juggling. That is not necessary.

Triton-util provides simple higher-level abstractions for frequent but repetitive tasks. This allows you to write code that is closer to how you actually think.

Example: Say you have a 2d matrix of shape (max0,max1) and stride (stride0,stride1), which you have chunked along both axes. Each chunk is size (sz0,sz1), and you want to get the (n0,n1)th chunk. With triton-util, you write

load_2d(ptr, sz0, sz1, n0, n1, max0, max1, stride0) # stride1 defaults to 1

instead of

offs0 = n0 * sz0 + tl.arange(0, sz0)
offs1 = n1 * sz1 + tl.arange(0, sz1)
offs = offs0[:,None] * stride0 + offs1[None,:] * stride1
mask = (offs0[:,None] < max0) & (offs1[None,:] < max1)
tl.load(ptr + offs, mask) 

Additionally, triton-util provides handy utility functions to make debugging easier. Want to print txt only on the 1st kernel? Write print_once(txt) - that's it!

Finally, triton-util is progressive, ie you can use as little or as much as you want. It's fully interoperable with triton. (It is, in fact, pure triton.)


Installing

pip install triton-util


Debugging utils

print_once(txt)

  • Print txt, only on 1st kernel (ie all pids = 0)

breakpoint_once()

  • Enter breakpoint, only on 1st kernel (ie all pids = 0)

print_if(txt, conds)

  • Print txt, if condition on pids is fulfilled
  • Eg print_if(txt, '=0,>1') prints if pid_0 = 0, pid_1 > 1 and pid_2 is arbitrary

breakpoint_if(conds)

  • Enter breakpoint, if condition on pids is fulfilled
  • Eg breakpoint_if('=0,>1') stops if pid_0 = 0, pid_1 > 1 and pid_2 is arbitrary

assert_tensors_gpu_ready(*tensors)

  • assert all tensors are contiguous, and on GPU (unless 'TRITON_INTERPRET'=='1')

Coding utils

cdiv(a,b)

  • ceiling division

offset_1d(sz, n_prev_chunks=0)

  • Return 1d offsets to (n_prev_chunks+1)th chunk of size sz

offset_2d(offs_0, offs_1, stride_0, stride_1=1)

  • Create a 2d offets from two 1d offets

mask_1d(offs, max)

  • Create a 1d mask from a 1d offset and a max value

mask_2d(offs_0, offs_1, max_0, max_1)

  • Create a 2d mask from two 1d offsets and max values

load_1d(ptr, sz, n, max, stride=1)

  • Chunk 1d vector (defined by ptr) into 1d grid, where each chunk has size sz, and load the nth chunk.

load_full_1d(ptr, sz, stride=1)

  • Load 1d block of size sz

load_2d(ptr, sz0, sz1, n0, n1, max0, max1, stride0, stride1=1)

  • Chunk 2d matrix (defined by ptr) into 2d grid, where each chunk has size (sz0,sz1), and load the (n0,n1)th chunk.

load_full_2d(ptr, sz0, sz1, stride0, stride1=1)

  • Load 2d block of size sz0 x sz1

store_1d(vals, ptr, sz, n, max, stride=1)

  • Store 1d block into nth chunk of vector (defined by ptr), where each chunk has size sz

store_full_1d(vals, ptr, sz, stride=1)

  • Store 1d block into vector (defined by ptr)

store_2d(vals, ptr, sz0, sz1, n0, n1, max0, max1, stride0, stride1=1)

  • Store 2d block into (n0,n1)th chunk of matrix (defined by ptr), where each chunk has size (sz0, sz1)

store_full_2d(vals, ptr, sz0, sz1, stride0, stride1=1)

  • Store 2d block into matrix (defined by ptr)


Other resources: Looking for ...


Brought to you by Umer ❤️

About

Make triton easier

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages