Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: Comparing performance of mat-vec multiplication #27018

Open
3 tasks done
michalsustr opened this issue Oct 11, 2024 · 2 comments
Open
3 tasks done

[Performance]: Comparing performance of mat-vec multiplication #27018

michalsustr opened this issue Oct 11, 2024 · 2 comments
Assignees
Labels
category: CPU OpenVINO CPU plugin performance Performance related topics support_request

Comments

@michalsustr
Copy link

michalsustr commented Oct 11, 2024

Versions

openvino                               2024.4.0
torch                                  2.4.1
nncf                                   2.13.0
g++ (conda-forge gcc 14.1.0-1) 14.1.0

Operating System

Ubuntu 18.04 (LTS)

Device used for inference

CPU

OpenVINO installation

PyPi

Programming Language

Python

Hardware Architecture

x86 (64 bits)

Model used

Matrix-vector multiplication

Model quantization

Yes

Target Platform

$ lscpu
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                32
On-line CPU(s) list:   0-31
Thread(s) per core:    1
Core(s) per socket:    32
Socket(s):             1
NUMA node(s):          1
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 207
Model name:            INTEL(R) XEON(R) PLATINUM 8562Y+
Stepping:              2
CPU MHz:               3800.000
CPU max MHz:           2801.0000
CPU min MHz:           800.0000
BogoMIPS:              5600.00
Virtualization:        VT-x
L1d cache:             48K
L1i cache:             32K
L2 cache:              2048K
L3 cache:              61440K
NUMA node0 CPU(s):     0-31
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid md_clear pconfig flush_l1d arch_capabilities

Performance issue description

Hello,

I am exploring the landscape of CPU inference, specifically for latency sensitive applications, and benchmarking various implementations.

To test this, I did the simplest benchmark: matrix-vector multiplication (with dim 256).

I would like to understand why OpenVINO is slow for this benchmark -- what am I doing wrong?

I ran following files, each with taskset -c 0 <file>:

(A) Main benchmark with OpenVINO: benchmark_matmul_module.py

256,torch vanilla,5.992013029754162
256,torch jit,7.668975740671158
256,torch export,33.001066185534
256,openvino,20.659063011407852
256,openvino q-onnx,27.11499109864235
256,openvino q-nncf,21.42402809113264

(B) Pytorch benchmark with various floating point types: benchmark_mat_vec_torch.py

Mat-vec multiplication for size 256x256, torch.uint8 min time: 18.816092 us
Mat-vec multiplication for size 256x256, torch.int8 min time: 19.229017 us
Mat-vec multiplication for size 256x256, torch.float32 min time: 5.059992 us
Mat-vec multiplication for size 256x256, torch.float16 min time: 24.214038 us
Mat-vec multiplication for size 256x256, torch.bfloat16 min time: 20.070001 us

(C) Manual C++ (courtesy of chatGPT, it would take me weeks to code this myself.)

matvecmul_vanilla.cpp : 29us
matvecmul_avx512_fp16.cpp: 2.51us or 3.68 us
matvecmul_avx512_fp32.cpp: **1.58us**

From the results it looks that openvino has similar performance to the vanilla for loop in C++.
It is 10x slower than the chatgpt implementation.
Interestingly, pytorch is lagging behind chatgpt implementation as well.
I also wanted to try AMX tile instructions, but I couldn't get my hands on a compiler that can handle them, yet.

Step-by-step reproduction

# File: benchmark_matmul_module.py

import os
import time
from collections import defaultdict
from typing import Callable

import nncf
import numpy as np
import openvino
import openvino.properties.hint as hint
import torch
import torch.nn as nn
from onnxruntime import quantization, InferenceSession
from onnxruntime.quantization import QuantFormat
from openvino.runtime import Core
from openvino.runtime.utils.types import get_dtype


def model_to_onnx_file(file_name: str, model, example_inputs):
    original_mode = model.training
    model.eval()
    try:
        with torch.no_grad():
            torch.onnx.export(model, example_inputs, file_name)
    finally:
        model.train(original_mode)


