forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfloat_subbyte.h
788 lines (619 loc) · 23.7 KB
/
float_subbyte.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Defines classes for FP4/FP6 datatypes
*/
#pragma once
#include "cutlass/arch/config.h"
#include "cutlass/float8.h"
// FP4 types are available starting CUDA 12+
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUDA_FP4_ENABLED 1
#endif
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED))
# define CUDA_PTX_FP4FP6_CVT_ENABLED 1
#endif
#include "cutlass/cutlass.h"
#include "cutlass/exmy_base.h"
#include "cute/util/type_traits.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
// FP4 and FP6 types
struct float_e2m1_t;
struct float_e3m2_t;
// E2M1:
// 2 Exponent bits with 1 Mantissa bit
// Range: +-[0,0.5,1,1.5,2,3,4,5,6]
// has_Inf: false
// has_NaN: false
// has_denorm: true
// Exponent bias (exp_bias): 1
struct float_e2m1_t : public float_exmy_base<cutlass::detail::FpEncoding::E2M1, float_e2m1_t> {
using Base = float_exmy_base<cutlass::detail::FpEncoding::E2M1, float_e2m1_t>;
float_e2m1_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e2m1_t(double x) : Base(float(x)) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m1_t(float x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m1_t(int x) : Base(x) {
}
CUTLASS_HOST_DEVICE
float_e2m1_t(Base x) : Base(x) {
}
};
namespace detail {
// This new type is used to select correct MMA type and TMA type.
struct float_e2m1_unpacksmem_t : public float_exmy_base<cutlass::detail::FpEncoding::E2M1, float_e2m1_t> {
using Base = float_exmy_base<cutlass::detail::FpEncoding::E2M1, float_e2m1_t>;
float_e2m1_unpacksmem_t() = default;
CUTLASS_HOST_DEVICE
float_e2m1_unpacksmem_t(float_e2m1_unpacksmem_t const& x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m1_unpacksmem_t(double x) : Base(float(x)) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m1_unpacksmem_t(float x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m1_unpacksmem_t(int x) : Base(x) {
}
CUTLASS_HOST_DEVICE
float_e2m1_unpacksmem_t(Base x) : Base(x) {
}
};
} // namespace detail
/// Defines the size of an element in bits - specialized for float_e2m1_t
template <>
struct sizeof_bits<float_e2m1_t> {
static constexpr int value = 4;
};
template <>
struct sizeof_bits<detail::float_e2m1_unpacksmem_t> {
static constexpr int value = 4;
};
CUTLASS_HOST_DEVICE
float_e2m1_t abs(float_e2m1_t const& val) {
using BaseType = typename float_e2m1_t::Base;
return float_e2m1_t(abs(BaseType{val.raw()}));
}
// E2M3:
// 2 Exponent bits with 3 Mantissa bit
// Range: [-7.5,+7.5]
// has_Inf: false
// has_NaN: false
// has_denorm: true
// Exponent bias (exp_bias): 1
struct float_e2m3_t : public float_exmy_base<cutlass::detail::FpEncoding::E2M3, float_e2m3_t> {
using Base = float_exmy_base<cutlass::detail::FpEncoding::E2M3, float_e2m3_t>;
float_e2m3_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e2m3_t(double x) : Base(float(x)) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m3_t(float x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m3_t(int x) : Base(x) {
}
CUTLASS_HOST_DEVICE
float_e2m3_t(Base x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m3_t(float_e3m2_t x);
};
namespace detail {
struct float_e2m3_unpack8bits_t: public float_exmy_base<cutlass::detail::FpEncoding::E2M3, float_e2m3_unpack8bits_t> {
// Used in register.
using Base = float_exmy_base<cutlass::detail::FpEncoding::E2M3, float_e2m3_unpack8bits_t>;
float_e2m3_unpack8bits_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e2m3_unpack8bits_t(double x) : Base(float(x)) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m3_unpack8bits_t(float x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m3_unpack8bits_t(int x) : Base(x) {
}
CUTLASS_HOST_DEVICE
float_e2m3_unpack8bits_t(Base x) : Base(x) {
}
};
// This new type is used to select correct MMA type and TMA type.
struct float_e2m3_unpacksmem_t : public float_exmy_base<cutlass::detail::FpEncoding::E2M3, float_e2m3_t> {
using Base = float_exmy_base<cutlass::detail::FpEncoding::E2M3, float_e2m3_t>;
float_e2m3_unpacksmem_t() = default;
CUTLASS_HOST_DEVICE
float_e2m3_unpacksmem_t(float_e2m3_unpacksmem_t const& x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m3_unpacksmem_t(double x) : Base(float(x)) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m3_unpacksmem_t(float x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e2m3_unpacksmem_t(int x) : Base(x) {
}
CUTLASS_HOST_DEVICE
float_e2m3_unpacksmem_t(Base x) : Base(x) {
}
};
} // namespace detail
/// Defines the size of an element in bits - specialized for float_e2m3_t
template <>
struct sizeof_bits<float_e2m3_t> {
static constexpr int value = 6;
};
/// Defines the size of an element in bits - specialized for float_e2m3_unpacksmem_t
template <>
struct sizeof_bits<detail::float_e2m3_unpacksmem_t> {
static constexpr int value = 6;
};
CUTLASS_HOST_DEVICE
float_e2m3_t abs(float_e2m3_t const& val) {
using BaseType = typename float_e2m3_t::Base;
return float_e2m3_t(abs(BaseType{val.raw()}));
}
// E3M2:
// 3 Exponent bits, 2 Mantissa bits
// Range: [-28:+28]
// has_inf: false
// has_NaN: false
// has_denorm: true
// Exponent bias (exp_bias): 3
struct float_e3m2_t : public float_exmy_base<cutlass::detail::FpEncoding::E3M2, float_e3m2_t> {
using Base = float_exmy_base<cutlass::detail::FpEncoding::E3M2, float_e3m2_t>;
float_e3m2_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e3m2_t(double x) : Base(float(x)) {
}
CUTLASS_HOST_DEVICE
explicit float_e3m2_t(float x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e3m2_t(int x) : Base(x) {
}
CUTLASS_HOST_DEVICE
float_e3m2_t(Base x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e3m2_t(float_e2m3_t x);
};
namespace detail {
struct float_e3m2_unpack8bits_t : public float_exmy_base<cutlass::detail::FpEncoding::E3M2, float_e3m2_unpack8bits_t> {
using Base = float_exmy_base<cutlass::detail::FpEncoding::E3M2, float_e3m2_unpack8bits_t>;
float_e3m2_unpack8bits_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e3m2_unpack8bits_t(double x) : Base(float(x)) {
}
CUTLASS_HOST_DEVICE
explicit float_e3m2_unpack8bits_t(float x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e3m2_unpack8bits_t(int x) : Base(x) {
}
CUTLASS_HOST_DEVICE
float_e3m2_unpack8bits_t(Base x) : Base(x) {
}
};
// This new type is used to select correct MMA type and TMA type.
struct float_e3m2_unpacksmem_t : public float_exmy_base<cutlass::detail::FpEncoding::E3M2, float_e3m2_t> {
using Base = float_exmy_base<cutlass::detail::FpEncoding::E3M2, float_e3m2_t>;
float_e3m2_unpacksmem_t() = default;
CUTLASS_HOST_DEVICE
float_e3m2_unpacksmem_t(float_e3m2_unpacksmem_t const& x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e3m2_unpacksmem_t(double x) : Base(float(x)) {
}
CUTLASS_HOST_DEVICE
explicit float_e3m2_unpacksmem_t(float x) : Base(x) {
}
CUTLASS_HOST_DEVICE
explicit float_e3m2_unpacksmem_t(int x) : Base(x) {
}
CUTLASS_HOST_DEVICE
float_e3m2_unpacksmem_t(Base x) : Base(x) {
}
};
} // namespace detail
/// Defines the size of an element in bits - specialized for float_e3m2_t
template <>
struct sizeof_bits<float_e3m2_t> {
static constexpr int value = 6;
};
/// Defines the size of an element in bits - specialized for float_e3m2_unpacksmem_t
template <>
struct sizeof_bits<detail::float_e3m2_unpacksmem_t> {
static constexpr int value = 6;
};
CUTLASS_HOST_DEVICE
float_e3m2_t abs(float_e3m2_t const& val) {
using BaseType = typename float_e3m2_t::Base;
return float_e3m2_t(abs(BaseType{val.raw()}));
}
/// Defines the size of an element in bits - specialized for float_e3m2_unpack8bits_t
template <>
struct sizeof_bits<detail::float_e3m2_unpack8bits_t> {
static constexpr int value = 8;
};
/// Defines the size of an element in bits - specialized for float_e2m3_unpack8bits_t
template <>
struct sizeof_bits<detail::float_e2m3_unpack8bits_t> {
static constexpr int value = 8;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Get the register type used in kernel
//
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template<typename T>
struct get_unpacked_element_type;
template <>
struct get_unpacked_element_type<float_e2m3_t> {
using type = detail::float_e2m3_unpack8bits_t;
};
template <>
struct get_unpacked_element_type<float_e3m2_t> {
using type = detail::float_e3m2_unpack8bits_t;
};
} // namespace detail
// ///////////////////////////////////////////////////////////////////////////////////////////////////
// //
// // float_e2m3_t <=> float_e3m2_t conversions
// //
// ///////////////////////////////////////////////////////////////////////////////////////////////////
CUTLASS_HOST_DEVICE
float_e2m3_t::float_e2m3_t(float_e3m2_t x)
{
storage = convert_from_float(float(x)).storage;
}
CUTLASS_HOST_DEVICE
float_e3m2_t::float_e3m2_t(float_e2m3_t x)
{
storage = convert_from_float(float(x)).storage;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////
///
/// Umbrella floating-point 6-bit data type : type_erased_dynamic_float6_t
/// This umbrella datatype can be enabled when a user provides a specific
/// datatype in runtime argument list.
///
/// Currently supported runtime datatypes compatible with type_erased_dynamic_float6_t:
/// MXF8F6F4Format::E2M3
/// MXF8F6F4Format::E3M2
///
///////////////////////////////////////////////////////////////
union type_erased_dynamic_float6_t {
cutlass::float_e2m3_t e2m3;
cutlass::float_e3m2_t e3m2;
CUTLASS_HOST_DEVICE
explicit operator cutlass::float_e2m3_t() const {
return e2m3;
}
CUTLASS_HOST_DEVICE
explicit operator cutlass::float_e3m2_t() const {
return e3m2;
}
};
template <>
struct sizeof_bits<type_erased_dynamic_float6_t> {
static constexpr int value = 6;
};
///////////////////////////////////////////////////////////////
///
/// Umbrella floating-point 4-bit data type : type_erased_dynamic_float4_t
/// This umbrella datatype can be enabled when a user provides a specific
/// datatype in runtime argument list.
///
/// Currently supported runtime datatypes compatible with type_erased_dynamic_float4_t:
/// MXF8F6F4Format::E2M1
///
///////////////////////////////////////////////////////////////
union type_erased_dynamic_float4_t {
cutlass::float_e2m1_t e2m1;
CUTLASS_HOST_DEVICE
explicit operator cutlass::float_e2m1_t() const {
return e2m1;
}
};
template <>
struct sizeof_bits<type_erased_dynamic_float4_t> {
static constexpr int value = 4;
};
///////////////////////////////////////////////////////////////
/// MX/NV types for float6 and float4
/// Intended to be used in builders
///////////////////////////////////////////////////////////////
template <class F6Type>
struct mx_float6_t
{
static_assert(cute::is_same_v<F6Type,cutlass::float_e2m3_t>
|| cute::is_same_v<F6Type,cutlass::float_e3m2_t>
|| cute::is_same_v<F6Type,type_erased_dynamic_float6_t>
, "Only float_e2m3_t, float_e3m2_t can have scale factors for MXFP6");
using ScaleFactorType = cutlass::float_ue8m0_t;
using DataType = F6Type;
};
using type_erased_dynamic_mx_float6_t = mx_float6_t<type_erased_dynamic_float6_t>;
template <class F4Type>
struct mx_float4_t
{
static_assert(cute::is_same_v<F4Type,cutlass::float_e2m1_t>
|| cute::is_same_v<F4Type,type_erased_dynamic_float4_t>
, "Only float_e2m1_t type_erased_dynamic_float4_t can have scale factors for MXFP4");
using ScaleFactorType = cutlass::float_ue8m0_t;
using DataType = F4Type;
};
using type_erased_dynamic_mx_float4_t = mx_float4_t<type_erased_dynamic_float4_t>;
template <class F4Type>
struct nv_float4_t
{
static_assert(cute::is_same_v<F4Type,cutlass::float_e2m1_t>
|| cute::is_same_v<F4Type,type_erased_dynamic_float4_t>
, "Only float_e2m1_t type_erased_dynamic_float4_t can have scale factors for NVFP4");
using ScaleFactorType = cutlass::float_ue4m3_t;
using DataType = F4Type;
};
using type_erased_dynamic_nv_float4_t = nv_float4_t<type_erased_dynamic_float4_t>;
namespace detail {
union type_erased_dynamic_float6_unpacksmem_t {
cutlass::detail::float_e2m3_unpacksmem_t e2m3_unpacksmem;
cutlass::detail::float_e3m2_unpacksmem_t e3m2_unpacksmem;
CUTLASS_HOST_DEVICE
explicit operator cutlass::detail::float_e2m3_unpacksmem_t() const {
return e2m3_unpacksmem;
}
CUTLASS_HOST_DEVICE
explicit operator cutlass::detail::float_e3m2_unpacksmem_t() const {
return e3m2_unpacksmem;
}
};
union type_erased_dynamic_float4_unpacksmem_t {
cutlass::detail::float_e2m1_unpacksmem_t e2m1_unpacksmem;
CUTLASS_HOST_DEVICE
explicit operator cutlass::detail::float_e2m1_unpacksmem_t() const {
return e2m1_unpacksmem;
}
};
};
template <>
struct sizeof_bits<detail::type_erased_dynamic_float6_unpacksmem_t> {
static constexpr int value = 6;
};
template <>
struct sizeof_bits<detail::type_erased_dynamic_float4_unpacksmem_t> {
static constexpr int value = 4;
};
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Standard Library operations and definitions
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(__CUDACC_RTC__)
namespace std {
/// Numeric limits common to all float4 types
template <typename T>
struct float_subbyte_base_numeric_limits
{
private:
using type = T;
public:
static bool const is_specialized = true;
static bool const is_signed = true;
static bool const is_integer = false;
static bool const is_exact = false;
static bool const has_quiet_NaN = false;
static bool const has_signaling_NaN = false;
static bool const has_denorm_loss = true;
static cutlass::platform::float_denorm_style const has_denorm = cutlass::platform::denorm_present;
static cutlass::platform::float_round_style const round_style = cutlass::platform::round_to_nearest;
static bool const is_iec559 = false;
static bool const is_bounded = true;
static bool const is_modulo = false;
static int const digits = type::Base::BitRepresentation::NUM_MANTISSA_BITS;
static bool const has_infinity = false;
/// Least positive value
static type min() { return type::bitcast(0x01); }
/// Maximum finite value
static type max() { return type::bitcast(type::Base::BitRepresentation::MAX_VALUE); }
/// Returns maximum rounding error
static type round_error() { return type(0.5f); }
/// Returns positive infinity value
static type infinity() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); }
/// Returns quiet NaN value
static type quiet_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); }
/// Returns signaling NaN value
static type signaling_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); }
/// Returns smallest positive subnormal value
static type denorm_min() { return type::bitcast(0x01); }
};
/// Numeric limits for float_e2m1_t
template <>
struct numeric_limits<cutlass::float_e2m1_t> : public float_subbyte_base_numeric_limits<cutlass::float_e2m1_t>
{
/// Minimum finite value
static cutlass::float_e2m1_t lowest() { return cutlass::float_e2m1_t::bitcast(0xf); }
/// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point
static cutlass::float_e2m1_t epsilon() { return cutlass::float_e2m1_t::bitcast(0x1); }
};
/// Numeric limits for float_e2m3_t
template <>
struct numeric_limits<cutlass::float_e2m3_t> : public float_subbyte_base_numeric_limits<cutlass::float_e2m3_t>
{
/// Minimum finite value
static cutlass::float_e2m3_t lowest() { return cutlass::float_e2m3_t::bitcast(0x2f); }
/// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point
static cutlass::float_e2m3_t epsilon() { return cutlass::float_e2m3_t::bitcast(0x1); }
};
/// Numeric limits for float_e3m2_t
template <>
struct numeric_limits<cutlass::float_e3m2_t> : public float_subbyte_base_numeric_limits<cutlass::float_e3m2_t>
{
/// Minimum finite value
static cutlass::float_e3m2_t lowest() { return cutlass::float_e3m2_t::bitcast(0x2f); }
/// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point
static cutlass::float_e3m2_t epsilon() { return cutlass::float_e3m2_t::bitcast(0x4); }
};
} // namespace std
#endif
namespace cutlass {
namespace platform {
/// Numeric limits common to all float4 types
template <typename T>
struct float_subbyte_base_numeric_limits
{
private:
using type = T;
public:
static bool const is_specialized = true;
static bool const is_signed = true;
static bool const is_integer = false;
static bool const is_exact = false;
static bool const has_quiet_NaN = false;
static bool const has_signaling_NaN = false;
static bool const has_denorm_loss = true;
static cutlass::platform::float_denorm_style const has_denorm = cutlass::platform::denorm_present;
static cutlass::platform::float_round_style const round_style = cutlass::platform::round_to_nearest;
static bool const is_iec559 = false;
static bool const is_bounded = true;
static bool const is_modulo = false;
static int const digits = type::Base::BitRepresentation::NUM_MANTISSA_BITS;
static bool const has_infinity = false;
/// Least positive value
static type min() { return type::bitcast(0x01); }
/// Maximum finite value
CUTLASS_HOST_DEVICE static type max() { return type::bitcast(type::Base::BitRepresentation::MAX_VALUE); }
/// Returns maximum rounding error
static type round_error() { return type(0.5f); }
/// Returns positive infinity value
static type infinity() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); }
/// Returns quiet NaN value
static type quiet_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); }
/// Returns signaling NaN value
static type signaling_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); }
/// Returns smallest positive subnormal value
static type denorm_min() { return type::bitcast(0x01); }
};
/// Forward Declaration
template <class T>
struct numeric_limits;
/// Numeric limits for float_e2m1_t
template <>
struct numeric_limits<cutlass::float_e2m1_t> : public float_subbyte_base_numeric_limits<cutlass::float_e2m1_t>
{
/// Minimum finite value
static cutlass::float_e2m1_t lowest() { return cutlass::float_e2m1_t::bitcast(0xf); }
/// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point
static cutlass::float_e2m1_t epsilon() { return cutlass::float_e2m1_t::bitcast(0x1); }
};
/// Numeric limits for float_e2m3_t
template <>
struct numeric_limits<cutlass::float_e2m3_t> : public float_subbyte_base_numeric_limits<cutlass::float_e2m3_t>
{
/// Minimum finite value
static cutlass::float_e2m3_t lowest() { return cutlass::float_e2m3_t::bitcast(0x2f); }
/// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point
static cutlass::float_e2m3_t epsilon() { return cutlass::float_e2m3_t::bitcast(0x1); }
};
/// Numeric limits for float_e3m2_t
template <>
struct numeric_limits<cutlass::float_e3m2_t> : public float_subbyte_base_numeric_limits<cutlass::float_e3m2_t>
{
/// Minimum finite value
static cutlass::float_e3m2_t lowest() { return cutlass::float_e3m2_t::bitcast(0x2f); }
/// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point
static cutlass::float_e3m2_t epsilon() { return cutlass::float_e3m2_t::bitcast(0x4); }
};
/// Numeric limits for float_e2m3_unpack8bits_t
template <>
struct numeric_limits<cutlass::detail::float_e2m3_unpack8bits_t> : public float_subbyte_base_numeric_limits<cutlass::detail::float_e2m3_unpack8bits_t>
{
/// Minimum finite value
static cutlass::detail::float_e2m3_unpack8bits_t lowest() { return cutlass::detail::float_e2m3_unpack8bits_t::bitcast(0x2f); }
/// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point
static cutlass::detail::float_e2m3_unpack8bits_t epsilon() { return cutlass::detail::float_e2m3_unpack8bits_t::bitcast(0x1); }
};
/// Numeric limits for float_e3m2_unpack8bits_t
template <>
struct numeric_limits<cutlass::detail::float_e3m2_unpack8bits_t> : public float_subbyte_base_numeric_limits<cutlass::detail::float_e3m2_unpack8bits_t>
{
/// Minimum finite value
static cutlass::detail::float_e3m2_unpack8bits_t lowest() { return cutlass::detail::float_e3m2_unpack8bits_t::bitcast(0x2f); }
/// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point
static cutlass::detail::float_e3m2_unpack8bits_t epsilon() { return cutlass::detail::float_e3m2_unpack8bits_t::bitcast(0x4); }
};
} // namespace platform
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// User-defined literals
//
CUTLASS_HOST_DEVICE
cutlass::float_e2m1_t operator"" _fe2m1(long double x)
{
return cutlass::float_e2m1_t(float(x));
}
CUTLASS_HOST_DEVICE
cutlass::float_e2m1_t operator"" _fe2m1(unsigned long long int x)
{
return cutlass::float_e2m1_t(int(x));
}
CUTLASS_HOST_DEVICE
cutlass::float_e2m3_t operator"" _fe2m3(long double x)
{
return cutlass::float_e2m3_t(float(x));
}
CUTLASS_HOST_DEVICE
cutlass::float_e2m3_t operator"" _fe2m3(unsigned long long int x)
{
return cutlass::float_e2m3_t(int(x));
}
CUTLASS_HOST_DEVICE
cutlass::float_e3m2_t operator"" _fe3m2(long double x)
{
return cutlass::float_e3m2_t(float(x));
}
CUTLASS_HOST_DEVICE
cutlass::float_e3m2_t operator"" _fe3m2(unsigned long long int x)
{
return cutlass::float_e3m2_t(int(x));
}
/////////////////////////////////////////////////////////////////////////////////////////////////