Skip to content

Commit

Permalink
- added size_splits to functional (pytorch#3837)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrblck authored and soumith committed Jan 4, 2018
1 parent dc76db3 commit 7c729e6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 13 deletions.
22 changes: 22 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4269,6 +4269,28 @@ def test_split(self):
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0)
start = start + target_size[dim]

# Variable sections split
tensor = torch.randn(20, 10)
dim = 0
split_sizes = [5, 5, 10]
target_sizes = ([[5, 10], [5, 10], [10, 10]])
splits = tensor.split(split_sizes, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0)
start = start + target_size[dim]

split_sizes = [2, 2, 6]
target_sizes = ([20, 2], [20, 2], [20, 6])
dim = 1
splits = tensor.split(split_sizes, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0)
start = start + target_size[dim]

def test_chunk(self):
tensor = torch.rand(4, 7)
num_chunks = 3
Expand Down
44 changes: 31 additions & 13 deletions torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,45 @@
]


def split(tensor, split_size, dim=0):
r"""Splits the tensor into chunks all of size :attr:`split_size` (if possible).
def split(tensor, split_size_or_sections, dim=0):
"""Splits the tensor into chunks.
If ``split_size_or_sections`` is an integer type, then ``tensor`` will be
split into equally sized chunks (if possible).
Last chunk will be smaller if the tensor size along a given dimension
is not divisible by :attr`split_size`.
is not divisible by ``split_size``.
If ``split_size_or_sections`` is a list, then ``tensor`` will be split
into ``len(split_size_or_sections)`` chunks with sizes in ``dim`` according
to ``split_size_or_sections``.
Arguments:
tensor (Tensor): the tensor to split
split_size (int): size of a single chunk
dim (int): dimension along which to split the tensor
tensor (Tensor): tensor to split.
split_size_or_sections (int) or (list(int)): size of a single chunk or
list of sizes for each chunk
dim (int): dimension along which to split the tensor.
"""
if dim < 0:
dim += tensor.dim()
dim_size = tensor.size(dim)
num_splits = (dim_size + split_size - 1) // split_size
last_split_size = split_size - (split_size * num_splits - dim_size)

def get_split_size(i):
return split_size if i < num_splits - 1 else last_split_size
return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i
in _range(0, num_splits))
if isinstance(split_size_or_sections, int):
split_size = split_size_or_sections
num_splits = (dim_size + split_size - 1) // split_size
last_split_size = split_size - (split_size * num_splits - dim_size)

def get_split_size(i):
return split_size if i < num_splits - 1 else last_split_size
return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i
in _range(0, num_splits))

else:
if dim_size != sum(split_size_or_sections):
raise ValueError("Sum of split sizes exceeds tensor dim")
split_indices = [0] + split_size_or_sections
split_indices = torch.cumsum(torch.Tensor(split_indices), dim=0)

return tuple(
tensor.narrow(int(dim), int(start), int(length))
for start, length in zip(split_indices, split_size_or_sections))


def chunk(tensor, chunks, dim=0):
Expand Down

0 comments on commit 7c729e6

Please sign in to comment.