class MatMulModule(nn.Module):
    def __init__(self, h):
        self.h = h
        super(MatMulModule, self).__init__()
        self.a = torch.randn(h, h).to(DTYPE)

    def forward(self, x):
        return torch.matmul(self.a, x)

    def input_data(self):
        return tuple([
            torch.rand([self.h, 1], dtype=DTYPE).requires_grad_(False),  # x
        ])

def get_matmul_network(h):
    return MatMulModule(h).to(dtype=DTYPE).requires_grad_(False).eval()


def make_sample_onnx_inputs(model: str | bytes):
    ort_sess = InferenceSession(model, providers=["CPUExecutionProvider"])
    return {
        node.name: np.random.rand(*node.shape).astype(np.float32)
        for node in ort_sess.get_inputs()
    }

def make_onnx_file(model, input_data, quantize=True, out_file="x.onnx"):
    model_to_onnx_file(out_file, model, input_data)

    if quantize:
        input_data = make_sample_onnx_inputs(out_file)
        quantization.quantize_static(out_file, out_file,
                                     calibration_data_reader=DummyDataReader(input_data),
                                     quant_format=QuantFormat.QDQ
                                     )

class DummyDataReader(quantization.CalibrationDataReader):
    def __init__(self, input_data):
        self.input_data = iter([input_data])
        self.done = False

    def get_next(self):
        try:
            return next(self.input_data)
        except StopIteration:
            return  None


def make_openvino_model(model_file: str, cpu_isa="AVX10_1_512_AMX_FP16", nncf_quantization=False) -> Callable[[list[np.ndarray]], np.ndarray]:
    """Exports a PyTorch model to ONNX and loads it into the C++ runtime model."""
    os.environ.update({
        # https://oneapi-src.github.io/oneDNN/dev_guide_cpu_dispatcher_control.html
        "ONEDNN_MAX_CPU_ISA": cpu_isa,
    })

    # Initialize OpenVINO's Inference Engine
    core = Core()
    core.set_property("CPU", {
        hint.execution_mode: hint.ExecutionMode.PERFORMANCE,
        hint.performance_mode: hint.PerformanceMode.LATENCY,
        hint.inference_precision: openvino.Type.bf16,
        "NUM_STREAMS": 1,  # Number of parallel inference streams
        "AFFINITY": "CORE",
        hint.enable_cpu_pinning: True,
        hint.enable_hyper_threading: False,
        hint.num_requests: 1,
        hint.scheduling_core_type: hint.SchedulingCoreType.PCORE_ONLY,

        "INFERENCE_NUM_THREADS": 1,  # Set the number of threads

        # -- LLM stuff
        # hint.dynamic_quantization_group_size:
        # hint.kv_cache_precision

        # -- Not supported by CPU
        # hint.model_distribution_policy: hint.ModelDistributionPolicy.TENSOR_PARALLEL,
        # hint.allow_auto_batching: 0,
        # hint.model_priority: hint.Priority.HIGH,

    })
    # TODO: latency optimizations https://docs.openvino.ai/2024/openvino-workflow/running-inference/optimize-inference/optimizing-latency/model-caching-overview.html

    model = core.read_model(model=model_file)
    if nncf_quantization:
        model = apply_nncf_quantization(model, input_data)
    return core.compile_model(model=model, device_name="CPU")


def apply_nncf_quantization(model, input_data):
    calibration_dataset = nncf.Dataset([input_data])
    return nncf.quantize(
        model, calibration_dataset,
        target_device=nncf.TargetDevice.CPU,
    )

def fill_tensor_random(tensor):
    dtype = get_dtype(tensor.element_type)
    rand_min, rand_max = (0, 1) if dtype == bool else (np.iinfo(np.uint8).min, np.iinfo(np.uint8).max)
    # np.random.uniform excludes high: add 1 to have it generated
    if np.dtype(dtype).kind in ['i', 'u', 'b']:
        rand_max += 1
    rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(0)))
    if 0 == tensor.get_size():
        raise RuntimeError("Models with dynamic shapes aren't supported. Input tensors must have specific shapes before inference")
    tensor.data[:] = rs.uniform(rand_min, rand_max, list(tensor.shape)).astype(dtype)


