Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 30, 2024
1 parent 0722969 commit de91cff
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 28 deletions.
22 changes: 14 additions & 8 deletions monailabel/tasks/infer/basic_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import copy
import logging
import os
import shutil
import time
from abc import abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from monai.data import decollate_batch
from monai.inferers import Inferer, SimpleInferer, SlidingWindowInferer
Expand All @@ -25,13 +27,10 @@
from monailabel.interfaces.exception import MONAILabelError, MONAILabelException
from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
from monailabel.interfaces.utils.transform import dump_data, run_transforms
from monailabel.tasks.infer.prompt_utils import check_prompts_format, prompt_run_inferer
from monailabel.transform.cache import CacheTransformDatad
from monailabel.transform.writer import ClassificationWriter, DetectionWriter, Writer
from monailabel.utils.others.generic import device_list, device_map, name_to_device
from monailabel.tasks.infer.prompt_utils import prompt_run_inferer, check_prompts_format

import shutil
import numpy as np

rearrange, _ = optional_import("einops", name="rearrange")

Expand Down Expand Up @@ -404,7 +403,7 @@ def __call__(self, data):
if d is None:
return run_transforms(data, transforms, log_prefix="PRE", use_compose=False)
return run_transforms(d, post_cache, log_prefix="PRE", use_compose=False) if post_cache else d
print('Finddddddddddd data pathhhhhhhhhhhhh:', data)
print("Finddddddddddd data pathhhhhhhhhhhhh:", data)
return run_transforms(data, transforms, log_prefix="PRE", use_compose=False)

def run_invert_transforms(self, data: Dict[str, Any], pre_transforms, names):
Expand Down Expand Up @@ -513,8 +512,15 @@ def run_inferer(self, data: Dict[str, Any], convert_to_batch=True, device="cuda"
network = self._get_network(device, data)
modelname = data.get("model", None)
if network:
if "vista" in modelname:
return prompt_run_inferer(data, inferer, network, input_key=self.input_key, output_label_key=self.output_label_key, device=device)
if "vista" in modelname:
return prompt_run_inferer(
data,
inferer,
network,
input_key=self.input_key,
output_label_key=self.output_label_key,
device=device,
)
else:
inputs = data[self.input_key]
inputs = inputs if torch.is_tensor(inputs) else torch.from_numpy(inputs)
Expand All @@ -523,7 +529,7 @@ def run_inferer(self, data: Dict[str, Any], convert_to_batch=True, device="cuda"

with torch.no_grad():
outputs = inferer(
inputs,
inputs,
network,
)

Expand Down
29 changes: 20 additions & 9 deletions monailabel/tasks/infer/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import os
import shutil
import torch
import numpy as np
from monai.utils import optional_import
from monai.data import decollate_batch
from typing import Any, Dict

import numpy as np
import torch
from einops import rearrange
from monai.data import decollate_batch
from monai.utils import optional_import

rearrange, _ = optional_import("einops", name="rearrange")

Expand All @@ -20,6 +21,7 @@ def transform_points(point, affine):
point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3]
return point


def check_prompts_format(label_prompt, points, point_labels):
"""check the format of user prompts
label_prompt: [1,2,3,4,...,B] List of tensors
Expand All @@ -28,7 +30,9 @@ def check_prompts_format(label_prompt, points, point_labels):
"""
# check prompt is given
if label_prompt is None and points is None:
everything_labels = list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132]))
everything_labels = list(
{i + 1 for i in range(132)} - {2, 16, 18, 20, 21, 23, 24, 25, 26, 27, 128, 129, 130, 131, 132}
)
if everything_labels is not None:
label_prompt = [torch.tensor(_) for _ in everything_labels]

Expand Down Expand Up @@ -68,7 +72,16 @@ def check_prompts_format(label_prompt, points, point_labels):
raise ValueError("Points must be given if point labels are given.")
return label_prompt, points, point_labels

def prompt_run_inferer(data: Dict[str, Any], inferer, network, input_key="image", output_label_key="pred", device="cuda", convert_to_batch=True):

