Skip to content

Commit

Permalink
working BF16
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Sep 11, 2023
1 parent adbc0f9 commit 9059cd0
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 119 deletions.
148 changes: 73 additions & 75 deletions wkv5_bf16/cuda/wkv5_cuda_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,89 +11,87 @@ __global__ void kernel_forward(const int B, const int T, const int C, const int
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_w += h*N;
_u += h*N;
_w += h*_N_;
_u += h*_N_;

__shared__ float state[N * N];
__shared__ F rr[N], kk[N];
__shared__ float state[_N_ * _N_], rr[_N_], kk[_N_];

for (int j = 0; j < N; ++j)
state[j * N + i] = 0;
for (int j = 0; j < _N_; ++j)
state[j * _N_ + i] = 0;

for (int _t = b*T*C + h*N + i, _tend = (b+1)*T*C + h*N + i; _t < _tend; _t += C)
for (int _t = b*T*C + h*_N_ + i; _t < (b+1)*T*C + h*_N_ + i; _t += C)
{
__syncthreads();
rr[i] = _r[_t];
kk[i] = _k[_t];
rr[i] = float(_r[_t]);
kk[i] = float(_k[_t]);
__syncthreads();

const F vv = _v[_t];
F yy = 0;
const float vv = _v[_t];
float yy = 0;

for (int j = 0; j < N; j++)
for (int j = 0; j < _N_; j++)
{
F x = kk[j] * vv;
float x = kk[j] * vv;

float s = state[j * N + i];
state[j * N + i] = s * _w[j] + float(x);
float s = state[j * _N_ + i];
state[j * _N_ + i] = s * _w[j] + x;

yy += rr[j] * (_u[j] * x + F(s));
yy += rr[j] * (float(_u[j]) * x + s);
}
_y[_t] = yy;
_y[_t] = F(yy);
}
}

template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C, const int H,
const F *__restrict__ const r, const F *__restrict__ const k, const F *__restrict__ const v, const float *__restrict__ w, const float *__restrict__ wwww, const F *__restrict__ u, const F *__restrict__ const gy,
F *__restrict__ const gr, F *__restrict__ const gk, F *__restrict__ const gv, F *__restrict__ gw, F *__restrict__ gu)
F *__restrict__ const gr, F *__restrict__ const gk, F *__restrict__ const gv, float *__restrict__ gw, float *__restrict__ gu)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
w += h*N;
u += h*N;
gu += h*N;
gw += h*N;
wwww += h*N;
w += h*_N_;
u += h*_N_;
gu += h*_N_;
gw += h*_N_;
wwww += h*_N_;

__shared__ float state[N * N];
__shared__ F vv[N], rr[N], kk[N], gyy[N];
__shared__ float state[_N_ * _N_], vv[_N_], rr[_N_], kk[_N_], gyy[_N_];

