Skip to content

Commit e4deac4

Browse files
meastyshaneahmed
andauthored
🐛 Fix Errors in the slidegraph Example Notebook (#608)
Fixes some issues with the slidegraph notebooks. 1. Updates due to changes in STRtree in recent shapely versions 2. In the 'cell-composition' mode, add the 'filter_coordinates' step so that the mask is considered when generating graph nodes. Also made a small tweak to mask filter so that mask doesnt have to be single channel. 3. Fix the resolution of the plots being wrong when not using pre-generated model 4. Fixes for a couple of issues related to datatypes, maybe they crept in at some point due to numpy or torch version changes. I have also added a note to explain the last few cells of the inference notebook are for composition features only, as there are only pretrained model weights for that mode. --------- Co-authored-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com>
1 parent 7369d8d commit e4deac4

File tree

5 files changed

+176
-125
lines changed

5 files changed

+176
-125
lines changed

examples/full-pipelines/slide-graph.ipynb

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@
638638
"\n",
639639
"def get_cell_compositions(\n",
640640
" wsi_path: str,\n",
641+
" mask_path: str,\n",
641642
" inst_pred_path: str,\n",
642643
" save_dir: str,\n",
643644
" num_types: int = 6,\n",
@@ -662,8 +663,6 @@
662663
" inst_boxes = np.array(inst_boxes)\n",
663664
"\n",
664665
" geometries = [shapely_box(*bounds) for bounds in inst_boxes]\n",
665-
" # An auxiliary dictionary to actually query the index within the source list\n",
666-
" index_by_id = {id(geo): idx for idx, geo in enumerate(geometries)}\n",
667666
" spatial_indexer = STRtree(geometries)\n",
668667
"\n",
669668
" # * Generate patch coordinates (in xy format)\n",
@@ -676,21 +675,30 @@
676675
" stride_shape=stride_shape,\n",
677676
" )\n",
678677
"\n",
678+
" # filter out coords which dont lie in mask\n",
679+
" selected_coord_indices = PatchExtractor.filter_coordinates(\n",
680+
" WSIReader.open(mask_path),\n",
681+
" patch_inputs,\n",
682+
" wsi_shape=wsi_shape,\n",
683+
" min_mask_ratio=0.5,\n",
684+
" )\n",
685+
" patch_inputs = patch_inputs[selected_coord_indices]\n",
686+
"\n",
679687
" bounds_compositions = []\n",
680688
" for bounds in patch_inputs:\n",
681689
" bounds_ = shapely_box(*bounds)\n",
682690
" indices = [\n",
683-
" index_by_id[id(geo)]\n",
691+
" geo\n",
684692
" for geo in spatial_indexer.query(bounds_)\n",
685-
" if bounds_.contains(geo)\n",
693+
" if bounds_.contains(geometries[geo])\n",
686694
" ]\n",
687695
" insts = [inst_pred[v][\"type\"] for v in indices]\n",
688696
" uids, freqs = np.unique(insts, return_counts=True)\n",
689697
" # A bound may not contain all types, hence, to sync\n",
690698
" # the array and placement across all types, we create\n",
691699
" # a holder then fill the count within.\n",
692700
" holder = np.zeros(num_types, dtype=np.int16)\n",
693-
" holder[uids] = freqs\n",
701+
" holder[uids.astype(int)] = freqs\n",
694702
" bounds_compositions.append(holder)\n",
695703
" bounds_compositions = np.array(bounds_compositions)\n",
696704
"\n",
@@ -706,8 +714,11 @@
706714
" inst_segmentor = NucleusInstanceSegmentor(\n",
707715
" pretrained_model=\"hovernet_fast-pannuke\",\n",
708716
" batch_size=16,\n",
709-
" num_postproc_workers=2,\n",
717+
" num_postproc_workers=4,\n",
718+
" num_loader_workers=4,\n",
710719
" )\n",
720+
" # bigger tile shape for postprocessing performance\n",
721+
" inst_segmentor.ioconfig.tile_shape = (4000, 4000)\n",
711722
" # Injecting customized preprocessing functions,\n",
712723
" # check the document or sample codes below for API\n",
713724
" inst_segmentor.model.preproc_func = preproc_func\n",
@@ -735,7 +746,7 @@
735746
"\n",
736747
" # TODO: parallelize this later if possible\n",
737748
" for idx, path in enumerate(output_paths):\n",
738-
" get_cell_compositions(wsi_paths[idx], path, save_dir)\n",
749+
" get_cell_compositions(wsi_paths[idx], msk_paths[idx], path, save_dir)\n",
739750
" return output_paths"
740751
]
741752
},
@@ -1035,7 +1046,7 @@
10351046
"outputs": [],
10361047
"source": [
10371048
"NODE_SIZE = 24\n",
1038-
"NODE_RESOLUTION = dict(resolution=0.5, units=\"mpp\")\n",
1049+
"NODE_RESOLUTION = dict(resolution=0.25, units=\"mpp\")\n",
10391050
"PLOT_RESOLUTION = dict(resolution=4.0, units=\"mpp\")"
10401051
]
10411052
},
@@ -1077,7 +1088,7 @@
10771088
"plot_resolution = reader.slide_dimensions(**PLOT_RESOLUTION)\n",
10781089
"fx = np.array(node_resolution) / np.array(plot_resolution)\n",
10791090
"\n",
1080-
"node_coordinates = np.array(graph.coords) / fx\n",
1091+
"node_coordinates = np.array(graph.coordinates) / fx\n",
10811092
"edges = graph.edge_index.T\n",
10821093
"\n",
10831094
"thumb = reader.slide_thumbnail(**PLOT_RESOLUTION)\n",
@@ -2458,7 +2469,7 @@
24582469
"\n",
24592470
"NODE_SIZE = 25\n",
24602471
"NUM_NODE_FEATURES = 4\n",
2461-
"NODE_RESOLUTION = dict(resolution=0.5, units=\"mpp\")\n",
2472+
"NODE_RESOLUTION = dict(resolution=0.25, units=\"mpp\")\n",
24622473
"PLOT_RESOLUTION = dict(resolution=4.0, units=\"mpp\")\n",
24632474
"\n",
24642475
"node_scaler = joblib.load(SCALER_PATH)\n",
@@ -2503,7 +2514,7 @@
25032514
"cmap = plt.get_cmap(\"inferno\")\n",
25042515
"graph = graph.to(\"cpu\")\n",
25052516
"\n",
2506-
"node_coordinates = np.array(graph.coords) / fx\n",
2517+
"node_coordinates = np.array(graph.coordinates) / fx\n",
25072518
"node_colors = (cmap(np.squeeze(node_activations))[..., :3] * 255).astype(np.uint8)\n",
25082519
"edges = graph.edge_index.T\n",
25092520
"\n",

