Skip to content

Commit

Permalink
[DOCS] Re-structured documentation hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Jul 27, 2021
1 parent ca04da3 commit 92242ac
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 222 deletions.
4 changes: 3 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@

# Sphinx gallery
extensions += ['sphinx_gallery.gen_gallery']
from sphinx_gallery.sorting import FileNameSortKey
sphinx_gallery_conf = {
'examples_dirs': '../python/tutorials/',
'gallery_dirs': 'tutorials',
'gallery_dirs': 'getting-started/tutorials',
'filename_pattern': '',
'ignore_pattern': r'__init__\.py',
'within_subsection_order': FileNameSortKey,
}

# Add any paths that contain templates here, relative to this directory.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
==============
From Source
Installation
==============

--------------
With Pip
--------------

Triton can be installed directly from pip with the following command

.. code-block:: python
pip install triton
--------------
From Source
--------------

+++++++++++++++
Python Package
+++++++++++++++
Expand Down
22 changes: 10 additions & 12 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
.. Triton documentation master file, created by
sphinx-quickstart on Mon Feb 10 01:01:37 2020.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to Triton's documentation!
==================================

.. toctree::
:maxdepth: 1
:caption: Installation Instructions
Triton is an imperative language and compiler for parallel programming. It aims to provide a programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.

Getting Started
---------------

installation/packaged-binaries
installation/from-source
- Follow the :doc:`installation instructions <getting-started/installation>` for your platform of choice.
- Take a look at the :doc:`tutorials <getting-started/tutorials/index>` to learn how to write your first Triton program.

.. toctree::
:maxdepth: 1
:caption: Installation Instructions
:caption: Getting Started
:hidden:

tutorials/index
getting-started/installation
getting-started/tutorials/index
8 changes: 0 additions & 8 deletions docs/installation/packaged-binaries.rst

This file was deleted.

158 changes: 0 additions & 158 deletions docs/tutorials/01-vector-add.ipynb

This file was deleted.

44 changes: 18 additions & 26 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
Vector Addition
=================
In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:
* The basic syntax of the Triton programming language
* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API
* The best practices for validating and benchmarking custom ops against native reference implementations
In this tutorial, you will write a simple, high-performance vector addition using Triton and learn about:
- The basic syntax of the Triton programming language
- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API
- The best practices for validating and benchmarking custom ops against native reference implementations
"""

# %%
# Writing the Compute Kernel
# Compute Kernel
# --------------------------
#
# Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel
Expand Down Expand Up @@ -49,23 +50,20 @@
# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.

# %%
# Writing the Torch bindings
# Torch bindings
# --------------------------
# The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.
# The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
#
# To create a `triton.kernel`, you only need three things:
# - `source: string`: the source-code of the kernel you want to create
# - `device: torch.device`: the device you want to compile this code for
# - `defines: dict`: the set of macros that you want the pre-processor to `#define` for you
# - :code:`source: string`: the source-code of the kernel you want to create
# - :code:`device: torch.device`: the device you want to compile this code for
# - :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you

import torch
import triton

# %%
# source-code for Triton compute kernel
# here we just copy-paste the above code without the extensive comments.
# you may prefer to store it in a .c file and load it from there instead.

_src = """
__global__ void add(float* z, float* x, float* y, int N){
// program id
Expand All @@ -82,13 +80,10 @@
}
"""

# %%
# This function returns a callable `triton.kernel` object
# created from the above source code.

# This function returns a callable `triton.kernel` object created from the above source code.
# For portability, we maintain a cache of kernels for different `torch.device`
# We compile the kernel with -DBLOCK=1024


def make_add_kernel(device):
cache = make_add_kernel.cache
if device not in cache:
Expand All @@ -99,12 +94,9 @@ def make_add_kernel(device):

make_add_kernel.cache = dict()

# %%
# This is a standard torch custom autograd Function
# The only difference is that we can now use the above kernel
# in the `forward` and `backward` functions.`


# This is a standard torch custom autograd Function;
# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`
class _add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
Expand All @@ -127,11 +119,11 @@ def forward(ctx, x, y):
return z


# Just like we standard PyTorch ops We use the `.apply` method to create a callable object for our function
# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function
add = _add.apply

# %%
# Writing a Unit Test
# Unit Test
# --------------------------
torch.manual_seed(0)
x = torch.rand(98432, device='cuda')
Expand All @@ -143,7 +135,7 @@ def forward(ctx, x, y):
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')

# %%
# Writing a Benchmark
# Benchmarking
# --------------------------
# We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does

Expand Down
Loading

0 comments on commit 92242ac

Please sign in to comment.