# Based on https://docs.openvino.ai/2024/learn-openvino/openvino-samples/sync-benchmark.html
def benchmark_openvino_sync(compiled_model):
    ireq = compiled_model.create_infer_request()
    # Fill input data for the ireq
    for model_input in compiled_model.inputs:
        fill_tensor_random(ireq.get_tensor(model_input))
    # Warm up
    ireq.infer()
    # Benchmark for seconds_to_run seconds and at least niter iterations
    seconds_to_run = 10
    niter = 10
    latencies = []
    start = time.perf_counter()
    time_point = start
    time_point_to_finish = start + seconds_to_run
    while time_point < time_point_to_finish or len(latencies) < niter:
        ireq.infer()
        iter_end = time.perf_counter()
        latencies.append((iter_end - time_point) * 1e6)
        time_point = iter_end
    # print_benchmark_results(latencies, top=len(latencies))
    return min(latencies)


def benchmark_model(model, inputs, iterations=10000):
    """Benchmark the model with the specified iterations, groups, and warmup."""
    durations = []
    with torch.no_grad():
        for _ in range(iterations):
            start_time = time.perf_counter()
            model(*inputs)
            end_time = time.perf_counter()
            durations.append((end_time - start_time) * 1e6)
        return min(durations)

def make_runner_model(runner, network, input_data):
    if runner == "torch vanilla":
        return network
    if runner == "torch jit":
        return torch.jit.script(network)
    if runner == "torch export":
        return torch.export.export(network, input_data).module()
    if runner == "openvino":
        make_onnx_file(network, input_data, out_file=model_file, quantize=False)
        return make_openvino_model(model_file=model_file, cpu_isa=cpu_isa)
    if runner == "openvino q-onnx":
        make_onnx_file(network, input_data, out_file=model_file, quantize=True)
        return make_openvino_model(model_file=model_file, cpu_isa=cpu_isa)
    if runner == "openvino q-nncf":
        make_onnx_file(network, input_data, out_file=model_file, quantize=False)
        return make_openvino_model(model_file=model_file, cpu_isa=cpu_isa, nncf_quantization=True)

def benchmark_runner(runner, network, input_data):
    if "openvino" in runner:
        return benchmark_openvino_sync(network)
    else:
        return benchmark_model(network, input_data)

if __name__ == "__main__":
    # ARGS
    DTYPE = torch.float32
    cpu_isa = "AVX10_1_512_AMX_FP16"  # Highest possible ISA
    model_file = "x.onnx"
    results = defaultdict(dict)
    runners = ["torch vanilla",
               "torch jit",
               "torch export",
               "openvino",
               "openvino q-onnx",
               "openvino q-nncf",
               ]
    hs = [256]

    for h in hs:
        network = get_matmul_network(h)
        input_data = network.input_data()
        for runner in runners:
            runner_model = make_runner_model(runner, network, input_data)
            results[(h,)][runner] = benchmark_runner(runner, runner_model, input_data)

    print("----")
    for params, tech_results in results.items():
        for name, value in tech_results.items():
            vals = [str(k) for k in params] + [str(name)] + [str(value)]
            print(",".join(vals))


    print("----")
    for params, tech_results in results.items():
        min_v = 1e10
        min_k = None
        for name, value in tech_results.items():
            if value is None:
                continue
            if value < min_v:
                min_v = value
                min_k = name
        vals = [str(k) for k in params] + [str(min_k)] + [str(min_v)]
        print(",".join(vals))
    
# File: benchmark_mat_vec_torch.py

import torch
import time

