Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-animal pose functionality #3

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions spec/ndx-pose.extensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,66 @@ groups:
- neurodata_type_inc: PoseEstimationSeries
doc: Estimated position data for each body part.
quantity: '*'
- neurodata_type_inc: PoseGroupingSeries
doc: Part grouping metadata for the individual in multi-animal experiments.
quantity: '?'
- neurodata_type_inc: AnimalIdentitySeries
doc: Predicted identity of the individual in multi-animal experiments.
quantity: '?'
- neurodata_type_def: PoseGroupingSeries
neurodata_type_inc: TimeSeries
doc: Instance-level part grouping timeseries for the individual animal. This contains
metadata of the part grouping procedure for multi-animal pose trackers.
datasets:
- name: name
dtype: text
doc: Description of the type of localization, e.g., 'Centroid' or 'Bounding box'.
- name: data
dtype: float32
dims:
- num_frames
shape:
- null
doc: Score of the grouping approach that associated all of the keypoints to the
same animal within the frame.
- name: location
dtype: float32
dims:
- - num_frames
- x, y
- - num_frames
- x, y, z
- - num_frames
- x1, y1, x2, y2
- - num_frames
- x1, y1, z1, x2, y2, z2
shape:
- - null
- 2
- - null
- 3
- - null
- 4
- - null
- 6
doc: Animal location for two-stage (top-down) multi-animal models, e.g., centroid
or bounding box.
quantity: '?'
- neurodata_type_def: AnimalIdentitySeries
neurodata_type_inc: TimeSeries
doc: Identity of the animal predicted by a tracking or re-ID algorithm in multi-animal
experiments.
datasets:
- name: data
dtype: float32
dims:
- num_frames
shape:
- null
doc: Score of the identity assignment approach that associated all of the keypoints
to the same animal over frames, e.g., MOT tracking score or ID classification
probability.
- name: name
dtype: text
doc: Unique animal identifier, track label, or class name used to identify this
animal in the experiment.
6 changes: 4 additions & 2 deletions spec/ndx-pose.namespace.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ namespaces:
- Ryan Ly
- Ben Dichter
- Alexander Mathis
- Talmo Pereira
contact:
- rly@lbl.gov
- bdichter@lbl.gov
- alexander.mathis@epfl.ch
doc: NWB extension to store pose estimation data
- talmo@salk.edu
doc: NWB extension to store single or multi-animal pose tracking.
name: ndx-pose
schema:
- namespace: core
Expand All @@ -17,4 +19,4 @@ namespaces:
- NWBDataInterface
- NWBContainer
- source: ndx-pose.extensions.yaml
version: 0.1.0
version: 0.2.0
2 changes: 1 addition & 1 deletion src/pynwb/ndx_pose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
load_namespaces(ndx_pose_specpath)

from . import io as __io
from .pose import PoseEstimationSeries, PoseEstimation
from .pose import PoseEstimationSeries, PoseEstimation, PoseGroupingSeries, AnimalIdentitySeries
26 changes: 25 additions & 1 deletion src/pynwb/ndx_pose/pose.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from hdmf.utils import docval, popargs, get_docval, call_docval_func

from pynwb import register_class, TimeSeries
from pynwb import register_class, TimeSeries, get_class
# from pynwb.behavior import SpatialSeries
from pynwb.core import MultiContainerInterface
from pynwb.device import Device


PoseGroupingSeries = get_class("PoseGroupingSeries", "ndx-pose")
AnimalIdentitySeries = get_class("AnimalIdentitySeries", "ndx-pose")


