Skip to content

[BUG] TorchLayer does not work correctly with broadcasting and tuple returns #5762

Closed
@dwierichs

Description

@dwierichs

Expected behavior

The code below runs.

Actual behavior

It errors out, because of invalid internal reshaping.

Additional information

If we return a list from the QNode instead, the example works fine.

Source code

import numpy as np
import pennylane as qml
import torch

n_qubits = 2
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev)
def qnode(inputs, weights):
    qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    return qml.expval(qml.Z(0)), qml.expval(qml.Z(1))

weight_shapes = {"weights": (3, n_qubits, 3)}

qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
x = torch.tensor(np.random.random((5, 2))) # Batched inputs with batch dim 5 for 2 qubits
qlayer.forward(x)

Tracebacks

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 18
     16 qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
     17 x = torch.tensor(np.random.random((5, 2))) # Batched inputs with batch dim 5 for 2 qubits
---> 18 qlayer.forward(x)

File ~/repos/pennylane/pennylane/qnn/torch.py:402, in TorchLayer.forward(self, inputs)
    399     inputs = torch.reshape(inputs, (-1, inputs.shape[-1]))
    401 # calculate the forward pass as usual
--> 402 results = self._evaluate_qnode(inputs)
    404 if isinstance(results, tuple):
    405     if has_batch_dim:

File ~/repos/pennylane/pennylane/qnn/torch.py:439, in TorchLayer._evaluate_qnode(self, x)
    436     return torch.hstack(_res).type(x.dtype)
    438 if isinstance(res, tuple) and len(res) > 1:
--> 439     return tuple(_combine_dimensions(r) for r in res)
    441 return _combine_dimensions(res)

File ~/repos/pennylane/pennylane/qnn/torch.py:439, in <genexpr>(.0)
    436     return torch.hstack(_res).type(x.dtype)
    438 if isinstance(res, tuple) and len(res) > 1:
--> 439     return tuple(_combine_dimensions(r) for r in res)
    441 return _combine_dimensions(res)

File ~/repos/pennylane/pennylane/qnn/torch.py:435, in TorchLayer._evaluate_qnode.<locals>._combine_dimensions(_res)
    433 def _combine_dimensions(_res):
    434     if len(x.shape) > 1:
--> 435         _res = [torch.reshape(r, (x.shape[0], -1)) for r in _res]
    436     return torch.hstack(_res).type(x.dtype)

File ~/repos/pennylane/pennylane/qnn/torch.py:435, in <listcomp>(.0)
    433 def _combine_dimensions(_res):
    434     if len(x.shape) > 1:
--> 435         _res = [torch.reshape(r, (x.shape[0], -1)) for r in _res]
    436     return torch.hstack(_res).type(x.dtype)

RuntimeError: shape '[5, -1]' is invalid for input of size 1

System information

pl dev

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug 🐛Something isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions