Skip to content

Commit e3235ba

Browse files
authored
Merge pull request #50 from nipreps/add_gm_threshold
WIP: Add gray matter threshold option to reference mask workflow
2 parents 643e14b + b22e544 commit e3235ba

File tree

9 files changed

+113
-6
lines changed

9 files changed

+113
-6
lines changed

petprep/data/reference_mask/config.json

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
"refmask_indices": [47, 8],
66
"erode_by_voxels": 1,
77
"exclude_indices": [7, 46, 172, 1007, 2007, 1009, 2009, 1011, 2011, 1013, 2013, 16],
8-
"dilate_by_voxels": 3
8+
"dilate_by_voxels": 3,
9+
"gm_prob_threshold": 0.8
10+
},
11+
"thalamus":
12+
{
13+
"refmask_indices": [10, 49],
14+
"erode_by_voxels": 1,
15+
"exclude_indices": [14, 24],
16+
"dilate_by_voxels": 3,
17+
"gm_prob_threshold": 0.2
918
}
1019
},
1120
"wm": {
@@ -16,7 +25,8 @@
1625
"exclude_indices": [10, 49, 12, 51, 13, 52, 11, 50, 4, 43, 31, 63],
1726
"dilate_by_voxels": 3,
1827
"smooth_fwhm_mm": 10,
19-
"target_volume_ml": 40
28+
"target_volume_ml": 40,
29+
"gm_prob_threshold": 0
2030
}
2131
}
2232
}

