Skip to content

Revert "Auto3DSeg skip trained algos" #6295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ jobs:
integration-py3:
container:
image: nvcr.io/nvidia/pytorch:22.04-py3 # CUDA 11.6 py38
options: --gpus all # shm-size 4g works fine
runs-on: [self-hosted, linux, x64, integration]
options: --gpus "device=0" --ipc host # shm-size 4g works fine
runs-on: [self-hosted, linux, x64, command]
steps:
# checkout the pull request branch
- uses: actions/checkout@v3
Expand All @@ -34,7 +34,7 @@ jobs:
run: |
which python
python -m pip install --upgrade pip wheel
python -m pip install --upgrade torch torchvision
python -m pip install --upgrade torch torchvision torchaudio
python -m pip install -r requirements-dev.txt
rm -rf /github/home/.cache/torch/hub/mmars/
- name: Run integration tests
Expand All @@ -43,14 +43,37 @@ jobs:
git config --global --add safe.directory /__w/MONAI/MONAI
git clean -ffdx
nvidia-smi
export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)
export CUDA_VISIBLE_DEVICES=$(python -m tests.utils -c 1 | tail -n 1)
echo $CUDA_VISIBLE_DEVICES
trap 'if pgrep python; then pkill python; fi;' ERR
python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
python -c $'import torch\na=[torch.zeros(1,device=f"cuda:{i}") for i in range(torch.cuda.device_count())];\nwhile True:print(a)' > /dev/null &
python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))'

# test auto3dseg
BUILD_MONAI=0 ./runtests.sh --build
python -m tests.test_auto3dseg_ensemble
python -m tests.test_auto3dseg_hpo
python -m tests.test_integration_autorunner
python -m tests.test_integration_gpu_customization

# test latest template
cd ../
git clone --depth 1 --branch main --single-branch https://github.com/Project-MONAI/research-contributions.git
ls research-contributions/
cp -r research-contributions/auto3dseg/algorithm_templates ../MONAI/
cd research-contributions && git log -1 && cd ..
export OMP_NUM_THREADS=4
export MKL_NUM_THREADS=4
export MONAI_TESTING_ALGO_TEMPLATE=algorithm_templates
python -m tests.test_auto3dseg_ensemble
python -m tests.test_auto3dseg_hpo
python -m tests.test_integration_autorunner
python -m tests.test_integration_gpu_customization

# the other tests
BUILD_MONAI=1 ./runtests.sh --build --net
BUILD_MONAI=1 ./runtests.sh --build --unittests --disttests
BUILD_MONAI=1 ./runtests.sh --build --unittests
if pgrep python; then pkill python; fi
shell: bash
- name: Add reaction
Expand Down
20 changes: 7 additions & 13 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(
# determine if we need to analyze, algo_gen or train from cache, unless manually provided
self.analyze = not self.cache["analyze"] if analyze is None else analyze
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
self.train = train
self.train = not self.cache["train"] if train is None else train
self.ensemble = ensemble # last step, no need to check

self.set_training_params()
Expand Down Expand Up @@ -758,8 +758,7 @@ def run(self):
logger.info("Skipping algorithm generation...")

# step 3: algo training
auto_train_choice = self.train is None
if self.train or (auto_train_choice and not self.cache["train"]):
if self.train:
history = import_bundle_algo_history(self.work_dir, only_trained=False)

if len(history) == 0:
Expand All @@ -768,15 +767,10 @@ def run(self):
"Possibly the required algorithms generation step was not completed."
)

if auto_train_choice:
history = [h for h in history if not h["is_trained"]] # skip trained

if len(history) > 0:
if not self.hpo:
self._train_algo_in_sequence(history)
else:
self._train_algo_in_nni(history)

if not self.hpo:
self._train_algo_in_sequence(history)
else:
self._train_algo_in_nni(history)
self.export_cache(train=True)
else:
logger.info("Skipping algorithm training...")
Expand Down Expand Up @@ -804,4 +798,4 @@ def run(self):
self.save_image(pred)
logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.")

logger.info("Auto3Dseg pipeline is completed successfully.")
logger.info("Auto3Dseg pipeline is complete successfully.")
2 changes: 1 addition & 1 deletion monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from monai.utils import ensure_tuple

logger = get_logger(module_name=__name__)
ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "4af80e1")
ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "7758ad1")

__all__ = ["BundleAlgo", "BundleGen"]

Expand Down
9 changes: 3 additions & 6 deletions monai/apps/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,11 @@ def import_bundle_algo_history(
if isinstance(algo, BundleAlgo): # algo's template path needs override
algo.template_path = algo_meta_data["template_path"]

best_metrics = "best_metrics"
is_trained = best_metrics in algo_meta_data

if only_trained:
if is_trained:
history.append({name: algo, "is_trained": is_trained, best_metrics: algo_meta_data[best_metrics]})
if "best_metrics" in algo_meta_data:
history.append({name: algo})
else:
history.append({name: algo, "is_trained": is_trained, best_metrics: algo_meta_data.get(best_metrics, None)})
history.append({name: algo})

return history

Expand Down
2 changes: 0 additions & 2 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import math
import os
import pickle
import warnings
from collections import abc, defaultdict
from collections.abc import Generator, Iterable, Mapping, Sequence, Sized
from copy import deepcopy
Expand Down Expand Up @@ -786,7 +785,6 @@ def rectify_header_sform_qform(img_nii):
return img_nii

norm = affine_to_spacing(img_nii.affine, r=d)
warnings.warn(f"Modifying image pixdim from {pixdim} to {norm}")

img_nii.header.set_zooms(norm)
return img_nii
Expand Down
2 changes: 1 addition & 1 deletion tests/test_retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
TEST_CASES_TS.append([model, *case])


@SkipIfBeforePyTorchVersion((1, 9))
@SkipIfBeforePyTorchVersion((1, 12))
@unittest.skipUnless(has_torchvision, "Requires torchvision")
class TestRetinaNet(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand Down
7 changes: 6 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import argparse
import copy
import datetime
import functools
Expand Down Expand Up @@ -784,6 +785,7 @@ def query_memory(n=2):
bash_string = "nvidia-smi --query-gpu=power.draw,temperature.gpu,memory.used --format=csv,noheader,nounits"

try:
print(f"query memory with n={n}")
p1 = Popen(bash_string.split(), stdout=PIPE)
output, error = p1.communicate()
free_memory = [x.split(",") for x in output.decode("utf-8").split("\n")[:-1]]
Expand Down Expand Up @@ -842,5 +844,8 @@ def command_line_tests(cmd, copy_env=True):
TEST_DEVICES.append([torch.device("cuda")])

if __name__ == "__main__":
print("\n", query_memory(), sep="\n") # print to stdout
parser = argparse.ArgumentParser(prog="util")
parser.add_argument("-c", "--count", default=2, help="max number of gpus")
args = parser.parse_args()
print("\n", query_memory(int(args.count)), sep="\n") # print to stdout
sys.exit(0)