@@ -334,7 +334,6 @@ struct parse_attention : op_parser<parse_attention>
334
334
{
335
335
auto mask_index_shape = mask_index->get_shape ();
336
336
auto mask_index_lens = mask_index_shape.lens ();
337
- bool mask_index_is_trash = false ;
338
337
339
338
if (mask_index_shape.type () != migraphx::shape::int32_type)
340
339
{
@@ -413,16 +412,27 @@ struct parse_attention : op_parser<parse_attention>
413
412
return input_arguments;
414
413
}
415
414
415
+ static std::vector<std::vector<instruction_ref>> qkv_split_per_head (const onnx_parser::node_info& info,
416
+ const std::vector<instruction_ref>& qkv_mats,
417
+ const attention_attr& attr_in,
418
+ const attention_infered& infered_in)
419
+ {
420
+ std::vector<std::vector<instruction_ref>> qkv_split;
421
+ return qkv_split;
422
+ }
423
+
416
424
static instruction_ref scale_dot_attention_head (const onnx_parser::node_info& info,
417
- const instruction_ref& Q,
418
- const instruction_ref& K,
419
- const instruction_ref& V,
425
+ const std::vector<instruction_ref>& QKV,
420
426
const instruction_ref& scale_factor,
421
427
const instruction_ref& mask,
422
428
const instruction_ref& bias,
423
429
bool masked=false ,
424
430
bool attn_bias=false )
425
431
{
432
+ auto Q = QKV.at (1 );
433
+ auto K = QKV.at (2 );
434
+ auto V = QKV.at (3 );
435
+
426
436
auto k_trans = info.add_instruction (make_op (" transpose" , {{" permutation" , {0 , 2 , 1 }}}), K);
427
437
k_trans->debug_print ();
428
438
auto qk_out = info.add_instruction (make_op (" dot" ), Q, k_trans);
@@ -450,54 +460,42 @@ struct parse_attention : op_parser<parse_attention>
450
460
}
451
461
452
462
// Get Q, K, V matricies from stacked weight matrix
453
- static void input_linear_to_qkv (const onnx_parser::node_info& info,
454
- const instruction_ref& input,
455
- const instruction_ref& stacked_weights,
456
- const std::vector<size_t >& qkv_sizes,
457
- const instruction_ref& input_bias,
458
- const bool has_input_bias,
459
- instruction_ref& Q,
460
- instruction_ref& K,
461
- instruction_ref& V)
463
+ static std::vector<instruction_ref> input_linear_to_qkv (const onnx_parser::node_info& info,
464
+ const instruction_ref& input,
465
+ const instruction_ref& stacked_weights,
466
+ const std::vector<size_t >& qkv_sizes,
467
+ const instruction_ref& input_bias,
468
+ const bool has_input_bias)
462
469
{
463
470
// Input encodes the batch, sequence_length and input_hidden_size (also known as embedding size)
464
471
auto input_lens = input->get_shape ().lens ();
465
472
466
473
auto stacked_weights_unsq = info.add_instruction (make_op (" unsqueeze" , {{" axes" , {0 }}}), stacked_weights);
467
- stacked_weights_unsq->debug_print ();
468
474
469
475
// Input stacked weights are (input_hidden_size, hidden_size + hidden_size + v_hidden_size) so slice out parts for each matrix
470
476
// Since we known the input_hidden size is one dimension wee need to slice out the weight tensors accordingly before we perform matmul
471
477
auto q_weight = info.add_instruction (make_op (" slice" , {{" axes" ,{2 }}, {" starts" , {0 }}, {" ends" , {qkv_sizes.at (0 )-1 }}}), stacked_weights_unsq);
472
478
auto k_weight = info.add_instruction (make_op (" slice" , {{" axes" ,{2 }}, {" starts" , {qkv_sizes.at (0 )}}, {" ends" , {qkv_sizes.at (1 ) + qkv_sizes.at (0 ) - 1 }}}), stacked_weights_unsq);
473
479
auto v_weight = info.add_instruction (make_op (" slice" , {{" axes" ,{2 }}, {" starts" , {qkv_sizes.at (0 ) + qkv_sizes.at (1 )}}, {" ends" , {qkv_sizes.at (0 ) + qkv_sizes.at (1 ) + qkv_sizes.at (2 ) -1 }}}), stacked_weights_unsq);
474
480
475
- q_weight->debug_print ();
476
- k_weight->debug_print ();
477
- v_weight->debug_print ();
478
-
479
481
// Add in batch dimension to weights
480
482
auto qk_lens = q_weight->get_shape ().lens ();
481
483
qk_lens.at (0 ) = input_lens.at (0 );
482
484
auto v_lens = v_weight->get_shape ().lens ();
483
485
v_lens.at (0 ) = input_lens.at (0 );
484
486
485
- std::cout << qk_lens.at (0 ) << " ," << qk_lens.at (1 ) << " ," << qk_lens.at (2 ) << std::endl;
486
-
487
487
// Broadcast to batch size
488
488
auto q_weight_bcast = info.add_instruction (make_op (" multibroadcast" , {{" out_lens" , qk_lens}}), q_weight);
489
489
auto k_weight_bcast = info.add_instruction (make_op (" multibroadcast" , {{" out_lens" , qk_lens}}), k_weight);
490
490
auto v_weight_bcast = info.add_instruction (make_op (" multibroadcast" , {{" out_lens" , v_lens}}), v_weight);
491
491
492
- q_weight_bcast->debug_print ();
493
- k_weight_bcast->debug_print ();
494
- v_weight_bcast->debug_print ();
495
- input->debug_print ();
496
-
497
492
// Broadcast by batch then multiply
498
- Q = info.add_instruction (make_op (" dot" ), input, q_weight_bcast);
499
- K = info.add_instruction (make_op (" dot" ), input, k_weight_bcast);
500
- V = info.add_instruction (make_op (" dot" ), input, v_weight_bcast);
493
+ auto Q = info.add_instruction (make_op (" dot" ), input, q_weight_bcast);
494
+ auto K = info.add_instruction (make_op (" dot" ), input, k_weight_bcast);
495
+ auto V = info.add_instruction (make_op (" dot" ), input, v_weight_bcast);
496
+
497
+ std::vector<instruction_ref>qkv_mats{Q, K, V};
498
+ return qkv_mats;
501
499
}
502
500
503
501
std::vector<instruction_ref> parse (const op_desc& /* opd*/ ,
@@ -514,47 +512,48 @@ struct parse_attention : op_parser<parse_attention>
514
512
auto input_data = inputs.at (0 );
515
513
auto weights = inputs.at (1 );
516
514
517
- instruction_ref q;
518
- instruction_ref k;
519
- instruction_ref v;
520
515
instruction_ref input_bias;
521
516
bool has_input_bias = false ;
522
- input_linear_to_qkv (info, input_data, weights, parsed_attributes.qkv_hidden_sizes , input_bias, has_input_bias, q, k, v);
523
517
524
- instruction_ref mask;
518
+ // Apply linear stage to QKV mats from weight matrix
519
+ auto qkv_mats = input_linear_to_qkv (info, input_data, weights, parsed_attributes.qkv_hidden_sizes , input_bias, has_input_bias);
520
+
521
+ instruction_ref attn_mask;
525
522
bool has_mask = false ;
526
523
instruction_ref attn_bias;
527
524
bool has_bias = false ;
528
525
529
-
530
- instruction_ref present;
531
-
532
526
// Used to scale all key values before any masking or other inputs
533
- auto scale_factor = info.add_literal (migraphx::literal{migraphx::shape{k->get_shape ().type ()}, {std::sqrt (k->get_shape ().elements ()) } } );
527
+ auto scale_factor = info.add_literal (migraphx::literal{migraphx::shape{qkv_mats.at (0 )->get_shape ().type ()},
528
+ {std::sqrt (infered_attributes.query_size )}});
534
529
535
530
instruction_ref output;
536
531
// Get vector of attention heads and then concat the output results
537
532
if (parsed_attributes.num_heads > 1 )
538
533
{
539
- std::vector<instruction_ref> vec_of_attn_outs (parsed_attributes.num_heads );
540
- std::transform (vec_of_attn_outs.begin (),
541
- vec_of_attn_outs.end (),
542
- vec_of_attn_outs.begin (),
543
- [&](auto &&) {
544
- return scale_dot_attention_head (info, q, k, v, scale_factor, mask, attn_bias, has_mask, has_bias);
545
- });
534
+ // Apply multi head splitting of qkv matrix prior to calculation
535
+ auto split_qkv = qkv_split_per_head (info, qkv_mats, parsed_attributes, infered_attributes);
536
+
537
+ std::vector<instruction_ref> vec_of_attn_outs;
538
+ std::transform (split_qkv.cbegin (),
539
+ split_qkv.cend (),
540
+ std::back_inserter (vec_of_attn_outs),
541
+ [&](auto && split_inputs) {
542
+ return scale_dot_attention_head (info, split_inputs, scale_factor, attn_mask, attn_bias, has_mask, has_bias);
543
+ });
546
544
output = info.add_instruction (make_op (" concat" ), vec_of_attn_outs);
547
545
}
548
546
else
549
547
{
550
- output = scale_dot_attention_head (info, q, k, v, scale_factor, mask , attn_bias, has_mask, has_bias);
548
+ output = scale_dot_attention_head (info, qkv_mats, scale_factor, attn_mask , attn_bias, has_mask, has_bias);
551
549
}
552
550
553
551
output->debug_print ();
554
552
555
553
std::vector<instruction_ref> output_vec{};
556
554
output_vec.push_back (output);
557
555
556
+ instruction_ref present;
558
557
if (parsed_attributes.past_present_share_buffer )
559
558
{
560
559
present = output;
0 commit comments