Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: Building a 3-layer MLP network with PyTorch, inference slows down after conversion with OpenViNO #22751

Open
3 tasks done
qiwang067 opened this issue Feb 9, 2024 · 6 comments
Assignees
Labels
category: PyTorch FE OpenVINO PyTorch Frontend conformance performance Performance related topics support_request

Comments

@qiwang067
Copy link

qiwang067 commented Feb 9, 2024

OpenVINO Version

2023.3.0

Operating System

Ubuntu 20.04 (LTS)

Device used for inference

None

OpenVINO installation

PyPi

Programming Language

Python

Hardware Architecture

x86 (64 bits)

Model used

MLP

Model quantization

No

Target Platform

No response

Performance issue description

Using 3-layer MLP to build a policy network for reinforcement learning with PyTorch, but inference slows down after conversion with OpenViNO:

  1. Without OpenViNO, inference 1000 times: 0.016s
  2. Directly converted to OpenViNO, inference 1000 times: 0.039s
  3. Convert to ONNX first, then convert to OpenViNO, inference 1000 times: 0.046s

3-layer MLP with PyTorch:

class MLP(nn.Module):
    def __init__(self, n_states,n_actions,hidden_dim=128):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(n_states, hidden_dim) 
        self.fc2 = nn.Linear(hidden_dim,hidden_dim) 
        self.fc3 = nn.Linear(hidden_dim, n_actions) 
        
    def forward(self, x):
        x = F.relu(self.fc1(x)) 
        x = F.relu(self.fc2(x))
        return self.fc3(x)

Step-by-step reproduction

Download and execute the following program:
https://github.com/qiwang067/openvino_rl/blob/main/dqn.py

Issue submission checklist

  • I'm reporting a performance issue. It's not a question.
  • I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
  • There is reproducer code and related data files such as images, videos, models, etc.
@qiwang067 qiwang067 added performance Performance related topics support_request labels Feb 9, 2024
@andrei-kochin
Copy link
Contributor

@qiwang067 please share the pip freeze output or requirements.txt to install as I the mentioned script fails on my end with:

    state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float)
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (64,) + inhomogeneous part.

https://github.com/qiwang067/openvino_rl/blob/main/dqn.py#L150 returns one additional argument in my env

@andrei-kochin andrei-kochin added the category: PyTorch FE OpenVINO PyTorch Frontend label Feb 9, 2024
@qiwang067
Copy link
Author

@qiwang067 please share the pip freeze output or requirements.txt to install as I the mentioned script fails on my end with:

    state_batch = torch.tensor(np.array(state_batch), device=self.device, dtype=torch.float)
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (64,) + inhomogeneous part.

https://github.com/qiwang067/openvino_rl/blob/main/dqn.py#L150 returns one additional argument in my env

@andrei-kochin please check the following link:
https://github.com/qiwang067/openvino_rl/blob/main/requirements.txt

@mvafin
Copy link
Contributor

mvafin commented Mar 28, 2024

With what reference do you compare? Is it torch eager? I see improved performance in my small script compared to torch eager. Could you provide n_states, n_actions, hidden_dim and shape of the input tensor you use to infer the model? Also what hardware do you use for inference?

@qiwang067
Copy link
Author

With what reference do you compare? Is it torch eager? I see improved performance in my small script compared to torch eager. Could you provide n_states, n_actions, hidden_dim and shape of the input tensor you use to infer the model? Also what hardware do you use for inference?

I compared with original torch version, not torch eager.
n_states:4, n_actions:2, hidden_dim:256, shape of the input tensor :torch.Size([1, 4]).
Hardware: 13th Gen Intel(R) Core(TM) i9-13900K

@mvafin
Copy link
Contributor

mvafin commented Apr 10, 2024

@qiwang067 I still see better performance on openvino:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import openvino as ov
import time

class MLP(nn.Module):
    def __init__(self, n_states,n_actions,hidden_dim=128):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(n_states, hidden_dim) 
        self.fc2 = nn.Linear(hidden_dim,hidden_dim) 
        self.fc3 = nn.Linear(hidden_dim, n_actions) 
        
    def forward(self, x):
        x = F.relu(self.fc1(x)) 
        x = F.relu(self.fc2(x))
        return self.fc3(x)

n_states = 4
n_actions = 2
hidden_dim = 256

input_data = np.random.randn(1, 4).astype(np.float32)
pt_inputs = torch.from_numpy(input_data)

torch_model = MLP(n_states, n_actions, hidden_dim)
ov_model = ov.convert_model(torch_model, example_input=[input_data])
ov_compiled_model = ov.compile_model(ov_model, device_name="CPU")
with torch.no_grad():
    torch_output = torch_model(pt_inputs).numpy()
ov_output = ov_compiled_model([input_data])[0]
np.testing.assert_allclose(ov_output, torch_output, atol=1e-4)

n = 10000
start = time.time()
with torch.no_grad():
    for i in range(n):
         torch_model(pt_inputs)
print(f"pt: {(time.time() - start) / n:.6f}")

start = time.time()
for i in range(n):
    ov_compiled_model([input_data])
print(f"ov: {(time.time() - start) / n:.6f}")

Returns:

pt: 0.000049
ov: 0.000025

Could you provide script to reproduce your issue?

@qiwang067
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: PyTorch FE OpenVINO PyTorch Frontend conformance performance Performance related topics support_request
Projects
None yet
Development

No branches or pull requests

5 participants