petprep/interfaces/reference_mask.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
class ExtractRefRegionInputSpec(BaseInterfaceInputSpec):
1616
seg_file = File(exists=True, mandatory=True, desc='Segmentation NIfTI file')
17+
gm_probseg = File(exists=True, desc='Gray matter probability map for thresholding')
1718
config_file = File(exists=True, mandatory=True, desc='Path to the config.json file')
1819
segmentation_type = traits.Str(mandatory=True, desc="Type of segmentation (e.g. 'gtm', 'wm')")
1920
region_name = traits.Str(
@@ -32,6 +33,9 @@ class ExtractRefRegion(SimpleInterface):
3233

3334
def _run_interface(self, runtime):
3435
seg_img = nib.load(self.inputs.seg_file)
36+
gm_prob_img = None
37+
if isdefined(self.inputs.gm_probseg):
38+
gm_prob_img = nib.load(self.inputs.gm_probseg)
3539

3640
if isdefined(self.inputs.override_indices):
3741
cfg = {'refmask_indices': list(self.inputs.override_indices)}
@@ -50,7 +54,11 @@ def _run_interface(self, runtime):
5054

5155
from petprep.utils.reference_mask import generate_reference_region
5256

53-
refmask_img = generate_reference_region(seg_img=seg_img, config=cfg)
57+
refmask_img = generate_reference_region(
58+
seg_img=seg_img,
59+
config=cfg,
60+
gm_probseg_img=gm_prob_img,
61+
)
5462

5563
out_file = os.path.abspath('refmask.nii.gz')
5664
nib.save(refmask_img, out_file)

petprep/interfaces/tests/test_reference_mask.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,34 @@ def test_extract_refregion_override_missing_config(tmp_path):
116116
out = nb.load(res.outputs.refmask_file).get_fdata()
117117
assert out[2, 2, 2] == 1
118118
assert out.sum() == 1
119+
120+
121+
def test_extract_refregion_gm_threshold(tmp_path):
122+
data = np.zeros((5, 5, 5), dtype=np.uint8)
123+
data[1, 1, 1] = 1
124+
data[2, 2, 2] = 1
125+
seg = tmp_path / 'seg.nii.gz'
126+
nb.Nifti1Image(data, np.eye(4)).to_filename(seg)
127+
gm = np.zeros((5, 5, 5), dtype=np.float32)
128+
gm[1, 1, 1] = 0.4
129+
gm[2, 2, 2] = 0.8
130+
gm_file = tmp_path / 'gm.nii.gz'
131+
nb.Nifti1Image(gm, np.eye(4)).to_filename(gm_file)
132+
133+
cfg = _create_config(tmp_path, [1], {'gm_prob_threshold': 0.5})
134+
135+
node = pe.Node(
136+
ExtractRefRegion(
137+
seg_file=str(seg),
138+
gm_probseg=str(gm_file),
139+
config_file=str(cfg),
140+
segmentation_type='testseg',
141+
region_name='region',
142+
),
143+
name='er5',
144+
base_dir=str(tmp_path),
145+
)
146+
res = node.run()
147+
out = nb.load(res.outputs.refmask_file).get_fdata()
148+
assert out.sum() == 1
149+
assert out[2, 2, 2] == 1

petprep/utils/reference_mask.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from skimage.morphology import ball, binary_dilation, binary_erosion
55

66

7-
def generate_reference_region(seg_img: nib.Nifti1Image, config: dict) -> nib.Nifti1Image:
7+
def generate_reference_region(
8+
seg_img: nib.Nifti1Image,
9+
config: dict,
10+
gm_probseg_img: nib.Nifti1Image | None = None,
11+
) -> nib.Nifti1Image:
812
"""Generate a reference region using a flexible config.
913
1014
Config keys:
@@ -14,6 +18,8 @@ def generate_reference_region(seg_img: nib.Nifti1Image, config: dict) -> nib.Nif
1418
- dilate_by_voxels (int, optional): Dilation radius for excluded regions.
1519
- smooth_fwhm_mm (float, optional): FWHM for smoothing the target region.
1620
- target_volume_ml (float, optional): Keep only the top N voxels by smoothed value.
21+
- gm_prob_threshold (float, optional): Threshold the final mask using a
22+
gray matter probability map. Requires ``gm_probseg_img``.
1723
1824
Returns:
1925
nib.Nifti1Image: Final reference mask.
@@ -52,4 +58,12 @@ def generate_reference_region(seg_img: nib.Nifti1Image, config: dict) -> nib.Nif
5258
threshold = values[-target_voxels]
5359
mask = ((smoothed >= threshold) & (mask > 0)).astype(np.uint8)
5460

61+
# Step 5: Optional gray matter probability thresholding
62+
if gm_probseg_img is not None and 'gm_prob_threshold' in config:
63+
gm_prob = gm_probseg_img.get_fdata()
64+
if gm_prob.shape != mask.shape:
65+
raise ValueError('gm_probseg_img does not match segmentation shape')
66+
mask = (mask > 0) & (gm_prob >= config['gm_prob_threshold'])
67+
mask = mask.astype(np.uint8)
68+
5569
return nib.Nifti1Image(mask, affine, header)

petprep/utils/tests/test_reference_mask.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,22 @@ def test_generate_reference_region_large_target_volume():
6262
out_img = generate_reference_region(img, config)
6363
mask = out_img.get_fdata()
6464
assert mask.sum() == seg.sum()
65+
66+
67+
def test_generate_reference_region_gm_threshold():
68+
seg = np.zeros((5, 5, 5), dtype=np.uint8)
69+
seg[2, 2, 2] = 1
70+
seg[3, 3, 3] = 1
71+
seg_img = nb.Nifti1Image(seg, np.eye(4))
72+
73+
gm = np.zeros((5, 5, 5), dtype=np.float32)
74+
gm[2, 2, 2] = 0.4
75+
gm[3, 3, 3] = 0.8
76+
gm_img = nb.Nifti1Image(gm, np.eye(4))
77+
78+
config = {'refmask_indices': [1], 'gm_prob_threshold': 0.5}
79+
80+
out = generate_reference_region(seg_img, config, gm_probseg_img=gm_img)
81+
data = out.get_fdata()
82+
assert data.sum() == 1
83+
assert data[3, 3, 3] == 1

petprep/workflows/pet/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def init_pet_wf(
243243
('t1w_preproc', 'inputnode.t1w_preproc'),
244244
('t1w_mask', 'inputnode.t1w_mask'),
245245
('t1w_dseg', 'inputnode.t1w_dseg'),
246+
('t1w_tpms', 'inputnode.t1w_tpms'),
246247
('subjects_dir', 'inputnode.subjects_dir'),
247248
('subject_id', 'inputnode.subject_id'),
248249
('fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),

petprep/workflows/pet/fit.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def init_pet_fit_wf(
164164
't1w_preproc',
165165
't1w_mask',
166166
't1w_dseg',
167+
't1w_tpms',
167168
'subjects_dir',
168169
'subject_id',
169170
'fsnative2t1w_xfm',
@@ -448,6 +449,8 @@ def init_pet_fit_wf(
448449
name='refmask_report_wf',
449450
)
450451

452+
gm_select = pe.Node(niu.Select(index=0), name='select_gm_probseg')
453+
451454
pet_ref_tacs_wf = init_pet_ref_tacs_wf(name='pet_ref_tacs_wf')
452455
pet_ref_tacs_wf.inputs.inputnode.metadata = str(
453456
Path(pet_file).with_suffix('').with_suffix('.json')
@@ -470,6 +473,8 @@ def init_pet_fit_wf(
470473
)
471474
ds_ref_tacs.inputs.source_file = pet_file
472475

476+
workflow.connect([(inputnode, gm_select, [('t1w_tpms', 'inlist')])])
477+
473478
workflow.connect(
474479
[
475480
(
@@ -479,6 +484,11 @@ def init_pet_fit_wf(
479484
('outputnode.segmentation', 'inputnode.seg_file'),
480485
],
481486
),
487+
(
488+
gm_select,
489+
refmask_wf,
490+
[('out', 'inputnode.gm_probseg')],
491+
),
482492
(
483493
refmask_wf,
484494
outputnode,

petprep/workflows/pet/reference_mask.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ def init_pet_refmask_wf(
4242

4343
workflow = pe.Workflow(name=name)
4444

45-
inputnode = pe.Node(IdentityInterface(fields=['seg_file']), name='inputnode')
45+
inputnode = pe.Node(
46+
IdentityInterface(fields=['seg_file', 'gm_probseg']),
47+
name='inputnode',
48+
)
4649
outputnode = pe.Node(IdentityInterface(fields=['refmask_file']), name='outputnode')
4750

4851
extract_mask = pe.Node(ExtractRefRegion(), name='extract_refregion')
@@ -56,7 +59,14 @@ def init_pet_refmask_wf(
5659

5760
workflow.connect(
5861
[
59-
(inputnode, extract_mask, [('seg_file', 'seg_file')]),
62+
(
63+
inputnode,
64+
extract_mask,
65+
[
66+
('seg_file', 'seg_file'),
67+
('gm_probseg', 'gm_probseg'),
68+
],
69+
),
6070
(extract_mask, outputnode, [('refmask_file', 'refmask_file')]),
6171
]
6272
)

petprep/workflows/pet/tests/test_fit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,10 @@ def test_refmask_report_connections(bids_root: Path, tmp_path: Path):
235235
'inputnode.source_files',
236236
) in petref_edge['connect']
237237

238+
gm_node = wf.get_node('select_gm_probseg')
239+
edge_prob = wf._graph.get_edge_data(gm_node, wf.get_node('pet_refmask_wf'))
240+
assert ('out', 'inputnode.gm_probseg') in edge_prob['connect']
241+
238242
assert any(name.startswith('pet_ref_tacs_wf') for name in wf.list_node_names())
239243
assert 'ds_ref_tacs' in wf.list_node_names()
240244
ds_tacs = wf.get_node('ds_ref_tacs')

0 commit comments

Comments
 (0)