-
Notifications
You must be signed in to change notification settings - Fork 494
Improved analytics for tracking usage of different fit modes #646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,72 +16,72 @@ | |
|
|
||
| # Copyright (c) Prior Labs GmbH 2025. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import copy | ||
| import logging | ||
| import warnings | ||
| from collections.abc import Callable, Sequence | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Annotated, Any, Literal | ||
| from typing_extensions import Self, deprecated | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from sklearn import config_context | ||
| from sklearn.base import BaseEstimator, ClassifierMixin, check_is_fitted | ||
| from sklearn.preprocessing import LabelEncoder | ||
| from tabpfn_common_utils.telemetry import track_model_call | ||
| from tabpfn_common_utils.telemetry import track_model_call, set_init_params | ||
|
|
||
| from tabpfn.base import ( | ||
| ClassifierModelSpecs, | ||
| check_cpu_warning, | ||
| create_inference_engine, | ||
| determine_precision, | ||
| get_preprocessed_datasets_helper, | ||
| initialize_model_variables_helper, | ||
| initialize_telemetry, | ||
| ) | ||
| from tabpfn.constants import ( | ||
| PROBABILITY_EPSILON_ROUND_ZERO, | ||
| SKLEARN_16_DECIMAL_PRECISION, | ||
| ModelVersion, | ||
| XType, | ||
| YType, | ||
| ) | ||
| from tabpfn.inference import InferenceEngine, InferenceEngineBatchedNoPreprocessing | ||
| from tabpfn.inference_tuning import ( | ||
| ClassifierEvalMetrics, | ||
| ClassifierTuningConfig, | ||
| find_optimal_classification_thresholds, | ||
| find_optimal_temperature, | ||
| get_tuning_splits, | ||
| resolve_tuning_config, | ||
| ) | ||
| from tabpfn.model_loading import ( | ||
| ModelSource, | ||
| load_fitted_tabpfn_model, | ||
| prepend_cache_path, | ||
| save_fitted_tabpfn_model, | ||
| ) | ||
| from tabpfn.preprocessing import ( | ||
| ClassifierEnsembleConfig, | ||
| DatasetCollectionWithPreprocessing, | ||
| EnsembleConfig, | ||
| PreprocessorConfig, | ||
| ) | ||
| from tabpfn.preprocessors.preprocessing_helpers import get_ordinal_encoder | ||
| from tabpfn.utils import ( | ||
| DevicesSpecification, | ||
| balance_probas_by_class_counts, | ||
| fix_dtypes, | ||
| get_embeddings, | ||
| infer_categorical_features, | ||
| infer_random_state, | ||
| process_text_na_dataframe, | ||
| validate_X_predict, | ||
| validate_Xy_fit, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| import numpy.typing as npt | ||
|
|
@@ -479,6 +479,9 @@ | |
| self.tuning_config = tuning_config | ||
| initialize_telemetry() | ||
|
|
||
| # Only anonymously record `fit_mode` usage | ||
| set_init_params({"fit_mode": self.fit_mode}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we do the same thing as for model_path and validate that it's a known fit mode? To avoid accidentally collecting PII. |
||
|
|
||
| @classmethod | ||
| def create_default_for_version(cls, version: ModelVersion, **overrides) -> Self: | ||
| """Construct a classifier that uses the given version of the model. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This telemetry initialization logic, including the call to
set_init_params, is also present inTabPFNRegressor.__init__. To improve maintainability and reduce code duplication, consider creating a new helper function insrc/tabpfn/base.pythat encapsulates this logic.For example, you could create a function in
base.py:Then you could replace these lines in both
TabPFNClassifierandTabPFNRegressorwith:This would centralize the telemetry setup and make it easier to add more parameters in the future.