Skip to content

Commit 8933c53

Browse files
committed
update
1 parent b5b8cb3 commit 8933c53

File tree

2 files changed

+37
-20
lines changed

2 files changed

+37
-20
lines changed

torchao/csrc/cpu/int8_sdpa.cpp

+28-16
Original file line numberDiff line numberDiff line change
@@ -967,9 +967,9 @@ template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_s
967967
inline typename std::enable_if_t<std::is_same_v<scalar_t, unsigned char>, void>
968968
sdpa_int8_kernel_one_loop_impl(
969969
const at::Tensor& output,
970-
const at::Tensor& query,
971-
const at::Tensor& key,
972-
const at::Tensor& value,
970+
const at::Tensor& q,
971+
const at::Tensor& k,
972+
const at::Tensor& v,
973973
double dropout_p,
974974
bool is_causal,
975975
at::Tensor& attention_mask,
@@ -984,9 +984,15 @@ sdpa_int8_kernel_one_loop_impl(
984984
float a_scale,
985985
int32_t o_zp,
986986
float o_scale) {
987-
// Query (Batch x Q_seq_len x Num_heads x Dim_per_head)
988-
// Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
989-
// Value (Batch x KV_seq_len x Num_heads x Dim_per_head)
987+
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
988+
// -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
989+
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
990+
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
991+
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
992+
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
993+
at::Tensor query = q.transpose(1, 2);
994+
at::Tensor key = k.transpose(1, 2);
995+
at::Tensor value = v.transpose(1, 2);
990996

991997
const auto accumulate_dtype = at::kFloat;
992998

@@ -1507,9 +1513,9 @@ template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_s
15071513
inline typename std::enable_if_t<std::is_same_v<scalar_t, unsigned char>, void>
15081514
sdpa_int8_kernel_several_loops_impl(
15091515
const at::Tensor& output,
1510-
const at::Tensor& query,
1511-
const at::Tensor& key,
1512-
const at::Tensor& value,
1516+
const at::Tensor& q,
1517+
const at::Tensor& k,
1518+
const at::Tensor& v,
15131519
double dropout_p,
15141520
bool is_causal,
15151521
at::Tensor& attention_mask,
@@ -1524,9 +1530,15 @@ sdpa_int8_kernel_several_loops_impl(
15241530
float a_scale,
15251531
int32_t o_zp,
15261532
float o_scale) {
1527-
// Query (Batch x Q_seq_len x Num_heads x Dim_per_head)
1528-
// Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
1529-
// Value (Batch x KV_seq_len x Num_heads x Dim_per_head)
1533+
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
1534+
// -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
1535+
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
1536+
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
1537+
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
1538+
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
1539+
at::Tensor query = q.transpose(1, 2);
1540+
at::Tensor key = k.transpose(1, 2);
1541+
at::Tensor value = v.transpose(1, 2);
15301542

15311543
const auto accumulate_dtype = at::kFloat;
15321544

@@ -2347,27 +2359,27 @@ at::Tensor _scaled_dot_product_int8_cpu(
23472359
k_zp, k_scale,
23482360
v_zp, v_scale,
23492361
a_zp, a_scale,
2350-
o_zp, o_scale);
2362+
o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2);
23512363
}
23522364

23532365
#ifdef CPU_CAPABILITY_AVX512
2354-
at::Tensor output = at::empty_like(query, query.options());
2366+
at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2);
23552367
sdpa_int8_fused_kernel(output, query, key, value,
23562368
dropout_p, is_causal, attn_mask, scale,
23572369
q_zp, q_scale,
23582370
k_zp, k_scale,
23592371
v_zp, v_scale,
23602372
a_zp, a_scale,
23612373
o_zp, o_scale);
2362-
return output;
2374+
return output.transpose(1, 2);
23632375
#else
23642376
return sdpa_int8_math_kernel(query, key, value,
23652377
dropout_p, is_causal, attn_mask, scale,
23662378
q_zp, q_scale,
23672379
k_zp, k_scale,
23682380
v_zp, v_scale,
23692381
a_zp, a_scale,
2370-
o_zp, o_scale);
2382+
o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2);
23712383
#endif // CPU_CAPABILITY_AVX512
23722384
}
23732385

torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,13 @@ def int8_sdpa(match: Match, *args, **kwargs):
6464
counters["inductor"]["int8_fuse_attention"] += 1
6565
counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes)
6666

67-
return L[torch.ops.torchao.scaled_dot_product_int8.default](
68-
query,
69-
key,
70-
value,
67+
trans_query = L[aten.permute.default](query, [0, 2, 1, 3])
68+
trans_key = L[aten.permute.default](key, [0, 2, 1, 3])
69+
trans_value = L[aten.permute.default](value, [0, 2, 1, 3])
70+
output = L[torch.ops.torchao.scaled_dot_product_int8.default](
71+
trans_query,
72+
trans_key,
73+
trans_value,
7174
attn_mask,
7275
0.0, #dropout
7376
False, #is_causal
@@ -83,6 +86,8 @@ def int8_sdpa(match: Match, *args, **kwargs):
8386
o_zp,
8487
o_scale,
8588
)
89+
trans_output = L[aten.permute.default](output, [0, 2, 1, 3])
90+
return L[aten.clone.default](trans_output, memory_format=torch.contiguous_format)
8691

8792
return int8_sdpa
8893

0 commit comments

Comments
 (0)