Skip to content

Refactor MAISI tutorial, migrate GenerativeAI import #1779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f544302
change mor_op position, rm generative import in vae
Can-Zhao Aug 12, 2024
9da99fb
rm xformer import in vae
Can-Zhao Aug 12, 2024
dcad093
add load ckpt functions, inference notebook can run
Can-Zhao Aug 12, 2024
148484b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
a551102
fix logging format, add inference script, add details in readme
Can-Zhao Aug 12, 2024
966583f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
4c9c780
fix typo
Can-Zhao Aug 12, 2024
0b3e3dc
Merge branch 'refactor_maisi' of https://github.com/Can-Zhao/tutorial…
Can-Zhao Aug 12, 2024
1272254
update readme
Can-Zhao Aug 12, 2024
a45c9f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
4af5f5a
reformat
Can-Zhao Aug 12, 2024
4aa9401
update readme
Can-Zhao Aug 12, 2024
b392528
clear directory in code
Can-Zhao Aug 12, 2024
a8ca930
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
b0dda95
clear directory in code
Can-Zhao Aug 12, 2024
1bc2a7d
Merge branch 'refactor_maisi' of https://github.com/Can-Zhao/tutorial…
Can-Zhao Aug 12, 2024
094e938
change epoch to 1 to save notebook run time
Can-Zhao Aug 12, 2024
6b064d1
rm controllable size in inference input
Can-Zhao Aug 12, 2024
fb080de
mv some description to Readme, use subset of data to train in noteboo…
Can-Zhao Aug 12, 2024
5b942f1
rm dir info
Can-Zhao Aug 12, 2024
837cfb7
rm xformer
Can-Zhao Aug 13, 2024
04a78bf
rm generative in sample.py, clean print in utils.py
Can-Zhao Aug 13, 2024
e97a51d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
d5dbabe
clean legacy code in utils.py
Can-Zhao Aug 13, 2024
1b287fe
Merge branch 'refactor_maisi' of https://github.com/Can-Zhao/tutorial…
Can-Zhao Aug 13, 2024
11be5ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
c223fad
rm repeated comment, rm xformer description, rm generative repo descr…
Can-Zhao Aug 13, 2024
daf5c2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
402f4c0
rm xformer description
Can-Zhao Aug 13, 2024
3c178c8
Merge branch 'main' into refactor_maisi
Can-Zhao Aug 13, 2024
288e73d
change docstring
Can-Zhao Aug 13, 2024
2659218
Merge branch 'refactor_maisi' of https://github.com/Can-Zhao/tutorial…
Can-Zhao Aug 13, 2024
4e774bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
0835773
Merge branch 'main' into refactor_maisi
Can-Zhao Aug 14, 2024
c56b0f2
refactor for new controlnet
Can-Zhao Aug 14, 2024
b91b83c
Merge branch 'refactor_maisi' of https://github.com/Can-Zhao/tutorial…
Can-Zhao Aug 14, 2024
59907c4
refactor inference
Can-Zhao Aug 14, 2024
7b3c389
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 12, 2024
commit 966583f779fe372e71be848c89888e26f2efa8ff
2 changes: 1 addition & 1 deletion generation/maisi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ The information for the inference input, like body region and anatomy to generat
- `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. The output will contain paired image and segmentation mask for the controllable anatomy.
- `"body_region"`: If "controllable_anatomy_size" is not specified, "body_region" will be used to constrain the region of generated images. It needs to be chosen from "head", "chest", "thorax", "abdomen", "pelvis", "lower".
- `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json".
- `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16.
- `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16.
- `"autoencoder_sliding_window_infer_overlap"`: float between 0 and 1. Large value will reduce the stitching artifacts when stitching patches during sliding window inference, but increase time cost. If you do not observe seam lines in the generated image result, you can use a smaller value to save inference time.


Expand Down
65 changes: 32 additions & 33 deletions generation/maisi/scripts/inference.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# # MAISI Inference Script
Expand All @@ -27,6 +27,7 @@
from scripts.utils import define_instance, load_autoencoder_ckpt, load_diffusion_ckpt
from scripts.utils_plot import find_label_center_loc, get_xyz_plot, show_image


