Skip to content

Conversation

@noahho
Copy link
Collaborator

@noahho noahho commented Jul 15, 2025

Motivation and Context

This change introduces two major improvements. First, it directly responds to the request in Issue #374 for raw logits from TabPFNClassifier. Explaining model outputs (e.g., with Shapley values) is more numerically stable and often more intuitive on an additive scale like logits. This PR adds a predict_logits method to facilitate these use cases.

Second, this work includes a significant refactoring of the internal forward and prediction methods for both TabPFNClassifier and TabPFNRegressor. This simplifies the forward pass interface, especially for the regressor, and provides a key benefit: a substantial reduction in memory usage for TabPFNRegressor. By processing ensemble outputs sequentially instead of stacking them, the regressor can now handle larger datasets and a higher number of estimators more efficiently.


Public API Changes

  • No Public API changes
  • Yes, Public API changes (Details below)

Details of Public API Changes:

  • TabPFNClassifier.predict_logits(X: XType) -> np.ndarray: A new public method that returns the raw, unnormalized logits for the input samples X.

  • TabPFNRegressor.forward(...) Signature Change (Breaking): The forward method for the regressor has been simplified. It now returns only a single tensor of logits, instead of the previous tuple (averaged_logits, outputs, borders). This is a breaking change for users who call forward() directly (e.g., in finetuning scripts). The new interface is simpler and more memory-efficient.

  • Internal Refactoring: The predict and predict_proba methods in both classes have been refactored to use the new internal logic. Their public signatures remain unchanged, ensuring backward compatibility for standard prediction workflows.


How Has This Been Tested?

This PR includes comprehensive new and updated test cases to ensure the correctness and consistency of the new functionality and refactored code paths:

  • predict_logits Consistency: A new test, test_predict_logits_and_consistency, has been added to tests/test_classifier_interface.py. This highly parametrized test verifies that predict_logits returns output of the correct shape and type. It crucially asserts that applying softmax to the raw logits yields results that are numerically close to the probabilities from predict_proba, ensuring consistency across various configurations.

  • Regressor Forward Pass: A new test test_forward_predict_logit_consistency was added to tests/test_regressor_interface.py to validate that the new, memory-efficient forward pass produces logits identical to the high-level predict method, ensuring correctness after the refactoring.

  • Behavioral Tests: New tests like test_softmax_temperature_impact_on_logits_magnitude and test_balance_probabilities_alters_proba_output confirm that the modular post-processing pipeline in the classifier behaves as expected.

  • Finetuning: Finetuning tests for the regressor (e.g., in tests/test_finetuning_regressor.py and examples/finetune_regressor.py) have been updated to align with the new, simplified forward method signature.


Checklist

  • The changes have been tested locally.
  • Documentation has been updated (if the public API or usage changes).
  • An entry has been added to CHANGELOG.md (if relevant for users).
  • The code follows the project's style guidelines.
  • I have considered the impact of these changes on the public API.

Copilot AI review requested due to automatic review settings July 15, 2025 16:52
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @noahho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the TabPFN library by exposing raw model logits for both classification and regression tasks, providing users with greater flexibility in how they consume model predictions. This new functionality is underpinned by a substantial refactoring of the internal prediction pipelines within both the classifier and regressor classes, leading to more modular, readable, and maintainable code. The changes streamline the application of various post-processing steps like temperature scaling, averaging, and class balancing.

Highlights

  • New predict_logits Methods: Introduced predict_logits methods for both TabPFNClassifier and TabPFNRegressor, allowing users to directly retrieve the raw, unnormalized model logits. This provides more granular control over model outputs.
  • Refactored Prediction Pipelines: Centralized and modularized the internal prediction logic for both classifier and regressor into new private _raw_predict methods. This consolidates input validation, preprocessing, and the application of post-processing steps, reducing code duplication and improving maintainability.
  • Enhanced forward Method in Classifier: The forward method in TabPFNClassifier now accepts a return_logits flag and utilizes a new, explicit post-processing pipeline composed of dedicated helper functions (_apply_temperature, _average_across_estimators, _apply_softmax, _apply_balancing). This makes the flow of data transformation clearer and more flexible.
  • Default n_estimators Update: The default number of estimators for TabPFNClassifier has been increased from 4 to 8, potentially improving out-of-the-box performance or stability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

This comment was marked as outdated.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the ability to get raw logits from both the TabPFNClassifier and TabPFNRegressor by adding a predict_logits method to each. This is achieved through a significant and well-executed refactoring of the prediction logic, which centralizes preprocessing and model forwarding into new _raw_predict methods. The post-processing steps have also been broken down into smaller, more modular helper functions, which greatly improves code clarity and maintainability.

