|
638 | 638 | "\n", |
639 | 639 | "def get_cell_compositions(\n", |
640 | 640 | " wsi_path: str,\n", |
| 641 | + " mask_path: str,\n", |
641 | 642 | " inst_pred_path: str,\n", |
642 | 643 | " save_dir: str,\n", |
643 | 644 | " num_types: int = 6,\n", |
|
662 | 663 | " inst_boxes = np.array(inst_boxes)\n", |
663 | 664 | "\n", |
664 | 665 | " geometries = [shapely_box(*bounds) for bounds in inst_boxes]\n", |
665 | | - " # An auxiliary dictionary to actually query the index within the source list\n", |
666 | | - " index_by_id = {id(geo): idx for idx, geo in enumerate(geometries)}\n", |
667 | 666 | " spatial_indexer = STRtree(geometries)\n", |
668 | 667 | "\n", |
669 | 668 | " # * Generate patch coordinates (in xy format)\n", |
|
676 | 675 | " stride_shape=stride_shape,\n", |
677 | 676 | " )\n", |
678 | 677 | "\n", |
| 678 | + " # filter out coords which dont lie in mask\n", |
| 679 | + " selected_coord_indices = PatchExtractor.filter_coordinates(\n", |
| 680 | + " WSIReader.open(mask_path),\n", |
| 681 | + " patch_inputs,\n", |
| 682 | + " wsi_shape=wsi_shape,\n", |
| 683 | + " min_mask_ratio=0.5,\n", |
| 684 | + " )\n", |
| 685 | + " patch_inputs = patch_inputs[selected_coord_indices]\n", |
| 686 | + "\n", |
679 | 687 | " bounds_compositions = []\n", |
680 | 688 | " for bounds in patch_inputs:\n", |
681 | 689 | " bounds_ = shapely_box(*bounds)\n", |
682 | 690 | " indices = [\n", |
683 | | - " index_by_id[id(geo)]\n", |
| 691 | + " geo\n", |
684 | 692 | " for geo in spatial_indexer.query(bounds_)\n", |
685 | | - " if bounds_.contains(geo)\n", |
| 693 | + " if bounds_.contains(geometries[geo])\n", |
686 | 694 | " ]\n", |
687 | 695 | " insts = [inst_pred[v][\"type\"] for v in indices]\n", |
688 | 696 | " uids, freqs = np.unique(insts, return_counts=True)\n", |
689 | 697 | " # A bound may not contain all types, hence, to sync\n", |
690 | 698 | " # the array and placement across all types, we create\n", |
691 | 699 | " # a holder then fill the count within.\n", |
692 | 700 | " holder = np.zeros(num_types, dtype=np.int16)\n", |
693 | | - " holder[uids] = freqs\n", |
| 701 | + " holder[uids.astype(int)] = freqs\n", |
694 | 702 | " bounds_compositions.append(holder)\n", |
695 | 703 | " bounds_compositions = np.array(bounds_compositions)\n", |
696 | 704 | "\n", |
|
706 | 714 | " inst_segmentor = NucleusInstanceSegmentor(\n", |
707 | 715 | " pretrained_model=\"hovernet_fast-pannuke\",\n", |
708 | 716 | " batch_size=16,\n", |
709 | | - " num_postproc_workers=2,\n", |
| 717 | + " num_postproc_workers=4,\n", |
| 718 | + " num_loader_workers=4,\n", |
710 | 719 | " )\n", |
| 720 | + " # bigger tile shape for postprocessing performance\n", |
| 721 | + " inst_segmentor.ioconfig.tile_shape = (4000, 4000)\n", |
711 | 722 | " # Injecting customized preprocessing functions,\n", |
712 | 723 | " # check the document or sample codes below for API\n", |
713 | 724 | " inst_segmentor.model.preproc_func = preproc_func\n", |
|
735 | 746 | "\n", |
736 | 747 | " # TODO: parallelize this later if possible\n", |
737 | 748 | " for idx, path in enumerate(output_paths):\n", |
738 | | - " get_cell_compositions(wsi_paths[idx], path, save_dir)\n", |
| 749 | + " get_cell_compositions(wsi_paths[idx], msk_paths[idx], path, save_dir)\n", |
739 | 750 | " return output_paths" |
740 | 751 | ] |
741 | 752 | }, |
|
1035 | 1046 | "outputs": [], |
1036 | 1047 | "source": [ |
1037 | 1048 | "NODE_SIZE = 24\n", |
1038 | | - "NODE_RESOLUTION = dict(resolution=0.5, units=\"mpp\")\n", |
| 1049 | + "NODE_RESOLUTION = dict(resolution=0.25, units=\"mpp\")\n", |
1039 | 1050 | "PLOT_RESOLUTION = dict(resolution=4.0, units=\"mpp\")" |
1040 | 1051 | ] |
1041 | 1052 | }, |
|
1077 | 1088 | "plot_resolution = reader.slide_dimensions(**PLOT_RESOLUTION)\n", |
1078 | 1089 | "fx = np.array(node_resolution) / np.array(plot_resolution)\n", |
1079 | 1090 | "\n", |
1080 | | - "node_coordinates = np.array(graph.coords) / fx\n", |
| 1091 | + "node_coordinates = np.array(graph.coordinates) / fx\n", |
1081 | 1092 | "edges = graph.edge_index.T\n", |
1082 | 1093 | "\n", |
1083 | 1094 | "thumb = reader.slide_thumbnail(**PLOT_RESOLUTION)\n", |
|
2458 | 2469 | "\n", |
2459 | 2470 | "NODE_SIZE = 25\n", |
2460 | 2471 | "NUM_NODE_FEATURES = 4\n", |
2461 | | - "NODE_RESOLUTION = dict(resolution=0.5, units=\"mpp\")\n", |
| 2472 | + "NODE_RESOLUTION = dict(resolution=0.25, units=\"mpp\")\n", |
2462 | 2473 | "PLOT_RESOLUTION = dict(resolution=4.0, units=\"mpp\")\n", |
2463 | 2474 | "\n", |
2464 | 2475 | "node_scaler = joblib.load(SCALER_PATH)\n", |
|
2503 | 2514 | "cmap = plt.get_cmap(\"inferno\")\n", |
2504 | 2515 | "graph = graph.to(\"cpu\")\n", |
2505 | 2516 | "\n", |
2506 | | - "node_coordinates = np.array(graph.coords) / fx\n", |
| 2517 | + "node_coordinates = np.array(graph.coordinates) / fx\n", |
2507 | 2518 | "node_colors = (cmap(np.squeeze(node_activations))[..., :3] * 255).astype(np.uint8)\n", |
2508 | 2519 | "edges = graph.edge_index.T\n", |
2509 | 2520 | "\n", |
|
0 commit comments