Skip to content

Commit 0fec081

Browse files
CISCpwilkin
authored andcommitted
ggml : check cuda and metal argsort limits and add test (ggml-org#16323)
* check cuda argsort limits and add test * add metal check
1 parent 2ca531e commit 0fec081

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3647,9 +3647,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
36473647
case GGML_OP_CONV_TRANSPOSE_2D:
36483648
case GGML_OP_POOL_2D:
36493649
case GGML_OP_SUM:
3650-
case GGML_OP_ARGSORT:
36513650
case GGML_OP_ACC:
36523651
return true;
3652+
case GGML_OP_ARGSORT:
3653+
// TODO: Support arbitrary column width
3654+
return op->src[0]->ne[0] <= 1024;
36533655
case GGML_OP_SUM_ROWS:
36543656
case GGML_OP_MEAN:
36553657
case GGML_OP_GROUP_NORM:

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
683683
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
684684
case GGML_OP_PAD_REFLECT_1D:
685685
case GGML_OP_TIMESTEP_EMBEDDING:
686-
case GGML_OP_ARGSORT:
687686
case GGML_OP_LEAKY_RELU:
688687
return op->src[0]->type == GGML_TYPE_F32;
688+
case GGML_OP_ARGSORT:
689+
// TODO: Support arbitrary column width
690+
return op->src[0]->ne[0] <= 1024;
689691
case GGML_OP_ARANGE:
690692
return true;
691693
case GGML_OP_FLASH_ATTN_EXT:

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6719,6 +6719,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
67196719
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
67206720
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
67216721
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
6722+
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection)
67226723
}
67236724

67246725
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {

0 commit comments

Comments
 (0)