Skip to content

Commit 6cfa70d

Browse files
initial
0 parents  commit 6cfa70d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+32575
-0
lines changed

.gitignore

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
env/
12+
build/
13+
develop-eggs/
14+
dist/
15+
downloads/
16+
eggs/
17+
.eggs/
18+
lib/
19+
lib64/
20+
parts/
21+
sdist/
22+
var/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
27+
# PyInstaller
28+
# Usually these files are written by a python script from a template
29+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
30+
*.manifest
31+
*.spec
32+
33+
# Installer logs
34+
pip-log.txt
35+
pip-delete-this-directory.txt
36+
37+
# Unit test / coverage reports
38+
htmlcov/
39+
.tox/
40+
.coverage
41+
.coverage.*
42+
.cache
43+
nosetests.xml
44+
coverage.xml
45+
*,cover
46+
.hypothesis/
47+
48+
# Translations
49+
*.mo
50+
*.pot
51+
52+
# Django stuff:
53+
*.log
54+
local_settings.py
55+
56+
# Flask stuff:
57+
instance/
58+
.webassets-cache
59+
60+
# Scrapy stuff:
61+
.scrapy
62+
63+
# Sphinx documentation
64+
docs/_build/
65+
66+
# PyBuilder
67+
target/
68+
69+
# IPython Notebook
70+
.ipynb_checkpoints
71+
72+
# pyenv
73+
.python-version
74+
75+
# celery beat schedule file
76+
celerybeat-schedule
77+
78+
# dotenv
79+
.env
80+
81+
# virtualenv
82+
venv/
83+
ENV/
84+
85+
# Spyder project settings
86+
.spyderproject
87+
88+
# Rope project settings
89+
.ropeproject
90+
91+
out*.txt
92+
*.nvvp
93+
mask.txt
94+
*mem*.txt*
95+
logs/
96+
temp/
97+
lib/
98+
build/

LICENSE.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2017 OpenAI
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

