-
Notifications
You must be signed in to change notification settings - Fork 79
Support batching for rbapinns #589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the logic for batching weights is very clean, this actually solves the problem nicely!
I put some consideration on the code, I am wondering about future maintainability, I think we can clean it more
""" | ||
super().__init__( | ||
model=model, | ||
problem=problem, | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
weighting=weighting, | ||
loss=loss, | ||
loss=torch.nn.MSELoss(reduction="none"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do this internally? The user will for sure forget to pass reduction=none when passing a custom loss
0 < gamma < 1 | ||
), f"Invalid range: expected 0 < gamma < 1, got {gamma=}" | ||
|
||
# Validate range for gamma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also a check on eta?
for condition_name in problem.conditions: | ||
self.weights[condition_name] = 0 | ||
# Initialize the weight of each point to 0 | ||
self.weights = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I would use register buffer see here. In this case we can restore the training without problems
return super().on_train_start() | ||
|
||
def _vect_to_scalar(self, loss_value): | ||
def training_step(self, batch, batch_idx, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the fact that we are touching training/val/test step nor the optimization cycle. We splitted this in #542 to make the code more maintainable. Can we avoid it?
From what I can see, the train/val/test can be kept the same (we don't need to take out global averaging of the losses, the user can decide imho whether to use it or not in addition to RBA weights)
self.store_log("test_loss", loss, self.get_batch_size(batch)) | ||
return loss | ||
|
||
def _optimization_cycle(self, batch, batch_idx, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would put this logic in loss_phys. It shouldn't be a big deal. In such a way we only need to override loss_phys.
Description
This PR fixes #588
In this PR, I have added the support for batching while using
RBAPINN
solver class.Checklist