Skip to content

Commit

Permalink
revert one declare_same_as usage
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Feb 14, 2023
1 parent 0cf2cd5 commit ffc5083
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,10 @@ def __call__(self, spatial_dim: nn.Dim) -> Tuple[nn.Tensor, nn.Dim]:
remaining_dim = spatial_dim - mat_spatial_size
left = nn.expand_dim(left, dim=remaining_dim)
right = nn.expand_dim(right, dim=remaining_dim)
concat, out_spatial_dim_ = nn.concat(
cond.true, out_spatial_dim_ = nn.concat(
(left, remaining_dim), (self.pos_emb, self.clipped_spatial_dim), (right, remaining_dim)
)
concat, out_spatial_dim_ = nn.replace_dim(concat, in_dim=out_spatial_dim_, out_dim=out_spatial_dim)
cond.true = concat
out_spatial_dim_.declare_same_as(out_spatial_dim)

# False branch, spatial_dim <= self.clipping
cond.false, _ = nn.slice_nd(
Expand Down

0 comments on commit ffc5083

Please sign in to comment.