-
Notifications
You must be signed in to change notification settings - Fork 493
Regression return logits #387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @noahho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the TabPFN library by exposing raw model logits for both classification and regression tasks, providing users with greater flexibility in how they consume model predictions. This new functionality is underpinned by a substantial refactoring of the internal prediction pipelines within both the classifier and regressor classes, leading to more modular, readable, and maintainable code. The changes streamline the application of various post-processing steps like temperature scaling, averaging, and class balancing.
Highlights
- New
predict_logitsMethods: Introducedpredict_logitsmethods for bothTabPFNClassifierandTabPFNRegressor, allowing users to directly retrieve the raw, unnormalized model logits. This provides more granular control over model outputs. - Refactored Prediction Pipelines: Centralized and modularized the internal prediction logic for both classifier and regressor into new private
_raw_predictmethods. This consolidates input validation, preprocessing, and the application of post-processing steps, reducing code duplication and improving maintainability. - Enhanced
forwardMethod in Classifier: Theforwardmethod inTabPFNClassifiernow accepts areturn_logitsflag and utilizes a new, explicit post-processing pipeline composed of dedicated helper functions (_apply_temperature,_average_across_estimators,_apply_softmax,_apply_balancing). This makes the flow of data transformation clearer and more flexible. - Default
n_estimatorsUpdate: The default number of estimators forTabPFNClassifierhas been increased from 4 to 8, potentially improving out-of-the-box performance or stability.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the ability to get raw logits from both the TabPFNClassifier and TabPFNRegressor by adding a predict_logits method to each. This is achieved through a significant and well-executed refactoring of the prediction logic, which centralizes preprocessing and model forwarding into new _raw_predict methods. The post-processing steps have also been broken down into smaller, more modular helper functions, which greatly improves code clarity and maintainability.
I've identified a few areas for improvement:
- A potentially unintentional change to a default hyperparameter.
- A regression in the user-facing documentation for the
predictmethod in the regressor. - A potential numerical instability issue due to taking the logarithm of probabilities that could be zero.
Overall, this is a great change that improves the flexibility and structure of the codebase. Addressing the points above will help ensure its quality and robustness.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new predict_logits method for the TabPFNClassifier and significantly refactors the internal prediction logic for both the classifier and the regressor. The refactoring for TabPFNRegressor is particularly noteworthy as it introduces a memory-efficient forward pass by processing ensemble outputs sequentially, which is a great improvement for handling larger datasets.
The changes are well-structured, and the introduction of a pipeline pattern in TabPFNClassifier.forward and helper methods in TabPFNRegressor greatly improves code clarity and maintainability. The new tests for the classifier are comprehensive and validate the new functionality and its consistency with existing methods.
My main feedback is focused on improving the robustness of the TabPFNRegressor by enhancing error messages and, most importantly, adding specific tests to validate the new memory-efficient forward pass implementation. While the existing tests cover the predict method's output, they don't directly verify the correctness of the new internal logic, which is a critical part of this change.
Overall, this is a high-quality pull request that brings valuable improvements to the library. The suggested changes will help ensure the long-term stability and correctness of the refactored code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for returning raw logits in the classifier, refactors the internal prediction and forward logic in both classifier and regressor to be more modular and memory‐efficient, and updates tests and examples to match these signature changes.
- Introduces
TabPFNClassifier.predict_logitsfor retrieving unnormalized logits. - Refactors
forwardandpredictinTabPFNRegressorfor sequential processing and lower memory usage. - Updates tests (
test_regressor_interface.py,test_finetuning_regressor.py,test_classifier_interface.py) and examples to align with new method signatures and behaviors.
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_regressor_interface.py | Wrapped tests in a class, updated to new forward signature, added dtype casts |
| tests/test_finetuning_regressor.py | Adjusted unpacking to match new regressor .forward return |
| tests/test_classifier_interface.py | Wrapped tests in a class, added tests for predict_logits, reorganized setup |
| src/tabpfn/regressor.py | Refactored predict, added _raw_predict, new memory‐efficient forward |
| src/tabpfn/classifier.py | Added _raw_predict, predict_logits, refactored forward into pipeline steps |
| examples/finetune_regressor.py | Changed example to match regressor’s new forward signature |
| CHANGELOG.md | Added entries for the new predict_logits method and breaking regressor change |
tests/test_classifier_interface.py
Outdated
| layer._patched_for_onnx = True | ||
|
|
||
| @pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning") | ||
| def test_onnx_exportable_cpu(self, X_y: tuple[np.ndarray, np.ndarray]) -> None: |
Copilot
AI
Jul 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test will error out if ONNX is not installed. Consider adding pytest.importorskip('onnx') at the start of the test to skip it when ONNX is unavailable.
| def test_onnx_exportable_cpu(self, X_y: tuple[np.ndarray, np.ndarray]) -> None: | |
| def test_onnx_exportable_cpu(self, X_y: tuple[np.ndarray, np.ndarray]) -> None: | |
| pytest.importorskip("onnx") |
klemens-floege
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job!, I put some comments but otherwise lgtm :)
| directly by user, so default should be False here. | ||
| # --- Pre-flight Checks and Assertions --- | ||
| # This import is only needed for type checking | ||
| from tabpfn.inference import InferenceEngineBatchedNoPreprocessing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we not move import to top of file, since lazy loading is not required?
| assert not np.allclose( | ||
| proba_no_balance, proba_balance, atol=1e-5 | ||
| ), "Probabilities did not change when balance_probabilities was toggled." | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Manually Check Rebalancing | |
| class_prob_in_train = model_no_balance.class_counts_ / model_no_balance.class_counts_.sum() | |
| proba_manually_balanced = proba_no_balance / class_prob_in_train | |
| proba_manually_balanced /= proba_manually_balanced.sum(axis=1, keepdims=True) | |
| np.testing.assert_allclose( | |
| proba_balance, | |
| proba_manually_balanced, | |
| atol=1e-5, | |
| rtol=1e-5, | |
| err_msg="The balanced probabilities from the model do not match the manually calculated ones.", | |
| ) |
d660e96 to
508284a
Compare
Motivation and Context
This change introduces two major improvements. First, it directly responds to the request in Issue #374 for raw logits from
TabPFNClassifier. Explaining model outputs (e.g., with Shapley values) is more numerically stable and often more intuitive on an additive scale like logits. This PR adds apredict_logitsmethod to facilitate these use cases.Second, this work includes a significant refactoring of the internal
forwardand prediction methods for bothTabPFNClassifierandTabPFNRegressor. This simplifies theforwardpass interface, especially for the regressor, and provides a key benefit: a substantial reduction in memory usage forTabPFNRegressor. By processing ensemble outputs sequentially instead of stacking them, the regressor can now handle larger datasets and a higher number of estimators more efficiently.Public API Changes
Details of Public API Changes:
TabPFNClassifier.predict_logits(X: XType) -> np.ndarray: A new public method that returns the raw, unnormalized logits for the input samplesX.TabPFNRegressor.forward(...)Signature Change (Breaking): Theforwardmethod for the regressor has been simplified. It now returns only a single tensor of logits, instead of the previous tuple(averaged_logits, outputs, borders). This is a breaking change for users who callforward()directly (e.g., in finetuning scripts). The new interface is simpler and more memory-efficient.Internal Refactoring: The
predictandpredict_probamethods in both classes have been refactored to use the new internal logic. Their public signatures remain unchanged, ensuring backward compatibility for standard prediction workflows.How Has This Been Tested?
This PR includes comprehensive new and updated test cases to ensure the correctness and consistency of the new functionality and refactored code paths:
predict_logitsConsistency: A new test,test_predict_logits_and_consistency, has been added totests/test_classifier_interface.py. This highly parametrized test verifies thatpredict_logitsreturns output of the correct shape and type. It crucially asserts that applying softmax to the raw logits yields results that are numerically close to the probabilities frompredict_proba, ensuring consistency across various configurations.Regressor Forward Pass: A new test
test_forward_predict_logit_consistencywas added totests/test_regressor_interface.pyto validate that the new, memory-efficientforwardpass produces logits identical to the high-levelpredictmethod, ensuring correctness after the refactoring.Behavioral Tests: New tests like
test_softmax_temperature_impact_on_logits_magnitudeandtest_balance_probabilities_alters_proba_outputconfirm that the modular post-processing pipeline in the classifier behaves as expected.Finetuning: Finetuning tests for the regressor (e.g., in
tests/test_finetuning_regressor.pyandexamples/finetune_regressor.py) have been updated to align with the new, simplifiedforwardmethod signature.Checklist
CHANGELOG.md(if relevant for users).