-
Notifications
You must be signed in to change notification settings - Fork 289
Enhance TabPFNRegressor to handle constant target values #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Enhance TabPFNRegressor to handle constant target values #263
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM!
logits=torch.full((X.shape[0], 1), self.constant_value_)
I think the logits shouldn't be equal to self.constant_value_
, right? I'm not sure what they should be equal to though, perhaps all ones or something else uniform? (I guess self.constant_value_ could also work then, but might be confusing).
Thanks for the review @LeoGrin! |
Thanks @anuragg1209! |
Hi @LeoGrin, I like the idea of changing the |
Thanks @anuragg1209! Do we still need to set the constant mean, mode etc manually now? I would have guessed that it would automatically work with the new criterion defined in fit (except for the |
The reason we are setting the mean, mode, etc. manually is because if we use a criterion-based approach, even with tight borders, numerical operations introduce tiny floating point differences from the constant value. The current approach of setting manually also works fine as we are only dealing with the edge case of constant y, but if we want to completely remove setting mean, mode, etc, manually, then we have to introduce a tolerance-based test. |
Thanks @anuragg1209! In this case maybe it actually makes more sense to put most things in |
Hi @noahho,
Fix #246.
New Tests:
test_constant_target
to verify that theTabPFNRegressor
correctly predicts a constant value when the targety
is constant.