-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f5e29bf
commit b218ea2
Showing
5 changed files
with
335 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
296 changes: 296 additions & 0 deletions
296
example-get-started-experiments/code/notebooks/TrainSegModel-sagemaker.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"\n", | ||
"ROOT = Path(\"../\")\n", | ||
"DATA = ROOT / \"data\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import warnings\n", | ||
"warnings.filterwarnings(\"ignore\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import shutil\n", | ||
"from functools import partial\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import torch\n", | ||
"from box import ConfigBox\n", | ||
"from dvclive import Live\n", | ||
"from dvclive.fastai import DVCLiveCallback\n", | ||
"from fastai.data.all import Normalize, get_files\n", | ||
"from fastai.metrics import DiceMulti\n", | ||
"from fastai.vision.all import (Resize, SegmentationDataLoaders,\n", | ||
" imagenet_stats, models, unet_learner)\n", | ||
"from ruamel.yaml import YAML\n", | ||
"from PIL import Image" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Load data and split it into train/test\n", | ||
"\n", | ||
"We have some [data in DVC](https://dvc.org/doc/start/data-management/data-versioning) that we can pull. \n", | ||
"\n", | ||
"This data includes:\n", | ||
"* satellite images\n", | ||
"* masks of the swimming pools in each satellite image\n", | ||
"\n", | ||
"DVC can help connect your data to your repo, but it isn't necessary to have your data in DVC to start tracking experiments with DVC and DVCLive." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!dvc pull" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"test_regions = [\"REGION_1-\"]\n", | ||
"\n", | ||
"img_fpaths = get_files(DATA / \"pool_data\" / \"images\", extensions=\".jpg\")\n", | ||
"\n", | ||
"train_data_dir = DATA / \"train_data\"\n", | ||
"train_data_dir.mkdir(exist_ok=True)\n", | ||
"test_data_dir = DATA / \"test_data\"\n", | ||
"test_data_dir.mkdir(exist_ok=True)\n", | ||
"for img_path in img_fpaths:\n", | ||
" msk_path = DATA / \"pool_data\" / \"masks\" / f\"{img_path.stem}.png\"\n", | ||
" if any(region in str(img_path) for region in test_regions):\n", | ||
" shutil.copy(img_path, test_data_dir)\n", | ||
" shutil.copy(msk_path, test_data_dir)\n", | ||
" else:\n", | ||
" shutil.copy(img_path, train_data_dir)\n", | ||
" shutil.copy(msk_path, train_data_dir)" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Create a data loader\n", | ||
"\n", | ||
"Load and prepare the images and masks by creating a data loader." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_mask_path(x, train_data_dir):\n", | ||
" return Path(train_data_dir) / f\"{Path(x).stem}.png\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"bs = 8\n", | ||
"valid_pct = 0.20\n", | ||
"img_size = 256\n", | ||
"\n", | ||
"data_loader = SegmentationDataLoaders.from_label_func(\n", | ||
" path=train_data_dir,\n", | ||
" fnames=get_files(train_data_dir, extensions=\".jpg\"),\n", | ||
" label_func=partial(get_mask_path, train_data_dir=train_data_dir),\n", | ||
" codes=[\"not-pool\", \"pool\"],\n", | ||
" bs=bs,\n", | ||
" valid_pct=valid_pct,\n", | ||
" item_tfms=Resize(img_size),\n", | ||
" batch_tfms=[\n", | ||
" Normalize.from_stats(*imagenet_stats),\n", | ||
" ],\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Review a sample batch of data\n", | ||
"\n", | ||
"Below are some examples of the images overlaid with their masks." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"data_loader.show_batch(alpha=0.7)" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Train multiple models with different learning rates using `DVCLiveCallback`\n", | ||
"\n", | ||
"Set up model training, using DVCLive to capture the results of each experiment." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def dice(mask_pred, mask_true, classes=[0, 1], eps=1e-6):\n", | ||
" dice_list = []\n", | ||
" for c in classes:\n", | ||
" y_true = mask_true == c\n", | ||
" y_pred = mask_pred == c\n", | ||
" intersection = 2.0 * np.sum(y_true * y_pred)\n", | ||
" dice = intersection / (np.sum(y_true) + np.sum(y_pred) + eps)\n", | ||
" dice_list.append(dice)\n", | ||
" return np.mean(dice_list)\n", | ||
"\n", | ||
"def evaluate(learn):\n", | ||
" test_img_fpaths = sorted(get_files(DATA / \"test_data\", extensions=\".jpg\"))\n", | ||
" test_dl = learn.dls.test_dl(test_img_fpaths)\n", | ||
" preds, _ = learn.get_preds(dl=test_dl)\n", | ||
" masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=np.uint8)\n", | ||
" test_mask_fpaths = [\n", | ||
" get_mask_path(fpath, DATA / \"test_data\") for fpath in test_img_fpaths\n", | ||
" ]\n", | ||
" masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths]\n", | ||
" dice_multi = 0.0\n", | ||
" for ii in range(len(masks_true)):\n", | ||
" mask_pred, mask_true = masks_pred[ii], masks_true[ii]\n", | ||
" width, height = mask_true.shape[1], mask_true.shape[0]\n", | ||
" mask_pred = np.array(\n", | ||
" Image.fromarray(mask_pred).resize((width, height)),\n", | ||
" dtype=int\n", | ||
" )\n", | ||
" mask_true = np.array(mask_true, dtype=int)\n", | ||
" dice_multi += dice(mask_true, mask_pred) / len(masks_true)\n", | ||
" return dice_multi" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"scrolled": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import tarfile\n", | ||
"\n", | ||
"train_arch = 'shufflenet_v2_x2_0'\n", | ||
"models_dir = ROOT / \"models\"\n", | ||
"models_dir.mkdir(exist_ok=True)\n", | ||
"results_dir = ROOT / \"results\" / \"train\"\n", | ||
"\n", | ||
"for base_lr in [0.001, 0.005]:\n", | ||
" # initialize dvclive, optionally provide output path, and save results as a dvc experiment\n", | ||
" with Live(str(results_dir), save_dvc_exp=True, report=\"notebook\") as live:\n", | ||
" # log a parameter\n", | ||
" live.log_param(\"train_arch\", train_arch)\n", | ||
" fine_tune_args = {\n", | ||
" 'epochs': 2,\n", | ||
" 'base_lr': base_lr\n", | ||
" }\n", | ||
" # log a dict of parameters\n", | ||
" live.log_params(fine_tune_args)\n", | ||
"\n", | ||
" learn = unet_learner(data_loader, \n", | ||
" arch=getattr(models, train_arch), \n", | ||
" metrics=DiceMulti)\n", | ||
" # train model and automatically capture metrics with DVCLiveCallback\n", | ||
" learn.fine_tune(\n", | ||
" **fine_tune_args,\n", | ||
" cbs=[DVCLiveCallback(live=live)])\n", | ||
"\n", | ||
" learn.export(fname=(models_dir / \"model.pkl\").absolute())\n", | ||
"\n", | ||
" # add additional post-training summary metrics\n", | ||
" live.summary[\"evaluate/dice_multi\"] = evaluate(learn)\n", | ||
"\n", | ||
" torch.save(learn.model, (models_dir / \"model.pth\").absolute())\n", | ||
"\n", | ||
" # save model artifact to dvc\n", | ||
" live.log_artifact(\n", | ||
" str(models_dir / \"model.pth\"),\n", | ||
" type=\"model\",\n", | ||
" name=\"pool-segmentation\",\n", | ||
" desc=\"This is a Computer Vision (CV) model that's segmenting out swimming pools from satellite images.\",\n", | ||
" labels=[\"cv\", \"segmentation\", \"satellite-images\", \"unet\"],\n", | ||
" )\n", | ||
"\n", | ||
" # For deploying the model to Sagemaker will also create a sagemaker model archive\n", | ||
" # and save it to dvc - this is because Sagemaker expects a particular archive structure\n", | ||
" sagemaker_dir = ROOT / \"sagemaker\"\n", | ||
"\n", | ||
" with tarfile.open(sagemaker_dir / \"model.tar.gz\", \"w:gz\") as tar:\n", | ||
" tar.add(sagemaker_dir / \"code\", \"code\")\n", | ||
" tar.add(models_dir / \"model.pth\", \"code/model.pth\")\n", | ||
"\n", | ||
" live.log_artifact(sagemaker_dir / \"model.tar.gz\")\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.6" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters