Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 43 additions & 13 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ struct test_result {
};

// Printer classes for different output formats
enum class test_status_t { NOT_SUPPORTED, OK, FAIL };
enum class test_status_t { NOT_SUPPORTED, OK, FAIL, SKIPPED };

struct test_operation_info {
std::string op_name;
Expand Down Expand Up @@ -687,6 +687,8 @@ struct printer {
virtual void print_backend_status(const backend_status_info & info) { (void) info; }

virtual void print_overall_summary(const overall_summary_info & info) { (void) info; }

virtual void print_failed_tests(const std::vector<std::string> & failed_tests) { (void) failed_tests; }
};

struct console_printer : public printer {
Expand Down Expand Up @@ -804,6 +806,17 @@ struct console_printer : public printer {
}
}

void print_failed_tests(const std::vector<std::string> & failed_tests) override {
if (failed_tests.empty()) {
return;
}

printf("\nFailing tests:\n");
for (const auto & test_name : failed_tests) {
printf(" %s\n", test_name.c_str());
}
}

private:
void print_test_console(const test_result & result) {
printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
Expand Down Expand Up @@ -1056,6 +1069,8 @@ struct test_case {

std::vector<ggml_tensor *> sentinels;

std::string current_op_name;

void add_sentinel(ggml_context * ctx) {
if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) {
return;
Expand Down Expand Up @@ -1127,7 +1142,10 @@ struct test_case {
}
}

bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) {
test_status_t eval(ggml_backend_t backend1,
ggml_backend_t backend2,
const char * op_names_filter,
printer * output_printer) {
mode = MODE_TEST;

ggml_init_params params = {
Expand All @@ -1144,11 +1162,12 @@ struct test_case {
add_sentinel(ctx);

ggml_tensor * out = build_graph(ctx);
std::string current_op_name = op_desc(out);
current_op_name = op_desc(out);

if (!matches_filter(out, op_names_filter)) {
//printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx);
return true;
return test_status_t::SKIPPED;
}

// check if the backends support the ops
Expand All @@ -1172,7 +1191,7 @@ struct test_case {
}

ggml_free(ctx);
return true;
return test_status_t::NOT_SUPPORTED;
}

// post-graph sentinel
Expand All @@ -1184,7 +1203,7 @@ struct test_case {
if (buf == NULL) {
printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
ggml_free(ctx);
return false;
return test_status_t::FAIL;
}

// build graph
Expand Down Expand Up @@ -1289,7 +1308,7 @@ struct test_case {
output_printer->print_test_result(result);
}

return test_passed;
return test_passed ? test_status_t::OK : test_status_t::FAIL;
}

bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
Expand All @@ -1306,7 +1325,7 @@ struct test_case {
GGML_ASSERT(ctx);

ggml_tensor * out = build_graph(ctx.get());
std::string current_op_name = op_desc(out);
current_op_name = op_desc(out);
if (!matches_filter(out, op_names_filter)) {
//printf(" %s: skipping\n", op_desc(out).c_str());
return true;
Expand Down Expand Up @@ -1435,8 +1454,9 @@ struct test_case {
ggml_context_ptr ctx(ggml_init(params)); // smart ptr
GGML_ASSERT(ctx);

ggml_tensor * out = build_graph(ctx.get());
std::string current_op_name = op_desc(out);
ggml_tensor * out = build_graph(ctx.get());
current_op_name = op_desc(out);

if (!matches_filter(out, op_names_filter)) {
return true;
}
Expand Down Expand Up @@ -7356,16 +7376,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}

size_t n_ok = 0;
size_t tests_run = 0;
std::vector<std::string> failed_tests;
for (auto & test : test_cases) {
if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) {
test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer);
if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
continue;
}
tests_run++;
if (status == test_status_t::OK) {
n_ok++;
} else if (status == test_status_t::FAIL) {
failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
}
}
output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false));
output_printer->print_summary(test_summary_info(n_ok, tests_run, false));
output_printer->print_failed_tests(failed_tests);

ggml_backend_free(backend_cpu);

return n_ok == test_cases.size();
return n_ok == tests_run;
}

if (mode == MODE_GRAD) {
Expand Down