@@ -11,18 +11,18 @@ private:
1111 uint8_t * ptr;
1212
1313public:
14- int total_bytes;
14+ int64_t total_bytes;
1515
1616 __device__ __forceinline__ Buffer () : ptr(nullptr ), total_bytes(0 ) {}
1717
1818 __device__ __forceinline__ Buffer (void *& gbl_ptr, int num_elems, int offset = 0 ) {
1919 total_bytes = num_elems * sizeof (dtype_t );
20- ptr = reinterpret_cast <uint8_t *>(gbl_ptr) + offset * sizeof (dtype_t );
21- gbl_ptr = reinterpret_cast <uint8_t *>(gbl_ptr) + total_bytes;
20+ ptr = static_cast <uint8_t *>(gbl_ptr) + offset * sizeof (dtype_t );
21+ gbl_ptr = static_cast <uint8_t *>(gbl_ptr) + total_bytes;
2222 }
2323
2424 __device__ __forceinline__ Buffer advance_also (void *& gbl_ptr) {
25- gbl_ptr = reinterpret_cast <uint8_t *>(gbl_ptr) + total_bytes;
25+ gbl_ptr = static_cast <uint8_t *>(gbl_ptr) + total_bytes;
2626 return *this ;
2727 }
2828
@@ -35,30 +35,30 @@ template <typename dtype_t, int kNumRanks = 1>
3535struct AsymBuffer {
3636private:
3737 uint8_t * ptrs[kNumRanks ];
38- int num_bytes;
38+ int64_t num_bytes;
3939
4040public:
41- int total_bytes;
41+ int64_t total_bytes;
4242
4343 __device__ __forceinline__ AsymBuffer (void *& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0 , int num_sms = 1 , int offset = 0 ) {
4444 EP_STATIC_ASSERT (kNumRanks == 1 , " " );
4545 num_bytes = num_elems * sizeof (dtype_t );
4646
47- int per_channel_bytes = num_bytes * num_ranks;
47+ int64_t per_channel_bytes = num_bytes * num_ranks;
4848 total_bytes = per_channel_bytes * num_sms;
49- ptrs[0 ] = reinterpret_cast <uint8_t *>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
50- gbl_ptr = reinterpret_cast <uint8_t *>(gbl_ptr) + total_bytes;
49+ ptrs[0 ] = static_cast <uint8_t *>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
50+ gbl_ptr = static_cast <uint8_t *>(gbl_ptr) + total_bytes;
5151 }
5252
5353 __device__ __forceinline__ AsymBuffer (void ** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0 , int num_sms = 1 , int offset = 0 ) {
5454 EP_STATIC_ASSERT (kNumRanks > 1 , " " );
5555 num_bytes = num_elems * sizeof (dtype_t );
5656
57- int per_channel_bytes = num_bytes * num_ranks;
57+ int64_t per_channel_bytes = num_bytes * num_ranks;
5858 total_bytes = per_channel_bytes * num_sms;
5959 for (int i = 0 ; i < kNumRanks ; ++i) {
60- ptrs[i] = reinterpret_cast <uint8_t *>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
61- gbl_ptrs[i] = reinterpret_cast <uint8_t *>(gbl_ptrs[i]) + total_bytes;
60+ ptrs[i] = static_cast <uint8_t *>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
61+ gbl_ptrs[i] = static_cast <uint8_t *>(gbl_ptrs[i]) + total_bytes;
6262 }
6363 }
6464
@@ -69,14 +69,14 @@ public:
6969 }
7070
7171 __device__ __forceinline__ AsymBuffer advance_also (void *& gbl_ptr) {
72- gbl_ptr = reinterpret_cast <uint8_t *>(gbl_ptr) + total_bytes;
72+ gbl_ptr = static_cast <uint8_t *>(gbl_ptr) + total_bytes;
7373 return *this ;
7474 }
7575
7676 template <int kNumAlsoRanks >
7777 __device__ __forceinline__ AsymBuffer advance_also (void ** gbl_ptrs) {
7878 for (int i = 0 ; i < kNumAlsoRanks ; ++i)
79- gbl_ptrs[i] = reinterpret_cast <uint8_t *>(gbl_ptrs[i]) + total_bytes;
79+ gbl_ptrs[i] = static_cast <uint8_t *>(gbl_ptrs[i]) + total_bytes;
8080 return *this ;
8181 }
8282
@@ -97,19 +97,19 @@ private:
9797 // NOTES: for non-decoupled case, `recv_ptr` is not used
9898 uint8_t * send_ptr;
9999 uint8_t * recv_ptr;
100- int num_bytes;
100+ int64_t num_bytes;
101101
102102public:
103- int total_bytes;
103+ int64_t total_bytes;
104104
105105 __device__ __forceinline__ SymBuffer (void *& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0 , int num_sms = 1 ) {
106106 num_bytes = num_elems * sizeof (dtype_t );
107107
108- int per_channel_bytes = num_bytes * num_ranks;
108+ int64_t per_channel_bytes = num_bytes * num_ranks;
109109 total_bytes = per_channel_bytes * num_sms * (static_cast <int >(kDecoupled ) + 1 );
110- send_ptr = reinterpret_cast <uint8_t *>(gbl_ptr) + per_channel_bytes * sm_id;
111- recv_ptr = reinterpret_cast <uint8_t *>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
112- gbl_ptr = reinterpret_cast <uint8_t *>(gbl_ptr) + total_bytes;
110+ send_ptr = static_cast <uint8_t *>(gbl_ptr) + per_channel_bytes * sm_id;
111+ recv_ptr = static_cast <uint8_t *>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
112+ gbl_ptr = static_cast <uint8_t *>(gbl_ptr) + total_bytes;
113113 }
114114
115115 __device__ __forceinline__ dtype_t * send_buffer (int idx = 0 ) {
0 commit comments