Skip to content

Conversation

@nkovela1
Copy link
Collaborator

This PR contains some small ops fixes for Torch unit tests, particularly those relating to incompatibility with numpy ops. Subsequent PRs will contain remaining fixes.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

- [Misra, 2019](https://arxiv.org/abs/1908.08681)
"""
x = backend.convert_to_tensor(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is torch related, can we move it to the torch backend?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling mish in activations.py does x * backend.nn.tanh(backend.nn.softplus(x)) a few lines above this, and this will multiply a numpy array (x) by a Torch tensor, which is not allowed. The conversion to a tensor for x (first arg) must therefore be done here.

if isinstance(self._values[k], list):
avg = np.mean(
self._values[k][0] / max(1, self._values[k][1])
backend.convert_to_numpy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe Haifeng has already provided a fix that should make this change unnecessary (earlier today)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, reverted thanks!

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fchollet fchollet merged commit e3af2a6 into main Jun 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants