You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a part of my torch IR with "torch.aten._scaled_dot_product_flash_attention_for_cpu": %false_174 = torch.constant.bool false %none_175 = torch.constant.none %106:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%73, %95, %101, %float0.000000e00, %false_174, %105, %none_175) : (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.float, !torch.bool, !torch.vtensor<[1,1,6,208],f32>, !torch.none) -> (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,6],f32>)
when try to use iree-opt to convert the whole torch ir to linalg ir, error occurs: (sd_shark_2) root@cltech218:/workspace/bailuan/official_iree/iree-build/tools# ./iree-opt --torch-to-iree test.mlir -o out.mlir test.mlir:307:14: error: 'tm_tensor.attention' op query and mask batch dimension mismatch %106:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%73, %95, %101, %float0.000000e00, %false_174, %105, %none_175) : (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.float, !torch.bool, !torch.vtensor<[1,1,6,208],f32>, !torch.none) -> (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,6],f32>) ^ test.mlir:307:14: note: see current operation: %424 = "tm_tensor.attention"(%405, %406, %407, %408, %423) <{operandSegmentSizes = array<i32: 4, 1>}> : (tensor<64x6x128xf32>, tensor<64x208x128xf32>, tensor<64x208x128xf32>, tensor<1x6x208xf32>, tensor<64x6x128xf32>) -> tensor<64x6x128xf32>
I noticed that, before lowering, torch.aten._scaled_dot_product_flash_attention_for_cpu's q-tensor shape is [2,32,6,128], mask-tensor shape is [1,1,6,208], but after lowering, tm_tensor.attention's q-tensor shape is [64x6x128], mask-tensor shape is [1x6x208], so it is where the error came from. is this expected?
As i know, we always try to flat qkv tensor by combining last 2 dims but not first 2 dims, is it right?
The text was updated successfully, but these errors were encountered:
This is a part of my torch IR with "torch.aten._scaled_dot_product_flash_attention_for_cpu":
%false_174 = torch.constant.bool false %none_175 = torch.constant.none %106:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%73, %95, %101, %float0.000000e00, %false_174, %105, %none_175) : (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.float, !torch.bool, !torch.vtensor<[1,1,6,208],f32>, !torch.none) -> (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,6],f32>)
when try to use iree-opt to convert the whole torch ir to linalg ir, error occurs:
(sd_shark_2) root@cltech218:/workspace/bailuan/official_iree/iree-build/tools# ./iree-opt --torch-to-iree test.mlir -o out.mlir test.mlir:307:14: error: 'tm_tensor.attention' op query and mask batch dimension mismatch %106:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%73, %95, %101, %float0.000000e00, %false_174, %105, %none_175) : (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.vtensor<[2,32,208,128],f32>, !torch.float, !torch.bool, !torch.vtensor<[1,1,6,208],f32>, !torch.none) -> (!torch.vtensor<[2,32,6,128],f32>, !torch.vtensor<[2,32,6],f32>) ^ test.mlir:307:14: note: see current operation: %424 = "tm_tensor.attention"(%405, %406, %407, %408, %423) <{operandSegmentSizes = array<i32: 4, 1>}> : (tensor<64x6x128xf32>, tensor<64x208x128xf32>, tensor<64x208x128xf32>, tensor<1x6x208xf32>, tensor<64x6x128xf32>) -> tensor<64x6x128xf32>
I noticed that, before lowering, torch.aten._scaled_dot_product_flash_attention_for_cpu's q-tensor shape is [2,32,6,128], mask-tensor shape is [1,1,6,208], but after lowering, tm_tensor.attention's q-tensor shape is [64x6x128], mask-tensor shape is [1x6x208], so it is where the error came from. is this expected?
As i know, we always try to flat qkv tensor by combining last 2 dims but not first 2 dims, is it right?
The text was updated successfully, but these errors were encountered: