-
Notifications
You must be signed in to change notification settings - Fork 0
/
trait_pipelines.py
2299 lines (2154 loc) · 85.4 KB
/
trait_pipelines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Extract traits in a pipeline based on a trait graph."""
import json
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
import attrs
import networkx as nx
import numpy as np
import pandas as pd
from sleap_roots.angle import (
get_node_ind,
get_root_angle,
get_vector_angles_from_gravity,
)
from sleap_roots.bases import (
get_base_ct_density,
get_base_length,
get_base_length_ratio,
get_base_median_ratio,
get_base_tip_dist,
get_base_xs,
get_base_ys,
get_bases,
get_root_widths,
)
from sleap_roots.convhull import (
get_chull_area,
get_chull_intersection_vectors,
get_chull_intersection_vectors_left,
get_chull_intersection_vectors_right,
get_chull_line_lengths,
get_chull_max_height,
get_chull_max_width,
get_chull_perimeter,
get_convhull,
get_chull_areas_via_intersection,
get_chull_area_via_intersection_below,
get_chull_area_via_intersection_above,
)
from sleap_roots.ellipse import (
fit_ellipse,
get_ellipse_a,
get_ellipse_b,
get_ellipse_ratio,
)
from sleap_roots.lengths import get_curve_index, get_max_length_pts, get_root_lengths
from sleap_roots.networklength import (
get_bbox,
get_network_distribution,
get_network_distribution_ratio,
get_network_length,
get_network_solidity,
get_network_width_depth_ratio,
)
from sleap_roots.points import (
associate_lateral_to_primary,
filter_plants_with_unexpected_ct,
filter_roots_with_nans,
get_all_pts_array,
get_count,
get_filtered_lateral_pts,
get_filtered_primary_pts,
get_nodes,
join_pts,
)
from sleap_roots.scanline import (
count_scanline_intersections,
get_scanline_first_ind,
get_scanline_last_ind,
)
from sleap_roots.series import Series
from sleap_roots.summary import SUMMARY_SUFFIXES, get_summary
from sleap_roots.tips import get_tip_xs, get_tip_ys, get_tips
warnings.filterwarnings(
"ignore",
message="invalid value encountered in intersection",
category=RuntimeWarning,
module="shapely",
)
warnings.filterwarnings(
"ignore", message="All-NaN slice encountered", category=RuntimeWarning
)
warnings.filterwarnings(
"ignore", message="All-NaN axis encountered", category=RuntimeWarning
)
warnings.filterwarnings(
"ignore",
message="Degrees of freedom <= 0 for slice.",
category=RuntimeWarning,
module="numpy",
)
warnings.filterwarnings(
"ignore", message="Mean of empty slice", category=RuntimeWarning
)
warnings.filterwarnings(
"ignore",
message="invalid value encountered in sqrt",
category=RuntimeWarning,
module="skimage",
)
warnings.filterwarnings(
"ignore",
message="invalid value encountered in double_scalars",
category=RuntimeWarning,
)
warnings.filterwarnings(
"ignore",
message="invalid value encountered in scalar divide",
category=RuntimeWarning,
module="ellipse",
)
class NumpyArrayEncoder(json.JSONEncoder):
"""Custom encoder for NumPy array types."""
def default(self, obj):
"""Serialize NumPy arrays to lists.
Args:
obj: The object to serialize.
Returns:
A list representation of the NumPy array.
"""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.int64):
return int(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)
@attrs.define
class TraitDef:
"""Definition of how to compute a trait.
Attributes:
name: Unique identifier for the trait.
fn: Function used to compute the trait's value.
input_traits: List of trait names that should be computed before the current
trait and are expected as input positional arguments to `fn`.
scalar: Indicates if the trait is scalar (has a dimension of 0 per frame). If
`True`, the trait is also listed in `SCALAR_TRAITS`.
include_in_csv: `True `indicates the trait should be included in downstream CSV
files.
kwargs: Additional keyword arguments to be passed to the `fn` function. These
arguments are not reused from previously computed traits.
description: String describing the trait for documentation purposes.
Notes:
The `fn` specified will be called with a pattern like:
```
trait_def = TraitDef(
name="my_trait",
fn=compute_my_trait,
input_traits=["input_trait_1", "input_trait_2"],
scalar=True,
include_in_csv=True,
kwargs={"kwarg1": True}
)
traits[trait_def.name] = trait_def.fn(
*[traits[input_trait] for input_trait in trait_def.input_traits],
**trait_def.kwargs
)
```
For this example, the last line is equivalent to:
```
traits["my_trait"] = trait_def.fn(
traits["input_trait_1"], traits["input_trait_2"],
kwarg1=True
)
```
"""
name: str
fn: Callable
input_traits: List[str]
scalar: bool
include_in_csv: bool
kwargs: Dict[str, Any] = attrs.field(factory=dict)
description: Optional[str] = None
@attrs.define
class Pipeline:
"""Pipeline for computing traits.
Attributes:
traits: List of `TraitDef` objects.
trait_map: Dictionary mapping trait names to their definitions.
trait_computation_order: List of trait names in the order they should be
computed.
"""
traits: List[TraitDef] = attrs.field(init=False)
trait_map: Dict[str, TraitDef] = attrs.field(init=False)
trait_computation_order: List[str] = attrs.field(init=False)
def __attrs_post_init__(self):
"""Build pipeline objects from traits list."""
# Build list of trait definitions.
self.traits = self.define_traits()
# Check that trait names are unique.
trait_names = [trait.name for trait in self.traits]
if len(trait_names) != len(set(trait_names)):
raise ValueError("Trait names must be unique.")
# Map trait names to their definitions.
self.trait_map = {trait_def.name: trait_def for trait_def in self.traits}
# Determine computation order by topologically sorting the nodes.
self.trait_computation_order = self.get_computation_order()
def define_traits(self) -> List[TraitDef]:
"""Return list of `TraitDef` objects."""
raise NotImplementedError
def get_computation_order(self) -> List[str]:
"""Determine computation order by topologically sorting the nodes.
Returns:
A list of trait names in the order they should be computed.
"""
# Infer edges from trait map.
edges = []
for trait_def in self.traits:
for input_trait in trait_def.input_traits:
edges.append((input_trait, trait_def.name))
# Build networkx graph from inferred edges.
G = nx.DiGraph()
G.add_edges_from(edges)
# Determine computation order by topologically sorting the nodes.
trait_computation_order = list(nx.topological_sort(G))
return trait_computation_order
@property
def summary_traits(self) -> List[str]:
"""List of traits to include in the summary CSV."""
return [
trait.name
for trait in self.traits
if trait.include_in_csv and not trait.scalar
]
@property
def csv_traits(self) -> List[str]:
"""List of frame-level traits to include in the CSV."""
csv_traits = []
for trait in self.traits:
if trait.include_in_csv:
if trait.scalar:
csv_traits.append(trait.name)
else:
csv_traits.extend(
[f"{trait.name}_{suffix}" for suffix in SUMMARY_SUFFIXES]
)
return csv_traits
@property
def csv_traits_multiple_plants(self) -> List[str]:
"""List of frame-level traits to include in the CSV for multiple plants."""
csv_traits = []
for trait in self.traits:
if trait.include_in_csv:
csv_traits.append(trait.name)
return csv_traits
def compute_frame_traits(self, traits: Dict[str, Any]) -> Dict[str, Any]:
"""Compute traits based on the pipeline.
Args:
traits: Dictionary of traits where keys are trait names and values are
the trait values.
Returns:
A dictionary of computed traits.
"""
# Initialize traits container with initial data.
traits = traits.copy()
# Compute traits!
for trait_name in self.trait_computation_order:
if trait_name in traits:
# Skip traits already computed.
continue
# Get trait definition.
trait_def = self.trait_map[trait_name]
# Compute trait based on trait definition.
traits[trait_name] = trait_def.fn(
*[traits[input_trait] for input_trait in trait_def.input_traits],
**trait_def.kwargs,
)
return traits
def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]:
"""Return initial traits for a plant frame.
Args:
plant: The plant `Series` object.
frame_idx: The index of the current frame.
Returns:
A dictionary of initial traits.
This is defined on a per-pipeline basis as different plant species will have
different initial points to be used as starting traits.
Most commonly, this will be the primary and lateral root points for the
current frame.
"""
raise NotImplementedError
def compute_plant_traits(
self,
plant: Series,
write_csv: bool = False,
output_dir: str = ".",
csv_suffix: str = ".traits.csv",
return_non_scalar: bool = False,
) -> pd.DataFrame:
"""Compute traits for a plant.
Args:
plant: The plant image series as a `Series` object.
write_csv: A boolean value. If True, it writes per plant detailed
CSVs with traits for every instance on every frame.
output_dir: The directory to write the CSV files to.
csv_suffix: If `write_csv` is `True`, a CSV file will be saved with the same
name as the plant's `{plant.series_name}{csv_suffix}`.
return_non_scalar: If `True`, return all non-scalar traits as well as the
summarized traits.
Returns:
The computed traits as a pandas DataFrame.
"""
traits = []
for frame in range(len(plant)):
# Get initial traits for the frame.
initial_traits = self.get_initial_frame_traits(plant, frame)
# Compute traits via the frame-level pipeline.
frame_traits = self.compute_frame_traits(initial_traits)
# Compute trait summaries.
for trait_name in self.summary_traits:
trait_summary = get_summary(
frame_traits[trait_name], prefix=f"{trait_name}_"
)
frame_traits.update(trait_summary)
# Add metadata.
frame_traits["plant_name"] = plant.series_name
frame_traits["frame_idx"] = frame
traits.append(frame_traits)
traits = pd.DataFrame(traits)
# Move metadata columns to the front.
plant_name = traits.pop("plant_name")
frame_idx = traits.pop("frame_idx")
traits = pd.concat([plant_name, frame_idx, traits], axis=1)
if write_csv:
csv_name = Path(output_dir) / f"{plant.series_name}{csv_suffix}"
traits[["plant_name", "frame_idx"] + self.csv_traits].to_csv(
csv_name, index=False
)
if return_non_scalar:
return traits
else:
return traits[["plant_name", "frame_idx"] + self.csv_traits]
def compute_multiple_dicots_traits(
self,
series: Series,
write_json: bool = False,
json_suffix: str = ".all_frames_traits.json",
write_csv: bool = False,
csv_suffix: str = ".all_frames_summary.csv",
):
"""Computes plant traits for pipelines with multiple plants over all frames in a series.
Args:
series: The Series object containing the primary and lateral root points.
write_json: Whether to write the aggregated traits to a JSON file. Default is False.
json_suffix: The suffix to append to the JSON file name. Default is ".all_frames_traits.json".
write_csv: Whether to write the summary statistics to a CSV file. Default is False.
csv_suffix: The suffix to append to the CSV file name. Default is ".all_frames_summary.csv".
Returns:
A dictionary containing the series name, group, qc_fail, aggregated traits, and summary statistics.
"""
# Initialize the return structure with the series name and group
result = {
"series": str(series.series_name),
"group": str(series.group),
"qc_fail": series.qc_fail,
"traits": {},
"summary_stats": {},
}
# Check if the series has frames to process
if len(series) == 0:
print(f"Series '{series.series_name}' contains no frames to process.")
# Return early with the initialized structure
return result
# Initialize a separate dictionary to hold the aggregated traits across all frames
aggregated_traits = {}
# Iterate over frames in series
for frame in range(len(series)):
# Get initial points and number of plants per frame
initial_frame_traits = self.get_initial_frame_traits(series, frame)
# Compute initial associations and perform filter operations
frame_traits = self.compute_frame_traits(initial_frame_traits)
# Instantiate DicotPipeline
dicot_pipeline = DicotPipeline()
# Extract the plant associations for this frame
associations = frame_traits["plant_associations_dict"]
for primary_idx, assoc in associations.items():
primary_pts = assoc["primary_points"]
lateral_pts = assoc["lateral_points"]
# Get the initial frame traits for this plant using the primary and lateral points
initial_frame_traits = {
"primary_pts": primary_pts,
"lateral_pts": lateral_pts,
}
# Use the dicot pipeline to compute the plant traits on this frame
plant_traits = dicot_pipeline.compute_frame_traits(initial_frame_traits)
# For each plant's traits in the frame
for trait_name, trait_value in plant_traits.items():
# Not all traits are added to the aggregated traits dictionary
if trait_name in dicot_pipeline.csv_traits_multiple_plants:
if trait_name not in aggregated_traits:
# Initialize the trait array if it's the first frame
aggregated_traits[trait_name] = [np.atleast_1d(trait_value)]
else:
# Append new trait values for subsequent frames
aggregated_traits[trait_name].append(
np.atleast_1d(trait_value)
)
# After processing, update the result dictionary with computed traits
for trait, arrays in aggregated_traits.items():
aggregated_traits[trait] = np.concatenate(arrays, axis=0)
result["traits"] = aggregated_traits
# Write to JSON if requested
if write_json:
json_name = f"{series.series_name}{json_suffix}"
try:
with open(json_name, "w") as f:
json.dump(
result, f, cls=NumpyArrayEncoder, ensure_ascii=False, indent=4
)
print(f"Aggregated traits saved to {json_name}")
except IOError as e:
print(f"Error writing JSON file '{json_name}': {e}")
# Compute summary statistics and update result
summary_stats = {}
for trait_name, trait_values in aggregated_traits.items():
trait_stats = get_summary(trait_values, prefix=f"{trait_name}_")
summary_stats.update(trait_stats)
result["summary_stats"] = summary_stats
# Optionally write summary stats to CSV
if write_csv:
csv_name = f"{series.series_name}{csv_suffix}"
try:
summary_df = pd.DataFrame([summary_stats])
summary_df.insert(0, "series", series.series_name)
summary_df.to_csv(csv_name, index=False)
print(f"Summary statistics saved to {csv_name}")
except IOError as e:
print(f"Failed to write CSV file '{csv_name}': {e}")
# Return the final result structure
return result
def compute_multiple_dicots_traits_for_groups(
self,
series_list: List[Series],
output_dir: str = "grouped_traits",
write_json: bool = False,
json_suffix: str = ".grouped_traits.json",
write_csv: bool = False,
csv_suffix: str = ".grouped_summary.csv",
) -> List[
Dict[str, Union[str, List[str], Dict[str, Union[List[float], np.ndarray]]]]
]:
"""Aggregates plant traits over groups of samples.
Args:
series_list: A list of Series objects containing the primary and lateral root points for each sample.
output_dir: The directory to write the JSON and CSV files to. Default is "grouped_traits".
write_json: Whether to write the aggregated traits to a JSON file. Default is False.
json_suffix: The suffix to append to the JSON file name. Default is ".grouped_traits.json".
write_csv: Whether to write the summary statistics to a CSV file. Default is False.
csv_suffix: The suffix to append to the CSV file name. Default is ".grouped_summary.csv".
Returns:
A list of dictionaries containing the aggregated traits and summary statistics for each group.
"""
# Input Validation
if not isinstance(series_list, list) or not all(
isinstance(series, Series) for series in series_list
):
raise ValueError("series_list must be a list of Series objects.")
# Group series by their group property
series_groups = {}
for series in series_list:
# Exclude series with qc_fail flag set to 1
if int(series.qc_fail) == 1:
print(f"Skipping series '{series.series_name}' due to qc_fail flag.")
continue
# Get the group name from the series object
group_name = str(series.group)
if group_name not in series_groups:
series_groups[group_name] = {"names": [], "series": []}
# Store series names and objects in the dictionary
series_groups[group_name]["names"].append(str(series.series_name))
series_groups[group_name]["series"].append(series) # Store Series objects
# Initialize the list to hold the results for each group
grouped_results = []
# Iterate over each group of series
for group_name, group_data in series_groups.items():
# Initialize the return structure with the group name
group_result = {
"group": group_name,
"series": group_data["names"], # Use series names
"traits": {},
}
# Aggregate traits over all samples in the group
aggregated_traits = {}
# Iterate over each series in the group
for series in group_data["series"]:
print(f"Processing series '{series.series_name}'")
# Get the trait results for each series in the group
result = self.compute_multiple_dicots_traits(
series=series, write_json=False, write_csv=False
)
# Aggregate the series traits into the group traits
for trait, values in result["traits"].items():
# Ensure values are at least 1D
values = np.atleast_1d(values)
if trait not in aggregated_traits:
aggregated_traits[trait] = values
else:
# Concatenate the current values with the existing array
aggregated_traits[trait] = np.concatenate(
(aggregated_traits[trait], values)
)
group_result["traits"] = aggregated_traits
print(f"Finished processing group '{group_name}'")
# Write to JSON if requested
if write_json:
# Make the output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Construct the JSON file name
json_name = f"{group_name}{json_suffix}"
# Join the output directory with the JSON file name
json_path = Path(output_dir) / json_name
try:
with open(json_path, "w") as f:
json.dump(
group_result,
f,
cls=NumpyArrayEncoder,
ensure_ascii=False,
indent=4,
)
print(
f"Aggregated traits for group {group_name} saved to {str(json_path)}"
)
except IOError as e:
print(f"Error writing JSON file '{str(json_path)}': {e}")
# Compute summary statistics
summary_stats = {}
for trait, trait_values in aggregated_traits.items():
trait_stats = get_summary(trait_values, prefix=f"{trait}_")
summary_stats.update(trait_stats)
group_result["summary_stats"] = summary_stats
# Write summary stats to CSV if requested
if write_csv:
# Make the output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Construct the CSV file name
csv_name = f"{group_name}{csv_suffix}"
# Join the output directory with the CSV file name
csv_path = Path(output_dir) / csv_name
try:
summary_df = pd.DataFrame([summary_stats])
summary_df.insert(0, "genotype", group_name)
summary_df.to_csv(csv_path, index=False)
print(
f"Summary statistics for group {group_name} saved to {str(csv_path)}"
)
except IOError as e:
print(f"Failed to write CSV file '{str(csv_path)}': {e}")
# Append the group result to the list of results
grouped_results.append(group_result)
return grouped_results
def compute_batch_traits(
self,
plants: List[Series],
write_csv: bool = False,
csv_path: str = "traits.csv",
) -> pd.DataFrame:
"""Compute traits for a batch of plants.
Args:
plants: List of `Series` objects.
write_csv: If `True`, write the computed traits to a CSV file.
csv_path: Path to write the CSV file to.
Returns:
A pandas DataFrame of computed traits summarized over all frames of each
plant. The resulting dataframe will have a row for each plant and a column
for each plant-level summarized trait.
Summarized traits are prefixed with the trait name and an underscore,
followed by the summary statistic.
"""
all_traits = []
for plant in plants:
print(f"Processing series: {plant.series_name}")
# Compute frame level traits for the plant.
plant_traits = self.compute_plant_traits(plant)
# Summarize frame level traits.
plant_summary = {"plant_name": plant.series_name}
for trait_name in self.csv_traits:
trait_summary = get_summary(
plant_traits[trait_name], prefix=f"{trait_name}_"
)
plant_summary.update(trait_summary)
all_traits.append(plant_summary)
# Build dataframe from list of frame-level summaries.
all_traits = pd.DataFrame(all_traits)
if write_csv:
all_traits.to_csv(csv_path, index=False)
print(f"Batch traits saved to {csv_path}")
return all_traits
def compute_batch_multiple_dicots_traits(
self,
all_series: List[Series],
write_csv: bool = False,
csv_path: str = "traits.csv",
) -> pd.DataFrame:
"""Compute traits for a batch of series with multiple dicots.
Args:
all_series: List of `Series` objects.
write_csv: If `True`, write the computed traits to a CSV file.
csv_path: Path to write the CSV file to.
Returns:
A pandas DataFrame of computed traits summarized over all frames of each
series. The resulting dataframe will have a row for each series and a column
for each series-level summarized trait.
Summarized traits are prefixed with the trait name and an underscore,
followed by the summary statistic.
"""
all_series_summaries = []
for series in all_series:
print(f"Processing series '{series.series_name}'")
# Use the updated function and access its return value
series_result = self.compute_multiple_dicots_traits(
series, write_json=False, write_csv=False
)
# Prepare the series-level summary.
series_summary = {
"series_name": series_result["series"],
**series_result["summary_stats"], # Unpack summary_stats
}
all_series_summaries.append(series_summary)
# Convert list of dictionaries to a DataFrame
all_series_summaries_df = pd.DataFrame(all_series_summaries)
# Write to CSV if requested
if write_csv:
all_series_summaries_df.to_csv(csv_path, index=False)
print(f"Computed traits for all series saved to {csv_path}")
return all_series_summaries_df
def compute_batch_multiple_dicots_traits_for_groups(
self,
all_series: List[Series],
output_dir: str = "grouped_traits",
write_json: bool = False,
write_csv: bool = False,
csv_path: str = "group_summarized_traits.csv",
) -> pd.DataFrame:
"""Compute traits for a batch of grouped series with multiple dicots.
Args:
all_series: List of `Series` objects.
output_dir: The directory to write the JSON and CSV files to. Default is "grouped_traits".
write_json: If `True`, write each set of group traits to a JSON file.
write_csv: If `True`, write the computed traits to a CSV file.
csv_path: Path to write the CSV file to.
Returns:
A pandas DataFrame of computed traits summarized over all frames of each
group. The resulting dataframe will have a row for each series and a column
for each series-level summarized trait.
Summarized traits are prefixed with the trait name and an underscore,
followed by the summary statistic.
"""
# Check if the input list is empty
if not all_series:
raise ValueError("The input list 'all_series' is empty.")
try:
# Compute traits for each group of series
grouped_results = self.compute_multiple_dicots_traits_for_groups(
all_series,
output_dir=output_dir,
write_json=write_json,
write_csv=False,
)
except Exception as e:
raise RuntimeError(f"Error computing traits for groups: {e}")
# Prepare the list of dictionaries for the DataFrame
all_group_summaries = []
for group_result in grouped_results:
# Validate the expected key exists in the result
if "summary_stats" not in group_result:
raise KeyError(
"Expected key 'summary_stats' not found in group result."
)
# Assuming 'group' key exists in group_result and it indicates the genotype
genotype = group_result.get(
"group", "Unknown Genotype"
) # Default to "Unknown Genotype" if not found
# Start with a dictionary containing the genotype
group_summary = {"genotype": genotype}
# Add each trait statistic from the summary_stats dictionary to the group_summary
# This assumes summary_stats is a dictionary where keys are trait names and values are the statistics
for trait, statistic in group_result["summary_stats"].items():
group_summary[trait] = statistic
all_group_summaries.append(group_summary)
# Create a DataFrame from the list of dictionaries
all_group_summaries_df = pd.DataFrame(all_group_summaries)
# Write to CSV if requested
if write_csv:
try:
all_group_summaries_df.to_csv(csv_path, index=False)
print(f"Computed traits for all groups saved to {csv_path}")
except Exception as e:
raise IOError(f"Failed to write computed traits to CSV: {e}")
return all_group_summaries_df
@attrs.define
class DicotPipeline(Pipeline):
"""Pipeline for computing traits for dicot plants (primary + lateral roots).
Attributes:
img_height: Image height.
root_width_tolerance: Difference in projection norm between right and left side.
n_scanlines: Number of scan lines, np.nan for no interaction.
network_fraction: Length found in the lower fraction value of the network.
"""
img_height: int = 1080
root_width_tolerance: float = 0.02
n_scanlines: int = 50
network_fraction: float = 2 / 3
def define_traits(self) -> List[TraitDef]:
"""Define the trait computation pipeline for dicot plants."""
trait_definitions = [
TraitDef(
name="primary_max_length_pts",
fn=get_max_length_pts,
input_traits=["primary_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Points of the primary root with maximum length.",
),
TraitDef(
name="pts_all_array",
fn=get_all_pts_array,
input_traits=["primary_max_length_pts", "lateral_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Landmark points within a given frame as a flat array"
"of coordinates.",
),
TraitDef(
name="pts_list",
fn=join_pts,
input_traits=["primary_max_length_pts", "lateral_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="A list of instance arrays, each having shape `(nodes, 2)`.",
),
TraitDef(
name="root_widths",
fn=get_root_widths,
input_traits=["primary_max_length_pts", "lateral_pts"],
scalar=False,
include_in_csv=True,
kwargs={
"tolerance": self.root_width_tolerance,
"return_inds": False,
},
description="Estimate root width using bases of lateral roots.",
),
TraitDef(
name="lateral_count",
fn=get_count,
input_traits=["lateral_pts"],
scalar=True,
include_in_csv=True,
kwargs={},
description="Get the number of lateral roots.",
),
TraitDef(
name="lateral_proximal_node_inds",
fn=get_node_ind,
input_traits=["lateral_pts"],
scalar=False,
include_in_csv=False,
kwargs={"proximal": True},
description="Get the indices of the proximal nodes of lateral roots.",
),
TraitDef(
name="lateral_distal_node_inds",
fn=get_node_ind,
input_traits=["lateral_pts"],
scalar=False,
include_in_csv=False,
kwargs={"proximal": False},
description="Get the indices of the distal nodes of lateral roots.",
),
TraitDef(
name="lateral_lengths",
fn=get_root_lengths,
input_traits=["lateral_pts"],
scalar=False,
include_in_csv=True,
kwargs={},
description="Array of lateral root lengths of shape `(instances,)`.",
),
TraitDef(
name="lateral_base_pts",
fn=get_bases,
input_traits=["lateral_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Array of lateral bases `(instances, (x, y))`.",
),
TraitDef(
name="lateral_tip_pts",
fn=get_tips,
input_traits=["lateral_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Array of lateral tips `(instances, (x, y))`.",
),
TraitDef(
name="scanline_intersection_counts",
fn=count_scanline_intersections,
input_traits=["pts_list"],
scalar=False,
include_in_csv=True,
kwargs={
"height": self.img_height,
"n_line": self.n_scanlines,
},
description="Array of intersections of each scanline `(n_scanlines,)`.",
),
TraitDef(
name="lateral_angles_distal",
fn=get_root_angle,
input_traits=["lateral_pts", "lateral_distal_node_inds"],
scalar=False,
include_in_csv=True,
kwargs={"proximal": False, "base_ind": 0},
description="Array of lateral distal angles in degrees `(instances,)`.",
),
TraitDef(
name="lateral_angles_proximal",
fn=get_root_angle,
input_traits=["lateral_pts", "lateral_proximal_node_inds"],
scalar=False,
include_in_csv=True,
kwargs={"proximal": True, "base_ind": 0},
description="Array of lateral proximal angles in degrees "
"`(instances,)`.",
),
TraitDef(
name="network_solidity",
fn=get_network_solidity,
input_traits=["network_length", "chull_area"],
scalar=True,
include_in_csv=True,
kwargs={},
description="Scalar of the total network length divided by the network"
"convex area.",
),
TraitDef(
name="ellipse",
fn=fit_ellipse,
input_traits=["pts_all_array"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Tuple of (a, b, ratio) containing the semi-major axis "
"length, semi-minor axis length, and the ratio of the major to minor "
"lengths.",
),
TraitDef(
name="bounding_box",
fn=get_bbox,
input_traits=["pts_all_array"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Tuple of four parameters in bounding box.",
),
TraitDef(
name="convex_hull",
fn=get_convhull,
input_traits=["pts_all_array"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Convex hull of the points.",
),
TraitDef(
name="primary_proximal_node_ind",
fn=get_node_ind,
input_traits=["primary_max_length_pts"],
scalar=True,
include_in_csv=False,
kwargs={"proximal": True},
description="Get the indices of the proximal nodes of primary roots.",
),
TraitDef(
name="primary_angle_proximal",
fn=get_root_angle,
input_traits=["primary_max_length_pts", "primary_proximal_node_ind"],
scalar=True,
include_in_csv=True,