Skip to content

Commit

Permalink
Amend docstring and add test for Flatten module (pytorch#42084)
Browse files Browse the repository at this point in the history
Summary:
I've noticed when PR pytorch#22245 introduced `nn.Flatten`, the docstring had a bug where it wouldn't render properly on the web, and this PR addresses that. Additionally, it adds a unit test for this module.

**Actual**
![image](https://user-images.githubusercontent.com/13088001/88483672-cf896a00-cf3f-11ea-8b1b-a30d152e1368.png)

**Expected**
![image](https://user-images.githubusercontent.com/13088001/88483642-86391a80-cf3f-11ea-8333-0964a027a172.png)

Pull Request resolved: pytorch#42084

Reviewed By: mrshenli

Differential Revision: D22756662

Pulled By: ngimel

fbshipit-source-id: 60c58c18c9a68854533196ed6b9e9fb0d4f83520
  • Loading branch information
alvgaona authored and facebook-github-bot committed Jul 27, 2020
1 parent 4290d0b commit 3e121d9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
9 changes: 9 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8326,6 +8326,15 @@ def test_functional_grad_conv(self):
torch.nn.grad._grad_input_padding(torch.rand(1, 2, 3), [1, 2, 5], (1,), (0,), (3,))
self.assertEqual(len(w), 1)

def test_flatten(self):
tensor_input = torch.randn(2, 1, 2, 3)

# Flatten Tensor

flatten = nn.Flatten(start_dim=1, end_dim=-1)
tensor_output = flatten(tensor_input)
self.assertEqual(tensor_output.size(), torch.Size([2, 6]))

def test_unflatten(self):
tensor_input = torch.randn(2, 50)

Expand Down
10 changes: 7 additions & 3 deletions torch/nn/modules/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,24 @@
class Flatten(Module):
r"""
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, \prod *dims)` (for the default case).
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
Examples::
>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>> nn.Conv2d(1, 32, 5, 1, 1),
>>> nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])
"""
__constants__ = ['start_dim', 'end_dim']
start_dim: int
Expand Down

0 comments on commit 3e121d9

Please sign in to comment.