Skip to content

Commit ca31c32

Browse files
authored
Rename "HALF" and "sh" to "BFLOAT16" and "sb"
1 parent 5800758 commit ca31c32

9 files changed

Lines changed: 178 additions & 222 deletions

File tree

cblas.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPE
392392
/* convert BFLOAT16 array to double array */
393393
void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout);
394394
/* dot production of BFLOAT16 input arrays, and output as float */
395-
float cblas_shdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
395+
float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
396396

397397
#ifdef __cplusplus
398398
}

common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ typedef unsigned long BLASULONG;
260260
#ifndef BFLOAT16
261261
#include <stdint.h>
262262
typedef uint16_t bfloat16;
263-
#define HALFCONVERSION 1
263+
#define BFLOAT16CONVERSION 1
264264
#endif
265265

266266
#ifdef USE64BITINT
@@ -303,7 +303,7 @@ typedef int blasint;
303303
#define SIZE 8
304304
#define BASE_SHIFT 3
305305
#define ZBASE_SHIFT 4
306-
#elif defined(HALF)
306+
#elif defined(BFLOAT16)
307307
#define IFLOAT bfloat16
308308
#define XFLOAT IFLOAT
309309
#define FLOAT float

common_interface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ double BLASFUNC(dsdot) (blasint *, float *, blasint *, float *, blasint *);
5454
double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *);
5555
xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *);
5656

57-
float BLASFUNC(shdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *);
57+
float BLASFUNC(sbdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *);
5858
void BLASFUNC(shstobf16) (blasint *, float *, blasint *, bfloat16 *, blasint *);
5959
void BLASFUNC(shdtobf16) (blasint *, double *, blasint *, bfloat16 *, blasint *);
6060
void BLASFUNC(sbf16tos) (blasint *, bfloat16 *, blasint *, float *, blasint *);
@@ -474,7 +474,7 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint
474474

475475
/* Level 3 routines */
476476

477-
void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
477+
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
478478
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
479479
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
480480
float *, blasint *, float *, blasint *, float *, float *, blasint *);

common_level1.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ float sdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG);
4646
double dsdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG);
4747
double ddot_k(BLASLONG, double *, BLASLONG, double *, BLASLONG);
4848
xdouble qdot_k(BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
49-
float shdot_k(BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
49+
float sbdot_k(BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
5050

5151
void shstobf16_k(BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG);
5252
void shdtobf16_k(BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG);

common_level3.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,
5555
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
5656

5757

58-
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
58+
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
5959
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
6060
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
6161
float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
@@ -78,10 +78,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *,
7878
xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
7979
#endif
8080

81-
int shgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
82-
int shgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
83-
int shgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
84-
int shgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
81+
int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
82+
int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
83+
int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
84+
int sbgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
8585
int sgemm_incopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
8686
int sgemm_itcopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
8787
int sgemm_oncopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
@@ -505,7 +505,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
505505
int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
506506
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
507507

508-
int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
508+
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
509509
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
510510
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
511511

@@ -534,10 +534,10 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float
534534
int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG);
535535
int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG);
536536

537-
int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
538-
int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
539-
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
540-
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
537+
int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
538+
int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
539+
int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
540+
int sbgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
541541

542542
int sgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
543543
int sgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
@@ -631,10 +631,10 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON
631631
int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG);
632632
#endif
633633

634-
int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
635-
int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
636-
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
637-
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
634+
int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
635+
int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
636+
int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
637+
int sbgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
638638

639639
int sgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
640640
int sgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);

common_macro.h

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@
644644

645645
#define GEADD_K DGEADD_K
646646

647-
#elif defined(HALF)
647+
#elif defined(BFLOAT16)
648648

