Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mnist example update to new compile stage API #831

Merged
merged 9 commits into from
Jul 24, 2023
Prev Previous commit
Next Next commit
setup training workflow
  • Loading branch information
eddogola authored and Eddy Ogola Onyango committed Jul 20, 2023
commit ee7a0c1fe985e910621347b555e2a509b43fee94
78 changes: 72 additions & 6 deletions examples/mnist/new_pippy_mnist.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from tqdm import tqdm
import argparse
import os

import torch
from torch import nn
import torch.optim as optim
import torch.distributed as dist
from torchvision import datasets, transforms
from torch.nn.functional import cross_entropy
from torch.utils.data import DistributedSampler, DataLoader
from torch.utils.data import DataLoader

from pippy.microbatch import sum_reducer, TensorChunkSpec
from pippy.IR import LossWrapper, PipeSplitWrapper
from pippy.compile import compile_stage

USE_TQDM = bool(int(os.getenv("USE_TQDM", 1)))
LR_VERBOSE = bool(int(os.getenv("LR_VERBOSE", 1)))


def run_worker(args):
# define transforms
Expand Down Expand Up @@ -65,12 +69,72 @@ def forward(self, input, target):
device=args.device,
group=None,
example_inputs=[x, target],
# output_chunk_spec={
# "loss": sum_reducer,
# "logits": TensorChunkSpec(0),
# },
)

# setup optimizer
optimizer = optim.Adam(stage.submod.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8)
# setup lr scheduler
lr_sched = optim.lr_scheduler.LinearLR(optimizer, verbose=LR_VERBOSE)

loaders = {
"train": train_dataloader,
"valid": valid_dataloader,
}

batches_events_contexts = []

for epoch in range(args.max_epochs):
print(f"Epoch: {epoch + 1} of {args.max_epochs}")

for k, dataloader in loaders.items():
epoch_correct = 0
epoch_all = 0
for i, (x_batch, y_batch) in enumerate(tqdm(dataloader) if USE_TQDM else dataloader):
x_batch = x_batch.to(args.device)
y_batch = y_batch.to(args.device)

if k == "train":
stage.train()
optimizer.zero_grad()

if args.rank == 0:
out = stage(x_batch)
elif args.rank == args.world_size - 1:
out = stage(y_batch)
else:
stage()

# outp, loss = stage(x_batch, y_batch)
preds = out.argmax(-1)
correct = (preds == y_batch).sum()
all = len(y_batch)
epoch_correct += correct.item()
epoch_all += all
optimizer.step()
else:
stage.eval()
with torch.no_grad():
if args.rank == 0:
out = stage(x_batch, y_batch)
elif args.rank == args.world_size - 1:
out = stage()
else:
stage()
# outp, _ = stage(x_batch, y_batch)
preds = out.argmax(-1)
correct = (preds == y_batch).sum()
all = len(y_batch)
epoch_correct += correct.item()
epoch_all += all

print(f"Loader: {k} Accuracy: {epoch_correct / epoch_all}")

if k == "train":
lr_sched.step()
# if LR_VERBOSE:
# print(f"Pipe ") # should we have pp_ranks


dist.barrier()
print(f"Rank {args.rank} completed!")

Expand All @@ -91,8 +155,10 @@ def main(args=None):
"--master_port", type=str, default=os.getenv("MASTER_PORT", "29500")
)
parser.add_argument("--cuda", type=int, default=int(torch.cuda.is_available()))
parser.add_argument("--max_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=10)
parser.add_argument("--chunks", type=int, default=4)
parser.add_argument("--visualize", type=int, default=1, choices=[0, 1])
args = parser.parse_args(args)
if args.cuda:
dev_id = args.rank % torch.cuda.device_count()
Expand Down