def benchmark_min_time_torch(matrix_size, dtype=torch.float32, num_of_runs=100):
    """
    Benchmark the minimum time to perform matrix multiplication.

    Args:
        matrix_size (int): Size of the square matrices.
        dtype (torch.dtype): Data type of the matrices.
        num_of_runs (int): Number of runs to perform.

    Returns:
        float: The minimum time taken (in seconds) to perform the matrix multiplication.
    """
    # Create random matrices on CPU
    a = torch.randn(matrix_size, matrix_size, device="cpu").to(dtype=dtype)
    b = torch.randn(matrix_size, 1, device="cpu").to(dtype=dtype)

    min_time = float('inf')  # Initialize minimum time to infinity

    with torch.no_grad():
        for _ in range(num_of_runs):
            start_time = time.perf_counter()  # Record start time
            torch.mm(a, b)            # Perform matrix multiplication
            end_time = time.perf_counter()    # Record end time

            run_time = end_time - start_time  # Calculate elapsed time
            if run_time < min_time:
                min_time = run_time           # Update minimum time if current run is faster
    return min_time * 1e6

if __name__ == "__main__":
    matrix_size = 256
    num_of_runs = 100

    for dtype in [torch.uint8, torch.int8, torch.float32, torch.float16, torch.bfloat16]:
        min_time = benchmark_min_time_torch(matrix_size=matrix_size, dtype=dtype, num_of_runs=num_of_runs)
        print(f"Mat-vec multiplication for size {matrix_size}x{matrix_size}, {dtype} min time: {min_time:.6f} us")
// File: matvecmul_vanilla.cpp
#include <iostream>
#include <chrono>
#include <iomanip> // For std::setw

// Constants for matrix and vector sizes
const size_t SIZE = 256;
const size_t NUM_RUNS = 10000;

// Function to perform matrix-vector multiplication
void multiply(const double matrix[SIZE][SIZE], const double vec[SIZE], double result[SIZE]) {
    for (size_t i = 0; i < SIZE; ++i) {
        double sum = 0.0;
        for (size_t j = 0; j < SIZE; ++j) {
            sum += matrix[i][j] * vec[j];
        }
        result[i] = sum;
    }
}

int main() {
    // Initialize the matrix with matrix[i][j] = i + j
    double matrix[SIZE][SIZE];
    for (size_t i = 0; i < SIZE; ++i) {
        for (size_t j = 0; j < SIZE; ++j) {
            matrix[i][j] = static_cast<double>(i + j);
        }
    }

    // Initialize the vector with vec[j] = j
    double vec[SIZE];
    for (size_t j = 0; j < SIZE; ++j) {
        vec[j] = static_cast<double>(j);
    }

    // Array to store the result
    double result[SIZE] = {0.0};

    // Variable to store the minimum duration (initialized to maximum possible)
    std::chrono::microseconds min_duration = std::chrono::microseconds::max();

    // Perform the multiplication NUM_RUNS times
    for (size_t run = 0; run < NUM_RUNS; ++run) {
        // Record the start time
        auto start = std::chrono::high_resolution_clock::now();

        // Perform matrix-vector multiplication
        multiply(matrix, vec, result);

        // Record the end time
        auto end = std::chrono::high_resolution_clock::now();

        // Calculate the duration of this multiplication
        auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);

        // Update the minimum duration if this run was faster
        if (duration < min_duration) {
            min_duration = duration;
        }
    }

    // Output the minimum duration observed
    std::cout << "Minimum time for matrix-vector multiplication over " << NUM_RUNS << " runs: "
              << min_duration.count() << " microseconds.\n";

    return 0;
}
// File: matvecmul_avx512_fp32.cpp

#include <immintrin.h>
#include <chrono>
#include <iostream>
#include <vector>
#include <random>
#include <algorithm>
#include <cstring>  // For memset

// Define the size of the matrix and vector
constexpr int N = 256;

