Skip to content

Commit c644b31

Browse files
1 parent fe71e40 commit c644b31

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

src/openlifu/seg/skinseg.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,83 @@ def vtk_img_from_array_and_affine(vol_array:np.ndarray, affine:np.ndarray) -> vt
125125
vtk_img.GetPointData().SetScalars(vol_array_vtk)
126126

127127
return vtk_img
128+
129+
def create_closed_surface_from_labelmap(
130+
binary_labelmap:vtk.vtkImageData,
131+
decimation_factor:float=0.,
132+
smoothing_factor:float=0.5
133+
) -> vtk.vtkPolyData:
134+
""" Create a surface mesh vtkPolyData from a binary labelmap vtkImageData.
135+
136+
Args:
137+
binary_labelmap: input vtkImageData binary labelmap
138+
decimation_factor: 0.0 for no decimation, 1.0 for maximum reduction.
139+
smoothing_factor: 0.0 for no smoothing, 1.0 for maximum smoothing.
140+
141+
Returns:
142+
vtkPolyData: the resulting surface mesh
143+
144+
The algorithm here is based on the labelmap-to-closed-surface algorithm in 3D Slicer:
145+
https://github.com/Slicer/Slicer/blob/677932127c73a6c78654d4afd9458a655a4eef63/Libs/vtkSegmentationCore/vtkBinaryLabelmapToClosedSurfaceConversionRule.cxx#L246-L476
146+
"""
147+
148+
# step 1: pad by 1 pixel all around with 0s, to ensure that the surface is still closed
149+
# even if the labelmap runs up against the image boundary.
150+
padder = vtk.vtkImageConstantPad()
151+
padder.SetInputData(binary_labelmap)
152+
extent = binary_labelmap.GetExtent()
153+
padder.SetOutputWholeExtent(
154+
extent[0] - 1, extent[1] + 1,
155+
extent[2] - 1, extent[3] + 1,
156+
extent[4] - 1, extent[5] + 1,
157+
)
158+
padder.Update()
159+
padded_labelmap = padder.GetOutput()
160+
161+
# step 1: extract surface
162+
flying_edges = vtk.vtkDiscreteFlyingEdges3D()
163+
flying_edges.SetInputData(padded_labelmap)
164+
flying_edges.ComputeGradientsOff()
165+
flying_edges.ComputeNormalsOff()
166+
flying_edges.Update()
167+
surface_mesh = flying_edges.GetOutput()
168+
169+
# step 2: decimation
170+
if decimation_factor > 0.0:
171+
decimator = vtk.vtkDecimatePro()
172+
decimator.SetInputData(surface_mesh)
173+
decimator.SetFeatureAngle(60)
174+
decimator.SplittingOff()
175+
decimator.PreserveTopologyOn()
176+
decimator.SetMaximumError(1)
177+
decimator.SetTargetReduction(decimation_factor)
178+
decimator.Update()
179+
surface_mesh = decimator.GetOutput()
180+
181+
# step 3: smoothing
182+
if smoothing_factor > 0.0:
183+
smoother = vtk.vtkWindowedSincPolyDataFilter()
184+
smoother.SetInputData(surface_mesh)
185+
186+
# map smoothing factor to passband and iterations, copying the approach taken by Slicer
187+
passband = pow(10.0, -4.0 * smoothing_factor)
188+
num_iterations = 20 + int(smoothing_factor * 40)
189+
190+
smoother.SetNumberOfIterations(num_iterations)
191+
smoother.SetPassBand(passband)
192+
smoother.BoundarySmoothingOff()
193+
smoother.FeatureEdgeSmoothingOff()
194+
smoother.NonManifoldSmoothingOn()
195+
smoother.NormalizeCoordinatesOn()
196+
smoother.Update()
197+
surface_mesh = smoother.GetOutput()
198+
199+
# step 4: compute normals
200+
normals = vtk.vtkPolyDataNormals()
201+
normals.SetInputData(surface_mesh)
202+
normals.ConsistencyOn()
203+
normals.SplittingOff()
204+
normals.Update()
205+
surface_mesh = normals.GetOutput()
206+
207+
return surface_mesh

tests/test_skinseg.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from openlifu.seg.skinseg import (
88
compute_foreground_mask,
9+
create_closed_surface_from_labelmap,
910
take_largest_connected_component,
1011
vtk_img_from_array_and_affine,
1112
)
@@ -62,3 +63,21 @@ def test_vtk_img_from_array_and_affine():
6263
point_id = vtk_img.FindPoint([x,y,z])
6364

6465
assert vtk_img.GetPointData().GetScalars().GetTuple1(point_id) == pytest.approx(vol_array[i,j,k])
66+
67+
def test_create_closed_surface_from_labelmap():
68+
# create a ball of radius 7 for a labelmap
69+
labelmap = np.zeros((20,20,20))
70+
sphere_radius = 7
71+
sphere_center = np.array([10,10,10])
72+
add_ball(labelmap, tuple(sphere_center), sphere_radius)
73+
labelmap_vtk = vtk_img_from_array_and_affine(labelmap, affine = np.eye(4))
74+
75+
# run the algorithm to be tested
76+
surface = create_closed_surface_from_labelmap(labelmap_vtk, decimation_factor=0.5)
77+
78+
# verify that the points on the generated mesh are not too far off being at distance 7 from the ball center
79+
points = surface.GetPoints()
80+
for i in range(points.GetNumberOfPoints()):
81+
point_position = np.array(points.GetPoint(i))
82+
point_distance_from_sphere_center = np.linalg.norm(point_position - sphere_center, ord=2)
83+
assert np.abs(point_distance_from_sphere_center - sphere_radius) < 1.

0 commit comments

Comments
 (0)