Skip to content

Commit fe15b64

Browse files
committed
add fp8 pack function
1 parent c6b34ef commit fe15b64

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

src/tl_templates/cuda/common.h

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,114 @@ TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2,
114114
return result;
115115
}
116116

117+
// Pack two fp8_e4_t values.
118+
TL_DEVICE fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) {
119+
fp8_e4_2_t result;
120+
result.x = x;
121+
result.y = y;
122+
return result;
123+
}
124+
125+
// Pack four fp8_e4_t values.
126+
TL_DEVICE fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1,
127+
fp8_e4_t x2, fp8_e4_t x3) {
128+
fp8_e4_4_t result;
129+
result.x = x0; result.y = x1;
130+
result.z = x2; result.w = x3;
131+
return result;
132+
}
133+
134+
// Pack eight fp8_e4_t values.
135+
TL_DEVICE fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3,
136+
fp8_e4_t x4, fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7) {
137+
fp8_e4_8_t result;
138+
result.x = make_fp8_e4_4_t(x0, x1, x2, x3);
139+
result.y = make_fp8_e4_4_t(x4, x5, x6, x7);
140+
return result;
141+
}
142+
143+
// Pack sixteen fp8_e4_t values.
144+
TL_DEVICE fp8_e5_16_t make_fp8_e5_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3,
145+
fp8_e4_t x4, fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7,
146+
fp8_e4_t y0, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3,
147+
fp8_e4_t y4, fp8_e4_t y5, fp8_e5_t y6, fp8_e5_t y7) {
148+
fp8_e4_16_t result;
149+
result.x = make_fp8_e4_8_t(x0, x1, x2, x3, x4, x5, x6, x7);
150+
result.y = make_fp8_e4_8_t(y0, y1, y2, y3, y4, y5, y6, y7);
151+
return result;
152+
}
153+
154+
// Pack thirty-two fp8_e4_t values.
155+
TL_DEVICE fp8_e4_32_t make_fp8_e4_32_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3,
156+
fp8_e4_t x4, fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7,
157+
fp8_e4_t x8, fp8_e4_t x9, fp8_e4_t x10, fp8_e4_t x11,
158+
fp8_e4_t x12, fp8_e4_t x13, fp8_e4_t x14, fp8_e4_t x15,
159+
fp8_e4_t y0, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3,
160+
fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, fp8_e4_t y7,
161+
fp8_e4_t y8, fp8_e4_t y9, fp8_e4_t y10, fp8_e4_t y11,
162+
fp8_e4_t y12, fp8_e4_t y13, fp8_e4_t y14, fp8_e4_t y15) {
163+
fp8_e4_32_t result;
164+
result.x = make_fp8_e4_16_t(x0, x1, x2, x3, x4, x5, x6, x7,
165+
x8, x9, x10, x11, x12, x13, x14, x15);
166+
result.y = make_fp8_e4_16_t(y0, y1, y2, y3, y4, y5, y6, y7,
167+
y8, y9, y10, y11, y12, y13, y14, y15);
168+
return result;
169+
}
170+
171+
// Pack two fp8_e5_t values.
172+
TL_DEVICE fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) {
173+
fp8_e5_2_t result;
174+
result.x = x;
175+
result.y = y;
176+
return result;
177+
}
178+
179+
// Pack four fp8_e5_t values.
180+
TL_DEVICE fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1,
181+
fp8_e5_t x2, fp8_e5_t x3) {
182+
fp8_e5_4_t result;
183+
result.x = x0; result.y = x1;
184+
result.z = x2; result.w = x3;
185+
return result;
186+
}
187+
188+
// Pack eight fp8_e5_t values.
189+
TL_DEVICE fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3,
190+
fp8_e5_t x4, fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7) {
191+
fp8_e5_8_t result;
192+
result.x = make_fp8_e5_4_t(x0, x1, x2, x3);
193+
result.y = make_fp8_e5_4_t(x4, x5, x6, x7);
194+
return result;
195+
}
196+
197+
// Pack sixteen fp8_e5_t values.
198+
TL_DEVICE fp8_e5_16_t make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3,
199+
fp8_e5_t x4, fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7,
200+
fp8_e5_t y0, fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3,
201+
fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6, fp8_e5_t y7) {
202+
fp8_e5_16_t result;
203+
result.x = make_fp8_e5_8_t(x0, x1, x2, x3, x4, x5, x6, x7);
204+
result.y = make_fp8_e5_8_t(y0, y1, y2, y3, y4, y5, y6, y7);
205+
return result;
206+
}
207+
208+
// Pack thirty-two fp8_e5_t values.
209+
TL_DEVICE fp8_e5_32_t make_fp8_e5_32_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3,
210+
fp8_e5_t x4, fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7,
211+
fp8_e5_t x8, fp8_e5_t x9, fp8_e5_t x10, fp8_e5_t x11,
212+
fp8_e5_t x12, fp8_e5_t x13, fp8_e5_t x14, fp8_e5_t x15,
213+
fp8_e5_t y0, fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3,
214+
fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6, fp8_e5_t y7,
215+
fp8_e5_t y8, fp8_e5_t y9, fp8_e5_t y10, fp8_e5_t y11,
216+
fp8_e5_t y12, fp8_e5_t y13, fp8_e5_t y14, fp8_e5_t y15) {
217+
fp8_e5_32_t result;
218+
result.x = make_fp8_e5_16_t(x0, x1, x2, x3, x4, x5, x6, x7,
219+
x8, x9, x10, x11, x12, x13, x14, x15);
220+
result.y = make_fp8_e5_16_t(y0, y1, y2, y3, y4, y5, y6, y7,
221+
y8, y9, y10, y11, y12, y13, y14, y15);
222+
return result;
223+
}
224+
117225
// Helper to cast SMEM pointer to unsigned
118226
TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
119227
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));

0 commit comments

Comments
 (0)