Skip to content

Commit 32f91e4

Browse files
Cleanup linear projection stage beging multihead partitioning
clean up debug from input_linear_to_qkv and have input be put in via vector if instructions.
1 parent 5fff079 commit 32f91e4

File tree

1 file changed

+44
-45
lines changed

1 file changed

+44
-45
lines changed

src/onnx/parse_attention.cpp

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,6 @@ struct parse_attention : op_parser<parse_attention>
334334
{
335335
auto mask_index_shape = mask_index->get_shape();
336336
auto mask_index_lens = mask_index_shape.lens();
337-
bool mask_index_is_trash = false;
338337

339338
if(mask_index_shape.type() != migraphx::shape::int32_type)
340339
{
@@ -413,16 +412,27 @@ struct parse_attention : op_parser<parse_attention>
413412
return input_arguments;
414413
}
415414

415+
static std::vector<std::vector<instruction_ref>> qkv_split_per_head(const onnx_parser::node_info& info,
416+
const std::vector<instruction_ref>& qkv_mats,
417+
const attention_attr& attr_in,
418+
const attention_infered& infered_in)
419+
{
420+
std::vector<std::vector<instruction_ref>> qkv_split;
421+
return qkv_split;
422+
}
423+
416424
static instruction_ref scale_dot_attention_head(const onnx_parser::node_info& info,
417-
const instruction_ref& Q,
418-
const instruction_ref& K,
419-
const instruction_ref& V,
425+
const std::vector<instruction_ref>& QKV,
420426
const instruction_ref& scale_factor,
421427
const instruction_ref& mask,
422428
const instruction_ref& bias,
423429
bool masked=false,
424430
bool attn_bias=false)
425431
{
432+
auto Q = QKV.at(1);
433+
auto K = QKV.at(2);
434+
auto V = QKV.at(3);
435+
426436
auto k_trans = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 1}}}), K);
427437
k_trans->debug_print();
428438
auto qk_out = info.add_instruction(make_op("dot"), Q, k_trans);
@@ -450,54 +460,42 @@ struct parse_attention : op_parser<parse_attention>
450460
}
451461

452462
// Get Q, K, V matricies from stacked weight matrix
453-
static void input_linear_to_qkv(const onnx_parser::node_info& info,
454-
const instruction_ref& input,
455-
const instruction_ref& stacked_weights,
456-
const std::vector<size_t>& qkv_sizes,
457-
const instruction_ref& input_bias,
458-
const bool has_input_bias,
459-
instruction_ref& Q,
460-
instruction_ref& K,
461-
instruction_ref& V)
463+
static std::vector<instruction_ref> input_linear_to_qkv(const onnx_parser::node_info& info,
464+
const instruction_ref& input,
465+
const instruction_ref& stacked_weights,
466+
const std::vector<size_t>& qkv_sizes,
467+
const instruction_ref& input_bias,
468+
const bool has_input_bias)
462469
{
463470
// Input encodes the batch, sequence_length and input_hidden_size (also known as embedding size)
464471
auto input_lens = input->get_shape().lens();
465472

466473
auto stacked_weights_unsq = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), stacked_weights);
467-
stacked_weights_unsq->debug_print();
468474

469475
// Input stacked weights are (input_hidden_size, hidden_size + hidden_size + v_hidden_size) so slice out parts for each matrix
470476
// Since we known the input_hidden size is one dimension wee need to slice out the weight tensors accordingly before we perform matmul
471477
auto q_weight = info.add_instruction(make_op("slice", {{"axes",{2}}, {"starts", {0}}, {"ends", {qkv_sizes.at(0)-1}}}), stacked_weights_unsq);
472478
auto k_weight = info.add_instruction(make_op("slice", {{"axes",{2}}, {"starts", {qkv_sizes.at(0)}}, {"ends", {qkv_sizes.at(1) + qkv_sizes.at(0) - 1}}}), stacked_weights_unsq);
473479
auto v_weight = info.add_instruction(make_op("slice", {{"axes",{2}}, {"starts", {qkv_sizes.at(0) + qkv_sizes.at(1)}}, {"ends", {qkv_sizes.at(0) + qkv_sizes.at(1) + qkv_sizes.at(2) -1 }}}), stacked_weights_unsq);
474480

475-
q_weight->debug_print();
476-
k_weight->debug_print();
477-
v_weight->debug_print();
478-
479481
// Add in batch dimension to weights
480482
auto qk_lens = q_weight->get_shape().lens();
481483
qk_lens.at(0) = input_lens.at(0);
482484
auto v_lens = v_weight->get_shape().lens();
483485
v_lens.at(0) = input_lens.at(0);
484486

