@@ -594,46 +594,46 @@ bool validate_flash_attention_args(
594
594
const Tensor& key,
595
595
const Tensor& value,
596
596
const optional<Tensor>& attn_mask) {
597
- ET_LOG_MSG_AND_RETURN_IF_FALSE (query.dim () == 4 , " query must be a 4D tensor" );
598
- ET_LOG_MSG_AND_RETURN_IF_FALSE (key.dim () == 4 , " key must be a 4D tensor" );
599
- ET_LOG_MSG_AND_RETURN_IF_FALSE (value.dim () == 4 , " value must be a 4D tensor" );
597
+ ET_LOG_MSG_AND_RETURN_UNLESS (query.dim () == 4 , " query must be a 4D tensor" );
598
+ ET_LOG_MSG_AND_RETURN_UNLESS (key.dim () == 4 , " key must be a 4D tensor" );
599
+ ET_LOG_MSG_AND_RETURN_UNLESS (value.dim () == 4 , " value must be a 4D tensor" );
600
600
601
601
// Sizes
602
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
602
+ ET_LOG_MSG_AND_RETURN_UNLESS (
603
603
(query.size (3 ) == value.size (3 )) && (key.size (3 ) == value.size (3 )),
604
604
" scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size" );
605
605
606
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
606
+ ET_LOG_MSG_AND_RETURN_UNLESS (
607
607
(query.scalar_type () == ScalarType::Float), " Query must be Float type" );
608
608
609
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
609
+ ET_LOG_MSG_AND_RETURN_UNLESS (
610
610
(query.scalar_type () == key.scalar_type ()) &&
611
611
(query.scalar_type () == value.scalar_type ()),
612
612
" Key and Value must have the same data type as Query" );
613
613
614
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
614
+ ET_LOG_MSG_AND_RETURN_UNLESS (
615
615
!attn_mask.has_value () || attn_mask.value ().dim () == 2 ,
616
616
" Attention mask must be a 2D tensor" );
617
617
618
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
618
+ ET_LOG_MSG_AND_RETURN_UNLESS (
619
619
!attn_mask.has_value () ||
620
620
attn_mask.value ().scalar_type () == query.scalar_type (),
621
621
" Attention mask must be a 2D tensor" );
622
622
623
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
623
+ ET_LOG_MSG_AND_RETURN_UNLESS (
624
624
is_contiguous_dim_order (query.dim_order ().data (), query.dim ()),
625
625
" key cache must be in contiguous dim order" );
626
626
627
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
627
+ ET_LOG_MSG_AND_RETURN_UNLESS (
628
628
is_contiguous_dim_order (key.dim_order ().data (), key.dim ()),
629
629
" value cache must be in contiguous dim order" );
630
630
631
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
631
+ ET_LOG_MSG_AND_RETURN_UNLESS (
632
632
is_contiguous_dim_order (value.dim_order ().data (), value.dim ()),
633
633
" value cache must be in contiguous dim order" );
634
634
635
635
if (attn_mask.has_value ()) {
636
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
636
+ ET_LOG_MSG_AND_RETURN_UNLESS (
637
637
is_contiguous_dim_order (
638
638
attn_mask.value ().dim_order ().data (), attn_mask.value ().dim ()),
639
639
" value cache must be in contiguous dim order" );
@@ -647,21 +647,21 @@ bool validate_cache_params(
647
647
const Tensor& v_cache,
648
648
int64_t start_pos,
649
649
int64_t seq_length) {
650
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
650
+ ET_LOG_MSG_AND_RETURN_UNLESS (
651
651
k_cache.dim () == 4 , " kcache must be a 4D tensor" );
652
652
653
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
653
+ ET_LOG_MSG_AND_RETURN_UNLESS (
654
654
v_cache.dim () == 4 , " v_cache must be a 4D tensor" );
655
655
656
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
656
+ ET_LOG_MSG_AND_RETURN_UNLESS (
657
657
start_pos < k_cache.size (1 ),
658
658
" start_pos must be less than key cache at dim 1" );
659
659
660
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
660
+ ET_LOG_MSG_AND_RETURN_UNLESS (
661
661
start_pos < v_cache.size (1 ),
662
662
" start_pos must be less than value cache at dim 1" );
663
663
664
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
664
+ ET_LOG_MSG_AND_RETURN_UNLESS (
665
665
(start_pos + seq_length) <= k_cache.size (1 ),
666
666
" start_post + seq_length must be less than max seq length supported by key cache."
667
667
" start pos: %" PRId64 " , seq_length: %" PRId64
@@ -671,7 +671,7 @@ bool validate_cache_params(
671
671
seq_length,
672
672
k_cache.size (1 ));
673
673
674
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
674
+ ET_LOG_MSG_AND_RETURN_UNLESS (
675
675
(start_pos + seq_length) <= v_cache.size (1 ),
676
676
" start_post + seq_length must be less than max seq length supported by key cache."
677
677
" start pos: %" PRId64 " , seq_length: %" PRId64
@@ -682,11 +682,11 @@ bool validate_cache_params(
682
682
v_cache.size (1 ));
683
683
684
684
// Make sure they are in contiguous dim order
685
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
685
+ ET_LOG_MSG_AND_RETURN_UNLESS (
686
686
is_contiguous_dim_order (k_cache.dim_order ().data (), k_cache.dim ()),
687
687
" key cache must be in contiguous dim order" );
688
688
689
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
689
+ ET_LOG_MSG_AND_RETURN_UNLESS (
690
690
is_contiguous_dim_order (v_cache.dim_order ().data (), v_cache.dim ()),
691
691
" value cache must be in contiguous dim order" );
692
692
0 commit comments