Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ezmsg/baseproc/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,5 @@ def partial_fit(self, message: AxisArray) -> None:
...

async def apartial_fit(self, message: AxisArray) -> None: ...
def partial_fit_transform(self, message: AxisArray) -> MessageOutType: ...
async def apartial_fit_transform(self, message: AxisArray) -> MessageOutType: ...
62 changes: 36 additions & 26 deletions src/ezmsg/baseproc/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import pickle
import typing
import warnings
from abc import ABC, abstractmethod

from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from .processor import (
BaseProcessor,
Expand All @@ -14,7 +14,7 @@
)
from .protocols import MessageInType, MessageOutType, SettingsType, StateType
from .util.asio import run_coroutine_sync
from .util.message import SampleMessage, is_sample_message
from .util.message import is_sample_message
from .util.typeresolution import resolve_typevar


Expand Down Expand Up @@ -274,37 +274,47 @@ async def apartial_fit(self, message: AxisArray) -> None:
return self.partial_fit(message)

def __call__(self, message: MessageInType) -> MessageOutType | None:
"""
Adapt transformer with training data (and optionally labels)
in AxisArray with attrs["trigger"].

Args:
message: An AxisArray with optional trigger in attrs["trigger"],
containing labels (y) in attrs["trigger"].value and
data (X) in message.data

Returns: None
"""
if is_sample_message(message):
if isinstance(message, SampleMessage):
# Auto-convert old format → new format
message = replace(
message.sample,
attrs={**message.sample.attrs, "trigger": message.trigger},
)
return self.partial_fit(message)
warnings.warn(
f"{self.__class__.__name__}.__call__() received a sample message "
"(AxisArray with 'trigger' in attrs). Auto-routing to partial_fit "
"has been removed. Use partial_fit() for training only, or "
"partial_fit_transform() for training + inference.",
UserWarning,
stacklevel=2,
)
return super().__call__(message)

async def __acall__(self, message: MessageInType) -> MessageOutType | None:
if is_sample_message(message):
if isinstance(message, SampleMessage):
message = replace(
message.sample,
attrs={**message.sample.attrs, "trigger": message.trigger},
)
return await self.apartial_fit(message)
warnings.warn(
f"{self.__class__.__name__}.__acall__() received a sample message "
"(AxisArray with 'trigger' in attrs). Auto-routing to partial_fit "
"has been removed. Use apartial_fit() for training only, or "
"apartial_fit_transform() for training + inference.",
UserWarning,
stacklevel=2,
)
return await super().__acall__(message)

def partial_fit_transform(self, message: AxisArray) -> MessageOutType:
"""Train on the message, then run inference and return the result."""
msg_hash = self._hash_message(message)
if msg_hash != self._hash:
self._reset_state(message)
self._hash = msg_hash
self.partial_fit(message)
return self._process(message)

async def apartial_fit_transform(self, message: AxisArray) -> MessageOutType:
"""Async variant of partial_fit_transform."""
msg_hash = self._hash_message(message)
if msg_hash != self._hash:
self._reset_state(message)
self._hash = msg_hash
await self.apartial_fit(message)
return await self._aprocess(message)


class BaseAsyncTransformer(
BaseStatefulTransformer[SettingsType, MessageInType, MessageOutType, StateType],
Expand Down
9 changes: 7 additions & 2 deletions src/ezmsg/baseproc/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class BaseAdaptiveTransformerUnit(
INPUT_SAMPLE = ez.InputStream(AxisArray)
INPUT_SIGNAL = ez.InputStream(MessageInType)
OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
OUTPUT_SAMPLE = ez.OutputStream(MessageOutType)

def create_processor(self) -> None:
# self.processor: AdaptiveTransformerType[SettingsType, MessageInType, MessageOutType, StateType]
Expand All @@ -241,8 +242,12 @@ async def on_signal(self, message: MessageInType) -> typing.AsyncGenerator:
yield self.OUTPUT_SIGNAL, result

@ez.subscriber(INPUT_SAMPLE)
async def on_sample(self, msg: AxisArray) -> None:
await self.processor.apartial_fit(msg)
@ez.publisher(OUTPUT_SAMPLE)
@profile_subpub(trace_oldest=False)
async def on_sample(self, msg: AxisArray) -> typing.AsyncGenerator:
result = await self.processor.apartial_fit_transform(msg)
if result is not None:
yield self.OUTPUT_SAMPLE, result


class BaseClockDrivenUnit(
Expand Down
24 changes: 17 additions & 7 deletions tests/test_baseproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class MockSettings:
@processor_state
class MockState:
iterations: int = 0
fit_count: int = 0
hash: int = -1


Expand Down Expand Up @@ -141,7 +142,7 @@ def _process(self, message: MockMessageA) -> MockMessageB:
return MockMessageB()

def partial_fit(self, message: AxisArray) -> None:
self._state.iterations += 1
self._state.fit_count += 1


class MockAsyncTransformer(BaseAsyncTransformer[MockSettings, MockMessageA, MockMessageB, MockState]):
Expand Down Expand Up @@ -770,20 +771,29 @@ class TestBaseAdaptiveTransformer:
def test_partial_fit(self):
transformer = MockAdaptiveTransformer()
transformer.partial_fit(mock_sample_message())
assert transformer.state.iterations == 1
assert transformer.state.fit_count == 1

@pytest.mark.asyncio
async def test_apartial_fit(self):
transformer = MockAdaptiveTransformer()
await transformer.apartial_fit(mock_sample_message())
assert transformer.state.iterations == 1
assert transformer.state.fit_count == 1

def test_call_with_sample_message(self):
def test_call_with_sample_message_warns(self):
transformer = MockAdaptiveTransformer()
sample_msg = mock_sample_message()
result = transformer(sample_msg)
assert result is None # partial_fit returns None
assert transformer.state.iterations == 1
with pytest.warns(UserWarning, match="Auto-routing to partial_fit"):
result = transformer(sample_msg)
assert isinstance(result, MockMessageB) # inference, not partial_fit
assert transformer.state.fit_count == 0 # partial_fit NOT called

def test_partial_fit_transform(self):
transformer = MockAdaptiveTransformer()
sample_msg = mock_sample_message()
result = transformer.partial_fit_transform(sample_msg)
assert isinstance(result, MockMessageB)
assert transformer.state.fit_count == 1
assert transformer.state.iterations == 1 # _process was called


class TestBaseAsyncTransformer:
Expand Down