examples/inference-pipelines/slide-graph.ipynb

Lines changed: 151 additions & 111 deletions
Large diffs are not rendered by default.

requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# torch installation
2-
--extra-index-url https://download.pytorch.org/whl/cu117; sys_platform != "darwin"
2+
--extra-index-url https://download.pytorch.org/whl/cu118; sys_platform != "darwin"
33
albumentations>=1.3.0
44
Click>=8.1.3
55
defusedxml>=0.7.1

tiatoolbox/tools/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def build(
395395

396396
return {
397397
"x": feature_centroids,
398-
"edge_index": edge_index,
398+
"edge_index": edge_index.astype(np.int64),
399399
"coordinates": point_centroids,
400400
}
401401

tiatoolbox/tools/patchextraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def filter_coordinates(
284284
tissue_mask = mask_reader.img
285285

286286
# Scaling the coordinates_list to the `tissue_mask` array resolution
287-
scale_factors = np.array(tissue_mask.shape[::-1]) / np.array(wsi_shape)
287+
scale_factors = np.array(tissue_mask.shape[1::-1]) / np.array(wsi_shape)
288288
scaled_coords = coordinates_list.copy().astype(np.float32)
289289
scaled_coords[:, [0, 2]] *= scale_factors[0]
290290
scaled_coords[:, [0, 2]] = np.clip(

0 commit comments

Comments
 (0)