Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise an error if sync_dist is set when logging torchmetrics #13903

Closed
wants to merge 1 commit into from

Conversation

rohitgr7
Copy link
Contributor

What does this PR do?

sync_dist is ignored when torchmetrics instances are logged.

Does your PR introduce any breaking changes? If yes, please list them.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@rohitgr7 rohitgr7 added refactor logging Related to the `LoggerConnector` and `log()` labels Jul 28, 2022
@rohitgr7 rohitgr7 self-assigned this Jul 28, 2022
@rohitgr7 rohitgr7 added this to the pl:1.8 milestone Jul 28, 2022
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Jul 28, 2022
@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Jul 28, 2022

An alternate solution (which I think is better) can be to set these values (sync_dist, reduce_fx, sync_dist_group) for torchmetrics instance internally with an info message without raising any such errors. For this, we might have to make some changes to the Metric class.

thoughts? @awaelchli @carmocca

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternate solution (which I think is better) can be to set these values (sync_dist, reduce_fx, sync_dist_group) for torchmetrics instance internally with an info message without raising any such errors. For this, we might have to make some changes to the Metric class.

Why is it better? Failing should be fine and it makes it very clear that they shouldn't be passed.

we might have to make some changes to the Metric class

Why?

and not isinstance(value, Tensor)
and (isinstance(value, Metric) or any(isinstance(v, Metric) for v in value.values()))
):
raise MisconfigurationException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would personally put this in the LightningModule.log method where we have all the other validation checks

@rohitgr7
Copy link
Contributor Author

An alternate solution (which I think is better) can be to set these values (sync_dist, reduce_fx, sync_dist_group) for torchmetrics instance internally with an info message without raising any such errors. For this, we might have to make some changes to the Metric class.

Why is it better? Failing should be fine and it makes it very clear that they shouldn't be passed.

we might have to make some changes to the Metric class

Why?

just for better UX. else the user has to take care of setting this at two different places in they are using both torchmetrics and tensors. Or even if they are using torchmetrics, it makes it easy for them since PL will handle it and it's their intention too. Just makes it easy for them since both are lightning products.

The changes that will be required are just the setters/getters for these arguments.

@awaelchli
Copy link
Contributor

I look at this from a different angle.

The default for sync_dist in the log method is False. Which means when unset, it indicates that we will never communicate with different processes to compute the metric. When logging a TorchMetric however, this is clearly not the case. On the other hand, setting sync_dist=True explicitly and logging a TorchMetric does make sense, even if it doesn't trigger the exact same logic. It makes it appear that it is being "ignored", and raising an error here would prevent anyone to combine it with logging both a scalar and a TorchMetric. One can argue either way, I think it is up to us to determine whether this is considered a convenience or a misleading setting.

If we decide to produce an error like proposed in this PR, we would additionally need to set the default from False to None and also error when it gets set to False.

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Aug 3, 2022

and raising an error here would prevent anyone to combine it with logging both a scalar and a TorchMetric.

@awaelchli
not really, since they are checked per log value.

self.log('some_val', scalar_val, sync_dist=True)  <- does not raise an error
self.log('some_tm_instance, torchmetric_instance, sync_dist=True)  <- raises an error

which means users should unset sync_dist for the second log. I can add the key name to make it more explicit. And this can easily be caused with fast_dev_run since we don't disable logging, but the loggers.

Also, the only reason I raised this PR is to avoid the silent behavior here and have seen concerns regarding this over slack/discussions.

@awaelchli
Copy link
Contributor

awaelchli commented Aug 3, 2022

not really, since they are checked per log value.

My comment was about log_dict, when different types of values get submitted in a dict together.

If we decide to produce an error like proposed in this PR, we would additionally need to set the default from False to None and also error when it gets set to False.

Do you agree with this or not?

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Aug 3, 2022

My comment was about log_dict, when different types of values get submitted in a dict together.

yes, in that case, user has to separate them and make 2 calls. That's why I suggested this #13903 (comment), which will end up with no errors.

If we decide to produce an error like proposed in this PR, we would additionally need to set the default from False to None and also error when it gets set to False.

why do we need to raise an error when it's False? with DDP, we are not forcing them to sync always.

@awaelchli
Copy link
Contributor

why do we need to raise an error when it's False? with DDP, we are not forcing them to sync always.

When using torchmetrics, it will sync whether you want it or not. Internally inside torchmetrics. One can chose whether it happens additionally on step or not. A user logging a metric and setting sync_dist=False is a contradiction, conceptually.

Honestly, I wasn't sure in some of the statements I made earlier and had a hard time looking at the source code behind our result objects. I don't feel very confident, perhaps some of these checks are already done internally, but I wouldn't know where to look.

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Aug 5, 2022

When using torchmetrics, it will sync whether you want it or not.

True, sorry, I noticed this later. We can make it None by default.

One can chose whether it happens additionally on step or not.

we can determine that based on on_step and on_epoch if sync_dist!=None:

  • If either just on_step=True or just on_epoch=True, set the flag inside torchmetrics accordingly
  • If on_step=True and on_epoch=True:
    • Set both the flags
    • Set only sync_on_compute to align it with how we handle scalar metrics.
    • Raise an error to ask user to set it by themselves, since we can't determine if they want to sync on_step, on_epoch or on both.

perhaps some of these checks are already done internally, but I wouldn't know where to look.

if you are talking about:

  • in case torchmetrics are logged, what happens when sync_dist=True? It is ignored
  • what happens on step level if sync_dist=True. Nothing.

@carmocca
Copy link
Contributor

carmocca commented Nov 8, 2022

This is a valuable improvement but low priority for the team. I created an issue (#15588) to track this idea so that somebody can pick it up. Closing the PR.

@carmocca carmocca closed this Nov 8, 2022
@Borda Borda deleted the ref/tm_sync_dist branch December 9, 2022 04:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
logging Related to the `LoggerConnector` and `log()` pl Generic label for PyTorch Lightning package refactor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants