@@ -967,9 +967,9 @@ template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_s
967
967
inline typename std::enable_if_t <std::is_same_v<scalar_t , unsigned char >, void >
968
968
sdpa_int8_kernel_one_loop_impl (
969
969
const at::Tensor& output,
970
- const at::Tensor& query ,
971
- const at::Tensor& key ,
972
- const at::Tensor& value ,
970
+ const at::Tensor& q ,
971
+ const at::Tensor& k ,
972
+ const at::Tensor& v ,
973
973
double dropout_p,
974
974
bool is_causal,
975
975
at::Tensor& attention_mask,
@@ -984,9 +984,15 @@ sdpa_int8_kernel_one_loop_impl(
984
984
float a_scale,
985
985
int32_t o_zp,
986
986
float o_scale) {
987
- // Query (Batch x Q_seq_len x Num_heads x Dim_per_head)
988
- // Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
989
- // Value (Batch x KV_seq_len x Num_heads x Dim_per_head)
987
+ // Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
988
+ // -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
989
+ // Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
990
+ // -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
991
+ // Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
992
+ // -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
993
+ at::Tensor query = q.transpose (1 , 2 );
994
+ at::Tensor key = k.transpose (1 , 2 );
995
+ at::Tensor value = v.transpose (1 , 2 );
990
996
991
997
const auto accumulate_dtype = at::kFloat ;
992
998
@@ -1507,9 +1513,9 @@ template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_s
1507
1513
inline typename std::enable_if_t <std::is_same_v<scalar_t , unsigned char >, void >
1508
1514
sdpa_int8_kernel_several_loops_impl (
1509
1515
const at::Tensor& output,
1510
- const at::Tensor& query ,
1511
- const at::Tensor& key ,
1512
- const at::Tensor& value ,
1516
+ const at::Tensor& q ,
1517
+ const at::Tensor& k ,
1518
+ const at::Tensor& v ,
1513
1519
double dropout_p,
1514
1520
bool is_causal,
1515
1521
at::Tensor& attention_mask,
@@ -1524,9 +1530,15 @@ sdpa_int8_kernel_several_loops_impl(
1524
1530
float a_scale,
1525
1531
int32_t o_zp,
1526
1532
float o_scale) {
1527
- // Query (Batch x Q_seq_len x Num_heads x Dim_per_head)
1528
- // Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
1529
- // Value (Batch x KV_seq_len x Num_heads x Dim_per_head)
1533
+ // Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
1534
+ // -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
1535
+ // Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
1536
+ // -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
1537
+ // Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
1538
+ // -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
1539
+ at::Tensor query = q.transpose (1 , 2 );
1540
+ at::Tensor key = k.transpose (1 , 2 );
1541
+ at::Tensor value = v.transpose (1 , 2 );
1530
1542
1531
1543
const auto accumulate_dtype = at::kFloat ;
1532
1544
@@ -2347,27 +2359,27 @@ at::Tensor _scaled_dot_product_int8_cpu(
2347
2359
k_zp, k_scale,
2348
2360
v_zp, v_scale,
2349
2361
a_zp, a_scale,
2350
- o_zp, o_scale);
2362
+ o_zp, o_scale). transpose ( 1 , 2 ). contiguous (). transpose ( 1 , 2 ) ;
2351
2363
}
2352
2364
2353
2365
#ifdef CPU_CAPABILITY_AVX512
2354
- at::Tensor output = at::empty_like (query, query.options ());
2366
+ at::Tensor output = at::empty_like (query, query.options ()). transpose ( 1 , 2 ) ;
2355
2367
sdpa_int8_fused_kernel (output, query, key, value,
2356
2368
dropout_p, is_causal, attn_mask, scale,
2357
2369
q_zp, q_scale,
2358
2370
k_zp, k_scale,
2359
2371
v_zp, v_scale,
2360
2372
a_zp, a_scale,
2361
2373
o_zp, o_scale);
2362
- return output;
2374
+ return output. transpose ( 1 , 2 ) ;
2363
2375
#else
2364
2376
return sdpa_int8_math_kernel (query, key, value,
2365
2377
dropout_p, is_causal, attn_mask, scale,
2366
2378
q_zp, q_scale,
2367
2379
k_zp, k_scale,
2368
2380
v_zp, v_scale,
2369
2381
a_zp, a_scale,
2370
- o_zp, o_scale);
2382
+ o_zp, o_scale). transpose ( 1 , 2 ). contiguous (). transpose ( 1 , 2 ) ;
2371
2383
#endif // CPU_CAPABILITY_AVX512
2372
2384
}
2373
2385
0 commit comments