I've identified a few areas for improvement:

  • A potentially unintentional change to a default hyperparameter.
  • A regression in the user-facing documentation for the predict method in the regressor.
  • A potential numerical instability issue due to taking the logarithm of probabilities that could be zero.

Overall, this is a great change that improves the flexibility and structure of the codebase. Addressing the points above will help ensure its quality and robustness.

@noahho noahho requested a review from klemens-floege July 15, 2025 17:45
@noahho
Copy link
Collaborator Author

noahho commented Jul 16, 2025

@gemini-code-assist review

@noahho noahho requested a review from Copilot July 16, 2025 08:55
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new predict_logits method for the TabPFNClassifier and significantly refactors the internal prediction logic for both the classifier and the regressor. The refactoring for TabPFNRegressor is particularly noteworthy as it introduces a memory-efficient forward pass by processing ensemble outputs sequentially, which is a great improvement for handling larger datasets.

The changes are well-structured, and the introduction of a pipeline pattern in TabPFNClassifier.forward and helper methods in TabPFNRegressor greatly improves code clarity and maintainability. The new tests for the classifier are comprehensive and validate the new functionality and its consistency with existing methods.

My main feedback is focused on improving the robustness of the TabPFNRegressor by enhancing error messages and, most importantly, adding specific tests to validate the new memory-efficient forward pass implementation. While the existing tests cover the predict method's output, they don't directly verify the correctness of the new internal logic, which is a critical part of this change.

Overall, this is a high-quality pull request that brings valuable improvements to the library. The suggested changes will help ensure the long-term stability and correctness of the refactored code.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for returning raw logits in the classifier, refactors the internal prediction and forward logic in both classifier and regressor to be more modular and memory‐efficient, and updates tests and examples to match these signature changes.

  • Introduces TabPFNClassifier.predict_logits for retrieving unnormalized logits.
  • Refactors forward and predict in TabPFNRegressor for sequential processing and lower memory usage.
  • Updates tests (test_regressor_interface.py, test_finetuning_regressor.py, test_classifier_interface.py) and examples to align with new method signatures and behaviors.

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/test_regressor_interface.py Wrapped tests in a class, updated to new forward signature, added dtype casts
tests/test_finetuning_regressor.py Adjusted unpacking to match new regressor .forward return
tests/test_classifier_interface.py Wrapped tests in a class, added tests for predict_logits, reorganized setup
src/tabpfn/regressor.py Refactored predict, added _raw_predict, new memory‐efficient forward
src/tabpfn/classifier.py Added _raw_predict, predict_logits, refactored forward into pipeline steps
examples/finetune_regressor.py Changed example to match regressor’s new forward signature
CHANGELOG.md Added entries for the new predict_logits method and breaking regressor change

layer._patched_for_onnx = True

@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning")
def test_onnx_exportable_cpu(self, X_y: tuple[np.ndarray, np.ndarray]) -> None:
Copy link

Copilot AI Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test will error out if ONNX is not installed. Consider adding pytest.importorskip('onnx') at the start of the test to skip it when ONNX is unavailable.

Suggested change
def test_onnx_exportable_cpu(self, X_y: tuple[np.ndarray, np.ndarray]) -> None:
def test_onnx_exportable_cpu(self, X_y: tuple[np.ndarray, np.ndarray]) -> None:
pytest.importorskip("onnx")

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@klemens-floege klemens-floege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job!, I put some comments but otherwise lgtm :)

directly by user, so default should be False here.
# --- Pre-flight Checks and Assertions ---
# This import is only needed for type checking
from tabpfn.inference import InferenceEngineBatchedNoPreprocessing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not move import to top of file, since lazy loading is not required?

assert not np.allclose(
proba_no_balance, proba_balance, atol=1e-5
), "Probabilities did not change when balance_probabilities was toggled."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Manually Check Rebalancing
class_prob_in_train = model_no_balance.class_counts_ / model_no_balance.class_counts_.sum()
proba_manually_balanced = proba_no_balance / class_prob_in_train
proba_manually_balanced /= proba_manually_balanced.sum(axis=1, keepdims=True)
np.testing.assert_allclose(
proba_balance,
proba_manually_balanced,
atol=1e-5,
rtol=1e-5,
err_msg="The balanced probabilities from the model do not match the manually calculated ones.",
)

@noahho noahho changed the title Add return_logits function for classifier Regression return logits Jul 22, 2025
@noahho noahho force-pushed the regressor-classifier-logits branch from d660e96 to 508284a Compare September 4, 2025 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants