diff --git a/spikingjelly/clock_driven/neuron_kernel.py b/spikingjelly/clock_driven/neuron_kernel.py index 33ce0001..bd968733 100644 --- a/spikingjelly/clock_driven/neuron_kernel.py +++ b/spikingjelly/clock_driven/neuron_kernel.py @@ -857,7 +857,7 @@ def create_bptt_kernel(sg_cuda_code_fun, hard_reset: bool, detach_reset: bool, d __syncthreads(); if (threadIdx.x == 0) { - grad_reciprocal_tau[0] = sdata[0]; + atomicAdd(grad_reciprocal_tau, sdata[0]); } } ''' @@ -952,7 +952,22 @@ def create_bptt_kernel(sg_cuda_code_fun, hard_reset: bool, detach_reset: bool, d __syncthreads(); if (threadIdx.x == 0) { - grad_reciprocal_tau[0] = __hadd(__low2half(sdata[0]), __high2half(sdata[0])); + //grad_reciprocal_tau[0] = __hadd(__low2half(sdata[0]), __high2half(sdata[0])); + + /* + The 32-bit floating-point version of atomicAdd() is only supported by devices of compute capability 2.x and higher. + + The 64-bit floating-point version of atomicAdd() is only supported by devices of compute capability 6.x and higher. + + The 32-bit __half2 floating-point version of atomicAdd() is only supported by devices of compute capability 6.x and higher. The atomicity of the __half2 or __nv_bfloat162 add operation is guaranteed separately for each of the two __half or __nv_bfloat16 elements; the entire __half2 or __nv_bfloat162 is not guaranteed to be atomic as a single 32-bit access. + + The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher. + + The 16-bit __nv_bfloat16 floating-point version of atomicAdd() is only supported by devices of compute capability 8.x and higher. + */ + + atomicAdd(grad_reciprocal_tau, __hadd(__low2half(sdata[0]), __high2half(sdata[0]))); + } } ''' @@ -973,6 +988,8 @@ def forward(ctx, x_seq: torch.Tensor, v_last: torch.Tensor, reciprocal_tau: torc elif x_seq.dtype == torch.float16: dtype = 'fp16' cp_dtype = np.half + assert torch.cuda.get_device_capability(device)[0] >= 7, "MultiStepParametricLIFNodePTT can not run in the current device with float16 because the 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher." + else: raise NotImplementedError