Skip to content

Commit

Permalink
[STYLE] run autopep8 and isort (triton-lang#421)
Browse files Browse the repository at this point in the history
Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
  • Loading branch information
madeleineth authored Jan 6, 2022
1 parent 120cda0 commit 8bf551a
Show file tree
Hide file tree
Showing 30 changed files with 747 additions and 628 deletions.
53 changes: 27 additions & 26 deletions python/bench/bench_blocksparse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch

import triton

# -------------------------------
Expand All @@ -8,26 +9,26 @@
nt = {False: 'n', True: 't'}
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block',
line_vals = [16, 32, 64, 128],
line_names = ['Block16', 'Block32', 'Block64', 'Block128'],
ylabel = 'TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\
for AT in [False] for BT in [False] \
for op_mode in ['dsd'] for layout_mode in ['dense']
x_names=['M', 'N', 'K'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg='block',
line_vals=[16, 32, 64, 128],
line_names=['Block16', 'Block32', 'Block64', 'Block128'],
ylabel='TFLOPS',
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args={'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)
for AT in [False] for BT in [False]
for op_mode in ['dsd'] for layout_mode in ['dense']
]


@triton.testing.perf_report(square_confs)
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode]
# create layout
Expand All @@ -45,10 +46,10 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode]*1e-12
}[op_mode] * 1e-12
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)


Expand All @@ -58,15 +59,15 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,

square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block',
line_vals = [16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'],
ylabel = 'GBPS',
plot_name = f'{layout_mode}-square',
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)\
x_names=['M', 'N'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg='block',
line_vals=[16, 32, 64],
line_names=['Block16', 'Block32', 'Block64'],
ylabel='GBPS',
plot_name=f'{layout_mode}-square',
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)
for layout_mode in ['dense', 'tril']
]

Expand All @@ -88,4 +89,4 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)


bench_matmul.run(print_data=True, show_plots=True)
bench_matmul.run(print_data=True, show_plots=True)
25 changes: 13 additions & 12 deletions python/bench/bench_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import torch

import triton

confs = [
triton.testing.Benchmark(
x_names = ['N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg = 'provider',
line_vals = ['triton', 'torch'],
line_names = ['Triton', 'Torch'],
ylabel = 'GBPS',
plot_name = f'{mode}-2048',
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
)\
x_names=['N'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
ylabel='GBPS',
plot_name=f'{mode}-2048',
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
)
for mode in ['forward', 'backward']
]

Expand All @@ -24,8 +25,8 @@ def bench_op(M, N, dtype, mode, provider):
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3
# forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
'triton': triton.ops.cross_entropy}[provider]
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward':
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
if mode == 'backward':
Expand All @@ -37,4 +38,4 @@ def bench_op(M, N, dtype, mode, provider):


if __name__ == '__main__':
bench_op.run(print_data=True)
bench_op.run(print_data=True)
22 changes: 12 additions & 10 deletions python/bench/bench_matmul.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import triton
import torch
import os

import triton


def rounded_linspace(low, high, steps, div):
Expand Down Expand Up @@ -29,25 +29,27 @@ def rounded_linspace(low, high, steps, div):
transformer_confs = [
triton.testing.Benchmark(
x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 32, 128),
x_vals=rounded_linspace(NK // 16, NK, 32, 128),
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]\
for i, x in enumerate(["N", "K"])\
for M in [2048]
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]
for i, x in enumerate(["N", "K"])
for M in [2048]
]


@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT: a = a.t()
if BT: b = b.t()
if AT:
a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "cublas":
Expand All @@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
except:
except Exception:
return None
return None
5 changes: 3 additions & 2 deletions python/bench/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import argparse
import sys
import os
import inspect
import os
import sys

import triton


Expand Down
29 changes: 14 additions & 15 deletions python/setup.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import distutils
import distutils.spawn
import os
import re
import sys
import sysconfig
import platform
import re
import shutil
import subprocess
import distutils
import glob
import sys
import tarfile
import tempfile
import shutil
import urllib.request
from distutils.version import LooseVersion
from setuptools import setup, Extension, find_packages

from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
from setuptools.command.test import test as TestCommand
import distutils.spawn
import urllib.request
import tarfile


