Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename grav_index to curve_index #68

Merged
merged 3 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sleap_roots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sleap_roots.convhull
import sleap_roots.ellipse
import sleap_roots.networklength
import sleap_roots.lengths
import sleap_roots.points
import sleap_roots.scanline
import sleap_roots.series
Expand All @@ -16,4 +17,4 @@

# Define package version.
# This is read dynamically by setuptools in pyproject.toml to determine the release version.
__version__ = "0.0.4"
__version__ = "0.0.5"
18 changes: 9 additions & 9 deletions sleap_roots/lengths.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ def get_root_lengths_max(pts: np.ndarray) -> np.ndarray:
return max_length


def get_grav_index(
def get_curve_index(
lengths: Union[float, np.ndarray], base_tip_dists: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
"""Calculate the gravitropism index of a root.
"""Calculate the curvature index of a root.

The gravitropism index quantifies the curviness of the root's growth. A higher
gravitropism index indicates a curvier root (less responsive to gravity), while a
The curvature index quantifies the curviness of the root's growth. A higher
curvature index indicates a curvier root (less responsive to gravity), while a
lower index indicates a straighter root (more responsive to gravity). The index is
computed as the difference between the maximum root length and straight-line
distance from the base to the tip of the root, normalized by the root length.
Expand All @@ -129,7 +129,7 @@ def get_grav_index(
root(s). Can be a scalar or a 1D numpy array of shape `(instances,)`.

Returns:
Gravitropism index of the root(s), quantifying its/their curviness. Will be a
Curvature index of the root(s), quantifying its/their curviness. Will be a
scalar if input is scalar, or a 1D numpy array of shape `(instances,)`
otherwise.
"""
Expand All @@ -144,8 +144,8 @@ def get_grav_index(
if lengths.shape != base_tip_dists.shape:
raise ValueError("The shapes of lengths and base_tip_dists must match.")

# Calculate the gravitropism index where possible
grav_index = np.where(
# Calculate the curvature index where possible
curve_index = np.where(
(~np.isnan(lengths))
& (~np.isnan(base_tip_dists))
& (lengths > 0)
Expand All @@ -156,6 +156,6 @@ def get_grav_index(

# Return scalar or array based on the input type
if is_scalar_input:
return grav_index.item()
return curve_index.item()
else:
return grav_index
return curve_index
2 changes: 1 addition & 1 deletion sleap_roots/scanline.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_scanline_last_ind(scanline_intersection_counts: np.ndarray):
Return:
Scalar of count_scanline_interaction index for the last interaction.
"""
# get the first scanline index using scanline_intersection_counts
# get the last scanline index using scanline_intersection_counts
if np.where((scanline_intersection_counts > 0))[0].shape[0] > 0:
scanline_last_ind = np.where((scanline_intersection_counts > 0))[0][-1]
return scanline_last_ind
Expand Down
18 changes: 9 additions & 9 deletions sleap_roots/trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
get_ellipse_b,
get_ellipse_ratio,
)
from sleap_roots.lengths import get_grav_index, get_max_length_pts, get_root_lengths
from sleap_roots.lengths import get_curve_index, get_max_length_pts, get_root_lengths
from sleap_roots.networklength import (
get_bbox,
get_network_distribution,
Expand Down Expand Up @@ -811,13 +811,13 @@ def define_traits(self) -> List[TraitDef]:
description="Scalar of base median ratio.",
),
TraitDef(
name="grav_index",
fn=get_grav_index,
name="curve_index",
fn=get_curve_index,
input_traits=["primary_length", "primary_base_tip_dist"],
scalar=True,
include_in_csv=True,
kwargs={},
description="Scalar of primary root gravity index.",
description="Scalar of primary root curvature index.",
),
TraitDef(
name="base_length_ratio",
Expand Down Expand Up @@ -1189,13 +1189,13 @@ def define_traits(self) -> List[TraitDef]:
"tip(s) of the main root(s).",
),
TraitDef(
name="main_grav_indices",
name="main_curve_indices",
fn=get_base_tip_dist,
input_traits=["main_base_pts", "main_tip_pts"],
scalar=False,
include_in_csv=True,
kwargs={},
description="Gravitropism index for each main root.",
description="Curvature index for each main root.",
),
TraitDef(
name="network_solidity",
Expand Down Expand Up @@ -1291,13 +1291,13 @@ def define_traits(self) -> List[TraitDef]:
"convex hull.",
),
TraitDef(
name="grav_index",
fn=get_grav_index,
name="curve_index",
fn=get_curve_index,
input_traits=["primary_length", "primary_base_tip_dist"],
scalar=True,
include_in_csv=True,
kwargs={},
description="Scalar of primary root gravity index.",
description="Scalar of primary root curvature index.",
),
TraitDef(
name="primary_base_tip_dist",
Expand Down
44 changes: 22 additions & 22 deletions tests/test_lengths.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sleap_roots.lengths import (
get_grav_index,
get_curve_index,
get_root_lengths,
get_root_lengths_max,
get_max_length_pts,
Expand Down Expand Up @@ -146,8 +146,8 @@ def lengths_all_nan():
return np.array([np.nan, np.nan, np.nan])


# tests for get_grav_index function
def test_get_grav_index_canola(canola_h5):
# tests for get_curve_index function
def test_get_curve_index_canola(canola_h5):
series = Series.load(
canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes"
)
Expand All @@ -158,22 +158,22 @@ def test_get_grav_index_canola(canola_h5):
bases = get_bases(max_length_pts)
tips = get_tips(max_length_pts)
base_tip_dist = get_base_tip_dist(bases, tips)
grav_index = get_grav_index(primary_length, base_tip_dist)
np.testing.assert_almost_equal(grav_index, 0.08898137324716636)
curve_index = get_curve_index(primary_length, base_tip_dist)
np.testing.assert_almost_equal(curve_index, 0.08898137324716636)


def test_get_grav_index():
def test_get_curve_index():
# Test 1: Scalar inputs where length > base_tip_dist
# Gravitropism index should be (10 - 8) / 10 = 0.2
assert get_grav_index(10, 8) == 0.2
# Curvature index should be (10 - 8) / 10 = 0.2
assert get_curve_index(10, 8) == 0.2

# Test 2: Scalar inputs where length and base_tip_dist are zero
# Should return NaN as length is zero
assert np.isnan(get_grav_index(0, 0))
assert np.isnan(get_curve_index(0, 0))

# Test 3: Scalar inputs where length < base_tip_dist
# Should return NaN as it's an invalid case
assert np.isnan(get_grav_index(5, 10))
assert np.isnan(get_curve_index(5, 10))

# Test 4: Array inputs covering various cases
# Case 1: length > base_tip_dist, should return 0.2
Expand All @@ -183,35 +183,35 @@ def test_get_grav_index():
lengths = np.array([10, 0, 5, 15])
base_tip_dists = np.array([8, 0, 10, 12])
expected = np.array([0.2, np.nan, np.nan, 0.2])
result = get_grav_index(lengths, base_tip_dists)
result = get_curve_index(lengths, base_tip_dists)
assert np.allclose(result, expected, equal_nan=True)

# Test 5: Mismatched shapes between lengths and base_tip_dists
# Should raise a ValueError
with pytest.raises(ValueError):
get_grav_index(np.array([10, 20]), np.array([8]))
get_curve_index(np.array([10, 20]), np.array([8]))

# Test 6: Array inputs with NaN values
# Case 1: length > base_tip_dist, should return 0.2
# Case 2 and 3: either length or base_tip_dist is NaN, should return NaN
lengths = np.array([10, np.nan, np.nan])
base_tip_dists = np.array([8, 8, np.nan])
expected = np.array([0.2, np.nan, np.nan])
result = get_grav_index(lengths, base_tip_dists)
result = get_curve_index(lengths, base_tip_dists)
assert np.allclose(result, expected, equal_nan=True)


def test_get_grav_index_shape():
def test_get_curve_index_shape():
# Check if scalar inputs result in scalar output
result = get_grav_index(10, 8)
result = get_curve_index(10, 8)
assert isinstance(
result, (int, float)
), f"Expected scalar output, got {type(result)}"

# Check if array inputs result in array output
lengths = np.array([10, 15])
base_tip_dists = np.array([8, 12])
result = get_grav_index(lengths, base_tip_dists)
result = get_curve_index(lengths, base_tip_dists)
assert isinstance(
result, np.ndarray
), f"Expected np.ndarray output, got {type(result)}"
Expand All @@ -225,7 +225,7 @@ def test_get_grav_index_shape():
# Check the shape of output for larger array inputs
lengths = np.array([10, 15, 20, 25])
base_tip_dists = np.array([8, 12, 18, 22])
result = get_grav_index(lengths, base_tip_dists)
result = get_curve_index(lengths, base_tip_dists)
assert (
result.shape == lengths.shape
), f"Output shape {result.shape} does not match input shape {lengths.shape}"
Expand All @@ -235,22 +235,22 @@ def test_nan_values():
lengths = np.array([10, np.nan, 30])
base_tip_dists = np.array([8, 16, np.nan])
np.testing.assert_array_equal(
get_grav_index(lengths, base_tip_dists), np.array([0.2, np.nan, np.nan])
get_curve_index(lengths, base_tip_dists), np.array([0.2, np.nan, np.nan])
)


def test_zero_lengths():
lengths = np.array([0, 20, 30])
base_tip_dists = np.array([0, 16, 24])
np.testing.assert_array_equal(
get_grav_index(lengths, base_tip_dists), np.array([np.nan, 0.2, 0.2])
get_curve_index(lengths, base_tip_dists), np.array([np.nan, 0.2, 0.2])
)


def test_invalid_scalar_values():
assert np.isnan(get_grav_index(np.nan, 8))
assert np.isnan(get_grav_index(10, np.nan))
assert np.isnan(get_grav_index(0, 8))
assert np.isnan(get_curve_index(np.nan, 8))
assert np.isnan(get_curve_index(10, np.nan))
assert np.isnan(get_curve_index(0, 8))


# tests for `get_root_lengths`
Expand Down
12 changes: 6 additions & 6 deletions tests/test_trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ def test_younger_monocot_pipeline(rice_h5, rice_folder):

# Value range assertions for traits
assert (
rice_traits["grav_index"].fillna(0) >= 0
).all(), "grav_index in rice_traits contains negative values"
rice_traits["curve_index"].fillna(0) >= 0
).all(), "curve_index in rice_traits contains negative values"
assert (
all_traits["grav_index_median"] >= 0
).all(), "grav_index in all_traits contains negative values"
all_traits["curve_index_median"] >= 0
).all(), "curve_index in all_traits contains negative values"
assert (
all_traits["main_grav_indices_mean_median"] >= 0
).all(), "main_grav_indices_mean_median in all_traits contains negative values"
all_traits["main_curve_indices_mean_median"] >= 0
).all(), "main_curve_indices_mean_median in all_traits contains negative values"
assert (
(0 <= rice_traits["main_angles_proximal_p95"])
& (rice_traits["main_angles_proximal_p95"] <= 180)
Expand Down