Makefile

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
TARGET=./build
2+
3+
.PHONY: all compile clean
4+
all: compile
5+
6+
compile: blocksparse/blocksparse_ops.so
7+
python setup.py bdist_wheel --universal
8+
9+
clean:
10+
rm -vfr $(TARGET)
11+
12+
release: compile
13+
BRANCH=$(shell git rev-parse --abbrev-ref HEAD); if [ "$$BRANCH" != "master" ]; then echo "--- ERROR: refusing to build non-master branch"; exit 1; fi
14+
@git diff-index --quiet HEAD -- || ( echo '--- ERROR: will not build while git tree is dirty! please commit your changes. ---' && exit 1 )
15+
# hacky way to get the version from wheel name
16+
$(eval VERSION := $(shell ls -th dist/*.whl | head -1 |awk '{split($$1,r,"-");print r[2]}')) # '
17+
git tag v${VERSION}
18+
git push origin v${VERSION}
19+
20+
# Upload the binary wheel to PyPi. Needs `twine` installed and configured with your PyPi credentials.
21+
twine upload $(shell ls -th dist/*.whl | head -1)
22+
23+
CUDA_HOME?=/usr/local/cuda
24+
NV_INC?=$(CUDA_HOME)/include
25+
NV_LIB?=$(CUDA_HOME)/lib64
26+
TF_INC=$(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
27+
TF_LIB=$(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')
28+
NVCCFLAGS=-DGOOGLE_CUDA=1 -D_GLIBCXX_USE_CXX11_ABI=0 -arch=sm_61 -gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_52,code=sm_52 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_61,code=compute_61 -O3 -Xcompiler -fPIC
29+
CCFLAGS=-std=c++11 -O3 -DGOOGLE_CUDA=1 -D_GLIBCXX_USE_CXX11_ABI=0 -I$(TARGET) -I$(NV_INC) -I$(TF_INC) -I$(TF_INC)/external/nsync/public -fPIC
30+
31+
OBJS=\
32+
$(TARGET)/batch_norm_op.o \
33+
$(TARGET)/blocksparse_conv_op.o \
34+
$(TARGET)/blocksparse_kernels.o \
35+
$(TARGET)/blocksparse_l2_norm_op.o \
36+
$(TARGET)/blocksparse_matmul_op.o \
37+
$(TARGET)/cwise_linear_op.o \
38+
$(TARGET)/edge_bias_op.o \
39+
$(TARGET)/ew_op.o \
40+
$(TARGET)/gpu_types.o \
41+
$(TARGET)/layer_norm_op.o \
42+
43+
CU_OBJS=\
44+
$(TARGET)/batch_norm_op_gpu.cu.o \
45+
$(TARGET)/blocksparse_l2_norm_op_gpu.cu.o \
46+
$(TARGET)/blocksparse_matmul_op_gpu.cu.o \
47+
$(TARGET)/cwise_linear_op_gpu.cu.o \
48+
$(TARGET)/edge_bias_op_gpu.cu.o \
49+
$(TARGET)/ew_op_gpu.cu.o \
50+
$(TARGET)/layer_norm_cn_op_gpu.cu.o \
51+
$(TARGET)/layer_norm_nc_op_gpu.cu.o \
52+
53+
$(TARGET)/blocksparse_kernels.h: src/sass/*.sass
54+
mkdir -p $(shell dirname $@)
55+
python generate_kernels.py
56+
57+
blocksparse/blocksparse_ops.so: $(OBJS) $(CU_OBJS)
58+
g++ $^ -shared -o $@ -L$(TF_LIB) -L$(NV_LIB) -ltensorflow_framework -lcudart -lcuda
59+
60+
$(TARGET)/%.cu.o: src/%.cu $(TARGET)/blocksparse_kernels.h
61+
mkdir -p $(shell dirname $@)
62+
nvcc $(NVCCFLAGS) -c $< -o $@
63+
64+
$(TARGET)/%.o: src/%.cc src/*.h $(TARGET)/blocksparse_kernels.h
65+
mkdir -p $(shell dirname $@)
66+
g++ $(CCFLAGS) -c $< -o $@

README.md

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Blocksparse
2+
3+
The `blocksparse` package contains TensorFlow Ops and corresponding GPU kernels for block-sparse matrix multiplication. Also included are related ops like edge bias, sparse weight norm and layer norm.
4+
5+
To learn more, see [the launch post on the OpenAI blog](https://blog.openai.com/block-sparse-gpu-kernels/).
6+
7+
## Prerequisites
8+
9+
First, you need at least one Nvidia GPU. For best performance, we recommend using a Pascal or Maxwell generation GPU -- this is the full list of features by GPU type:
10+
11+
| GPU Family | BSMatMul-ASM | BSMatMul-CudaC | BSConv |
12+
|------------|------------------------|----------------|--------|
13+
| Kepler | - | X | - |
14+
| Maxwell | X (fastest) | X | X |
15+
| Pascal | X (fastest) | X | X |
16+
| Volta | - | X (fastest) | - |
17+
18+
Note that BSMatMul-CudaC **only supports `feature_axis=0`**, while BSMatMul-ASM only supports `feature_axis=1`.
19+
20+
Additionally, you need:
21+
22+
- A working Linux installation (we run Ubuntu 16.04) with the Nvidia drivers for your GPU.
23+
- CUDA 8 (in `/usr/local/cuda`)
24+
- Python 3.5 or newer, or 2.7 or newer
25+
- TensorFlow 1.4.0 or newer, [with GPU support](https://www.tensorflow.org/install/install_linux#install_tensorflow) (e.g. `pip install tensorflow-gpu`)
26+
- CUDA 9 and Volta will work if you update the build targets (-gencode=arch=compute_70,code=sm_70) and also build tenorflow from source.
27+
28+
## Installation
29+
30+
```
31+
pip install blocksparse
32+
```
33+
34+
## Usage
35+
36+
This example performs a block-sparse matrix multiplication:
37+
```
38+
from blocksparse.matmul import BlocksparseMatMul
39+
import tensorflow as tf
40+
import numpy as np
41+
42+
hidden_size = 4096
43+
block_size = 32
44+
minibatch_size = 64
45+
46+
# Create a (random) sparsity pattern
47+
sparsity = np.random.randint(2, size=(hidden_size//block_size,hidden_size//block_size))
48+
49+
# Initialize the sparse matrix multiplication object
50+
bsmm = BlocksparseMatMul(sparsity, block_size=block_size)
51+
52+
# Input to graph
53+
x = tf.placeholder(tf.float32, shape=[None, hidden_size])
54+
55+
# Initialize block-sparse weights
56+
w = tf.get_variable("w", bsmm.w_shape, dtype=tf.float32)
57+
58+
# Block-sparse matrix multiplication
59+
y = bsmm(x, w)
60+
61+
# Run
62+
sess = tf.InteractiveSession()
63+
sess.run(tf.global_variables_initializer())
64+
result = sess.run([y], feed_dict = {x: np.ones((minibatch_size,hidden_size), dtype='float32')})
65+
print(result)
66+
```
67+
68+
For a more involved example using block-sparse ops to train a language model, see [`examples/`](./examples/).
69+
70+
## Development
71+
72+
If you're interested in hacking on the ops and kernels, go ahead and build from source:
73+
74+
git clone git@github.com:openai/blocksparse.git
75+
cd blocksparse
76+
77+
make compile
78+
pip install dist/*.whl
79+
80+
# test it if you like
81+
test/blocksparse_matmul_test.py
82+
test/blocksparse_conv_test.py
83+
84+
If your CUDA is not in `/usr/local/cuda` or you have several versions, e.g. both `/usr/local/cuda-8.0` and `/usr/local/cuda-9.0`, set `CUDA_HOME` to the base path to use when compiling `make compile`.
85+
86+
87+
## API Documentation:
88+
89+
90+
### blocksparse.matmul
91+
92+
class BlocksparseMatMul(object)
93+
94+
def __init__(self, layout, block_size=32, feature_axis=1, name=None)
95+
96+
def i_shape(self, N): return (N, self.C) if self.axis else (self.C, N)
97+
def o_shape(self, N): return (N, self.K) if self.axis else (self.K, N)
98+
99+
# return the coordinate in the layout that corresponds to a given block id
100+
def block_coord(self, block): return self.updat_list[block]
101+
102+
103+
def ortho_init(self)
104+
105+
def identity_init(self, gpu=False)
106+
107+
def l2_normalize(self, W, gain=None, epsilon=1e-12, dtype=np.float32)
108+
109+
110+
def __call__(self, I, W, dw_dtype=tf.float32, name=None, bench=0)
111+
112+
def group_param_grads(param_grad, group_size=8, cast32=True)
113+
114+
115+
class SparseProj(object):
116+
117+
def __init__(self, nhidden, nproj=None, proj_stride=None, block_size=32, gather_lut=None, name=None)
118+
119+
def gather(self, x)
120+
def scatter(self, x)
121+
def scatter_add(self, x, y)
122+
def scatter_mul(self, x, y)
123+
124+
125+
126+
### blocksparse.conv
127+
128+
class BlocksparseConv(object):
129+
"""
130+
BCK: ( # block(B)/input(C)/output(K) feature dims
131+
( (c0, c1, c2, ...), (k0, k1, k2, ...) ), # block 0
132+
( (c0, c1, c2, ...), (k0, k1, k2, ...) ), # block 1
133+
( (c0, c1, c2, ...), (k0, k1, k2, ...) ), # block 2 ...
134+
)
135+
TRS: (T,R,S) or (R,S) or (S,) - filter spatial size dims
136+
DHW: (D,H,W) or (H,W) or (W,) - input image spatial size dims
137+
MPQ: (M,P,Q) or (P,Q) or (Q,) or None - output image spatial size dims (used for ambiguous dims in strided transpose conv)
138+
strides: (1,1,1) or (1,1) or (1,)
139+
dilates: (1,1,1) or (1,1) or (1,)
140+
padding: (1,1,1) or (1,1) or (1,) or "SAME" or "VALID"
141+
edge_bias: True/False
142+
"""
143+
def __init__(self, BCK, TRS, DHW, MPQ=None, strides=(1,1,1), dilates=(1,1,1), padding="SAME", edge_bias=False, debug=False, deconv=False)
144+
145+
def edge_bias_shape(self)
146+
147+
def i_shape(self, N)
148+
def o_shape(self, N)
149+
def f_shape(self, block=None)
150+
151+
152+
def __call__(self, F, I, edge_bias=None):
153+
154+
def l2_normalize(self, F, gain=None, epsilon=1e-12, dtype=np.float32):
155+
156+
class BlocksparseDeconv(BlocksparseConv)
157+
158+
def __init__(self, BCK, TRS, DHW, MPQ=None, strides=(1,1,1), dilates=(1,1,1), padding="SAME", edge_bias=False, debug=False)
159+
160+
161+
def cwise_linear(x, a=None, b=None)
162+
163+
164+
165+
### blocksparse.ew
166+
167+
def add(x, y, name=None)
168+
def multiply(x, y, name=None)
169+
def subtract(x, y, name=None)
170+
def divide(x, y, name=None)
171+
def maximum(x, y, name=None)
172+
def minimum(x, y, name=None)
173+
174+
def negative(x, name=None)
175+
def reciprocal(x, name=None)
176+
def square(x, name=None)
177+
def sqrt(x, name=None)
178+
def exp(x, name=None)
179+
def log(x, name=None)
180+
def sigmoid(x, name=None)
181+
def tanh(x, name=None)
182+
def relu(x, name=None)
183+
184+
def elu (x, alpha=1.0, name=None)
185+
186+
def fused_lstm_gates(c, *args, name=None)
187+
188+
def split4(x)
189+
def concat4(x0, x1, x2, x3)
190+
191+
def float_cast(x, dtype, dx_dtype=None)
192+
193+
def dropout(x, keep_prob=0.8, mask=None)
194+
195+
def add_n8(xs, name=None)
196+
197+
198+
199+
### blocksparse.norms
200+
201+
def layer_norm(x, g, b, axis=1, epsilon=1e-6, relu=False, bench=0)
202+
203+
def batch_norm(x, g, b, epsilon=1e-6)
204+
205+
def batch_norm_inference(x, g, b, m, v, epsilon=1e-6)
206+
207+

0 commit comments

Comments
 (0)