Skip to content

Commit

Permalink
Update wkv6state_cuda.cu
Browse files Browse the repository at this point in the history
  • Loading branch information
BlinkDL authored Dec 13, 2024
1 parent 852cd8f commit a6854e4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions RWKV-v5/cuda/wkv6state_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ __global__ void kernel_backward_111(const int B, const int T, const int C, const
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_u += h*_N_;
_s += h*_N_*_N_ + i;
_s += b*H*_N_*_N_ + h*_N_*_N_ + i;

__shared__ float u_[_N_];
__shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_];
Expand Down Expand Up @@ -195,7 +195,7 @@ __global__ void kernel_backward_222(const int B, const int T, const int C, const
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_s += h*_N_*_N_ + i;
_s += b*H*_N_*_N_ + h*_N_*_N_ + i;

__shared__ float v[_N_], gy[_N_];
float state[_N_], saaaa[_N_] = {0}, sbbbb[_T_-1] = {0}, scccc[_N_] = {0};
Expand Down

0 comments on commit a6854e4

Please sign in to comment.