-
Notifications
You must be signed in to change notification settings - Fork 49
Enable single sample processing #1380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: mk/develop/1300_assemble_diffusion_model
Are you sure you want to change the base?
Enable single sample processing #1380
Conversation
| else: | ||
| assert samples_per_mini_epoch, "must specify samples_per_mini_epoch if repeat_data" | ||
| self.len = samples_per_mini_epoch | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if I understand correctly, but do we not want to run an epoch with e.g. samples_per_mini_epoch: 4096 while always repeating the same e.g. 4 samples? In this case we would need to introduce another config parameter, e.g. repeat_num_samples: 4 or repeat_num_idxs: [1,2,3,4] (to define specific indices, not sure if would need this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is addressed in line 278, but may not cover some edge cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the underlying question is whether we need to adjust start_date and end_date in the config to shorten the dataset or if we can also repeat samples randomly from the entire dataset. We can discuss in 5 minutes :)
|
Somehow getting a |
MatKbauer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I have added some comments to retain (fsm + 1). Overall, I'd suggest we use this branch for our single sample experiments.
config/default_config.yml
Outdated
| end_date_val: 202201010000 | ||
| end_date: 201401011200 | ||
| start_date_val: 201401010000 | ||
| end_date_val: 201401011200 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's set the end_dates for train and val to 201401011800, such that we can keep (fsm + 1) below.
| forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs | ||
| forecast_len = ( | ||
| self.len_hrs * (fsm) | ||
| ) // self.step_hrs # TODO: check if it should be fsm + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the adjusted end_dates in the config, we can revert this to forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs, i.e., using (fsm + 1) instead of fsm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this comment
|
Added some comments and removed one edge case handling that is not necessary anymore due to extended data range. |
| forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs | ||
| print(f"idx range end is {idx_end}") | ||
| print(f"len hrs is {self.len_hrs}, step hrs is {self.step_hrs}, fsm is {fsm}") | ||
| forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs # NOTE: why is it fsm +1? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fsm := the maximum number of forecast steps to predict. +1 for the potential forecast_offset of 0 or 1. This guarantees that also the targets are always within the range defined by start_date and end_date.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining!
|
@moritzhauschulz @MatKbauer If this PR is ready, lets try and get it merged. |
Description
Basically trying to enable the code to allow processing a single sample so we can overfit the diffusion model to a single sample.
We introduce a --repeat_data flag. When active, and the number of samples in the dataset is lower than the size of the mini epoch, then the data is tiled to fill up the mini epoch. Currently a check is introduced to ensure that in that case the number of samples evenly divides the size of the mini epoch to have a balanced mini epoch.
Some additional checks are introduced to avoid misconfigurations, see code.
Some checks may need to be removed after the overfitting experiment for the diffusion model.
Not currently addressing #1370.
Issue Number
Closes #1379
Checklist before asking for review
Currently getting
ModuleNotFoundError: No module named 'flash_attn'when runninguv run train. Worked well before merging though..../scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60