Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ readme = "README.md"
requires-python = ">=3.10.15"
dynamic = ["version"]
dependencies = [
"array-api-compat>=1.11.0",
"ezmsg>=3.6.0",
"ezmsg-baseproc>=1.2.1",
"ezmsg-event>=0.6.0",
Expand All @@ -34,6 +35,7 @@ lint = [
test = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
"mlx>=0.18.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
]
docs = [
"sphinx>=8.0",
Expand Down
4 changes: 2 additions & 2 deletions src/ezmsg/simbiophys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@

# Dynamic Colored Noise
from .dynamic_colored_noise import (
ColoredNoiseFilterState,
DynamicColoredNoiseSettings,
DynamicColoredNoiseState,
DynamicColoredNoiseTransformer,
DynamicColoredNoiseUnit,
compute_kasdin_coefficients,
compute_kasdin_coefficients_batch,
)

# EEG
Expand Down Expand Up @@ -112,12 +112,12 @@
"CosineEncoderTransformer",
"CosineEncoderUnit",
# Dynamic Colored Noise
"ColoredNoiseFilterState",
"DynamicColoredNoiseSettings",
"DynamicColoredNoiseState",
"DynamicColoredNoiseTransformer",
"DynamicColoredNoiseUnit",
"compute_kasdin_coefficients",
"compute_kasdin_coefficients_batch",
# DNSS LFP
"DNSSLFPProducer",
"DNSSLFPSettings",
Expand Down
24 changes: 17 additions & 7 deletions src/ezmsg/simbiophys/cosine_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
AxisArray with shape (n_samples, output_ch) containing encoded values.
"""

import typing
from pathlib import Path

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from array_api_compat import get_namespace
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
Expand Down Expand Up @@ -82,10 +83,10 @@ class CosineEncoderState:
ch_axis: Pre-built channel axis for output messages.
"""

baseline: npt.NDArray[np.floating] | None = None
modulation: npt.NDArray[np.floating] | None = None
pd: npt.NDArray[np.floating] | None = None
speed_modulation: npt.NDArray[np.floating] | None = None
baseline: typing.Any = None
modulation: typing.Any = None
pd: typing.Any = None
speed_modulation: typing.Any = None
ch_axis: AxisArray.CoordinateAxis | None = None

@property
Expand Down Expand Up @@ -216,9 +217,18 @@ def _reset_state(self, message: AxisArray) -> None:
seed=self.settings.seed,
)

# Convert state parameters to the input array's backend
xp = get_namespace(message.data)
if xp is not np:
self.state.baseline = xp.asarray(self.state.baseline)
self.state.modulation = xp.asarray(self.state.modulation)
self.state.pd = xp.asarray(self.state.pd)
self.state.speed_modulation = xp.asarray(self.state.speed_modulation)

def _process(self, message: AxisArray) -> AxisArray:
"""Transform polar coordinates to encoded output."""
polar = np.asarray(message.data, dtype=np.float64)
polar = message.data
xp = get_namespace(polar)

if polar.ndim != 2 or polar.shape[1] != 2:
raise ValueError(f"Expected polar coords with shape (n_samples, 2), got {polar.shape}")
Expand All @@ -231,7 +241,7 @@ def _process(self, message: AxisArray) -> AxisArray:
# State arrays are pre-shaped to (1, output_ch) for broadcasting
output = (
self.state.baseline
+ self.state.modulation * magnitude * np.cos(angle - self.state.pd)
+ self.state.modulation * magnitude * xp.cos(angle - self.state.pd)
+ self.state.speed_modulation * magnitude
)

Expand Down
Loading