485-
std::cout << qk_lens.at(0) << "," << qk_lens.at(1) << "," << qk_lens.at(2) << std::endl;
486-
487487
//Broadcast to batch size
488488
auto q_weight_bcast = info.add_instruction(make_op("multibroadcast", {{"out_lens", qk_lens}}), q_weight);
489489
auto k_weight_bcast = info.add_instruction(make_op("multibroadcast", {{"out_lens", qk_lens}}), k_weight);
490490
auto v_weight_bcast = info.add_instruction(make_op("multibroadcast", {{"out_lens", v_lens}}), v_weight);
491491

492-
q_weight_bcast->debug_print();
493-
k_weight_bcast->debug_print();
494-
v_weight_bcast->debug_print();
495-
input->debug_print();
496-
497492
// Broadcast by batch then multiply
498-
Q = info.add_instruction(make_op("dot"), input, q_weight_bcast);
499-
K = info.add_instruction(make_op("dot"), input, k_weight_bcast);
500-
V = info.add_instruction(make_op("dot"), input, v_weight_bcast);
493+
auto Q = info.add_instruction(make_op("dot"), input, q_weight_bcast);
494+
auto K = info.add_instruction(make_op("dot"), input, k_weight_bcast);
495+
auto V = info.add_instruction(make_op("dot"), input, v_weight_bcast);
496+
497+
std::vector<instruction_ref>qkv_mats{Q, K, V};
498+
return qkv_mats;
501499
}
502500

503501
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
@@ -514,47 +512,48 @@ struct parse_attention : op_parser<parse_attention>
514512
auto input_data = inputs.at(0);
515513
auto weights = inputs.at(1);
516514

517-
instruction_ref q;
518-
instruction_ref k;
519-
instruction_ref v;
520515
instruction_ref input_bias;
521516
bool has_input_bias = false;
522-
input_linear_to_qkv(info, input_data, weights, parsed_attributes.qkv_hidden_sizes, input_bias, has_input_bias, q, k, v);
523517

524-
instruction_ref mask;
518+
// Apply linear stage to QKV mats from weight matrix
519+
auto qkv_mats = input_linear_to_qkv(info, input_data, weights, parsed_attributes.qkv_hidden_sizes, input_bias, has_input_bias);
520+
521+
instruction_ref attn_mask;
525522
bool has_mask = false;
526523
instruction_ref attn_bias;
527524
bool has_bias = false;
528525

529-
530-
instruction_ref present;
531-
532526
// Used to scale all key values before any masking or other inputs
533-
auto scale_factor = info.add_literal(migraphx::literal{migraphx::shape{k->get_shape().type()}, {std::sqrt(k->get_shape().elements()) } } );
527+
auto scale_factor = info.add_literal(migraphx::literal{migraphx::shape{qkv_mats.at(0)->get_shape().type()},
528+
{std::sqrt(infered_attributes.query_size)}});
534529

535530
instruction_ref output;
536531
//Get vector of attention heads and then concat the output results
537532
if(parsed_attributes.num_heads > 1)
538533
{
539-
std::vector<instruction_ref> vec_of_attn_outs(parsed_attributes.num_heads);
540-
std::transform(vec_of_attn_outs.begin(),
541-
vec_of_attn_outs.end(),
542-
vec_of_attn_outs.begin(),
543-
[&](auto&&) {
544-
return scale_dot_attention_head(info, q, k, v, scale_factor, mask, attn_bias, has_mask, has_bias);
545-
});
534+
// Apply multi head splitting of qkv matrix prior to calculation
535+
auto split_qkv = qkv_split_per_head(info, qkv_mats, parsed_attributes, infered_attributes);
536+
537+
std::vector<instruction_ref> vec_of_attn_outs;
538+
std::transform(split_qkv.cbegin(),
539+
split_qkv.cend(),
540+
std::back_inserter(vec_of_attn_outs),
541+
[&](auto && split_inputs) {
542+
return scale_dot_attention_head(info, split_inputs, scale_factor, attn_mask, attn_bias, has_mask, has_bias);
543+
});
546544
output = info.add_instruction(make_op("concat"), vec_of_attn_outs);
547545
}
548546
else
549547
{
550-
output = scale_dot_attention_head(info, q, k, v, scale_factor, mask, attn_bias, has_mask, has_bias);
548+
output = scale_dot_attention_head(info, qkv_mats, scale_factor, attn_mask, attn_bias, has_mask, has_bias);
551549
}
552550

553551
output->debug_print();
554552

555553
std::vector<instruction_ref> output_vec{};
556554
output_vec.push_back(output);
557555

556+
instruction_ref present;
558557
if(parsed_attributes.past_present_share_buffer)
559558
{
560559
present = output;

0 commit comments

Comments
 (0)