Skip to content

Commit

Permalink
fix #151
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Dec 10, 2021
1 parent ee2b22f commit 732f39a
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions spikingjelly/clock_driven/neuron_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def create_bptt_kernel(sg_cuda_code_fun, hard_reset: bool, detach_reset: bool, d
__syncthreads();
if (threadIdx.x == 0)
{
grad_reciprocal_tau[0] = sdata[0];
atomicAdd(grad_reciprocal_tau, sdata[0]);
}
}
'''
Expand Down Expand Up @@ -952,7 +952,22 @@ def create_bptt_kernel(sg_cuda_code_fun, hard_reset: bool, detach_reset: bool, d
__syncthreads();
if (threadIdx.x == 0)
{
grad_reciprocal_tau[0] = __hadd(__low2half(sdata[0]), __high2half(sdata[0]));
//grad_reciprocal_tau[0] = __hadd(__low2half(sdata[0]), __high2half(sdata[0]));
/*
The 32-bit floating-point version of atomicAdd() is only supported by devices of compute capability 2.x and higher.
The 64-bit floating-point version of atomicAdd() is only supported by devices of compute capability 6.x and higher.
The 32-bit __half2 floating-point version of atomicAdd() is only supported by devices of compute capability 6.x and higher. The atomicity of the __half2 or __nv_bfloat162 add operation is guaranteed separately for each of the two __half or __nv_bfloat16 elements; the entire __half2 or __nv_bfloat162 is not guaranteed to be atomic as a single 32-bit access.
The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
The 16-bit __nv_bfloat16 floating-point version of atomicAdd() is only supported by devices of compute capability 8.x and higher.
*/
atomicAdd(grad_reciprocal_tau, __hadd(__low2half(sdata[0]), __high2half(sdata[0])));
}
}
'''
Expand All @@ -973,6 +988,8 @@ def forward(ctx, x_seq: torch.Tensor, v_last: torch.Tensor, reciprocal_tau: torc
elif x_seq.dtype == torch.float16:
dtype = 'fp16'
cp_dtype = np.half
assert torch.cuda.get_device_capability(device)[0] >= 7, "MultiStepParametricLIFNodePTT can not run in the current device with float16 because the 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher."

else:
raise NotImplementedError

Expand Down

0 comments on commit 732f39a

Please sign in to comment.