Simple MNIST classifier example using PyTorch Lightning.
Install Miniconda if you haven't already.
Create a Conda environment and install dependencies. This MNIST example uses:
- PyTorch Lightning for constructing data module and model that can work across machines and utilize accelerator devices (GPUs/TPUs) if available.
- Hydra config for cleanly specifying config
conda env create --file environment.ymlBefore running any commands below, activate the environment:
conda activate mnistPip-based installation is not recommended as it provides less control on the dependencies, can be less optimal, and some packages are also only released through Conda. However a requirements.txt file is provided here for demo of a pip setup.
pip install -r requirements.txtInspect/modify the Hydra config at config.yaml. Then run:
# train
python mnist/train.py
# to change configs
python mnist/train.py datamodule.batch_size=64 model.hidden_size=128
# resume training from last checkpoint
python mnist/train.py resume=trueThis run training with validation at epoch-end, and test when training is done. Metrics will be logged from torchmetrics.
By using config-based approach, any variant to the run can be specified as parameter overrides to Hydra configs - hence we can tune hyperparameters without any code changes.
The entrypoint train.py returns a float to be used for optimization; the logged metrics in trainer can be accessed via trainer.callback_metrics, and the config cfg.metric specifies which field.
Hydra has an Optuna sweeper plugin. To run hyperparameter tuning, simply specify the parameter override and search space/details in config/optuna/, and run the folllowing:
# hyperparameter search using Optuna + Hydra. Configure in config/optuna.yaml
# view Optuna sweeper config
python mnist/train.py hydra/sweeper=optuna +optuna=tune -c hydra -p hydra.sweeper
# run Optuna sweeper using optuna/tune.yaml to search over tune and other hyperparams
python mnist/train.py hydra/sweeper=optuna +optuna=tune --multirunFor dstack usage, including interactive development, see workflows defined in .dstack/workflows/*.yaml.
First, start dstack hub:
# setup dstack
pip install -U dstack
# start dstack hub
dstack startThen in a new shell, init dstack project and run workflow:
# initialize project
dstack init
# run training workflow: pip-train
dstack run pip-train
# or run hyperparameter search workflow: pip-tune
dstack run pip-tune