Skip to content

Support batching for rbapinns #589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from

Conversation

GiovanniCanali
Copy link
Collaborator

@GiovanniCanali GiovanniCanali commented Jun 15, 2025

Description

This PR fixes #588

In this PR, I have added the support for batching while using RBAPINN solver class.

Checklist

  • Code follows the project’s Code Style Guidelines
  • Tests have been added or updated
  • Documentation has been updated if necessary
  • Pull request is linked to an open issue

@GiovanniCanali GiovanniCanali self-assigned this Jun 15, 2025
@GiovanniCanali GiovanniCanali added the enhancement New feature or request label Jun 15, 2025
@GiovanniCanali GiovanniCanali marked this pull request as ready for review June 15, 2025 10:53
Copy link
Collaborator

@dario-coscia dario-coscia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic for batching weights is very clean, this actually solves the problem nicely!

I put some consideration on the code, I am wondering about future maintainability, I think we can clean it more

"""
super().__init__(
model=model,
problem=problem,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
loss=loss,
loss=torch.nn.MSELoss(reduction="none"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this internally? The user will for sure forget to pass reduction=none when passing a custom loss

0 < gamma < 1
), f"Invalid range: expected 0 < gamma < 1, got {gamma=}"

# Validate range for gamma
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a check on eta?

for condition_name in problem.conditions:
self.weights[condition_name] = 0
# Initialize the weight of each point to 0
self.weights = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I would use register buffer see here. In this case we can restore the training without problems

return super().on_train_start()

def _vect_to_scalar(self, loss_value):
def training_step(self, batch, batch_idx, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the fact that we are touching training/val/test step nor the optimization cycle. We splitted this in #542 to make the code more maintainable. Can we avoid it?

From what I can see, the train/val/test can be kept the same (we don't need to take out global averaging of the losses, the user can decide imho whether to use it or not in addition to RBA weights)

self.store_log("test_loss", loss, self.get_batch_size(batch))
return loss

def _optimization_cycle(self, batch, batch_idx, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put this logic in loss_phys. It shouldn't be a big deal. In such a way we only need to override loss_phys.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support batching for Residual Based Attention PINNs (RBAPINNs)
3 participants