// Function to perform AVX-512 optimized matrix-vector multiplication
// Fastest iteration time: 1.578 microseconds
void matvec_avx512(const float* matrix, const float* vector, float* result) {
    for (int i = 0; i < N; ++i) {
        // Initialize accumulator to zero
        __m512 acc = _mm512_setzero_ps();

        // Process 16 elements at a time
        for (int j = 0; j < N; j += 16) {
            // Load 16 floats from the matrix row
            __m512 mat_val = _mm512_load_ps(&matrix[i * N + j]);

            // Load 16 floats from the vector
            __m512 vec_val = _mm512_load_ps(&vector[j]);

            // Multiply and accumulate
            acc = _mm512_fmadd_ps(mat_val, vec_val, acc);
        }

        // Horizontally add the 16 floats in the accumulator to get the final dot product
        float acc_sum = _mm512_reduce_add_ps(acc);

        // Store the result
        result[i] = acc_sum;
    }
}


// Compile with
// g++ -O3 -mavx512f -std=c++17 matvecmul_avx512_fp32.cpp -o matvecmul_avx512_fp32
int main() {
    // Allocate aligned memory for matrix, vector, and result
    // AVX-512 requires 64-byte alignment
    float* matrix;
    float* vector;
    float* result;

    // Using aligned_alloc (C++17). Alternatively, you can use _mm_malloc or posix_memalign
    matrix = static_cast<float*>(aligned_alloc(64, N * N * sizeof(float)));
    vector = static_cast<float*>(aligned_alloc(64, N * sizeof(float)));
    result = static_cast<float*>(aligned_alloc(64, N * sizeof(float)));

    if (!matrix || !vector || !result) {
        std::cerr << "Memory allocation failed!" << std::endl;
        return 1;
    }

    // Initialize the matrix and vector with random floats
    std::mt19937 rng(42); // Fixed seed for reproducibility
    std::uniform_real_distribution<float> dist(-1.0f, 1.0f);

    for (int i = 0; i < N * N; ++i) {
        matrix[i] = dist(rng);
    }

    for (int i = 0; i < N; ++i) {
        vector[i] = dist(rng);
    }

    // Number of benchmarking iterations
    const int iterations = 1000;

    // Variable to store the fastest time
    double fastest_time = 1e12;

    // Variable to prevent compiler optimizations
    volatile float checksum = 0.0f;

    for (int iter = 0; iter < iterations; ++iter) {
        // Optional: Clear the result array
        // memset(result, 0, N * sizeof(float));

        // Start timer
        auto start = std::chrono::high_resolution_clock::now();

        // Perform matrix-vector multiplication
        matvec_avx512(matrix, vector, result);

        // End timer
        auto end = std::chrono::high_resolution_clock::now();

        // Calculate elapsed time in microseconds
        std::chrono::duration<double, std::micro> elapsed = end - start;

        // Update the fastest time
        if (elapsed.count() < fastest_time) {
            fastest_time = elapsed.count();
        }

        // Accumulate the result to prevent compiler from optimizing away the computation
        for (int i = 0; i < N; ++i) {
            checksum += result[i];
        }
    }

    // Print the fastest time
    std::cout << "Fastest iteration time: " << fastest_time << " microseconds" << std::endl;

    // Print checksum (optional, to verify correctness)
    std::cout << "Checksum: " << checksum << std::endl;

    // Free allocated memory
    free(matrix);
    free(vector);
    free(result);

    return 0;
}
// File: matvecmul_avx512_fp16.cpp

#include <immintrin.h>
#include <chrono>
#include <iostream>
#include <vector>
#include <random>
#include <algorithm>
#include <cstring>   // For memset
#include <cstdint>   // For uint16_t
#include <cassert>

// Define the size of the matrix and vector
constexpr int N = 256;

// Helper function to convert float32 to float16 (returns uint16_t representation)
uint16_t float_to_float16(float value) {
    uint32_t bits;
    std::memcpy(&bits, &value, sizeof(bits));

    uint16_t sign = (bits >> 16) & 0x8000;
    int16_t exponent = ((bits >> 23) & 0xFF) - 112; // 127 - 15
    uint16_t mantissa = (bits >> 13) & 0x3FF;

    if (exponent <= 0) {
        // Subnormal or zero
        if (exponent < -10) {
            // Too small, becomes zero
            exponent = 0;
            mantissa = 0;
        } else {
            // Subnormal
            mantissa = (bits & 0x7FFFFF) | 0x800000;
            mantissa = mantissa >> (1 - exponent);
            exponent = 0;
        }
    } else if (exponent >= 31) {
        // Overflow, set to infinity
        exponent = 31;
        mantissa = 0;
    }

    return sign | (exponent << 10) | mantissa;
}

