Description
This is a part of my torch IR with "torch.aten._scaled_dot_product_flash_attention_for_cpu":
%false_174 = torch.constant.bool false %none_175 = torch.constant.none %106:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%73, %95, %101, %float0.000000e00, %false_174, %105, %none_175) : (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.float, !torch.bool, !torch.vtensor<[1,1,6,208],f32>, !torch.none) -> (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,6],f32>)
when try to use iree-opt to convert the whole torch ir to linalg ir, error occurs:
(sd_shark_2) root@cltech218:/workspace/bailuan/official_iree/iree-build/tools# ./iree-opt --torch-to-iree test.mlir -o out.mlir test.mlir:307:14: error: 'tm_tensor.attention' op query and mask batch dimension mismatch %106:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%73, %95, %101, %float0.000000e00, %false_174, %105, %none_175) : (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.float, !torch.bool, !torch.vtensor<[1,1,6,208],f32>, !torch.none) -> (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,6],f32>) ^ test.mlir:307:14: note: see current operation: %424 = "tm_tensor.attention"(%405, %406, %407, %408, %423) <{operandSegmentSizes = array<i32: 4, 1>}> : (tensor<64x6x128xf32>, tensor<64x208x128xf32>, tensor<64x208x128xf32>, tensor<1x6x208xf32>, tensor<64x6x128xf32>) -> tensor<64x6x128xf32>
I noticed that, before lowering, torch.aten._scaled_dot_product_flash_attention_for_cpu's q-tensor shape is [2,32,6,128], mask-tensor shape is [1,1,6,208], but after lowering, tm_tensor.attention's q-tensor shape is [64x6x128], mask-tensor shape is [1x6x208], so it is where the error came from. is this expected?
As i know, we always try to flat qkv tensor by combining last 2 dims but not first 2 dims, is it right?