Skip to content

[pre-commit.ci] pre-commit suggestions #525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ repos:
- id: detect-private-key

- repo: https://github.com/PyCQA/docformatter
rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5
rev: v1.7.7 # todo: fix for docformatter after last 1.7.5
hooks:
- id: docformatter
additional_dependencies: [tomli]
args: ["--in-place"]

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.19
rev: 0.7.22
hooks:
- id: mdformat
additional_dependencies:
Expand All @@ -48,7 +48,7 @@ repos:
)

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
rev: v0.11.13
hooks:
- id: ruff
args: ["--fix"]
Expand Down
12 changes: 6 additions & 6 deletions src/pytorch_tabular/categorical_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def transform(self, X):
raise ValueError("`fit` method must be called before `transform`.")
assert all(c in X.columns for c in self.cols)
if self.handle_missing == "error":
assert (
not X[self.cols].isnull().any().any()
), "`handle_missing` = `error` and missing values found in columns to encode."
assert not X[self.cols].isnull().any().any(), (
"`handle_missing` = `error` and missing values found in columns to encode."
)
X_encoded = X.copy(deep=True)
category_cols = X_encoded.select_dtypes(include="category").columns
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
Expand Down Expand Up @@ -153,9 +153,9 @@ def fit(self, X, y=None):
"""
self._before_fit_check(X, y)
if self.handle_missing == "error":
assert (
not X[self.cols].isnull().any().any()
), "`handle_missing` = `error` and missing values found in columns to encode."
assert not X[self.cols].isnull().any().any(), (
"`handle_missing` = `error` and missing values found in columns to encode."
)
for col in self.cols:
map = Series(unique(X[col].fillna(NAN_CATEGORY)), name=col).reset_index().rename(columns={"index": "value"})
map["value"] += 1
Expand Down
12 changes: 6 additions & 6 deletions src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ class DataConfig:
)

def __post_init__(self):
assert (
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
), "There should be at-least one feature defined in categorical, continuous, or date columns"
assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, (
"There should be at-least one feature defined in categorical, continuous, or date columns"
)
_validate_choices(self)
if os.name == "nt" and self.num_workers != 0:
print("Windows does not support num_workers > 0. Setting num_workers to 0")
Expand Down Expand Up @@ -255,9 +255,9 @@ class InferredConfig:

def __post_init__(self):
if self.embedding_dims is not None:
assert all(
(isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims
), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), (
"embedding_dims must be a list of tuples (cardinality, embedding_dim)"
)
self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims])
else:
self.embedded_cat_dim = 0
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_tabular/models/category_embedding/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class CategoryEmbeddingModelConfig(ModelConfig):
)
use_batch_norm: bool = field(
default=False,
metadata={"help": ("Flag to include a BatchNorm layer after each Linear Layer+DropOut." " Defaults to False")},
metadata={"help": ("Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to False")},
)
initialization: str = field(
default="kaiming",
Expand Down
36 changes: 18 additions & 18 deletions src/pytorch_tabular/models/common/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
x.get("continuous", torch.empty(0, 0)),
x.get("categorical", torch.empty(0, 0)),
)
assert (
categorical_data.shape[1] == self.categorical_dim
), "categorical_data must have same number of columns as categorical embedding layers"
assert (
continuous_data.shape[1] == self.continuous_dim
), "continuous_data must have same number of columns as continuous dim"
assert categorical_data.shape[1] == self.categorical_dim, (
"categorical_data must have same number of columns as categorical embedding layers"
)
assert continuous_data.shape[1] == self.continuous_dim, (
"continuous_data must have same number of columns as continuous dim"
)
embed = None
if continuous_data.shape[1] > 0:
if self.batch_norm_continuous_input:
Expand Down Expand Up @@ -141,12 +141,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
x.get("continuous", torch.empty(0, 0)),
x.get("categorical", torch.empty(0, 0)),
)
assert categorical_data.shape[1] == len(
self.cat_embedding_layers
), "categorical_data must have same number of columns as categorical embedding layers"
assert (
continuous_data.shape[1] == self.continuous_dim
), "continuous_data must have same number of columns as continuous dim"
assert categorical_data.shape[1] == len(self.cat_embedding_layers), (
"categorical_data must have same number of columns as categorical embedding layers"
)
assert continuous_data.shape[1] == self.continuous_dim, (
"continuous_data must have same number of columns as continuous dim"
)
embed = None
if continuous_data.shape[1] > 0:
if self.batch_norm_continuous_input:
Expand Down Expand Up @@ -273,12 +273,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
x.get("continuous", torch.empty(0, 0)),
x.get("categorical", torch.empty(0, 0)),
)
assert categorical_data.shape[1] == len(
self.cat_embedding_layers
), "categorical_data must have same number of columns as categorical embedding layers"
assert (
continuous_data.shape[1] == self.continuous_dim
), "continuous_data must have same number of columns as continuous dim"
assert categorical_data.shape[1] == len(self.cat_embedding_layers), (
"categorical_data must have same number of columns as categorical embedding layers"
)
assert continuous_data.shape[1] == self.continuous_dim, (
"continuous_data must have same number of columns as continuous dim"
)
embed = None
if continuous_data.shape[1] > 0:
cont_idx = torch.arange(self.continuous_dim, device=continuous_data.device).expand(
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_tabular/models/gate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __post_init__(self):
assert self.tree_depth > 0, "tree_depth should be greater than 0"
# Either gflu_stages or num_trees should be greater than 0
assert self.num_trees > 0, (
"`num_trees` must be greater than 0." "If you want a lighter model which performs better, use GANDALF."
"`num_trees` must be greater than 0.If you want a lighter model which performs better, use GANDALF."
)
super().__post_init__()

Expand Down
12 changes: 6 additions & 6 deletions src/pytorch_tabular/models/gate/gate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def __init__(
embedding_dropout: float = 0.0,
):
super().__init__()
assert (
binning_activation in self.BINARY_ACTIVATION_MAP.keys()
), f"`binning_activation should be one of {self.BINARY_ACTIVATION_MAP.keys()}"
assert (
feature_mask_function in self.ACTIVATION_MAP.keys()
), f"`feature_mask_function should be one of {self.ACTIVATION_MAP.keys()}"
assert binning_activation in self.BINARY_ACTIVATION_MAP.keys(), (
f"`binning_activation should be one of {self.BINARY_ACTIVATION_MAP.keys()}"
)
assert feature_mask_function in self.ACTIVATION_MAP.keys(), (
f"`feature_mask_function should be one of {self.ACTIVATION_MAP.keys()}"
)

self.gflu_stages = gflu_stages
self.num_trees = num_trees
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_tabular/models/mixture_density/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class MDNConfig(ModelConfig):
_probabilistic: bool = field(default=True)

def __post_init__(self):
assert (
self.backbone_config_class not in INCOMPATIBLE_BACKBONES
), f"{self.backbone_config_class} is not a supported backbone for MDN head"
assert self.backbone_config_class not in INCOMPATIBLE_BACKBONES, (
f"{self.backbone_config_class} is not a supported backbone for MDN head"
)
assert self.head == "MixtureDensityHead"
return super().__post_init__()

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_tabular/models/tabnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class TabNetModelConfig(ModelConfig):
)
gamma: float = field(
default=1.3,
metadata={"help": ("Float above 1, scaling factor for attention updates (usually between" " 1.0 to 2.0)")},
metadata={"help": ("Float above 1, scaling factor for attention updates (usually between 1.0 to 2.0)")},
)
n_independent: int = field(
default=2,
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_tabular/models/tabnet/tabnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ def _build_network(self):
self._head = nn.Identity()

def extract_embedding(self):
raise ValueError("Extracting Embeddings is not supported by Tabnet. Please use another" " compatible model")
raise ValueError("Extracting Embeddings is not supported by Tabnet. Please use another compatible model")
6 changes: 3 additions & 3 deletions src/pytorch_tabular/ssl_models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def __init__(
self._setup_metrics()

def _setup_encoder_decoder(self, encoder, encoder_config, decoder, decoder_config, inferred_config):
assert (encoder is not None) or (
encoder_config is not None
), "Either encoder or encoder_config must be provided"
assert (encoder is not None) or (encoder_config is not None), (
"Either encoder or encoder_config must be provided"
)
# assert (decoder is not None) or (decoder_config is not None),
# "Either decoder or decoder_config must be provided"
if encoder is not None:
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_tabular/ssl_models/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
assert categorical_data.shape[1] == len(
self._onehot_feat_idx + self._binary_feat_idx + self._embedding_feat_idx
), "categorical_data must have same number of columns as categorical embedding layers"
assert (
continuous_data.shape[1] == self.continuous_dim
), "continuous_data must have same number of columns as continuous dim"
assert continuous_data.shape[1] == self.continuous_dim, (
"continuous_data must have same number of columns as continuous dim"
)
# embed = None
if continuous_data.shape[1] > 0:
if self.batch_norm_continuous_input:
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_tabular/ssl_models/dae/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ class DenoisingAutoEncoderConfig(SSLModelConfig):
def __post_init__(self):
assert hasattr(self.encoder_config, "_backbone_name"), "encoder_config should have a _backbone_name attribute"
if self.decoder_config is not None:
assert hasattr(
self.decoder_config, "_backbone_name"
), "decoder_config should have a _backbone_name attribute"
assert hasattr(self.decoder_config, "_backbone_name"), (
"decoder_config should have a _backbone_name attribute"
)
super().__post_init__()


Expand Down
Loading
Loading