Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto3DSeg] Add mlflow support in autorunner. #7176

Merged
merged 10 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class AutoRunner:
zip url will be downloaded and extracted into the work_dir.
allow_skip: a switch passed to BundleGen process which determines if some Algo in the default templates
can be skipped based on the analysis on the dataset from Auto3DSeg DataAnalyzer.
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of the remote
tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None.
kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage.

Expand Down Expand Up @@ -209,6 +211,7 @@ def __init__(
not_use_cache: bool = False,
templates_path_or_url: str | None = None,
allow_skip: bool = True,
mlflow_tracking_uri: str | None = None,
wyli marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
):
logger.info(f"AutoRunner using work directory {work_dir}")
Expand All @@ -220,6 +223,7 @@ def __init__(
self.algos = algos
self.templates_path_or_url = templates_path_or_url
self.allow_skip = allow_skip
self.mlflow_tracking_uri = mlflow_tracking_uri
self.kwargs = deepcopy(kwargs)

if input is None and os.path.isfile(self.data_src_cfg_name):
Expand Down Expand Up @@ -783,6 +787,7 @@ def run(self):
templates_path_or_url=self.templates_path_or_url,
data_stats_filename=self.datastats_filename,
data_src_cfg_name=self.data_src_cfg_name,
mlflow_tracking_uri=self.mlflow_tracking_uri,
)

if self.gpu_customization:
Expand Down
34 changes: 34 additions & 0 deletions monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, template_path: PathLike):
self.template_path = template_path
self.data_stats_files = ""
self.data_list_file = ""
self.mlflow_tracking_uri = None
self.output_path = ""
self.name = ""
self.best_metric = None
Expand Down Expand Up @@ -129,6 +130,17 @@ def set_data_source(self, data_src_cfg: str) -> None:
"""
self.data_list_file = data_src_cfg

def set_mlflow_tracking_uri(self, mlflow_tracking_uri: str | None) -> None:
"""
Set the tracking URI for MLflow server

Args:
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
the value is None.
"""
self.mlflow_tracking_uri = mlflow_tracking_uri # type: ignore

def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs: Any) -> dict:
"""
The configuration files defined when constructing this Algo instance might not have a complete training
Expand Down Expand Up @@ -432,6 +444,9 @@ class BundleGen(AlgoGen):
data_stats_filename: the path to the data stats file (generated by DataAnalyzer).
data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of
{"modality": "ct", "datalist": "path_to_json_datalist", "dataroot": "path_dir_data"}.
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
the value is None.
.. code-block:: bash

python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml"
Expand All @@ -444,6 +459,7 @@ def __init__(
templates_path_or_url: str | None = None,
data_stats_filename: str | None = None,
data_src_cfg_name: str | None = None,
mlflow_tracking_uri: str | None = None,
):
if algos is None or isinstance(algos, (list, tuple, str)):
if templates_path_or_url is None:
Expand Down Expand Up @@ -496,6 +512,7 @@ def __init__(

self.data_stats_filename = data_stats_filename
self.data_src_cfg_name = data_src_cfg_name
self.mlflow_tracking_uri = mlflow_tracking_uri
self.history: list[dict] = []

def set_data_stats(self, data_stats_filename: str) -> None:
Expand Down Expand Up @@ -524,6 +541,21 @@ def get_data_src(self):
"""Get the data source filename"""
return self.data_src_cfg_name

def set_mlflow_tracking_uri(self, mlflow_tracking_uri):
"""
Set the tracking URI for MLflow server

Args:
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
the value is None.
"""
self.mlflow_tracking_uri = mlflow_tracking_uri

def get_mlflow_tracking_uri(self):
"""Get the tracking URI for MLflow server"""
return self.mlflow_tracking_uri

def get_history(self) -> list:
"""Get the history of the bundleAlgo object with their names/identifiers"""
return self.history
Expand Down Expand Up @@ -575,9 +607,11 @@ def generate(
for f_id in ensure_tuple(fold_idx):
data_stats = self.get_data_stats()
data_src_cfg = self.get_data_src()
mlflow_tracking_uri = self.get_mlflow_tracking_uri()
gen_algo = deepcopy(algo)
gen_algo.set_data_stats(data_stats)
gen_algo.set_data_source(data_src_cfg)
gen_algo.set_mlflow_tracking_uri(mlflow_tracking_uri)
name = f"{gen_algo.name}_{f_id}"

if allow_skip:
Expand Down