def get_llvm():
# tries to find system LLVM
versions = ['-11.0', '-11', '-11-64']
versions = ['-11.0', '-11', '-11-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
if paths:
return '', ''
return '', ''
# download if nothing is installed
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
dir = '/tmp'
Expand All @@ -32,7 +31,7 @@ def get_llvm():
if not os.path.exists(llvm_library_dir):
try:
shutil.rmtree(os.path.join(dir, name))
except:
except Exception:
pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...')
Expand Down Expand Up @@ -96,7 +95,7 @@ def build_extension(self, ext):
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
]
Expand Down
83 changes: 45 additions & 38 deletions python/test/regression/test_performance.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from numpy import record
import torch
import triton
import triton.language as tl
import subprocess
import sys

import pytest
import torch
from numpy import record

import triton

#######################
# Utilities
#######################


def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
Expand All @@ -23,48 +27,51 @@ def nvsmi(attrs):
#######################

matmul_data = {
# square
(256 , 256 , 256 ) : {'v100': 0.027},
(512 , 512 , 512 ) : {'v100': 0.158},
(1024, 1024, 1024 ) : {'v100': 0.466},
(2048, 2048, 2048 ) : {'v100': 0.680},
(4096, 4096, 4096 ) : {'v100': 0.831},
(8192, 8192, 8192 ) : {'v100': 0.849},
# tall-skinny
(16 , 1024, 1024 ) : {'v100': 0.0128},
(16 , 4096, 4096 ) : {'v100': 0.0883},
(16 , 8192, 8192 ) : {'v100': 0.101},
(64 , 1024, 1024 ) : {'v100': 0.073},
(64 , 4096, 4096 ) : {'v100': 0.270},
(64 , 8192, 8192 ) : {'v100': 0.360},
(1024, 64 , 1024 ) : {'v100': 0.0692},
(4096, 64 , 4096 ) : {'v100': 0.264},
(8192, 64 , 8192 ) : {'v100': 0.323},
# # deep reductions
# (64 , 64 , 16384) : {'v100': 0.},
# (64 , 64 , 65536) : {'v100': 0.},
# (256 , 256 , 8192 ) : {'v100': 0.},
# (256 , 256 , 32768) : {'v100': 0.},
# square
(256, 256, 256): {'v100': 0.027},
(512, 512, 512): {'v100': 0.158},
(1024, 1024, 1024): {'v100': 0.466},
(2048, 2048, 2048): {'v100': 0.680},
(4096, 4096, 4096): {'v100': 0.831},
(8192, 8192, 8192): {'v100': 0.849},
# tall-skinny
(16, 1024, 1024): {'v100': 0.0128},
(16, 4096, 4096): {'v100': 0.0883},
(16, 8192, 8192): {'v100': 0.101},
(64, 1024, 1024): {'v100': 0.073},
(64, 4096, 4096): {'v100': 0.270},
(64, 8192, 8192): {'v100': 0.360},
(1024, 64, 1024): {'v100': 0.0692},
(4096, 64, 4096): {'v100': 0.264},
(8192, 64, 8192): {'v100': 0.323},
# # deep reductions
# (64 , 64 , 16384) : {'v100': 0.},
# (64 , 64 , 65536) : {'v100': 0.},
# (256 , 256 , 8192 ) : {'v100': 0.},
# (256 , 256 , 32768) : {'v100': 0.},
}


@pytest.mark.parametrize('M, N, K', matmul_data.keys())
def test_matmul(M, N, K):
ref_gpu_util = matmul_data[(M, N, K)]['v100']
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = 1350
max_gpu_perf = 1e-6*80*8*128*cur_sm_clock
max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2.*M*N*K/ms * 1e-9
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)


#######################
# Element-Wise
#######################
import triton.language as tl


@triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements,
Expand All @@ -80,29 +87,29 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,


elementwise_data = {
1024*16 : {'v100': 0.0219},
1024*64 : {'v100': 0.0791},
1024*256 : {'v100': 0.243},
1024*1024 : {'v100': 0.534},
1024*4096 : {'v100': 0.796},
1024*16384: {'v100': 0.905},
1024*65536: {'v100': 0.939},
1024 * 16: {'v100': 0.0219},
1024 * 64: {'v100': 0.0791},
1024 * 256: {'v100': 0.243},
1024 * 1024: {'v100': 0.534},
1024 * 4096: {'v100': 0.796},
1024 * 16384: {'v100': 0.905},
1024 * 65536: {'v100': 0.939},
}


@pytest.mark.parametrize('N', elementwise_data.keys())
def test_elementwise(N):
ref_gpu_util = elementwise_data[N]['v100']
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = 877
max_gpu_perf = 512*2*ref_mem_clock*1e-3
max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

Loading

0 comments on commit 8bf551a

Please sign in to comment.