-
Notifications
You must be signed in to change notification settings - Fork 62
feat: constraints #679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: constraints #679
Changes from all commits
2c6540f
6167efe
d21f530
ddb5792
4ebb24c
c17ba16
150db15
6db750e
404bf29
4efcb65
fab7655
c776180
cef67e4
9c58d2f
b09ba93
9c83af4
f948e2d
d03f8f7
187f811
4326767
e3dd47d
4f2ca23
2e93b86
a3f5dca
415dc64
8788132
799e0ea
1dd96f9
4eb408f
658b05f
a452c4e
6bb742c
b59c792
6749a15
6d55e72
6e4b6a5
608569b
16346a7
40900b0
32f9319
c132a07
90e3f0b
d031d18
4c80a0d
e54ced2
ab284e2
4d10171
d617cf7
10a6923
399a5aa
710ba19
3184882
5418abb
e8bfbe1
3cdcdc6
6ca0a11
2117437
e875292
073f09a
a19f589
accad5d
e9d29d3
e814fc1
f6e5b13
dd51ed9
11f8ce5
134fe6b
a53f2d1
125cbb0
ac0ec2f
b451fd0
24baa30
04efcca
6682708
ba2ce96
9997463
5fb7bad
90511e4
bfa340c
0a1ca0e
f3b85bb
d90ef35
8e9dcc2
f43429b
9345df8
b3567e3
5585ff4
21ced4e
45669d9
48f0ef8
9ffc4ff
f8643ef
672bf14
32711fc
b470651
4a2bf7a
956624e
4777ff0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| # Copyright 2025 MOSTLY AI | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """constraint transformation utilities.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import logging | ||
| from pathlib import Path | ||
|
|
||
| import pandas as pd | ||
|
|
||
| from mostlyai.sdk._data.constraints.types import ( | ||
| ConstraintHandler, | ||
| FixedCombinationsHandler, | ||
| InequalityHandler, | ||
| ) | ||
| from mostlyai.sdk.client._constraint_types import ( | ||
| FixedCombinations, | ||
| Inequality, | ||
| convert_constraint_config_to_typed, | ||
| ) | ||
| from mostlyai.sdk.domain import Generator | ||
|
|
||
| _LOG = logging.getLogger(__name__) | ||
|
|
||
| # type alias for constraint types | ||
| ConstraintType = FixedCombinations | Inequality | ||
|
|
||
|
|
||
| def _create_constraint_handler(constraint: ConstraintType, table=None) -> ConstraintHandler: | ||
| """factory function to create appropriate handler for a constraint.""" | ||
| if isinstance(constraint, FixedCombinations): | ||
| return FixedCombinationsHandler(constraint) | ||
| elif isinstance(constraint, Inequality): | ||
| return InequalityHandler(constraint, table=table) | ||
| else: | ||
| raise ValueError(f"unknown constraint type: {type(constraint)}") | ||
|
|
||
|
|
||
| class ConstraintTranslator: | ||
| """translates data between user schema and internal schema for constraints.""" | ||
|
|
||
| def __init__(self, constraints: list[ConstraintType], table=None): | ||
| self.constraints = constraints | ||
| self.table = table | ||
| self.handlers = [_create_constraint_handler(c, table=table) for c in constraints] | ||
|
|
||
| def to_internal(self, df: pd.DataFrame) -> pd.DataFrame: | ||
| """transform dataframe from user schema to internal schema.""" | ||
| for handler in self.handlers: | ||
| df = handler.to_internal(df) | ||
| return df | ||
|
|
||
| def to_original(self, df: pd.DataFrame) -> pd.DataFrame: | ||
| """transform dataframe from internal schema back to user schema.""" | ||
| for handler in self.handlers: | ||
| df = handler.to_original(df) | ||
| return df | ||
|
|
||
| def get_all_column_names(self, original_column_names: list[str]) -> list[str]: | ||
| """get list of all column names (original and internal constraint columns).""" | ||
| all_column_names = list(original_column_names) | ||
| for handler in self.handlers: | ||
| all_column_names.extend(handler.get_internal_column_names()) | ||
| return all_column_names | ||
|
|
||
| def get_encoding_types(self) -> dict[str, str]: | ||
| """get combined encoding types for all internal columns.""" | ||
| encoding_types = {} | ||
| for handler in self.handlers: | ||
| encoding_types.update(handler.get_encoding_types()) | ||
| return encoding_types | ||
|
|
||
| @staticmethod | ||
| def from_generator_config( | ||
| generator: Generator, | ||
| table_name: str, | ||
| ) -> ConstraintTranslator | None: | ||
| """create constraint translator from generator configuration for a specific table.""" | ||
| if not generator.constraints: | ||
| return None | ||
|
|
||
| table = next((t for t in generator.tables if t.name == table_name), None) | ||
| if not table: | ||
| return None | ||
|
|
||
| # convert constraints to typed objects and filter by table_name | ||
| typed_constraints = [] | ||
| for constraint in generator.constraints: | ||
| typed_constraint = convert_constraint_config_to_typed(constraint) | ||
| if typed_constraint.table_name == table_name: | ||
| typed_constraints.append(typed_constraint) | ||
|
|
||
| if not typed_constraints: | ||
| return None | ||
|
|
||
| # pass table to translator so handlers can check column types | ||
| constraint_translator = ConstraintTranslator(typed_constraints, table=table) | ||
| return constraint_translator | ||
|
|
||
|
|
||
| def preprocess_constraints_for_training( | ||
| *, | ||
| generator: Generator, | ||
| workspace_dir: Path, | ||
| target_table_name: str, | ||
| ) -> list[str] | None: | ||
| """preprocess constraint transformations for training data: | ||
| - transform constraints from user schema to internal schema (if any) | ||
| - update tgt-meta (encoding-types) and tgt-data with internal columns (if any) | ||
| - return list of all column names (original and internal constraint columns) for use in training | ||
| """ | ||
| target_table = next((t for t in generator.tables if t.name == target_table_name), None) | ||
| if not target_table: | ||
| _LOG.debug(f"table {target_table_name} not found in generator") | ||
| return None | ||
|
|
||
| if not generator.constraints: | ||
| return None | ||
|
|
||
| # convert constraints to typed objects and filter by table_name | ||
| typed_constraints = [] | ||
| for constraint in generator.constraints: | ||
| typed_constraint = convert_constraint_config_to_typed(constraint) | ||
| if typed_constraint.table_name == target_table_name: | ||
| typed_constraints.append(typed_constraint) | ||
|
|
||
| if not typed_constraints: | ||
| return None | ||
|
|
||
| _LOG.info(f"preprocessing constraints for table {target_table_name}") | ||
| # pass table to translator so handlers can check column types | ||
| constraint_translator = ConstraintTranslator(typed_constraints, table=target_table) | ||
|
|
||
| tgt_data_dir = workspace_dir / "OriginalData" / "tgt-data" | ||
| if not tgt_data_dir.exists(): | ||
| _LOG.warning(f"data directory not found: {tgt_data_dir}") | ||
| return None | ||
|
|
||
| parquet_files = sorted(list(tgt_data_dir.glob("part.*.parquet"))) | ||
| for parquet_file in parquet_files: | ||
| df = pd.read_parquet(parquet_file) | ||
| df_transformed = constraint_translator.to_internal(df) | ||
| df_transformed.to_parquet(parquet_file, index=True) | ||
|
|
||
| original_columns = [c.name for c in target_table.columns] if target_table.columns else [] | ||
| _update_meta_with_internal_columns(workspace_dir, target_table_name, constraint_translator, parquet_files) | ||
| all_column_names = constraint_translator.get_all_column_names(original_columns) | ||
| return all_column_names | ||
|
|
||
|
|
||
| def _update_meta_with_internal_columns( | ||
| workspace_dir: Path, | ||
| table_name: str, | ||
| constraint_translator: ConstraintTranslator, | ||
| parquet_files: list[Path], | ||
| ) -> None: | ||
| """update tgt-meta to reflect internal column structure after transformation.""" | ||
| if not parquet_files: | ||
| return | ||
|
|
||
| meta_dir = workspace_dir / "OriginalData" / "tgt-meta" | ||
| meta_dir.mkdir(parents=True, exist_ok=True) | ||
| encoding_types_file = meta_dir / "encoding-types.json" | ||
|
|
||
| if encoding_types_file.exists(): | ||
| with open(encoding_types_file) as f: | ||
| encoding_types = json.load(f) | ||
|
|
||
| encoding_types.update(constraint_translator.get_encoding_types()) | ||
|
|
||
| with open(encoding_types_file, "w") as f: | ||
| json.dump(encoding_types, f, indent=2) | ||
|
|
||
| _LOG.debug(f"updated encoding-types.json with internal columns for {table_name}") | ||
michdr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| _LOG.error(f"encoding-types.json not found to update internal columns for {table_name}") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # Copyright 2025 MOSTLY AI | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """constraint type handlers.""" | ||
|
|
||
| from mostlyai.sdk._data.constraints.types.base import ConstraintHandler | ||
| from mostlyai.sdk._data.constraints.types.fixed_combinations import FixedCombinationsHandler | ||
| from mostlyai.sdk._data.constraints.types.inequality import InequalityHandler | ||
|
|
||
| __all__ = [ | ||
| "ConstraintHandler", | ||
| "FixedCombinationsHandler", | ||
| "InequalityHandler", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # Copyright 2025 MOSTLY AI | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """base constraint handler class.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import ABC, abstractmethod | ||
|
|
||
| import pandas as pd | ||
|
|
||
|
|
||
| class ConstraintHandler(ABC): | ||
| """abstract base class for constraint handlers.""" | ||
|
|
||
| @abstractmethod | ||
| def get_internal_column_names(self) -> list[str]: | ||
| """return list of internal column names created by this handler.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def to_internal(self, df: pd.DataFrame) -> pd.DataFrame: | ||
| """transform dataframe (in-place) from user schema to internal schema.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def to_original(self, df: pd.DataFrame) -> pd.DataFrame: | ||
| """transform dataframe (in-place) from internal schema back to user schema.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def get_encoding_types(self) -> dict[str, str]: | ||
| """return encoding types for internal columns.""" | ||
| pass | ||
|
|
||
| def _validate_columns(self, df: pd.DataFrame, columns: list[str]) -> None: | ||
| """validate that all required columns exist in the dataframe.""" | ||
| missing_cols = set(columns) - set(df.columns) | ||
| if missing_cols: | ||
| raise ValueError( | ||
| f"Columns {sorted(missing_cols)} required by {self.__class__.__name__} " | ||
| f"not found in dataframe. Available columns: {sorted(df.columns)}" | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| # Copyright 2025 MOSTLY AI | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """fixed combinations constraint handler.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import hashlib | ||
| import json | ||
| import logging | ||
|
|
||
| import pandas as pd | ||
|
|
||
| from mostlyai.sdk._data.constraints.types.base import ConstraintHandler | ||
| from mostlyai.sdk.client._constraint_types import FixedCombinations | ||
|
|
||
| _LOG = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _generate_internal_column_name(prefix: str, columns: list[str]) -> str: | ||
| """generate a deterministic internal column name.""" | ||
| key = "|".join(columns) | ||
| hash_suffix = hashlib.md5(key.encode()).hexdigest()[:8] | ||
| columns_str = "_".join(col.upper() for col in columns) | ||
| return f"__CONSTRAINT_{prefix}_{columns_str}_{hash_suffix}__" | ||
|
|
||
|
|
||
| class FixedCombinationsHandler(ConstraintHandler): | ||
| """handler for FixedCombinations constraints.""" | ||
|
|
||
| def __init__(self, constraint: FixedCombinations): | ||
| self.constraint = constraint | ||
| self.table_name = constraint.table_name | ||
| self.columns = constraint.columns | ||
| self.merged_name = _generate_internal_column_name("FC", self.columns) | ||
|
|
||
| def get_internal_column_names(self) -> list[str]: | ||
| return [self.merged_name] | ||
|
|
||
| def to_internal(self, df: pd.DataFrame) -> pd.DataFrame: | ||
| self._validate_columns(df, self.columns) | ||
|
|
||
| def merge_row(row): | ||
| values = [row[col] if pd.notna(row[col]) else None for col in self.columns] | ||
| # JSON serialization handles all escaping automatically | ||
| return json.dumps(values, ensure_ascii=False) | ||
|
|
||
| df[self.merged_name] = df.apply(merge_row, axis=1) | ||
| return df | ||
|
|
||
| def to_original(self, df: pd.DataFrame) -> pd.DataFrame: | ||
| if self.merged_name in df.columns: | ||
|
|
||
| def split_row(merged_value: str) -> list[str]: | ||
| if pd.isna(merged_value): | ||
| return [""] * len(self.columns) | ||
| elif merged_value == "_RARE_": | ||
| return ["_RARE_"] * len(self.columns) | ||
| try: | ||
| values = json.loads(merged_value) | ||
| return [str(v) if v is not None else "" for v in values] | ||
michdr marked this conversation as resolved.
Show resolved
Hide resolved
michdr marked this conversation as resolved.
Show resolved
Hide resolved
michdr marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NA values converted to empty strings in round-tripMedium Severity In There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing element count validation causes crash on malformed dataMedium Severity The
michdr marked this conversation as resolved.
Show resolved
Hide resolved
michdr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| except json.JSONDecodeError: | ||
| _LOG.error(f"failed to decode JSON for {merged_value}; using empty values") | ||
| return [""] * len(self.columns) | ||
michdr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| split_values = df[self.merged_name].apply(split_row) | ||
| split_df = pd.DataFrame(split_values.tolist(), index=df.index) | ||
|
|
||
| # preserve original index | ||
| original_index = df.index | ||
| split_df.index = original_index | ||
|
|
||
| # assign to original columns | ||
| for i, col in enumerate(self.columns): | ||
| df[col] = split_df[i].values | ||
|
|
||
| # drop the merged column | ||
| df = df.drop(columns=[self.merged_name]) | ||
| return df | ||
|
|
||
| def get_encoding_types(self) -> dict[str, str]: | ||
| # always use TABULAR encoding for constraints, regardless of model_type | ||
| # constraints merge columns which requires categorical encoding | ||
| return {self.merged_name: "TABULAR_CATEGORICAL"} | ||
Uh oh!
There was an error while loading. Please reload this page.