Skip to content

Commit fd3d6b5

Browse files
committed
Refine code
1 parent 5730278 commit fd3d6b5

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchao/csrc/cpu/aten_kernels/float8_linear.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,11 +492,12 @@ void _float8_linear_impl(
492492

493493
at::parallel_for(0, num_parallel_blocks, 1, [&](int64_t begin, int64_t end) {
494494
// Get the address of pre-allocated buffers
495-
float* y_buf = y_buffer.data_ptr<float>() + at::get_thread_num() * block_size;
495+
int tid = at::get_thread_num();
496+
float* y_buf = y_buffer.data_ptr<float>() + tid * block_size;
496497
at::BFloat16 *dqA_buffer = nullptr, *dqB_buffer = nullptr;
497498
float* ukernel_buf = nullptr;
498499
#if defined(CPU_CAPABILITY_AVX512)
499-
at::BFloat16* micro_gemm_buf = micro_gemm_buffer.data_ptr<at::BFloat16>() + at::get_thread_num() * buffer_size;
500+
at::BFloat16* micro_gemm_buf = micro_gemm_buffer.data_ptr<at::BFloat16>() + tid * buffer_size;
500501
ukernel_buf = reinterpret_cast<float*>(micro_gemm_buf);
501502
#ifndef CPUBLAS_BRGEMM_F8F8F32
502503
dqA_buffer = micro_gemm_buf;

0 commit comments

Comments
 (0)