Skip to content

Commit 09e08c3

Browse files
committed
fix: singleconvblock key
1 parent 72b6e8b commit 09e08c3

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "Equimo"
3-
version = "0.4.0-alpha.7"
3+
version = "0.4.0-alpha.9"
44
description = "Implementation of popular vision models in Jax"
55
readme = "README.md"
66
requires-python = ">=3.10"

src/equimo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.4.0-alpha.7"
1+
__version__ = "0.4.0-alpha.9"

src/equimo/layers/convolution.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from curses import KEY_EIC
12
from typing import Callable, Optional, Sequence, Tuple
23

34
import equinox as eqx
@@ -1022,12 +1023,14 @@ def __call__(
10221023
key: PRNGKeyArray,
10231024
inference: Optional[bool] = None,
10241025
):
1025-
key_dropout, key_droppath = jr.split(key, 2)
1026+
key_sdwc, key_ec, key_mdwc, key_proj, key_dropout, key_droppath = jr.split(
1027+
key, 6
1028+
)
10261029

1027-
out = self.start_dw_conv(x)
1028-
out = self.expand_conv(out)
1029-
out = self.middle_dw_conv(out)
1030-
out = self.proj_conv(out)
1030+
out = self.start_dw_conv(x, inference=inference, key=key_sdwc)
1031+
out = self.expand_conv(out, inference=inference, key=key_ec)
1032+
out = self.middle_dw_conv(out, inference=inference, key=key_mdwc)
1033+
out = self.proj_conv(out, inference=inference, key=key_proj)
10311034

10321035
out = self.dropout(out, inference=inference, key=key_dropout)
10331036

0 commit comments

Comments
 (0)