Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

#350 make vec from scalars #425

Closed
wants to merge 14 commits into from

Conversation

Travmatth
Copy link
Contributor

Addresses issue #350, renames existing backend.concat methods to backend.shape_concat and implement backend.concat methods using backend stack method. If name backend.shape_concat is too similar to existing backend.concat_shape methods could alternatively fold backend.concat into existing method of same name as special case when axis == 0. Did not squash commits.

Copy link
Contributor

@chaserileyroberts chaserileyroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks for the fix! Approval after small comment.

@@ -135,7 +135,12 @@ def rq_decomposition(
raise NotImplementedError(
"Backend '{}' has not implemented rq_decomposition.".format(self.name))

def concat(self, values: Sequence[Tensor], axis) -> Tensor:
def shape_concat(self, values: Sequence[Tensor], axis) -> Tensor:
"""Concatenate a sequence of tensors together about the given axis."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change the description here to be explicitly only for shape calculations?

Copy link
Contributor Author

@Travmatth Travmatth Jan 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to the description, now reads:

"""Concatenate a sequence of tensors together about the given axis,
intended only for use in shape calculations"""

@chaserileyroberts
Copy link
Contributor

Hmm. Still seems like there's a shape based error for the Pytorch backend. Can you investigate?

return np.concatenate(values, axis)

def concat(self, values: Tensor, axis: int = 0) -> Tensor:
return np.stack(values, axis)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aw, that would explain it. You'll need to use pytorch's concat equivalent here instead of numpy.

Comment on lines 118 to 125
def concat(self, values: Sequence[Tensor], axis: int = 0) -> Tensor:
new_shape = None
if axis == 0:
new_shape = ShellTensor(values)
else:
new_shape = self.shape_concat(values, axis)
return new_shape

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't this this is correct. ShellTensor is a tensor type that only stores its shape and has no concrete values, so when you do a normal concatenation of multiple ShellTensors, you'll want to add the value of that specified axis together.

Though this function really isn't important at all. Just throw a NotImplementedError and we'll add it when we actually need it.

@@ -39,9 +39,12 @@ def convert_to_tensor(self, tensor: Tensor) -> Tensor:
result = self.jax.jit(lambda x: x)(tensor)
return result

def concat(self, values: Tensor, axis: int) -> Tensor:
def shape_concat(self, values: Tensor, axis: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doubly come to think of it, we don't need the axis number for shape_concat at all since we always use just -1. Can you remove this argument?

@Travmatth
Copy link
Contributor Author

It looks as though the Travis build for python 3.7 failed due to an error when downloading pytype. The Travis build for python 3.6 successfully completed.

@Travmatth Travmatth closed this Jan 28, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants