Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/ezmsg/baseproc/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from dataclasses import dataclass

import ezmsg.core as ez

from .util.message import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray

# --- Processor state decorator ---
processor_state = functools.partial(dataclass, unsafe_hash=True, frozen=False, init=False)
Expand Down Expand Up @@ -138,12 +137,12 @@ def stateful_op(


class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
def partial_fit(self, message: SampleMessage) -> None:
def partial_fit(self, message: AxisArray) -> None:
"""Update transformer state using labeled training data.

This method should update the internal state/parameters of the transformer
based on the provided labeled samples, without performing any transformation.
"""
...

async def apartial_fit(self, message: SampleMessage) -> None: ...
async def apartial_fit(self, message: AxisArray) -> None: ...
32 changes: 23 additions & 9 deletions src/ezmsg/baseproc/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import typing
from abc import ABC, abstractmethod

from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from .processor import (
BaseProcessor,
BaseProducer,
Expand Down Expand Up @@ -256,38 +259,49 @@ def stateful_op(
class BaseAdaptiveTransformer(
BaseStatefulTransformer[
SettingsType,
MessageInType | SampleMessage,
MessageInType,
MessageOutType | None,
StateType,
],
ABC,
typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
):
@abstractmethod
def partial_fit(self, message: SampleMessage) -> None: ...
def partial_fit(self, message: AxisArray) -> None: ...

async def apartial_fit(self, message: SampleMessage) -> None:
async def apartial_fit(self, message: AxisArray) -> None:
"""Override me if you need async partial fitting."""
return self.partial_fit(message)

def __call__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
def __call__(self, message: MessageInType) -> MessageOutType | None:
"""
Adapt transformer with training data (and optionally labels)
in SampleMessage
in AxisArray with attrs["trigger"].

Args:
message: An instance of SampleMessage with optional
labels (y) in message.trigger.value.data and
data (X) in message.sample.data
message: An AxisArray with optional trigger in attrs["trigger"],
containing labels (y) in attrs["trigger"].value and
data (X) in message.data

Returns: None
"""
if is_sample_message(message):
if isinstance(message, SampleMessage):
# Auto-convert old format → new format
message = replace(
message.sample,
attrs={**message.sample.attrs, "trigger": message.trigger},
)
return self.partial_fit(message)
return super().__call__(message)

async def __acall__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
async def __acall__(self, message: MessageInType) -> MessageOutType | None:
if is_sample_message(message):
if isinstance(message, SampleMessage):
message = replace(
message.sample,
attrs={**message.sample.attrs, "trigger": message.trigger},
)
return await self.apartial_fit(message)
return await super().__acall__(message)

Expand Down
5 changes: 2 additions & 3 deletions src/ezmsg/baseproc/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .processor import BaseConsumer, BaseProducer, BaseTransformer
from .protocols import MessageInType, MessageOutType, SettingsType
from .stateful import BaseAdaptiveTransformer, BaseStatefulConsumer, BaseStatefulTransformer
from .util.message import SampleMessage
from .util.profile import profile_subpub
from .util.typeresolution import resolve_typevar

Expand Down Expand Up @@ -223,7 +222,7 @@ class BaseAdaptiveTransformerUnit(
ABC,
typing.Generic[SettingsType, MessageInType, MessageOutType, AdaptiveTransformerType],
):
INPUT_SAMPLE = ez.InputStream(SampleMessage)
INPUT_SAMPLE = ez.InputStream(AxisArray)
INPUT_SIGNAL = ez.InputStream(MessageInType)
OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)

Expand All @@ -242,7 +241,7 @@ async def on_signal(self, message: MessageInType) -> typing.AsyncGenerator:
yield self.OUTPUT_SIGNAL, result

@ez.subscriber(INPUT_SAMPLE)
async def on_sample(self, msg: SampleMessage) -> None:
async def on_sample(self, msg: AxisArray) -> None:
await self.processor.apartial_fit(msg)


Expand Down
22 changes: 19 additions & 3 deletions src/ezmsg/baseproc/util/message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
import typing
import warnings
from dataclasses import dataclass, field

from ezmsg.util.messages.axisarray import AxisArray
Expand All @@ -19,13 +20,28 @@ class SampleTriggerMessage:

@dataclass
class SampleMessage:
"""
.. deprecated::
``SampleMessage`` is deprecated. Use ``AxisArray`` with
``attrs={"trigger": SampleTriggerMessage(...)}`` instead.
"""

trigger: SampleTriggerMessage
"""The time, window, and value (if any) associated with the trigger."""

sample: AxisArray
"""The data sampled around the trigger."""

def __post_init__(self):
warnings.warn(
"SampleMessage is deprecated. Use AxisArray with " "attrs={'trigger': SampleTriggerMessage(...)} instead.",
DeprecationWarning,
stacklevel=2,
)


def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
"""Check if the message is a SampleMessage."""
return hasattr(message, "trigger")
def is_sample_message(message: typing.Any) -> bool:
"""Detect old SampleMessage OR new AxisArray-with-trigger."""
if isinstance(message, SampleMessage):
return True
return isinstance(message, AxisArray) and "trigger" in getattr(message, "attrs", {})
20 changes: 11 additions & 9 deletions tests/test_baseproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import pickle
from types import NoneType
from typing import Any
from unittest.mock import MagicMock

import numpy as np
import pytest
from ezmsg.util.messages.axisarray import AxisArray

from ezmsg.baseproc import (
BaseAdaptiveTransformer,
Expand All @@ -21,7 +22,7 @@
BaseTransformer,
CompositeProcessor,
CompositeProducer,
SampleMessage,
SampleTriggerMessage,
_get_base_processor_message_in_type,
_get_base_processor_message_out_type,
_get_base_processor_settings_type,
Expand Down Expand Up @@ -135,11 +136,11 @@ class MockAdaptiveTransformer(BaseAdaptiveTransformer[MockSettings, MockMessageA
def _reset_state(self, message: MockMessageA) -> None:
self._state.iterations = 0

def _process(self, message: MockMessageA | SampleMessage) -> MockMessageB:
def _process(self, message: MockMessageA) -> MockMessageB:
self._state.iterations += 1
return MockMessageB()

def partial_fit(self, message: SampleMessage) -> None:
def partial_fit(self, message: AxisArray) -> None:
self._state.iterations += 1


Expand Down Expand Up @@ -756,10 +757,13 @@ def test_stateful_op(self):
assert new_state[0].iterations == 1


# Mock SampleMessage for testing BaseAdaptiveTransformer
# Helper to create an AxisArray with trigger in attrs for testing BaseAdaptiveTransformer
def mock_sample_message():
sample_message = MagicMock(spec=SampleMessage)
return sample_message
return AxisArray(
data=np.zeros((1, 1)),
dims=["time", "ch"],
attrs={"trigger": SampleTriggerMessage()},
)


class TestBaseAdaptiveTransformer:
Expand All @@ -776,9 +780,7 @@ async def test_apartial_fit(self):

def test_call_with_sample_message(self):
transformer = MockAdaptiveTransformer()
# Create a sample message with a trigger attribute
sample_msg = mock_sample_message()
setattr(sample_msg, "trigger", None)
result = transformer(sample_msg)
assert result is None # partial_fit returns None
assert transformer.state.iterations == 1
Expand Down