// Helper function to convert float16 (uint16_t) to float32
float float16_to_float(uint16_t h) {
    uint16_t sign = (h & 0x8000) >> 15;
    uint16_t exponent = (h & 0x7C00) >> 10;
    uint16_t mantissa = h & 0x03FF;

    uint32_t bits = 0;
    if (exponent == 0) {
        if (mantissa == 0) {
            // Zero
            bits = sign << 31;
        } else {
            // Subnormal
            // Normalize the number
            int e = -14;
            while ((mantissa & 0x0400) == 0) {
                mantissa <<= 1;
                e--;
            }
            mantissa &= 0x03FF;
            exponent = e + 127;
            bits = (sign << 31) | (exponent << 23) | (mantissa << 13);
        }
    } else if (exponent == 31) {
        // Inf or NaN
        bits = (sign << 31) | (0xFF << 23) | (mantissa << 13);
    } else {
        // Normalized
        bits = (sign << 31) | ((exponent + 112) << 23) | (mantissa << 13);
    }

    float f;
    std::memcpy(&f, &bits, sizeof(f));
    return f;
}

// Function to perform AVX-512 optimized matrix-vector multiplication with float16 data
void matvec_avx512_fp16(const uint16_t* matrix, const uint16_t* vector, uint16_t* result) {
    for (int i = 0; i < N; ++i) {
        __m512 acc = _mm512_setzero_ps(); // Initialize accumulator to zero

        // Fastest iteration time: 2.513 microseconds
        for (int j = 0; j < N; j += 16) { // Process 16 float16 elements at a time
            // Load 16 float16 elements from matrix row
            __m256i mat_half = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(&matrix[i * N + j]));
            __m512 mat_val = _mm512_cvtph_ps(mat_half); // Convert to float32

            // Load 16 float16 elements from vector
            __m256i vec_half = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(&vector[j]));
            __m512 vec_val = _mm512_cvtph_ps(vec_half); // Convert to float32

            // Multiply and accumulate
            acc = _mm512_fmadd_ps(mat_val, vec_val, acc);
        }

// Alternative for loop.
// Fastest iteration time: 3.68 microseconds
//        for (int j = 0; j < N; j += 32) { // Process 32 float16 elements at a time
//            // Load first 16 float16 elements from matrix row
//            __m256i mat_half1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(&matrix[i * N + j]));
//            __m512 mat_val1 = _mm512_cvtph_ps(mat_half1); // Convert to float32
//
//            // Load first 16 float16 elements from vector
//            __m256i vec_half1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(&vector[j]));
//            __m512 vec_val1 = _mm512_cvtph_ps(vec_half1); // Convert to float32
//
//            // Multiply and accumulate
//            acc = _mm512_fmadd_ps(mat_val1, vec_val1, acc);
//
//            // Load next 16 float16 elements from matrix row
//            __m256i mat_half2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(&matrix[i * N + j + 16]));
//            __m512 mat_val2 = _mm512_cvtph_ps(mat_half2); // Convert to float32
//
//            // Load next 16 float16 elements from vector
//            __m256i vec_half2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(&vector[j + 16]));
//            __m512 vec_val2 = _mm512_cvtph_ps(vec_half2); // Convert to float32
//
//            // Multiply and accumulate
//            acc = _mm512_fmadd_ps(mat_val2, vec_val2, acc);
//        }

        // Horizontally add the 16 floats in the accumulator to get the final dot product
        float acc_sum = _mm512_reduce_add_ps(acc);

        // Convert the float32 sum back to float16
        __m512 h_sum = _mm512_set1_ps(acc_sum);
        __m256i h_result = _mm512_cvtps_ph(h_sum, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);

        // Store the first float16 result
        result[i] = static_cast<uint16_t>(_mm256_extract_epi16(h_result, 0));
    }
}


// Compile with:
// g++ -O3 -mavx512f -mavx512fp16 -std=c++17 matvecmul_avx512_fp16.cpp -o matvecmul_avx512_fp16
int main() {
    // Allocate aligned memory for matrix, vector, and result
    // AVX-512 requires 64-byte alignment
    uint16_t* matrix;
    uint16_t* vector;
    uint16_t* result;

    // Using aligned_alloc (C++17). Ensure that the size is a multiple of alignment (64 bytes)
    matrix = static_cast<uint16_t*>(aligned_alloc(64, N * N * sizeof(uint16_t)));
    vector = static_cast<uint16_t*>(aligned_alloc(64, N * sizeof(uint16_t)));
    result = static_cast<uint16_t*>(aligned_alloc(64, N * sizeof(uint16_t)));

    if (!matrix || !vector || !result) {
        std::cerr << "Memory allocation failed!" << std::endl;
        return 1;
    }

    // Initialize the matrix and vector with random float16 values
    std::mt19937 rng(42); // Fixed seed for reproducibility
    std::uniform_real_distribution<float> dist(-1.0f, 1.0f);

    // Temporary float32 buffers for conversion
    std::vector<float> temp_matrix(N * N);
    std::vector<float> temp_vector(N);

    // Generate random float32 data
    for (int i = 0; i < N * N; ++i) {
        temp_matrix[i] = dist(rng);
    }

    for (int i = 0; i < N; ++i) {
        temp_vector[i] = dist(rng);
    }

    // Convert float32 data to float16 (uint16_t)
    for (int i = 0; i < N * N; ++i) {
        matrix[i] = float_to_float16(temp_matrix[i]);
    }

    for (int i = 0; i < N; ++i) {
        vector[i] = float_to_float16(temp_vector[i]);
    }

    // Number of benchmarking iterations
    const int iterations = 1000;

    // Variable to store the fastest time
    double fastest_time = 1e12;

    // Variable to prevent compiler optimizations
    volatile float checksum = 0.0f;

    for (int iter = 0; iter < iterations; ++iter) {
        // Start timer
        auto start = std::chrono::high_resolution_clock::now();

        // Perform matrix-vector multiplication
        matvec_avx512_fp16(matrix, vector, result);

        // End timer
        auto end = std::chrono::high_resolution_clock::now();

        // Calculate elapsed time in microseconds
        std::chrono::duration<double, std::micro> elapsed = end - start;

        // Update the fastest time
        if (elapsed.count() < fastest_time) {
            fastest_time = elapsed.count();
        }

        // Accumulate the result to prevent compiler from optimizing away the computation
        for (int i = 0; i < N; ++i) {
            checksum += float16_to_float(result[i]);
        }
    }

    // Print the fastest time
    std::cout << "Fastest iteration time: " << fastest_time << " microseconds" << std::endl;

    // Print checksum (optional, to verify correctness)
    std::cout << "Checksum: " << checksum << std::endl;

    // Free allocated memory
    free(matrix);
    free(vector);
    free(result);

    return 0;
}

Issue submission checklist

  • I'm reporting a performance issue. It's not a question.
  • I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
  • There is reproducer code and related data files such as images, videos, models, etc.
@michalsustr michalsustr added performance Performance related topics support_request labels Oct 11, 2024
@michalsustr michalsustr changed the title [Performance]: [Performance]: Comparing performance of mat-vec multiplication Oct 11, 2024
@rkazants rkazants added the category: CPU OpenVINO CPU plugin label Oct 11, 2024
@rkazants
Copy link
Contributor

@dmitry-gorokhov, @mg-intel, please take a look.

@wenjiew
Copy link

wenjiew commented Oct 18, 2024

@usstq Maybe you can take a quick look? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: CPU OpenVINO CPU plugin performance Performance related topics support_request
Projects
None yet
Development

No branches or pull requests

6 participants