@register_class('PoseEstimationSeries', 'ndx-pose')
class PoseEstimationSeries(TimeSeries):
"""
Expand Down Expand Up @@ -52,6 +56,20 @@ class PoseEstimation(MultiContainerInterface):
'type': PoseEstimationSeries,
'attr': 'pose_estimation_series'
},
{
'add': 'add_pose_grouping_series',
'get': 'get_pose_grouping_series',
'create': 'create_pose_grouping_series',
'type': PoseGroupingSeries,
'attr': 'pose_grouping_series'
},
{
'add': 'add_animal_identity_series',
'get': 'get_animal_identity_series',
'create': 'create_animal_identity_series',
'type': AnimalIdentitySeries,
'attr': 'animal_identity_series'
},
{
'add': 'add_device',
'get': 'get_devices',
Expand All @@ -69,6 +87,8 @@ class PoseEstimation(MultiContainerInterface):

# TODO fill in doc
@docval({'name': 'pose_estimation_series', 'type': ('array_data', 'data'), 'doc': (''), 'default': None},
{'name': 'pose_grouping_series', 'type': ('array_data', 'data'), 'doc': (''), 'default': None},
{'name': 'animal_identity_series', 'type': ('array_data', 'data'), 'doc': (''), 'default': None},
{'name': 'name', 'type': str, 'doc': (''), 'default': 'PoseEstimation'},
{'name': 'description', 'type': str, 'doc': (''), 'default': None},
{'name': 'original_videos', 'type': ('array_data', 'data'), 'shape': (None, ),
Expand All @@ -88,13 +108,17 @@ def __init__(self, **kwargs):
"""
""" # TODO
pose_estimation_series, description = popargs('pose_estimation_series', 'description', kwargs)
pose_grouping_series = popargs('pose_grouping_series', kwargs)
animal_identity_series = popargs('animal_identity_series', kwargs)
original_videos, labeled_videos, = popargs('original_videos', 'labeled_videos', kwargs)
dimensions, scorer = popargs('dimensions', 'scorer', kwargs)
source_software, source_software_version = popargs('source_software', 'source_software_version', kwargs)
nodes, edges = popargs('nodes', 'edges', kwargs)
devices = popargs('devices', kwargs)
call_docval_func(super().__init__, kwargs)
self.pose_estimation_series = pose_estimation_series
self.pose_grouping_series = pose_grouping_series
self.animal_identity_series = animal_identity_series
self.description = description
self.original_videos = original_videos
self.labeled_videos = labeled_videos
Expand Down
115 changes: 112 additions & 3 deletions src/pynwb/tests/unit/test_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from pynwb import NWBFile
from pynwb.testing import TestCase

from ndx_pose import PoseEstimationSeries, PoseEstimation
from ndx_pose import PoseEstimationSeries, PoseEstimation, PoseGroupingSeries, AnimalIdentitySeries


def create_series():
def create_pose_series():
data = np.random.rand(100, 3) # num_frames x (x, y, z)
timestamps = np.linspace(0, 10, num=100) # a timestamp for every frame
confidence = np.random.rand(100) # a confidence value for every frame
Expand Down Expand Up @@ -81,7 +81,7 @@ def setUp(self):

def test_constructor(self):
"""Test that the constructor for PoseEstimation sets values as expected."""
pose_estimation_series = create_series()
pose_estimation_series = create_pose_series()
pe = PoseEstimation(
pose_estimation_series=pose_estimation_series,
description='Estimated positions of front paws using DeepLabCut.',
Expand Down Expand Up @@ -112,3 +112,112 @@ def test_constructor(self):
# self.assertEqual(len(pe.devices), 2)
# self.assertIs(pe.devices['camera1'], self.nwbfile.devices['camera1'])
# self.assertIs(pe.devices['camera2'], self.nwbfile.devices['camera2'])


class TestPoseGroupingSeriesConstructor(TestCase):

def test_constructor(self):
timestamps = np.linspace(0, 10, num=10) # a timestamp for every frame
centroid = np.random.rand(10, 2) # location of animal for every frame
score = np.random.rand(10,) # num_frames

s = PoseGroupingSeries(
name="Centroid",
timestamps=timestamps,
data=score,
location=centroid,
)
self.assertEqual(s.name, "Centroid")
np.testing.assert_array_equal(s.timestamps, timestamps)
np.testing.assert_array_equal(s.data, score)
np.testing.assert_array_equal(s.location, centroid)

bbox = np.random.rand(10, 4)

s = PoseGroupingSeries(
name="Bounding box",
timestamps=timestamps,
data=score,
location=bbox,
)
self.assertEqual(s.name, "Bounding box")
np.testing.assert_array_equal(s.timestamps, timestamps)
np.testing.assert_array_equal(s.data, score)
np.testing.assert_array_equal(s.location, bbox)

s = PoseGroupingSeries(
name="PAF matching score",
timestamps=timestamps,
data=score,
)
self.assertEqual(s.name, "PAF matching score")
np.testing.assert_array_equal(s.timestamps, timestamps)
np.testing.assert_array_equal(s.data, score)


class TestAnimalIdentitySeriesConstructor(TestCase):

def test_constructor(self):
timestamps = np.linspace(0, 10, num=10) # a timestamp for every frame
score = np.random.rand(10,) # num_frames

s = AnimalIdentitySeries(
name="Mouse1",
timestamps=timestamps,
data=score,
)
self.assertEqual(s.name, "Mouse1")
np.testing.assert_array_equal(s.timestamps, timestamps)
np.testing.assert_array_equal(s.data, score)



class TestPoseEstimationMultiAnimalConstructor(TestCase):

def test_constructor(self):
"""Test that the constructor for PoseEstimation sets values as expected."""
pose_estimation_series = create_pose_series()
n_frames = pose_estimation_series[0].data.shape[0]
pose_grouping_series = PoseGroupingSeries(
name="Centroid",
timestamps=pose_estimation_series[0].timestamps,
data=np.random.rand(n_frames),
location=np.random.rand(n_frames, 3),
)
animal_identity_series = AnimalIdentitySeries(
name="Mouse1",
timestamps=pose_estimation_series[0].timestamps,
data=np.random.rand(n_frames),
)

pe = PoseEstimation(
pose_estimation_series=pose_estimation_series,
pose_grouping_series=[pose_grouping_series],
animal_identity_series=[animal_identity_series],
description='Estimated positions of front paws using DeepLabCut.',
original_videos=['camera1.mp4', 'camera2.mp4'],
labeled_videos=['camera1_labeled.mp4', 'camera2_labeled.mp4'],
dimensions=[[640, 480], [1024, 768]],
scorer='DLC_resnet50_openfieldOct30shuffle1_1600',
source_software='DeepLabCut',
source_software_version='2.2b8',
nodes=['front_left_paw', 'front_right_paw'],
edges=[[0, 1]],
# devices=[self.nwbfile.devices['camera1'], self.nwbfile.devices['camera2']],
)

self.assertEqual(pe.name, 'PoseEstimation')
self.assertEqual(len(pe.pose_estimation_series), 2)
self.assertIs(pe.pose_estimation_series['front_left_paw'], pose_estimation_series[0])
self.assertIs(pe.pose_estimation_series['front_right_paw'], pose_estimation_series[1])
self.assertIs(pe.pose_grouping_series["Centroid"], pose_grouping_series)
self.assertIs(pe.animal_identity_series["Mouse1"], animal_identity_series)
self.assertEqual(pe.description, 'Estimated positions of front paws using DeepLabCut.')
self.assertEqual(pe.original_videos, ['camera1.mp4', 'camera2.mp4'])
self.assertEqual(pe.labeled_videos, ['camera1_labeled.mp4', 'camera2_labeled.mp4'])
self.assertEqual(pe.dimensions, [[640, 480], [1024, 768]])
self.assertEqual(pe.scorer, 'DLC_resnet50_openfieldOct30shuffle1_1600')
self.assertEqual(pe.source_software, 'DeepLabCut')
self.assertEqual(pe.source_software_version, '2.2b8')
self.assertEqual(pe.nodes, ['front_left_paw', 'front_right_paw'])
self.assertEqual(pe.edges, [[0, 1]])
69 changes: 64 additions & 5 deletions src/spec/create_extension_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
def main():
# these arguments were auto-generated from your cookiecutter inputs
ns_builder = NWBNamespaceBuilder(
doc='NWB extension to store pose estimation data',
doc='NWB extension to store single or multi-animal pose tracking.',
name='ndx-pose',
version='0.1.0',
author=['Ryan Ly', 'Ben Dichter', 'Alexander Mathis'],
contact=['rly@lbl.gov', 'bdichter@lbl.gov', 'alexander.mathis@epfl.ch'],
version='0.2.0',
author=['Ryan Ly', 'Ben Dichter', 'Alexander Mathis', 'Talmo Pereira'],
contact=['rly@lbl.gov', 'bdichter@lbl.gov', 'alexander.mathis@epfl.ch', 'talmo@princeton.edu'],
)

ns_builder.include_type('SpatialSeries', namespace='core')
Expand Down Expand Up @@ -61,6 +61,55 @@ def main():
],
)

pose_grouping_series = NWBGroupSpec(
neurodata_type_def='PoseGroupingSeries',
neurodata_type_inc='TimeSeries',
doc='Instance-level part grouping timeseries for the individual animal. This contains metadata of the part grouping procedure for multi-animal pose trackers.',
datasets=[
NWBDatasetSpec(
name='name',
doc="Description of the type of localization, e.g., 'Centroid' or 'Bounding box'.",
dtype='text'
),
NWBDatasetSpec(
name='data',
doc='Score of the grouping approach that associated all of the keypoints to the same animal within the frame.',
dtype='float32',
dims=['num_frames'],
shape=[None]
),
NWBDatasetSpec(
name='location',
doc='Animal location for two-stage (top-down) multi-animal models, e.g., centroid or bounding box.',
dtype='float32',
dims=[['num_frames', 'x, y'], ['num_frames', 'x, y, z'], ['num_frames', 'x1, y1, x2, y2'], ['num_frames', 'x1, y1, z1, x2, y2, z2']],
shape=[[None, 2], [None, 3], [None, 4], [None, 6]],
quantity="?"
),
],
)


animal_identity_series = NWBGroupSpec(
neurodata_type_def='AnimalIdentitySeries',
neurodata_type_inc='TimeSeries',
doc='Identity of the animal predicted by a tracking or re-ID algorithm in multi-animal experiments.',
datasets=[
NWBDatasetSpec(
name='data',
doc='Score of the identity assignment approach that associated all of the keypoints to the same animal over frames, e.g., MOT tracking score or ID classification probability.',
dtype='float32',
dims=['num_frames'],
shape=[None],
),
NWBDatasetSpec(
name='name',
doc='Unique animal identifier, track label, or class name used to identify this animal in the experiment.',
dtype='text'
),
],
)

pose_estimation = NWBGroupSpec(
neurodata_type_def='PoseEstimation',
neurodata_type_inc='NWBDataInterface',
Expand All @@ -73,6 +122,16 @@ def main():
doc='Estimated position data for each body part.',
quantity='*',
),
NWBGroupSpec(
neurodata_type_inc='PoseGroupingSeries',
doc='Part grouping metadata for the individual in multi-animal experiments.',
quantity='?',
),
NWBGroupSpec(
neurodata_type_inc='AnimalIdentitySeries',
doc='Predicted identity of the individual in multi-animal experiments.',
quantity='?',
),
],
datasets=[
NWBDatasetSpec(
Expand Down Expand Up @@ -154,7 +213,7 @@ def main():
# ],
)

new_data_types = [pose_estimation_series, pose_estimation]
new_data_types = [pose_estimation_series, pose_estimation, pose_grouping_series, animal_identity_series]

# export the spec to yaml files in the spec folder
output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'spec'))
Expand Down