-
Notifications
You must be signed in to change notification settings - Fork 9.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add a script to run all pytorch examples #591
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
#!/bin/sh | ||
# | ||
# This script runs through the code in each of the python examples. | ||
# The purpose is just as an integrtion test, not to actually train | ||
# models in any meaningful way. For that reason, most of these set | ||
# epochs = 1. | ||
# | ||
# Optionally specify a comma separated list of examples to run. | ||
# can be run as: | ||
# ./run_python_examples.sh "install_deps,run_all,clean" | ||
# to pip install dependencies (other than pytorch), run all examples, | ||
# and remove temporary/changed data files. | ||
# Expects pytorch to be installed. | ||
|
||
BASE_DIR=`pwd`"/"`dirname $0` | ||
EXAMPLES=`echo $1 | sed -e 's/ //g'` | ||
|
||
if which nvcc ; then | ||
echo "using cuda" | ||
CUDA=1 | ||
CUDA_FLAG="--cuda" | ||
else | ||
echo "not using cuda" | ||
CUDA=0 | ||
CUDA_FLAG="" | ||
fi | ||
|
||
ERRORS="" | ||
|
||
function error() { | ||
ERR=$1 | ||
ERRORS="$ERRORS\n$ERR" | ||
echo $ERR | ||
} | ||
|
||
function install_deps() { | ||
echo "installing requirements" | ||
cat $BASE_DIR/*/requirements.txt | \ | ||
sort -u | \ | ||
# testing the installed version of torch, so don't pip install it. | ||
grep -vE '^torch$' | \ | ||
pip install -r /dev/stdin || \ | ||
{ error "failed to install dependencies"; exit 1; } | ||
} | ||
|
||
function start() { | ||
EXAMPLE=${FUNCNAME[1]} | ||
cd $BASE_DIR/$EXAMPLE | ||
echo "Running example: $EXAMPLE" | ||
} | ||
|
||
function dcgan() { | ||
start | ||
if [ ! -d "lsun" ]; then | ||
echo "cloning repo to get lsun dataset" | ||
git clone https://github.com/fyu/lsun || { error "couldn't clone lsun repo needed for dcgan"; return; } | ||
fi | ||
# 'classroom' much smaller than the default 'bedroom' dataset. | ||
DATACLASS="classroom" | ||
if [ ! -d "lsun/${DATACLASS}_train_lmdb" ]; then | ||
pushd lsun | ||
python download.py -c $DATACLASS || { error "couldn't download $DATACLASS for dcgan"; return; } | ||
unzip ${DATACLASS}_train_lmdb.zip || { error "couldn't unzip $DATACLASS"; return; } | ||
popd | ||
fi | ||
python main.py --dataset lsun --dataroot lsun --classes $DATACLASS --niter 1 $CUDA_FLAG || error "dcgan failed" | ||
} | ||
|
||
function fast_neural_style() { | ||
start | ||
if [ ! -d "saved_models" ]; then | ||
echo "downloading saved models for fast neural style" | ||
python download_saved_models.py | ||
fi | ||
test -d "saved_models" || { error "saved models not found"; return; } | ||
|
||
echo "running fast neural style model" | ||
python neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg --cuda $CUDA || error "neural_style.py failed" | ||
} | ||
|
||
function imagenet() { | ||
start | ||
if [[ ! -d "sample/val" || ! -d "sample/train" ]]; then | ||
mkdir -p sample/val/n | ||
mkdir -p sample/train/n | ||
wget "https://upload.wikimedia.org/wikipedia/commons/5/5a/Socks-clinton.jpg" || { error "couldn't download sample image for imagenet"; return; } | ||
mv Socks-clinton.jpg sample/train/n | ||
cp sample/train/n/* sample/val/n/ | ||
fi | ||
python main.py --epochs 1 sample/ || error "imagenet example failed" | ||
} | ||
|
||
function mnist() { | ||
start | ||
python main.py --epochs 1 || error "mnist example failed" | ||
} | ||
|
||
function mnist_hogwild() { | ||
start | ||
python main.py --epochs 1 $CUDA_FLAG || error "mnist hogwild failed" | ||
} | ||
|
||
function regression() { | ||
start | ||
python main.py --epochs 1 $CUDA_FLAG || error "regression failed" | ||
} | ||
|
||
function reinforcement_learning() { | ||
start | ||
python reinforce.py || error "reinforcement learning failed" | ||
} | ||
|
||
function snli() { | ||
start | ||
echo "installing 'en' model if not installed" | ||
python -m spacy download en || { error "couldn't download 'en' model needed for snli"; return; } | ||
echo "training..." | ||
python train.py --epochs 1 --no-bidirectional || error "couldn't train snli" | ||
} | ||
|
||
function super_resolution() { | ||
start | ||
python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 1 --lr 0.001 || error "super resolution failed" | ||
} | ||
|
||
function time_sequence_prediciton() { | ||
start | ||
python generate_sine_wave.py || { error "generate sine wave failed"; return; } | ||
python train.py || error "time sequence prediction training failed" | ||
} | ||
|
||
function vae() { | ||
start | ||
python main.py --epochs 1 || error "vae failed" | ||
} | ||
|
||
function word_language_model() { | ||
start | ||
python main.py --epochs 1 $CUDA_FLAG || error "word_language_model failed" | ||
} | ||
|
||
function clean() { | ||
cd $BASE_DIR | ||
rm -rf dcgan/_cache_lsun_classroom_train_lmdb dcgan/fake_samples_epoch_000.png dcgan/lsun/ dcgan/netD_epoch_0.pth dcgan/netG_epoch_0.pth dcgan/real_samples.png fast_neural_style/saved_models.zip fast_neural_style/saved_models/ imagenet/checkpoint.pth.tar imagenet/lsun/ imagenet/model_best.pth.tar imagenet/sample/ snli/.data/ snli/.vector_cache/ snli/results/ super_resolution/dataset/ super_resolution/model_epoch_1.pth word_language_model/model.pt || error "couldn't clean up some files" | ||
|
||
git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image" | ||
} | ||
|
||
function run_all() { | ||
dcgan | ||
fast_neural_style | ||
imagenet | ||
mnist | ||
mnist_hogwild | ||
regression | ||
reinforcement_learning | ||
snli | ||
super_resolution | ||
time_sequence_prediction | ||
vae | ||
word_language_model | ||
} | ||
|
||
# by default, run all examples | ||
if [ "" == "$EXAMPLES" ]; then | ||
run_all | ||
else | ||
for i in $(echo $EXAMPLES | sed "s/,/ /g") | ||
do | ||
$i | ||
done | ||
fi | ||
|
||
if [ "" == "$ERRORS" ]; then | ||
tput setaf 2 | ||
echo "Completed successfully" | ||
else | ||
tput setaf 1 | ||
echo "Some examples failed:" | ||
printf "$ERRORS" | ||
fi | ||
|
||
tput sgr0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
torch | ||
torchtext | ||
spacy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stuff like LSUN and Imagenet are really large datasets.
I wonder if we can do an environment variable like
$DATA_ROOT
to go look for them over there first. This way we can cache / mount them in CI