|
| 1 | +# Blocksparse |
| 2 | + |
| 3 | +The `blocksparse` package contains TensorFlow Ops and corresponding GPU kernels for block-sparse matrix multiplication. Also included are related ops like edge bias, sparse weight norm and layer norm. |
| 4 | + |
| 5 | +To learn more, see [the launch post on the OpenAI blog](https://blog.openai.com/block-sparse-gpu-kernels/). |
| 6 | + |
| 7 | +## Prerequisites |
| 8 | + |
| 9 | +First, you need at least one Nvidia GPU. For best performance, we recommend using a Pascal or Maxwell generation GPU -- this is the full list of features by GPU type: |
| 10 | + |
| 11 | +| GPU Family | BSMatMul-ASM | BSMatMul-CudaC | BSConv | |
| 12 | +|------------|------------------------|----------------|--------| |
| 13 | +| Kepler | - | X | - | |
| 14 | +| Maxwell | X (fastest) | X | X | |
| 15 | +| Pascal | X (fastest) | X | X | |
| 16 | +| Volta | - | X (fastest) | - | |
| 17 | + |
| 18 | +Note that BSMatMul-CudaC **only supports `feature_axis=0`**, while BSMatMul-ASM only supports `feature_axis=1`. |
| 19 | + |
| 20 | +Additionally, you need: |
| 21 | + |
| 22 | +- A working Linux installation (we run Ubuntu 16.04) with the Nvidia drivers for your GPU. |
| 23 | +- CUDA 8 (in `/usr/local/cuda`) |
| 24 | +- Python 3.5 or newer, or 2.7 or newer |
| 25 | +- TensorFlow 1.4.0 or newer, [with GPU support](https://www.tensorflow.org/install/install_linux#install_tensorflow) (e.g. `pip install tensorflow-gpu`) |
| 26 | +- CUDA 9 and Volta will work if you update the build targets (-gencode=arch=compute_70,code=sm_70) and also build tenorflow from source. |
| 27 | + |
| 28 | +## Installation |
| 29 | + |
| 30 | +``` |
| 31 | +pip install blocksparse |
| 32 | +``` |
| 33 | + |
| 34 | +## Usage |
| 35 | + |
| 36 | +This example performs a block-sparse matrix multiplication: |
| 37 | +``` |
| 38 | +from blocksparse.matmul import BlocksparseMatMul |
| 39 | +import tensorflow as tf |
| 40 | +import numpy as np |
| 41 | +
|
| 42 | +hidden_size = 4096 |
| 43 | +block_size = 32 |
| 44 | +minibatch_size = 64 |
| 45 | +
|
| 46 | +# Create a (random) sparsity pattern |
| 47 | +sparsity = np.random.randint(2, size=(hidden_size//block_size,hidden_size//block_size)) |
| 48 | +
|
| 49 | +# Initialize the sparse matrix multiplication object |
| 50 | +bsmm = BlocksparseMatMul(sparsity, block_size=block_size) |
| 51 | +
|
| 52 | +# Input to graph |
| 53 | +x = tf.placeholder(tf.float32, shape=[None, hidden_size]) |
| 54 | +
|
| 55 | +# Initialize block-sparse weights |
| 56 | +w = tf.get_variable("w", bsmm.w_shape, dtype=tf.float32) |
| 57 | +
|
| 58 | +# Block-sparse matrix multiplication |
| 59 | +y = bsmm(x, w) |
| 60 | +
|
| 61 | +# Run |
| 62 | +sess = tf.InteractiveSession() |
| 63 | +sess.run(tf.global_variables_initializer()) |
| 64 | +result = sess.run([y], feed_dict = {x: np.ones((minibatch_size,hidden_size), dtype='float32')}) |
| 65 | +print(result) |
| 66 | +``` |
| 67 | + |
| 68 | +For a more involved example using block-sparse ops to train a language model, see [`examples/`](./examples/). |
| 69 | + |
| 70 | +## Development |
| 71 | + |
| 72 | +If you're interested in hacking on the ops and kernels, go ahead and build from source: |
| 73 | + |
| 74 | + git clone git@github.com:openai/blocksparse.git |
| 75 | + cd blocksparse |
| 76 | + |
| 77 | + make compile |
| 78 | + pip install dist/*.whl |
| 79 | + |
| 80 | + # test it if you like |
| 81 | + test/blocksparse_matmul_test.py |
| 82 | + test/blocksparse_conv_test.py |
| 83 | + |
| 84 | +If your CUDA is not in `/usr/local/cuda` or you have several versions, e.g. both `/usr/local/cuda-8.0` and `/usr/local/cuda-9.0`, set `CUDA_HOME` to the base path to use when compiling `make compile`. |
| 85 | + |
| 86 | + |
| 87 | +## API Documentation: |
| 88 | + |
| 89 | + |
| 90 | +### blocksparse.matmul |
| 91 | + |
| 92 | + class BlocksparseMatMul(object) |
| 93 | + |
| 94 | + def __init__(self, layout, block_size=32, feature_axis=1, name=None) |
| 95 | + |
| 96 | + def i_shape(self, N): return (N, self.C) if self.axis else (self.C, N) |
| 97 | + def o_shape(self, N): return (N, self.K) if self.axis else (self.K, N) |
| 98 | + |
| 99 | + # return the coordinate in the layout that corresponds to a given block id |
| 100 | + def block_coord(self, block): return self.updat_list[block] |
| 101 | + |
| 102 | + |
| 103 | + def ortho_init(self) |
| 104 | + |
| 105 | + def identity_init(self, gpu=False) |
| 106 | + |
| 107 | + def l2_normalize(self, W, gain=None, epsilon=1e-12, dtype=np.float32) |
| 108 | + |
| 109 | + |
| 110 | + def __call__(self, I, W, dw_dtype=tf.float32, name=None, bench=0) |
| 111 | + |
| 112 | + def group_param_grads(param_grad, group_size=8, cast32=True) |
| 113 | + |
| 114 | + |
| 115 | + class SparseProj(object): |
| 116 | + |
| 117 | + def __init__(self, nhidden, nproj=None, proj_stride=None, block_size=32, gather_lut=None, name=None) |
| 118 | + |
| 119 | + def gather(self, x) |
| 120 | + def scatter(self, x) |
| 121 | + def scatter_add(self, x, y) |
| 122 | + def scatter_mul(self, x, y) |
| 123 | + |
| 124 | + |
| 125 | + |
| 126 | +### blocksparse.conv |
| 127 | + |
| 128 | + class BlocksparseConv(object): |
| 129 | + """ |
| 130 | + BCK: ( # block(B)/input(C)/output(K) feature dims |
| 131 | + ( (c0, c1, c2, ...), (k0, k1, k2, ...) ), # block 0 |
| 132 | + ( (c0, c1, c2, ...), (k0, k1, k2, ...) ), # block 1 |
| 133 | + ( (c0, c1, c2, ...), (k0, k1, k2, ...) ), # block 2 ... |
| 134 | + ) |
| 135 | + TRS: (T,R,S) or (R,S) or (S,) - filter spatial size dims |
| 136 | + DHW: (D,H,W) or (H,W) or (W,) - input image spatial size dims |
| 137 | + MPQ: (M,P,Q) or (P,Q) or (Q,) or None - output image spatial size dims (used for ambiguous dims in strided transpose conv) |
| 138 | + strides: (1,1,1) or (1,1) or (1,) |
| 139 | + dilates: (1,1,1) or (1,1) or (1,) |
| 140 | + padding: (1,1,1) or (1,1) or (1,) or "SAME" or "VALID" |
| 141 | + edge_bias: True/False |
| 142 | + """ |
| 143 | + def __init__(self, BCK, TRS, DHW, MPQ=None, strides=(1,1,1), dilates=(1,1,1), padding="SAME", edge_bias=False, debug=False, deconv=False) |
| 144 | + |
| 145 | + def edge_bias_shape(self) |
| 146 | + |
| 147 | + def i_shape(self, N) |
| 148 | + def o_shape(self, N) |
| 149 | + def f_shape(self, block=None) |
| 150 | + |
| 151 | + |
| 152 | + def __call__(self, F, I, edge_bias=None): |
| 153 | + |
| 154 | + def l2_normalize(self, F, gain=None, epsilon=1e-12, dtype=np.float32): |
| 155 | + |
| 156 | + class BlocksparseDeconv(BlocksparseConv) |
| 157 | + |
| 158 | + def __init__(self, BCK, TRS, DHW, MPQ=None, strides=(1,1,1), dilates=(1,1,1), padding="SAME", edge_bias=False, debug=False) |
| 159 | + |
| 160 | + |
| 161 | + def cwise_linear(x, a=None, b=None) |
| 162 | + |
| 163 | + |
| 164 | + |
| 165 | +### blocksparse.ew |
| 166 | + |
| 167 | + def add(x, y, name=None) |
| 168 | + def multiply(x, y, name=None) |
| 169 | + def subtract(x, y, name=None) |
| 170 | + def divide(x, y, name=None) |
| 171 | + def maximum(x, y, name=None) |
| 172 | + def minimum(x, y, name=None) |
| 173 | + |
| 174 | + def negative(x, name=None) |
| 175 | + def reciprocal(x, name=None) |
| 176 | + def square(x, name=None) |
| 177 | + def sqrt(x, name=None) |
| 178 | + def exp(x, name=None) |
| 179 | + def log(x, name=None) |
| 180 | + def sigmoid(x, name=None) |
| 181 | + def tanh(x, name=None) |
| 182 | + def relu(x, name=None) |
| 183 | + |
| 184 | + def elu (x, alpha=1.0, name=None) |
| 185 | + |
| 186 | + def fused_lstm_gates(c, *args, name=None) |
| 187 | + |
| 188 | + def split4(x) |
| 189 | + def concat4(x0, x1, x2, x3) |
| 190 | + |
| 191 | + def float_cast(x, dtype, dx_dtype=None) |
| 192 | + |
| 193 | + def dropout(x, keep_prob=0.8, mask=None) |
| 194 | + |
| 195 | + def add_n8(xs, name=None) |
| 196 | + |
| 197 | + |
| 198 | + |
| 199 | +### blocksparse.norms |
| 200 | + |
| 201 | + def layer_norm(x, g, b, axis=1, epsilon=1e-6, relu=False, bench=0) |
| 202 | + |
| 203 | + def batch_norm(x, g, b, epsilon=1e-6) |
| 204 | + |
| 205 | + def batch_norm_inference(x, g, b, m, v, epsilon=1e-6) |
| 206 | + |
| 207 | + |
0 commit comments