Skip to content

Commit

Permalink
Correct crown-curve-indices definition in trait pipeline (#83)
Browse files Browse the repository at this point in the history
* add mermaid diagram notebook

* display mermaid diagrams and save them as pngs

* add tests for json serialization

* check attributes of labeled frame instead of labeledframe which does not assert as equal

* updated documentation for loading series from slp paths

* update DicotPipeline notebook with `Series` changes

* modify curve index function so that warnings are not output when lengths are not valid

* fix definition of `crown_curve_indices`

* fix `crown_curve_indices` in `OlderMonocotPipeline`

* assert curve indices are between 0 and 1 in tests
  • Loading branch information
eberrigan authored Aug 26, 2024
1 parent fe7cdae commit afc4fa4
Show file tree
Hide file tree
Showing 8 changed files with 577 additions and 49 deletions.
88 changes: 52 additions & 36 deletions notebooks/DicotPipeline.ipynb

Large diffs are not rendered by default.

451 changes: 451 additions & 0 deletions notebooks/Pipeline_mermaid_diagrams.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sleap_roots/lengths.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_curve_index(
& (~np.isnan(base_tip_dists))
& (lengths > 0)
& (lengths >= base_tip_dists),
(lengths - base_tip_dists) / lengths,
(lengths - base_tip_dists) / np.where(lengths != 0, lengths, np.nan),
np.nan,
)

Expand Down
9 changes: 4 additions & 5 deletions sleap_roots/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,10 @@ def load_series_from_slps(
) -> List[Series]:
"""Load a list of Series from a list of .slp paths.
To load the `Series`, the files must be named with the following convention:
To load the `Series`, the files must be named with the following convention.
The `slp_paths` are expeted to have the `series_name` in the filename and "primary",
"lateral", or "crown" in the filename to differentiate the predictions.
h5_path: '/path/to/scan/series_name.h5'
primary_path: '/path/to/scan/series_name.model{model_id}.rootprimary.slp'
lateral_path: '/path/to/scan/series_name.model{model_id}.rootlateral.slp'
crown_path: '/path/to/scan/series_name.model{model_id}.rootcrown.slp'
Note that everything is expected to be in the same folder.
Our pipeline outputs prediction files with this format:
Expand All @@ -500,7 +499,7 @@ def load_series_from_slps(
if h5s:
# Get directory of the h5s
h5_dir = Path(slp_paths[0]).parent
# Generate the path to the .h5 file
# Create path to the .h5 file
h5_path = h5_dir / f"{series_name}.h5"
else:
h5_path = None
Expand Down
8 changes: 4 additions & 4 deletions sleap_roots/trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,8 +1597,8 @@ def define_traits(self) -> List[TraitDef]:
),
TraitDef(
name="crown_curve_indices",
fn=get_base_tip_dist,
input_traits=["crown_base_pts", "crown_tip_pts"],
fn=get_curve_index,
input_traits=["crown_lengths", "crown_base_tip_dists"],
scalar=False,
include_in_csv=True,
kwargs={},
Expand Down Expand Up @@ -1974,8 +1974,8 @@ def define_traits(self) -> List[TraitDef]:
),
TraitDef(
name="crown_curve_indices",
fn=get_base_tip_dist,
input_traits=["crown_base_pts", "crown_tip_pts"],
fn=get_curve_index,
input_traits=["crown_lengths", "crown_base_tip_dists"],
scalar=False,
include_in_csv=True,
kwargs={},
Expand Down
1 change: 0 additions & 1 deletion tests/test_lengths.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def test_invalid_scalar_values():
assert np.isnan(get_curve_index(0, 8))


# tests for `get_root_lengths`
def test_curve_index_float():
assert get_curve_index(10.0, 5.0) == 0.5

Expand Down
14 changes: 12 additions & 2 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,12 @@ def test_get_frame_rice_10do(
# Get the crown labeled frame
crown_lf = frames.get("crown")

assert crown_lf == expected_labeled_frame
# Compare the attributes of the labeled frames
assert crown_lf.frame_idx == expected_labeled_frame.frame_idx
assert crown_lf.instances == expected_labeled_frame.instances
assert crown_lf.video.filename == expected_labeled_frame.video.filename
assert crown_lf.video.shape == expected_labeled_frame.video.shape
assert crown_lf.video.backend == expected_labeled_frame.video.backend
assert series.series_name == "0K9E8BI"


Expand All @@ -302,7 +307,12 @@ def test_get_frame_rice_10do_no_video(
# Get the crown labeled frame
crown_lf = frames.get("crown")

assert crown_lf == expected_labeled_frame
# Compare the attributes of the labeled frames
assert crown_lf.frame_idx == expected_labeled_frame.frame_idx
assert crown_lf.instances == expected_labeled_frame.instances
assert crown_lf.video.filename == expected_labeled_frame.video.filename
assert crown_lf.video.shape == expected_labeled_frame.video.shape
assert crown_lf.video.backend == expected_labeled_frame.video.backend
assert series.series_name == "0K9E8BI"


Expand Down
53 changes: 53 additions & 0 deletions tests/test_trait_pipelines.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import numpy as np
import pandas as pd
import json
import pytest

from sleap_roots.trait_pipelines import (
DicotPipeline,
YoungerMonocotPipeline,
OlderMonocotPipeline,
MultipleDicotPipeline,
NumpyArrayEncoder,
)
from sleap_roots.series import (
Series,
Expand All @@ -15,6 +19,47 @@
)


def test_numpy_array_serialization():
array = np.array([1, 2, 3])
expected = [1, 2, 3]
json_str = json.dumps(array, cls=NumpyArrayEncoder)
assert json.loads(json_str) == expected


def test_numpy_int64_serialization():
int64_value = np.int64(42)
expected = 42
json_str = json.dumps(int64_value, cls=NumpyArrayEncoder)
assert json.loads(json_str) == expected


def test_unsupported_type_serialization():
class UnsupportedType:
pass

with pytest.raises(TypeError):
json.dumps(UnsupportedType(), cls=NumpyArrayEncoder)


def test_mixed_data_serialization():
data = {
"array": np.array([1, 2, 3]),
"int64": np.int64(42),
"regular_int": 99,
"list": [4, 5, 6],
"dict": {"key": "value"},
}
expected = {
"array": [1, 2, 3],
"int64": 42,
"regular_int": 99,
"list": [4, 5, 6],
"dict": {"key": "value"},
}
json_str = json.dumps(data, cls=NumpyArrayEncoder)
assert json.loads(json_str) == expected


def test_dicot_pipeline(
canola_h5,
soy_h5,
Expand Down Expand Up @@ -107,12 +152,17 @@ def test_younger_monocot_pipeline(rice_pipeline_output_folder):
assert (
rice_traits["curve_index"].fillna(0) >= 0
).all(), "curve_index in rice_traits contains negative values"
assert rice_traits["curve_index"].fillna(0).max() <= 1, "curve_index in rice_traits contains values greater than 1"
assert (
all_traits["curve_index_median"] >= 0
).all(), "curve_index in all_traits contains negative values"
assert all_traits["curve_index_median"].max() <= 1, "curve_index in all_traits contains values greater than 1"
assert (
all_traits["crown_curve_indices_mean_median"] >= 0
).all(), "crown_curve_indices_mean_median in all_traits contains negative values"
assert (
all_traits["crown_curve_indices_mean_median"] <= 1
).all(), "crown_curve_indices_mean_median in all_traits contains values greater than 1"
assert (
(0 <= rice_traits["crown_angles_proximal_p95"])
& (rice_traits["crown_angles_proximal_p95"] <= 180)
Expand Down Expand Up @@ -169,6 +219,9 @@ def test_older_monocot_pipeline(rice_10do_pipeline_output_folder):
assert (
all_traits["crown_curve_indices_mean_median"] >= 0
).all(), "crown_curve_indices_mean_median in all_traits contains negative values"
assert (
all_traits["crown_curve_indices_mean_median"] <= 1
).all(), "crown_curve_indices_mean_median in all_traits contains values greater than 1"
assert (
(0 <= rice_traits["crown_angles_proximal_p95"])
& (rice_traits["crown_angles_proximal_p95"] <= 180)
Expand Down

0 comments on commit afc4fa4

Please sign in to comment.