-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/lightning #35
base: master
Are you sure you want to change the base?
Conversation
…from the base Trainer, will go to Lightning
…g is now much cleaner and all of lightning functionality is in one file
ok with the last commit I feel a bit better about things. |
nice! yeah, when you enable tensorboard and checkpoints the training slows down a bit. we actually have tests to ensure we don’t incur a meaningful overhead over a vanilla pytorch script. lightning only adds 300ms per epoch (and no memory leaks, etc) https://github.com/PyTorchLightning/pytorch-lightning/blob/master/benchmarks/test_parity.py |
next you should convert the datasets to datamodules. this means you can 100% decouple the model from the data :) |
@williamFalcon ok done. few things I find a bit gross:
|
Alright calling it here for today, I'm tired and still have some actual work to do. I'm pretty sure I don't understand how |
ah i see the confusion! what is stage in setup?.fit() calls setup (for train, val data) Our data calls are lazy, so we defer initializing them as long as we can to not cause unnecessary overhead. 0.9.0 -> 1.0 feedbackGot through a good chunk of the refactors (check out the evaluate loop). Wrapping up the rest over the next few days to get the train loop to look pristine as well haha. But, if there are any weird issues you see with the API, or any other paradigms we might consider, happy to make the changes this week to enable any use cases we may have missed or get rid of any parts that feel gross. Do the below only if you care about multi-gpu...Datamodule tutorial by one of our team members. If you're training on 1 gpu, none of what i'm about to say matters. this only matters when making the code agnostic to n gpus prepare dataThis is to do something only once (ie: in 100 GPU world, only on GPU 0. examples are: tokenize, download, etc...). setupThis is another prep stage but it's called on every GPU. This means that splitting or anything like that can be done here. train, val, test dataloadersThese are lazy called... which means that you don't have the overhead of creating the data until you really need it (this is key for performance applications). data for initthere's a case where your model might depend on information about the data (ie: voca size, num_classes, etc). In this case, you can just hardcode this into the datamodule: def __init__(...):
self.num_classes = x
self.vocab_size = y or get the info in setup: def setup(...):
download()
tokenize()
self.vocab = count_vocab() Normally, lightning calls prepare_data and setup for you automatically in training. However, depending on how you set it up (let's say you got the vocab size in setup), then you can manually call it after init. dm = Datamodule()
# even if called manually, lightning makes sure it only happens on the correct devices
dm.prepare_data()
dm.setup()
model = LitModel(vocab_size=dm.vocab_size)
trainer = Trainer()
trainer.fit(model, dm) I also don't love that you have to call prepare_data + setup yourself. Open to any ideas that you might have :) Datamodule examplesSimclr made agnostic of dataset. (it's pretty cool you can use the same code and train on any dataset without changing your original base code. datamodule videosand of course, the obligatory tutorial on datamodules by one of our team members. real-world tutorialsHere's one of our new tutorials as well on implementing SimCLR which would be a more realistic complex example. |
bench.py
Outdated
def setup(self, stage): # called for every GPU/machine | ||
if stage == 'train' or stage == 'fit': | ||
pass # nothing to do, the train_dataset is initialized in the constructor | ||
elif stage == 'val': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stage here is whether the setup is happening on fit or test...
BUT... to your point, do you think it might make more sense to change the stage to 'train', 'val', 'test'? the thing is that train, val are usually handled together (ie: split train into train/val and have a separate test set)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I'm pretty sure I totally misunderstood what a "stage" is and thought it referred to splits.
bench.py
Outdated
# ----------------------------------------------------------------------------- | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-x', '--num-epochs', type=int, default=5, help="number of epochs to train for") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Datamodules can do:
parser = argparse.ArgumentParser()
# enables whatever you have in your init in argparse :)
parser = CharDataModule.add_argparse_args(parser)
# enable all the trainer flags in argparse
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
# now you can init whatever objects automatically as well:
trainer = Trainer.from_argparse_args(args, any_flag_to_override=...)
dm = CharDataModule.from_argparse_args(args)
Which lets you do things like:
python main.py --gpus 2 --num_nodes 3 --batch_size 32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat! I'll have to read more of the docs
bench.py
Outdated
def train_dataloader(self): | ||
loader = DataLoader(self.train_dataset, batch_size=self.batch_size, | ||
shuffle=True, pin_memory=bool(self.pin_memory), | ||
num_workers=self.num_workers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw... with multiple GPUs num_workers > 0 in ddp_spawn mode is slow. This is a pytorch limitation because ddp_spawn generates subprocesses and in each subprocess there are more subprocesses generated by dataloaders.
that's why we recommend ddp as the backend for multi-gpu but unfortunately can't be called on a jupyter lab because those have limitations as well haha.
basically, until we re-invent jupyter notebooks we are a bit stuck haha...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, subtle point!
mingpt/fake_lightning.py
Outdated
if self.gpus > 0 and torch.cuda.is_available(): | ||
logger.info("found CUDA device, shipping model to GPU") | ||
device = 'cuda' | ||
self.model = self.model.to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you are enabling gpus=x i would just do model.cuda(x)
so people can place models on a gpu indexed by the PCI_BUS_ID (you might need this flag enabled though
export CUDA_DEVICE_ORDER=PCI_BUS_ID
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but for teaching purposes it may be overkill
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, yes it's cleaner. I like the more general .to syntax because we'll have many different XPUs etc, feels a bit more future proof
yeah, this is an optional abstraction (you can use any dataloader with lightning). we also just introduced it, so need to make the docs more clear. but it really is optional, so it’s not a big deal to use dataloaders directly. it just makes the data more reusable |
Okay I merged one more big refactor. Honestly I am starting to think this branch was a very bad idea. I thought I could make things clean but there is a lot of baggage that Lightning "leaks" in a number of places, e.g. w.r.t. model checkpointing, the use of Training/Eval Result structures, forcing me into relatively odd looking abstractions and half-measures. Anyway, thank you for your help @williamFalcon , I'll have to sleep on this a few days, read the Lightning docs more, and then maybe give it another shot some other time. |
ok, i understand the confusion! doc updatesI updated the docs to show results as an optional extension! Also split the docs into optional vs required. Only required APIs are:
Optional:
no resultsWe added the results object recently. But forgot to show in docs that it is 100% optional. def training_step(...)
loss = ...
# option 1
return loss
# if you also want to log
return {'loss': loss, 'log': {'train_loss': loss}}
# Option 2 (optional):
# results just make it more flexible/clean and adds functionality
result = TrainResult(loss)
result.log('train_loss', loss, on_step=True, on_epoch=True) checkpointsCheckpoints store hyperparams, training state, etc... however if you just want the plain python checkpoint: ckpt = torch.load(path)
model.load_state_dict(ckpt['state_dict']) |
bench.py
Outdated
# ckpt_path = os.path.join(args.default_root_dir, 'model.pt') | ||
# model.load_from_checkpoint(ckpt_path) # load the best checkpoint we found | ||
# trainer.test(test_dataloader=test_dataloader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# ckpt_path = os.path.join(args.default_root_dir, 'model.pt') | |
# model.load_from_checkpoint(ckpt_path) # load the best checkpoint we found | |
# trainer.test(test_dataloader=test_dataloader) | |
# Note: LIGHTNING automatically loads the best checkpoint when you call .test() | |
trainer.test(test_dataloader=test_dataloader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it. it looks like test_dataloader
is not a kwarg, it's test_dataloaders
with an 's'. Similar to val_dataloaders, but not the same as train_dataloader without the s, it looks like. Some of the docs are inconsistent on the use of "s" btw, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. we enable multiple dataloaders for val and test. coming support for train.
not in research i’m used to, but turns out some people need two datasets to validate haha. go figure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some people always need something, which is why frameworks are so hard. Next thing you know you can't use a list of data loaders and have to introduce a DataLoaderSetManager
object.
mingpt/fake_lightning.py
Outdated
class Result: | ||
""" very thin wrapper around a result of a train/val/test step of the model """ | ||
def __init__(self, minimize=None, checkpoint_on=None): | ||
self.minimize = minimize | ||
self.checkpoint_on = checkpoint_on | ||
|
||
def log(self, key, val): | ||
setattr(self, key, val) | ||
|
||
class TrainResult(Result): | ||
pass | ||
|
||
class EvalResult(Result): | ||
pass | ||
|
||
class LightningModule(nn.Module): | ||
|
||
def load_from_checkpoint(self, checkpoint_path): | ||
logger.info("loading the best model checkpoint from %s", checkpoint_path) | ||
state_dict = torch.load(checkpoint_path) | ||
self.load_state_dict(state_dict) | ||
|
||
class Callback: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry! this is 100% optional. This is a new addition and I see we forgot to include the simple case and doc examples using a dict or the loss directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, ok converted to use of dicts with latest commit
… expect this to do too much to latency
…perly under __main__
Ok I think things have improved quite a bit. In particular, my "fake lightning" has now been reduced all the way to class LightningModule(nn.Module):
pass
class Callback:
pass which is fun :) And I can train with the fake trainer or the lightning trainer and the code looks decent ish. |
I'm trying @williamFalcon , but I have somewhat mixed feelings about it. The API are now matched up and I can train the basic loop with either:
or
some overhead incurred, not that it matters too much at the stage of a single GPU.
To merge would still have to:
bench.py