def prompt_run_inferer(
data: Dict[str, Any],
inferer,
network,
input_key="image",
output_label_key="pred",
device="cuda",
convert_to_batch=True,
):
# Retrieve label_prompt, points, and point_labels
label_prompt, points, point_labels = (
data.get("label_prompt", None),
Expand All @@ -95,7 +108,7 @@ def prompt_run_inferer(data: Dict[str, Any], inferer, network, input_key="image"
if points is not None:
points = torch.as_tensor([points])

original_spatial_shape = np.array(data['image_meta_dict']['spatial_shape'])
original_spatial_shape = np.array(data["image_meta_dict"]["spatial_shape"])
resized_spatial_shape = np.array(data[input_key].shape[1:])
scaling_factors = resized_spatial_shape / original_spatial_shape
transformed_point = points * scaling_factors
Expand All @@ -111,11 +124,9 @@ def prompt_run_inferer(data: Dict[str, Any], inferer, network, input_key="image"
inputs = inputs[None].to(torch.device(device))
inputs = inputs.to(torch.device(device))


with torch.no_grad():
outputs = inferer(inputs, network, point_coords=points, point_labels=point_labels, class_vector=label_prompt)


if device.startswith("cuda"):
torch.cuda.empty_cache()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ export default class AutoSegmentation extends BaseTab {

if (this.props.viewConstants.SupportedClasses) {
const labels = this.props.viewConstants.SupportedClasses
let labelIndex = 1;
let labelIndex = 1;

for (const key in labels) {
const organName = labels[key];
console.log(organName)
if (organName.toLowerCase() !== 'background') {
const hexColor = segmentColors[labelIndex] || '#000000';
const hexColor = segmentColors[labelIndex] || '#000000';

selectedOrgans[organName] = { checked: false, color: hexColor };
labelIndex++;
Expand Down Expand Up @@ -107,13 +107,13 @@ export default class AutoSegmentation extends BaseTab {

onChangeOrgans = (organ, evt) => {
this.setState((prevState) => {
const selectedOrgans = { ...prevState.selectedOrgans };
const selectedOrgans = { ...prevState.selectedOrgans };

selectedOrgans[organ] = {
...selectedOrgans[organ],
checked: evt.target.checked,
};

return { selectedOrgans };
});
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ export default class PointPrompts extends BaseTab {

if (this.props.viewConstants.SupportedClasses) {
const labels = this.props.viewConstants.SupportedClasses
let labelIndex = 1;
let labelIndex = 1;

for (const key in labels) {
const organName = labels[key];
if (organName.toLowerCase() !== 'background') {
const hexColor = segmentColors[labelIndex] || '#000000';
const hexColor = segmentColors[labelIndex] || '#000000';
selectedOrgans[organName] = { checked: false, color: hexColor };
labelIndex++;
}
Expand Down Expand Up @@ -107,7 +107,7 @@ export default class PointPrompts extends BaseTab {
const config = this.props.onOptionsConfig();
const params =
config && config.infer && config.infer[model] ? config.infer[model] : {};


// let seriesInstanceUID = viewConstants.SeriesInstanceUID;

Expand Down Expand Up @@ -299,13 +299,13 @@ export default class PointPrompts extends BaseTab {

onChangeOrgans = (organ, evt) => {
this.setState((prevState) => {
const selectedOrgans = { ...prevState.selectedOrgans };
const selectedOrgans = { ...prevState.selectedOrgans };

selectedOrgans[organ] = {
...selectedOrgans[organ],
checked: evt.target.checked,
};

return { selectedOrgans };
});
};
Expand Down Expand Up @@ -465,7 +465,7 @@ export default class PointPrompts extends BaseTab {
</tr>

{Object.entries(this.state.selectedOrgans).map(([organ, { color, checked }]) => (

<tr
key={organ}
className='clickable-row'
Expand Down
2 changes: 1 addition & 1 deletion plugins/ohifv3_vista/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ cp -r ../../monailabel/endpoints/static/ohif www/html/ohif
cp -f config/monai_label.js www/html/ohif/app-config.js

# nginx -p `pwd` -c config/nginx.conf -e logs/error.log
nginx -p `pwd` -c config/nginx.conf
nginx -p `pwd` -c config/nginx.conf

0 comments on commit de91cff

Please sign in to comment.