649649
#define D_TO_BF16_K SHDTOBF16_K
650650
#define D_BF16_TO_K DBF16TOD_K
@@ -662,7 +662,7 @@
662662
#define ASUM_K SASUM_K
663663
#define DOTU_K SDOTU_K
664664
#define DOTC_K SDOTC_K
665-
#define BF16_DOT_K SHDOT_K
665+
#define BF16_DOT_K SBDOT_K
666666
#define AXPYU_K SAXPYU_K
667667
#define AXPYC_K SAXPYC_K
668668
#define AXPBY_K SAXPBY_K
@@ -682,32 +682,32 @@
682682
#define NRM2_K SNRM2_K
683683
#define SYMV_THREAD_U SSYMV_THREAD_U
684684
#define SYMV_THREAD_L SSYMV_THREAD_L
685-
#define GEMM_BETA SHGEMM_BETA
686-
#define GEMM_KERNEL_N SHGEMM_KERNEL
687-
#define GEMM_KERNEL_L SHGEMM_KERNEL
688-
#define GEMM_KERNEL_R SHGEMM_KERNEL
689-
#define GEMM_KERNEL_B SHGEMM_KERNEL
690-
691-
#define GEMM_NN SHGEMM_NN
692-
#define GEMM_CN SHGEMM_TN
693-
#define GEMM_TN SHGEMM_TN
694-
#define GEMM_NC SHGEMM_NT
695-
#define GEMM_NT SHGEMM_NT
696-
#define GEMM_CC SHGEMM_TT
697-
#define GEMM_CT SHGEMM_TT
698-
#define GEMM_TC SHGEMM_TT
699-
#define GEMM_TT SHGEMM_TT
700-
#define GEMM_NR SHGEMM_NN
701-
#define GEMM_TR SHGEMM_TN
702-
#define GEMM_CR SHGEMM_TN
703-
#define GEMM_RN SHGEMM_NN
704-
#define GEMM_RT SHGEMM_NT
705-
#define GEMM_RC SHGEMM_NT
706-
#define GEMM_RR SHGEMM_NN
707-
#define GEMM_ONCOPY SHGEMM_ONCOPY
708-
#define GEMM_OTCOPY SHGEMM_OTCOPY
709-
#define GEMM_INCOPY SHGEMM_INCOPY
710-
#define GEMM_ITCOPY SHGEMM_ITCOPY
685+
#define GEMM_BETA SBGEMM_BETA
686+
#define GEMM_KERNEL_N SBGEMM_KERNEL
687+
#define GEMM_KERNEL_L SBGEMM_KERNEL
688+
#define GEMM_KERNEL_R SBGEMM_KERNEL
689+
#define GEMM_KERNEL_B SBGEMM_KERNEL
690+
691+
#define GEMM_NN SBGEMM_NN
692+
#define GEMM_CN SBGEMM_TN
693+
#define GEMM_TN SBGEMM_TN
694+
#define GEMM_NC SBGEMM_NT
695+
#define GEMM_NT SBGEMM_NT
696+
#define GEMM_CC SBGEMM_TT
697+
#define GEMM_CT SBGEMM_TT
698+
#define GEMM_TC SBGEMM_TT
699+
#define GEMM_TT SBGEMM_TT
700+
#define GEMM_NR SBGEMM_NN
701+
#define GEMM_TR SBGEMM_TN
702+
#define GEMM_CR SBGEMM_TN
703+
#define GEMM_RN SBGEMM_NN
704+
#define GEMM_RT SBGEMM_NT
705+
#define GEMM_RC SBGEMM_NT
706+
#define GEMM_RR SBGEMM_NN
707+
#define GEMM_ONCOPY SBGEMM_ONCOPY
708+
#define GEMM_OTCOPY SBGEMM_OTCOPY
709+
#define GEMM_INCOPY SBGEMM_INCOPY
710+
#define GEMM_ITCOPY SBGEMM_ITCOPY
711711
#define SYMM_THREAD_LU SSYMM_THREAD_LU
712712
#define SYMM_THREAD_LL SSYMM_THREAD_LL
713713
#define SYMM_THREAD_RU SSYMM_THREAD_RU
@@ -723,22 +723,22 @@
723723
#define HEMM_THREAD_RU SHEMM_THREAD_RU
724724
#define HEMM_THREAD_RL SHEMM_THREAD_RL
725725

726-
#define GEMM_THREAD_NN SHGEMM_THREAD_NN
727-
#define GEMM_THREAD_CN SHGEMM_THREAD_TN
728-
#define GEMM_THREAD_TN SHGEMM_THREAD_TN
729-
#define GEMM_THREAD_NC SHGEMM_THREAD_NT
730-
#define GEMM_THREAD_NT SHGEMM_THREAD_NT
731-
#define GEMM_THREAD_CC SHGEMM_THREAD_TT
732-
#define GEMM_THREAD_CT SHGEMM_THREAD_TT
733-
#define GEMM_THREAD_TC SHGEMM_THREAD_TT
734-
#define GEMM_THREAD_TT SHGEMM_THREAD_TT
735-
#define GEMM_THREAD_NR SHGEMM_THREAD_NN
736-
#define GEMM_THREAD_TR SHGEMM_THREAD_TN
737-
#define GEMM_THREAD_CR SHGEMM_THREAD_TN
738-
#define GEMM_THREAD_RN SHGEMM_THREAD_NN
739-
#define GEMM_THREAD_RT SHGEMM_THREAD_NT
740-
#define GEMM_THREAD_RC SHGEMM_THREAD_NT
741-
#define GEMM_THREAD_RR SHGEMM_THREAD_NN
726+
#define GEMM_THREAD_NN SBGEMM_THREAD_NN
727+
#define GEMM_THREAD_CN SBGEMM_THREAD_TN
728+
#define GEMM_THREAD_TN SBGEMM_THREAD_TN
729+
#define GEMM_THREAD_NC SBGEMM_THREAD_NT
730+
#define GEMM_THREAD_NT SBGEMM_THREAD_NT
731+
#define GEMM_THREAD_CC SBGEMM_THREAD_TT
732+
#define GEMM_THREAD_CT SBGEMM_THREAD_TT
733+
#define GEMM_THREAD_TC SBGEMM_THREAD_TT
734+
#define GEMM_THREAD_TT SBGEMM_THREAD_TT
735+
#define GEMM_THREAD_NR SBGEMM_THREAD_NN
736+
#define GEMM_THREAD_TR SBGEMM_THREAD_TN
737+
#define GEMM_THREAD_CR SBGEMM_THREAD_TN
738+
#define GEMM_THREAD_RN SBGEMM_THREAD_NN
739+
#define GEMM_THREAD_RT SBGEMM_THREAD_NT
740+
#define GEMM_THREAD_RC SBGEMM_THREAD_NT
741+
#define GEMM_THREAD_RR SBGEMM_THREAD_NN
742742

743743
#ifdef UNIT
744744

@@ -2491,9 +2491,9 @@
24912491
#if defined(ARCH_X86) || defined(ARCH_X86_64) || defined(ARCH_IA64) || defined(ARCH_MIPS64) || defined(ARCH_ARM64)
24922492
extern BLASLONG gemm_offset_a;
24932493
extern BLASLONG gemm_offset_b;
2494-
extern BLASLONG shgemm_p;
2495-
extern BLASLONG shgemm_q;
2496-
extern BLASLONG shgemm_r;
2494+
extern BLASLONG sbgemm_p;
2495+
extern BLASLONG sbgemm_q;
2496+
extern BLASLONG sbgemm_r;
24972497
extern BLASLONG sgemm_p;
24982498
extern BLASLONG sgemm_q;
24992499
extern BLASLONG sgemm_r;

0 commit comments

Comments
 (0)