#pragma unroll
for (int j = 0; j < N; ++j){
state[j * N + i] = 0;
for (int j = 0; j < _N_; ++j){
state[j * _N_ + i] = 0;
}

const float ww = w[i];
const F uu = u[i];
const float uu = float(u[i]);
const float wwwww = wwww[i];
float saaaa[N] = {0.0f}, sbbbb[N] = {0.0f};
float saaaa[_N_] = {0.0f}, sbbbb[_N_] = {0.0f};

for (int _t = b*T*C + h*N + i, _tend = (b+1)*T*C + h*N + i; _t < _tend; _t += C)
for (int _t = b*T*C + h*_N_ + i, _tend = (b+1)*T*C + h*_N_ + i; _t < _tend; _t += C)
{
__syncthreads();
vv[i] = v[_t];
gyy[i] = gy[_t];
vv[i] = float(v[_t]);
gyy[i] = float(gy[_t]);
__syncthreads();

const F kk = k[_t];
const F rr = r[_t];
F grr = 0;
F guu = 0;
const float kk = float(k[_t]);
const float rr = float(r[_t]);
float grr = 0;
float guu = 0;

#pragma unroll
for (int j = 0; j < N; j++)
for (int j = 0; j < _N_; j++)
{
F x = vv[j] * kk;
float s = state[j * N + i];
state[j * N + i] = s * ww + float(x);
float x = vv[j] * kk;
float s = state[j * _N_ + i];
state[j * _N_ + i] = s * ww + x;

grr += gyy[j] * (uu * x + F(s));
grr += gyy[j] * (uu * x + s);
guu += rr * x * gyy[j];
}

gr[_t] = grr;
gr[_t] = F(grr);
atomicAdd(gu + i, guu);

if (_t < _tend - 2 * C){
Expand All @@ -102,80 +100,80 @@ __global__ void kernel_backward(const int B, const int T, const int C, const int
gyy[i] = gy[_t+2*C];
__syncthreads();

const F rr_value = r[_t+2*C];
const float rr_value = r[_t+2*C];

#pragma unroll
for (int j = 0; j < N; j++){
F x = vv[j] * kk;
saaaa[j] = ww * (saaaa[j] + sbbbb[j] + float(x));
sbbbb[j] = ww * (sbbbb[j] + float(x));
for (int j = 0; j < _N_; j++){
float x = vv[j] * kk;
saaaa[j] = ww * (saaaa[j] + sbbbb[j] + x);
sbbbb[j] = ww * (sbbbb[j] + x);

atomicAdd(gw+i, rr_value * wwwww * F(saaaa[j]) * gyy[j]);
atomicAdd(gw+i, rr_value * wwwww * saaaa[j] * gyy[j]);
}
}
}

#pragma unroll
for (int j = 0; j < N; ++j)
state[j * N + i] = 0;
for (int j = 0; j < _N_; ++j)
state[j * _N_ + i] = 0;

for (int _t = (b+1)*T*C + h*N + i - C, _tend = b*T*C + h*N + i; _t >= _tend; _t -= C)
for (int _t = (b+1)*T*C + h*_N_ + i - C, _tend = b*T*C + h*_N_ + i; _t >= _tend; _t -= C)
{
__syncthreads();
vv[i] = v[_t];
gyy[i] = gy[_t];
__syncthreads();

const F rr = r[_t];
F gkk = 0;
const float rr = r[_t];
float gkk = 0;

#pragma unroll
for (int j = 0; j < N; j++)
for (int j = 0; j < _N_; j++)
{
F x = gyy[j] * rr;
float s = state[j * N + i];
state[j * N + i] = s * ww + float(x);
float x = gyy[j] * rr;
float s = state[j * _N_ + i];
state[j * _N_ + i] = s * ww + x;

gkk += vv[j] * (uu * x + F(s));
gkk += vv[j] * (uu * x + s);
}
gk[_t] = gkk;
gk[_t] = F(gkk);
}

#pragma unroll
for (int j = 0; j < N; ++j)
state[j * N + i] = 0;
for (int j = 0; j < _N_; ++j)
state[j * _N_ + i] = 0;

for (int _t = (b+1)*T*C + h*N + i - C, _tend = b*T*C + h*N + i; _t >= _tend; _t -= C)
for (int _t = (b+1)*T*C + h*_N_ + i - C, _tend = b*T*C + h*_N_ + i; _t >= _tend; _t -= C)
{
__syncthreads();
kk[i] = k[_t];
rr[i] = r[_t];
__syncthreads();

const F gy_value = gy[_t];
F gvv = 0;
const float gy_value = gy[_t];
float gvv = 0;

#pragma unroll
for (int j = 0; j < N; j++)
for (int j = 0; j < _N_; j++)
{
F x = gy_value * rr[j];
float s = state[j * N + i];
state[j * N + i] = s * w[j] + float(x);
float x = gy_value * rr[j];
float s = state[j * _N_ + i];
state[j * _N_ + i] = s * w[j] + x;

gvv += kk[j] * (u[j] * x + F(s));
gvv += kk[j] * (u[j] * x + s);
}
gv[_t] = gvv;
gv[_t] = F(gvv);
}
}

void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
{
assert(H*N == C);
kernel_forward<<<dim3(B * H), dim3(N)>>>(B, T, C, H, r, k, v, w, u, y);
assert(H*_N_ == C);
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
}

void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, float *gw, float *gu)
{
assert(H*N == C);
kernel_backward<<<dim3(B * H), dim3(N)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
assert(H*_N_ == C);
kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
}
4 changes: 2 additions & 2 deletions wkv5_bf16/cuda/wkv5_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
typedef at::BFloat16 bf16;

void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, float *gw, float *gu);

void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<float>(), gu.data_ptr<float>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv5 forward");
Expand Down
Loading

0 comments on commit 9059cd0

Please sign in to comment.