def main():
parser = argparse.ArgumentParser(description="maisi.controlnet.training")
parser.add_argument(
Expand Down Expand Up @@ -60,19 +61,19 @@ def main():
# ## Set deterministic training for reproducibility
if args.random_seed is not None:
set_determinism(seed=args.random_seed)

# ## Setup data directory
# You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.
# This allows you to save results and reuse downloads.
# If not specified a temporary directory will be used.

directory = os.environ.get("MONAI_DATA_DIRECTORY")
directory = "/localhome/xyz/oss_data"
if directory is not None:
os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

files = [
{
"path": "models/autoencoder_epoch273.pt",
Expand Down Expand Up @@ -107,21 +108,21 @@ def main():
"url": "https://drive.google.com/file/d/16MKsDKkHvDyF2lEir4dzlxwex_GHStUf/view?usp=sharing",
},
]

for file in files:
file["path"] = file["path"] if "datasets/" not in file["path"] else os.path.join(root_dir, file["path"])
download_url(url=file["url"], filepath=file["path"])

# ## Read in environment setting, including data directory, model directory, and output directory
# The information for data directory, model directory, and output directory are saved in ./configs/environment.json
# The information for data directory, model directory, and output directory are saved in ./configs/environment.json
env_dict = json.load(open(args.environment_file, "r"))
for k, v in env_dict.items():
# Update the path to the downloaded dataset in MONAI_DATA_DIRECTORY
val = v if "datasets/" not in v else os.path.join(root_dir, v)
setattr(args, k, val)
print(f"{k}: {val}")
print("Global config variables have been loaded.")

# ## Read in configuration setting, including network definition, body region and anatomy to generate, etc.
#
# The information used for both training and inference, like network definition, is stored in "./configs/config_maisi.json". Training and inference should use the same "./configs/config_maisi.json".
Expand All @@ -133,18 +134,18 @@ def main():
# - `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. The output will contain paired image and segmentation mask for the controllable anatomy.
# - `"body_region"`: If "controllable_anatomy_size" is not specified, "body_region" will be used to constrain the region of generated images. It needs to be chosen from "head", "chest", "thorax", "abdomen", "pelvis", "lower".
# - `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json".
# - `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16.
# - `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16.
# - `"autoencoder_sliding_window_infer_overlap"`: float between 0 and 1. Large value will reduce the stitching artifacts when stitching patches during sliding window inference, but increase time cost. If you do not observe seam lines in the generated image result, you can use a smaller value to save inference time.
config_dict = json.load(open(args.config_file, "r"))
for k, v in config_dict.items():
setattr(args, k, v)

# check the format of inference inputs
config_infer_dict = json.load(open(args.inference_file, "r"))
for k, v in config_infer_dict.items():
setattr(args, k, v)
print(f"{k}: {v}")

check_input(
args.body_region,
args.anatomy_list,
Expand All @@ -155,43 +156,40 @@ def main():
)
latent_shape = [args.latent_channels, args.output_size[0] // 4, args.output_size[1] // 4, args.output_size[2] // 4]
print("Network definition and inference inputs have been loaded.")




# ## Initialize networks and noise scheduler, then load the trained model weights.
# The networks and noise scheduler are defined in `config_file`. We will read them in and load the model weights.
noise_scheduler = define_instance(args, "noise_scheduler")
mask_generation_noise_scheduler = define_instance(args, "mask_generation_noise_scheduler")

device = torch.device("cuda")

autoencoder = define_instance(args, "autoencoder_def").to(device)
checkpoint_autoencoder = load_autoencoder_ckpt(args.trained_autoencoder_path)
autoencoder.load_state_dict(checkpoint_autoencoder)

diffusion_unet = define_instance(args, "diffusion_unet_def").to(device)
checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path)
new_dict = load_diffusion_ckpt(diffusion_unet.state_dict(), checkpoint_diffusion_unet["unet_state_dict"])
diffusion_unet.load_state_dict(new_dict, strict=True)
scale_factor = checkpoint_diffusion_unet["scale_factor"].to(device)

controlnet = define_instance(args, "controlnet_def").to(device)
checkpoint_controlnet = torch.load(args.trained_controlnet_path)
monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict())
controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True)

mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder_def").to(device)
checkpoint_mask_generation_autoencoder = load_autoencoder_ckpt(args.trained_mask_generation_autoencoder_path)
mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)

mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion_def").to(device)
checkpoint_mask_generation_diffusion_unet = torch.load(args.trained_mask_generation_diffusion_path)
mask_generation_diffusion_unet.load_old_state_dict(checkpoint_mask_generation_diffusion_unet)
mask_generation_scale_factor = args.mask_generation_scale_factor

print("All the trained model weights have been loaded.")



# ## Define the LDM Sampler, which contains functions that will perform the inference.
ldm_sampler = LDMSampler(
args.body_region,
Expand Down Expand Up @@ -225,16 +223,17 @@ def main():
autoencoder_sliding_window_infer_size=args.autoencoder_sliding_window_infer_size,
autoencoder_sliding_window_infer_overlap=args.autoencoder_sliding_window_infer_overlap,
)

print(f"The generated image/mask pairs will be saved in {args.output_dir}.")
output_filenames = ldm_sampler.sample_multiple_images(args.num_output_samples)
print("MAISI image/mask generation finished")


if __name__ == "__main__":
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
main()
main()
Loading