Skip to content

Very poor gemmt performance compared to gemm and syrk #4921

Open
@david-cortes

Description

@david-cortes

I'm running some timings on operation t(X)*X on row-major matrices having many more rows than columns.

I'm finding that for these types of inputs, function gemmt is much slower than the equivalent from syrk or gemm, with a very wide margin.

Timings in milliseconds for input size 1,000,000 x 32, intel i12700H, average of 3 runs:

  • gemmt: 216.178
  • syrk: 41.0468
  • gemm: 39.55553

Version: OpenBLAS 0.3.28, built with OpenMP, compiled from source (gcc with cmake system). Same issue happen with pthreads, and same timing difference is observed when running single-threaded.

For reference, timings for other libraries:

  • MKL gemmt: 25.66533
  • MKL syrk: 12.57197
  • MKL gemm: 15.69447
  • tabmat's "sandwich" op: 29.3

Code to reproduce:

#include <iostream>
#include <chrono>
#include <random>
#include <memory>


#include <cblas.h>
extern "C" void cblas_dgemmt(const CBLAS_LAYOUT Layout, const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc);

using std::chrono::high_resolution_clock;
using std::chrono::duration_cast;
using std::chrono::duration;
using std::chrono::milliseconds;

int main()
{
    const size_t nrows = 1'000'000;
    const size_t ncols = 32;
    const size_t tot = nrows * ncols;

    std::mt19937 rng{123};
    std::normal_distribution norm_distr{0.0, 1.0};

    std::unique_ptr<double[]> X(new double[tot]);
    std::unique_ptr<double[]> out(new double[ncols*ncols]());
    for (size_t ix = 0; ix < tot; ix++) X[ix] = norm_distr(rng);

    auto t1 = high_resolution_clock::now();
    cblas_dgemmt(
        CblasRowMajor, CblasUpper, CblasTrans, CblasNoTrans,
        ncols, nrows,
        1., X.get(), ncols,
        X.get(), ncols,
        0., out.get(), ncols
    );
    auto t2 = high_resolution_clock::now();
    duration<double, std::milli> ms_double = t2 - t1;

    double sum_res = 0.;
    for (size_t ix = 0; ix < ncols*ncols; ix++) sum_res += out[ix];

    std::cout << "time gemmt:" << ms_double.count() << std::endl;
    std::cout << "sum gemmt:" << sum_res << std::endl;

    t1 = high_resolution_clock::now();
    cblas_dsyrk(
        CblasRowMajor, CblasUpper, CblasTrans,
        ncols, nrows,
        1., X.get(), ncols,
        0., out.get(), ncols
    );
    t2 = high_resolution_clock::now();
    ms_double = t2 - t1;
    sum_res = 0.;
    for (size_t ix = 0; ix < ncols*ncols; ix++) sum_res += out[ix];
    std::cout << "time syrk:" << ms_double.count() << std::endl;
    std::cout << "sum syrk:" << sum_res << std::endl;

    t1 = high_resolution_clock::now();
    cblas_dgemm(
        CblasRowMajor, CblasTrans, CblasNoTrans,
        ncols, ncols, nrows,
        1., X.get(), ncols,
        X.get(), ncols,
        0., out.get(), ncols
    );
    t2 = high_resolution_clock::now();
    ms_double = t2 - t1;
    sum_res = 0.;
    for (size_t ix = 0; ix < ncols*ncols; ix++) sum_res += out[ix];
    std::cout << "time gemm:" << ms_double.count() << std::endl;
    std::cout << "sum gemm:" << sum_res << std::endl;

    return 0;
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions