Skip to content

[Bug]: Matmul error on GPU plugin with f32 precision #28881

Open
@AntoineLejeuneEuresys

Description

OpenVINO Version

2024.4.0-16579-c3152d32c9c-releases/2024/4 (pip)

Operating System

Windows System

Device used for inference

GPU

Framework

None

Model used

No response

Issue description

For some specific shapes, the MatMul operator produces very different result compared to PyTorch. This difference only occurs with f32 precision.

Step-by-step reproduction

import numpy as np
import torch
import torch.nn as nn
import openvino as ov
import openvino.properties.hint as hints
from openvino.runtime import Core

print('PyTorch version', torch.__version__)
print('OV version', ov.runtime.__version__)

torch.manual_seed(123)
np.random.seed(141)

class Model(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.weights = torch.tensor(np.random.standard_normal((16, 24))).float()

    def forward(self, x):
        return torch.matmul(x, self.weights)


m = Model()
for batch_size in [1, 2, 4, 8, 16]:
    inp = torch.tensor(np.random.standard_normal((batch_size, 2, 1, 16))).float()
    ref = m(inp).detach().numpy()

    torch.onnx.export(m, inp, "model.onnx")

    # Run with OpenVINO

    params = {hints.execution_mode: hints.ExecutionMode.ACCURACY, hints.inference_precision: "f32"}
    core = Core()
    compiled = core.compile_model("model.onnx", "GPU", params)
    req = compiled.create_infer_request()
    out = req.infer(np.array(inp))
    out = next(iter(out.values()))
    print(f"Batch size {batch_size} diff: {np.max(np.abs(ref - out))}")

Relevant log output

PyTorch version 2.6.0+cpu
OV version 2024.4.0-16579-c3152d32c9c-releases/2024/4
Batch size 1 diff: 0.0
Batch size 2 diff: 0.0
Batch size 4 diff: 0.0
Batch size 8 diff: 15.743196487426758
Batch size 16 diff: 0.0

Issue submission checklist

  • I'm reporting an issue. It's not a question.
  • I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
  • There is reproducer code and related data files such as images, videos, models, etc.

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions