Skip to content

Commit

Permalink
Updated the examples to use pytorch 1.9.0 and pytorch-lightning 1.3.5 (
Browse files Browse the repository at this point in the history
…mlflow#4458)

* Updated the examples to use pytorch 1.9.0 and pytorch-lightning 1.3.5

Signed-off-by: KasirajanA <kasirajan.alayappan@ideas2it.com>
Signed-off-by: Kasirajan Alayappan <kasirajan.alayappan@ideas2it.com>

* Pytorch bert example fix for gpu run

Signed-off-by: Ubuntu <ubuntu@ip-172-31-26-167.us-east-2.compute.internal>
Signed-off-by: Kasirajan Alayappan <kasirajan.alayappan@ideas2it.com>

* Updated the torchvision version in pytorch examples

Signed-off-by: KasirajanA <kasirajan.alayappan@ideas2it.com>
Signed-off-by: Kasirajan Alayappan <kasirajan.alayappan@ideas2it.com>

* Skipping pytorch override pickle module testcase for torch>=1.9.0

Signed-off-by: Kasirajan Alayappan <kasirajan.alayappan@ideas2it.com>

* Skipping pytorch override pickle module testcase for torch>=1.9.0

Signed-off-by: Kasirajan Alayappan <kasirajan.alayappan@ideas2it.com>

* Revert "Skipping pytorch override pickle module testcase for torch>=1.9.0"

This reverts commit 7da47f8.

Signed-off-by: Kasirajan Alayappan <kasirajan.alayappan@ideas2it.com>

* Revert "Skipping pytorch override pickle module testcase for torch>=1.9.0"

This reverts commit 6277430.

Signed-off-by: Kasirajan Alayappan <kasirajan.alayappan@ideas2it.com>

* Addressing review comments

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Removing unused columns

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

* Setting default sample size to 2000

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>

Co-authored-by: Ubuntu <ubuntu@ip-172-31-26-167.us-east-2.compute.internal>
Co-authored-by: Shrinath Suresh <shrinath@ideas2it.com>
  • Loading branch information
3 people authored Jul 13, 2021
1 parent ec6a16e commit 48a5128
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 27 deletions.
6 changes: 3 additions & 3 deletions examples/pytorch/AxHyperOptimizationPTL/conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dependencies:
- pip
- pip:
- mlflow
- pytorch-lightning
- pytorch-lightning==1.3.5
- ax-platform
- torch>=1.6.0
- torchvision
- torchvision>=0.9.1
- torch==1.9.0
2 changes: 1 addition & 1 deletion examples/pytorch/BertNewsClassification/MLproject
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ entry_points:
batch_size: {type: int, default: 64}
num_workers: {type: int, default: 3}
learning_rate: {type: float, default: 0.001}
num_samples: {type: int, default: 15000}
num_samples: {type: int, default: 2000}
dataset: {type: str, default: "20newsgroups"}

command: |
Expand Down
13 changes: 6 additions & 7 deletions examples/pytorch/BertNewsClassification/bert_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, AdamW
from torchtext.utils import download_from_url, extract_archive
from torchtext.datasets.text_classification import URLS
import torchtext.datasets as td


def get_20newsgroups(num_samples):
Expand All @@ -32,9 +31,9 @@ def get_20newsgroups(num_samples):


def get_ag_news(num_samples):
dataset_tar = download_from_url(URLS["AG_NEWS"], root=".data")
extracted_files = extract_archive(dataset_tar)
train_csv_path = list(filter(lambda x: x.endswith("train.csv"), extracted_files))[0]
# reading the input
td.AG_NEWS(root="data", split=("train", "test"))
train_csv_path = "data/AG_NEWS/train.csv"
return (
pd.read_csv(train_csv_path, usecols=[0, 2], names=["label", "description"])
.assign(label=lambda df: df["label"] - 1) # make labels zero-based
Expand Down Expand Up @@ -401,12 +400,12 @@ def configure_optimizers(self):
early_stopping = EarlyStopping(monitor="val_loss", mode="min", verbose=True)

checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(), save_top_k=1, verbose=True, monitor="val_loss", mode="min", prefix="",
dirpath=os.getcwd(), save_top_k=1, verbose=True, monitor="val_loss", mode="min",
)
lr_logger = LearningRateMonitor()

trainer = pl.Trainer.from_argparse_args(
args, callbacks=[lr_logger, early_stopping], checkpoint_callback=checkpoint_callback
args, callbacks=[lr_logger, early_stopping, checkpoint_callback], checkpoint_callback=True
)
trainer.fit(model, dm)
trainer.test()
8 changes: 4 additions & 4 deletions examples/pytorch/BertNewsClassification/conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- boto3
- transformers>=4.0.0
- pandas
- torch==1.6.0
- torchvision==0.7.0
- torchtext==0.7.0
- pytorch-lightning==1.0.2
- torchvision>=0.9.1
- torch==1.9.0
- torchtext==0.10.0
- pytorch-lightning==1.3.5
6 changes: 3 additions & 3 deletions examples/pytorch/IterativePruning/conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ dependencies:
- pip
- pip:
- mlflow
- torchvision
- torchvision>=0.9.1
- cloudpickle==1.6.0
- pytorch_lightning>=1.0.5
- pytorch_lightning==1.3.5
- ax-platform
- prettytable
- torch>=1.6.0
- torch==1.9.0
6 changes: 3 additions & 3 deletions examples/pytorch/MNIST/conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ dependencies:
- pip
- pip:
- mlflow
- torch==1.8.1
- torchvision==0.9.1
- pytorch-lightning==1.0.2
- torchvision>=0.9.1
- torch==1.9.0
- pytorch-lightning==1.3.5
4 changes: 2 additions & 2 deletions examples/pytorch/MNIST/mnist_autolog_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,12 @@ def configure_optimizers(self):
)

checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(), save_top_k=1, verbose=True, monitor="val_loss", mode="min", prefix="",
dirpath=os.getcwd(), save_top_k=1, verbose=True, monitor="val_loss", mode="min"
)
lr_logger = LearningRateMonitor()

trainer = pl.Trainer.from_argparse_args(
args, callbacks=[lr_logger, early_stopping], checkpoint_callback=checkpoint_callback
args, callbacks=[lr_logger, early_stopping, checkpoint_callback], checkpoint_callback=True
)
trainer.fit(model, dm)
trainer.test()
4 changes: 2 additions & 2 deletions examples/pytorch/torchscript/IrisClassification/conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ dependencies:
- sklearn
- cloudpickle==1.6.0
- boto3
- torch==1.6.0
- torchvision==0.7.0
- torchvision>=0.9.1
- torch==1.9.0
4 changes: 2 additions & 2 deletions examples/pytorch/torchscript/MNIST/conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ dependencies:
- mlflow
- cloudpickle==1.6.0
- boto3
- torch==1.6.0
- torchvision==0.7.0
- torchvision>=0.9.1
- torch==1.9.0

0 comments on commit 48a5128

Please sign in to comment.