Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions src/tabpfn/architectures/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,19 @@
from tabpfn.architectures.base.config import ModelConfig
from tabpfn.architectures.base.transformer import PerFeatureTransformer
from tabpfn.architectures.encoders import (
InputNormalizationEncoderStep,
FeatureTransformEncoderStep,
LinearInputEncoderStep,
MLPInputEncoderStep,
MulticlassClassificationTargetEncoderStep,
NanHandlingEncoderStep,
NormalizeFeatureGroupsEncoderStep,
RemoveDuplicateFeaturesEncoderStep,
RemoveEmptyFeaturesEncoderStep,
SeqEncStep,
SequentialEncoder,
VariableNumFeaturesEncoderStep,
TorchPreprocessingPipeline,
TorchPreprocessingStep,
)

if TYPE_CHECKING:
from torch import nn

from tabpfn.architectures.interface import ArchitectureConfig


Expand Down Expand Up @@ -80,7 +78,7 @@ def get_architecture(
config=config,
# Things that were explicitly passed inside `build_model()`
encoder=get_encoder(
num_features=config.features_per_group,
num_features_per_group=config.features_per_group,
embedding_size=config.emsize,
remove_empty_features=config.remove_empty_features,
remove_duplicate_features=config.remove_duplicate_features,
Expand Down Expand Up @@ -113,7 +111,7 @@ def get_architecture(

def get_encoder( # noqa: PLR0913
*,
num_features: int,
num_features_per_group: int,
embedding_size: int,
remove_empty_features: bool,
remove_duplicate_features: bool,
Expand All @@ -127,45 +125,43 @@ def get_encoder( # noqa: PLR0913
encoder_type: Literal["linear", "mlp"] = "linear",
encoder_mlp_hidden_dim: int | None = None,
encoder_mlp_num_layers: int = 2,
) -> nn.Module:
inputs_to_merge = {"main": {"dim": num_features}}
) -> TorchPreprocessingPipeline:
inputs_to_merge = {"main": {"dim": num_features_per_group}}

encoder_steps: list[SeqEncStep] = []
encoder_steps: list[TorchPreprocessingStep] = []
if remove_empty_features:
encoder_steps += [RemoveEmptyFeaturesEncoderStep()]

if remove_duplicate_features:
# TODO: This is a No-op currently. We cannot remove it
# because loading the state_dict of the model depends on it being present,
# currently. Fix this by making the state_dict loading agnostic of the
# presence of this step.
encoder_steps += [RemoveDuplicateFeaturesEncoderStep()]

encoder_steps += [NanHandlingEncoderStep(keep_nans=nan_handling_enabled)]

if nan_handling_enabled:
inputs_to_merge["nan_indicators"] = {"dim": num_features}
inputs_to_merge["nan_indicators"] = {"dim": num_features_per_group}

encoder_steps += [
VariableNumFeaturesEncoderStep(
num_features=num_features,
normalize_by_used_features=False,
in_keys=["nan_indicators"],
out_keys=["nan_indicators"],
),
]
if normalize_by_used_features:
encoder_steps += [_legacy_normalize_features_no_op(num_features_per_group)]

encoder_steps += [
InputNormalizationEncoderStep(
FeatureTransformEncoderStep(
normalize_on_train_only=normalize_on_train_only,
normalize_to_ranking=normalize_to_ranking,
normalize_x=normalize_x,
remove_outliers=remove_outliers,
),
]

encoder_steps += [
VariableNumFeaturesEncoderStep(
num_features=num_features,
normalize_by_used_features=normalize_by_used_features,
),
]
if normalize_by_used_features:
encoder_steps += [
NormalizeFeatureGroupsEncoderStep(
num_features_per_group=num_features_per_group,
),
]

num_input_features = sum(i["dim"] for i in inputs_to_merge.values())
if encoder_type == "mlp":
Expand Down Expand Up @@ -196,7 +192,7 @@ def get_encoder( # noqa: PLR0913
f"Invalid encoder type: {encoder_type} (expected 'linear' or 'mlp')"
)

return SequentialEncoder(*encoder_steps, output_key="output")
return TorchPreprocessingPipeline(encoder_steps, output_key="output")


def get_y_encoder(
Expand All @@ -205,8 +201,8 @@ def get_y_encoder(
embedding_size: int,
nan_handling_y_encoder: bool,
max_num_classes: int,
) -> nn.Module:
steps: list[SeqEncStep] = []
) -> TorchPreprocessingPipeline:
steps: list[TorchPreprocessingStep] = []
inputs_to_merge = [{"name": "main", "dim": num_inputs}]
if nan_handling_y_encoder:
steps += [NanHandlingEncoderStep()]
Expand All @@ -223,4 +219,19 @@ def get_y_encoder(
out_keys=("output",),
),
]
return SequentialEncoder(*steps, output_key="output")
return TorchPreprocessingPipeline(steps, output_key="output")


def _legacy_normalize_features_no_op(
num_features_per_group: int,
) -> TorchPreprocessingStep:
"""Create a no-op step to normalize features.

This is a no-op currently. We need it to keep the state_dict of
the model compatible with previously saved checkpoints. Remove
in future versions.
"""
return NormalizeFeatureGroupsEncoderStep(
num_features_per_group=num_features_per_group,
normalize_by_used_features=False,
)
53 changes: 32 additions & 21 deletions src/tabpfn/architectures/base/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tabpfn.architectures.encoders import (
LinearInputEncoderStep,
NanHandlingEncoderStep,
SequentialEncoder,
TorchPreprocessingPipeline,
)
from tabpfn.architectures.interface import Architecture
from tabpfn.errors import TabPFNValidationError
Expand Down Expand Up @@ -155,29 +155,39 @@ def __init__( # noqa: D417, PLR0913
super().__init__()

if encoder is None:
encoder = SequentialEncoder(
LinearInputEncoderStep(
num_features=1,
emsize=config.emsize,
replace_nan_by_zero=False,
bias=True,
in_keys=("main",),
out_keys=("output",),
),
encoder = TorchPreprocessingPipeline(
steps=[
LinearInputEncoderStep(
num_features=1,
emsize=config.emsize,
replace_nan_by_zero=False,
bias=True,
in_keys=("main",),
out_keys=("output",),
)
],
output_key="output",
)

if y_encoder is None:
y_encoder = SequentialEncoder(
NanHandlingEncoderStep(),
LinearInputEncoderStep(
num_features=2,
emsize=config.emsize,
replace_nan_by_zero=False,
bias=True,
out_keys=("output",),
in_keys=("main", "nan_indicators"),
),
y_encoder = TorchPreprocessingPipeline(
steps=[
NanHandlingEncoderStep(
in_keys=("main",),
out_keys=("main", "nan_indicators"),
),
LinearInputEncoderStep(
num_features=2,
emsize=config.emsize,
replace_nan_by_zero=False,
bias=True,
out_keys=("output",),
in_keys=("main", "nan_indicators"),
),
],
output_key="output",
)

self.encoder = encoder
self.y_encoder = y_encoder
self.ninp = config.emsize
Expand Down Expand Up @@ -465,14 +475,15 @@ def forward( # noqa: PLR0912, C901
extra_encoders_args = {}
if categorical_inds_to_use is not None and isinstance(
self.encoder,
SequentialEncoder,
TorchPreprocessingPipeline,
):
# Transform cat. features accordingly to correspond to following to merge
# of batch and feature_group dimensions below (i.e., concat lists)
extra_encoders_args["categorical_inds"] = sum(categorical_inds_to_use, []) # noqa: RUF017

for k in x:
x[k] = einops.rearrange(x[k], "b s f n -> s (b f) n")

embedded_x = einops.rearrange(
self.encoder(
x,
Expand Down
26 changes: 10 additions & 16 deletions src/tabpfn/architectures/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,29 @@
from .pipeline_interfaces import (
InputEncoder,
SeqEncStep,
SequentialEncoder,
)
from .projections import (
LinearInputEncoderStep,
MLPInputEncoderStep,
TorchPreprocessingPipeline,
TorchPreprocessingStep,
)
from .steps import (
CategoricalInputEncoderPerFeatureEncoderStep,
FeatureTransformEncoderStep,
FrequencyFeatureEncoderStep,
InputNormalizationEncoderStep,
LinearInputEncoderStep,
MLPInputEncoderStep,
MulticlassClassificationTargetEncoderStep,
NanHandlingEncoderStep,
NormalizeFeatureGroupsEncoderStep,
RemoveDuplicateFeaturesEncoderStep,
RemoveEmptyFeaturesEncoderStep,
VariableNumFeaturesEncoderStep,
)

__all__ = (
"CategoricalInputEncoderPerFeatureEncoderStep",
"FeatureTransformEncoderStep",
"FrequencyFeatureEncoderStep",
"InputEncoder",
"InputNormalizationEncoderStep",
"LinearInputEncoderStep",
"MLPInputEncoderStep",
"MulticlassClassificationTargetEncoderStep",
"NanHandlingEncoderStep",
"NormalizeFeatureGroupsEncoderStep",
"RemoveDuplicateFeaturesEncoderStep",
"RemoveEmptyFeaturesEncoderStep",
"SeqEncStep",
"SequentialEncoder",
"VariableNumFeaturesEncoderStep",
"TorchPreprocessingPipeline",
"TorchPreprocessingStep",
)
Loading
Loading