Skip to content

Commit

Permalink
concat
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Oct 27, 2021
1 parent bf31bd1 commit a41a9b3
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions nn/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,23 @@
Array (Tensor) functions
"""

from typing import Optional, Tuple, List
from typing import Optional, Union, Tuple, List
from returnn.util.basic import NotSpecified
from .base import LayerRef
from .base import LayerRef, Layer


def concat(sources: Union[List[LayerRef], Tuple[LayerRef, ...]], *,
axis: Optional[str] = NotSpecified,
name: Optional[str] = None) -> Layer:
"""
Concatenates multiple sources (by default in feature axis).
"""
if axis is NotSpecified or axis is None or axis.upper() == "F":
# standard case
from .base import make_layer
return make_layer({"class": "copy", "from": sources}, name=name or "concat")
else:
raise NotImplementedError(f"Cannot handle concat with axis {axis!r} yet")


def split(source: LayerRef, *,
Expand Down

0 comments on commit a41a9b3

Please sign in to comment.