Skip to content

fix TokenAttention #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ optional-dependencies.docs = [
"ipython",
"myst-nb>=1.1",
"pandas",
"scvi-tools",
"scvi-tools>=1.3.1",
"setuptools", # Until pybtex >0.23.0 releases: https://bitbucket.org/pybtex-devs/pybtex/issues/169/
"sphinx>=8",
"sphinx-autodoc-typehints",
Expand All @@ -78,7 +78,7 @@ optional-dependencies.embedding = [
"transformers",
]
optional-dependencies.external = [
"scvi-tools",
"scvi-tools>=1.3.1",
]
optional-dependencies.pp = [
"pertpy",
Expand Down
7 changes: 4 additions & 3 deletions src/cellflow/networks/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,11 @@ def __call__(
token_shape = (len(x), 1)
class_token = nn.Embed(num_embeddings=1, features=x.shape[-1])(jnp.int32(jnp.zeros(token_shape)))
z = jnp.concatenate((class_token, x), axis=-2)
token_mask = jnp.zeros((x.shape[0], 1, x.shape[1] + 1, x.shape[1] + 1))
token_mask = token_mask.at[:, :, 0, :].set(1)
token_mask = token_mask.at[:, :, :, 0].set(1)
token_mask = jnp.ones((x.shape[0], 1, x.shape[1] + 1, x.shape[1] + 1))
token_mask = token_mask.at[:, :, 1:, 1:].set(mask)
cls_token_to_data = mask[0, 0, :, :].sum(axis=0) > 0
token_mask = token_mask.at[:, :, 0, 1:].set(cls_token_to_data)
token_mask = token_mask.at[:, :, 1:, 0].set(cls_token_to_data)

# attention
attention = nn.MultiHeadDotProductAttention(
Expand Down
29 changes: 29 additions & 0 deletions tests/networks/test_aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import jax
import jax.numpy as jnp
import pytest

from cellflow.networks._set_encoders import ConditionEncoder
from cellflow.networks._utils import SeedAttentionPooling, TokenAttentionPooling


class TestAggregator:
@pytest.mark.parametrize("agg", [TokenAttentionPooling, SeedAttentionPooling])
def test_mask_impact_on_TokenAttentionPooling(self, agg):
rng = jax.random.PRNGKey(0)
init_rng, mask_rng = jax.random.split(rng, 2)
condition = jax.random.normal(rng, (2, 3, 7))
condition = jnp.concatenate((condition, jnp.zeros((2, 1, 7))), axis=1)
cond_encoder = ConditionEncoder(32)
_, attn_mask = cond_encoder._get_masks({"conditions": condition})
random_mask = jax.random.bernoulli(mask_rng, 0.5, attn_mask.shape).astype(jnp.int32)
agg = agg()
variables = agg.init(init_rng, condition, random_mask, training=True)
out = agg.apply(variables, condition, attn_mask, training=True)
out_rand = agg.apply(variables, condition, random_mask, training=True)
# output dim = input dim for TokenAttentionPooling, output dim = 64 by default in SeedAttentionPooling
assert out.shape[0] == 2
assert out.shape[1] == 7 if isinstance(agg, TokenAttentionPooling) else 64
assert out_rand.shape[0] == 2
assert out_rand.shape[1] == 7 if isinstance(agg, TokenAttentionPooling) else 64
assert not jnp.allclose(out[0], out_rand[0], atol=1e-6)
assert not jnp.allclose(out[1], out_rand[1], atol=1e-6)
Loading