Skip to content

Per Column Transformer #5

Open
Open
@AKuederle

Description

First suggestion for a PerColumnTransformer

class PerColumnTransformer(BaseTransformer, TrainableTransformerMixin):
    transformer: BaseTransformer
    per_column_transformer: OptimizableParameter[Optional[List[Tuple[str, BaseTransformer]]]]

    def __init__(
        self, transformer: BaseTransformer, *, per_col_transformer: Optional[List[Tuple[str, BaseTransformer]]] = None
    ):
        self.transformer = transformer
        self.per_col_transformer = per_col_transformer

    def transform(self, data: SingleSensorData, **kwargs) -> Self:
        self.data = data
        transformed_data = {}
        # We either used the per column transformer that are already created during self_optimize or we create them
        # here.
        # Note, that we don't store them in the parameters, when they are created here, because this is the "action"
        # method and we don't want to modify parameters in an action method.
        per_col_transformer = self.per_col_transformer or [
            (col_name, self.transformer.clone()) for col_name in data.columns
        ]
        for col_name, transformer in per_col_transformer:
            transformed_data[col_name] = transformer.clone().transform(data[col_name], **kwargs).transformed_data_
        self.transformed_data_ = pd.concat(transformed_data, axis=1)
        return self

    def self_optimize(self, data: Sequence[SingleSensorData], **kwargs) -> Self:
        # If there is no per_col_transformer set yet, we create one transformer for each column.
        if self.per_col_transformer is None:
            self.per_col_transformer = [(col_name, self.transformer.clone()) for col_name in data[0].columns]
        optimized_transformers = []
        for col_name, transformer in self.per_col_transformer:
            if isinstance(transformer, TrainableTransformerMixin):
                transformer = transformer.self_optimize([d[[col_name]] for d in data], **kwargs)
            optimized_transformers.append((col_name, transformer.clone()))

        self.per_col_transformer = optimized_transformers
        return self

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions