-
Notifications
You must be signed in to change notification settings - Fork 363
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks for the fix! Approval after small comment.
@@ -135,7 +135,12 @@ def rq_decomposition( | |||
raise NotImplementedError( | |||
"Backend '{}' has not implemented rq_decomposition.".format(self.name)) | |||
|
|||
def concat(self, values: Sequence[Tensor], axis) -> Tensor: | |||
def shape_concat(self, values: Sequence[Tensor], axis) -> Tensor: | |||
"""Concatenate a sequence of tensors together about the given axis.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change the description here to be explicitly only for shape calculations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added to the description, now reads:
"""Concatenate a sequence of tensors together about the given axis,
intended only for use in shape calculations"""
Hmm. Still seems like there's a shape based error for the Pytorch backend. Can you investigate? |
return np.concatenate(values, axis) | ||
|
||
def concat(self, values: Tensor, axis: int = 0) -> Tensor: | ||
return np.stack(values, axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aw, that would explain it. You'll need to use pytorch's concat equivalent here instead of numpy.
def concat(self, values: Sequence[Tensor], axis: int = 0) -> Tensor: | ||
new_shape = None | ||
if axis == 0: | ||
new_shape = ShellTensor(values) | ||
else: | ||
new_shape = self.shape_concat(values, axis) | ||
return new_shape | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't this this is correct. ShellTensor
is a tensor type that only stores its shape and has no concrete values, so when you do a normal concatenation of multiple ShellTensors
, you'll want to add the value of that specified axis together.
Though this function really isn't important at all. Just throw a NotImplementedError
and we'll add it when we actually need it.
@@ -39,9 +39,12 @@ def convert_to_tensor(self, tensor: Tensor) -> Tensor: | |||
result = self.jax.jit(lambda x: x)(tensor) | |||
return result | |||
|
|||
def concat(self, values: Tensor, axis: int) -> Tensor: | |||
def shape_concat(self, values: Tensor, axis: int) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doubly come to think of it, we don't need the axis number for shape_concat
at all since we always use just -1. Can you remove this argument?
It looks as though the Travis build for python 3.7 failed due to an error when downloading pytype. The Travis build for python 3.6 successfully completed. |
Addresses issue #350, renames existing
backend.concat
methods tobackend.shape_concat
and implementbackend.concat
methods using backendstack
method. If namebackend.shape_concat
is too similar to existingbackend.concat_shape
methods could alternatively foldbackend.concat
into existing method of same name as special case whenaxis == 0
. Did not squash commits.