Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
2c6540f
feat: constraints (WIP)
michdr Nov 26, 2025
6167efe
wip
michdr Nov 26, 2025
d21f530
Merge branch 'main' into feat-constraints
michdr Nov 26, 2025
ddb5792
wip
michdr Nov 26, 2025
4ebb24c
wip
michdr Nov 26, 2025
c17ba16
simplify + test
michdr Nov 26, 2025
150db15
fix
michdr Nov 26, 2025
6db750e
add docstring
michdr Nov 26, 2025
404bf29
wip
michdr Nov 26, 2025
4efcb65
Merge branch 'main' into feat-constraints
michdr Nov 26, 2025
fab7655
rm preprocess_constraints step
michdr Nov 27, 2025
c776180
Merge branch 'main' into feat-constraints
michdr Nov 27, 2025
cef67e4
don't mutate generator
michdr Nov 27, 2025
9c58d2f
fix model report columns
michdr Nov 27, 2025
b09ba93
wip
michdr Nov 27, 2025
9c83af4
keep original columns
michdr Nov 28, 2025
f948e2d
wip
michdr Nov 28, 2025
d03f8f7
add constraint types
michdr Nov 28, 2025
187f811
wip
michdr Nov 28, 2025
4326767
add onehot
michdr Dec 1, 2025
e3dd47d
add strict boundaries
michdr Dec 1, 2025
4f2ca23
Merge branch 'main' into feat-constraints
michdr Dec 1, 2025
2e93b86
fix: keep columns
michdr Dec 2, 2025
a3f5dca
validate no conflicting inequality
michdr Dec 4, 2025
415dc64
feat: make FixedCombinationHandler separator robust with escaping
michdr Dec 4, 2025
8788132
validate columns' existence
michdr Dec 4, 2025
799e0ea
consider seed (imputation); better reconstruction logic for inequality
michdr Dec 4, 2025
1dd96f9
suppress sklearn PCA RuntimeWarnings
michdr Dec 5, 2025
4eb408f
fix (un)escpaing; to be revisited
michdr Dec 5, 2025
658b05f
bug fixes
michdr Dec 5, 2025
a452c4e
move constraints to GeneratorConfig (top level)
michdr Dec 9, 2025
6bb742c
fix tests
michdr Dec 9, 2025
b59c792
fix other tests, adapting to new structure
michdr Dec 9, 2025
6749a15
Merge branch 'main' into feat-constraints
michdr Dec 9, 2025
6d55e72
bug fixes
michdr Dec 10, 2025
6e4b6a5
refactor + handle edge cases
michdr Dec 10, 2025
608569b
consolidations, simplifications + tests + docstrings to explain obscu…
michdr Dec 10, 2025
16346a7
slight cleanup + tests
michdr Dec 10, 2025
40900b0
Merge branch 'main' into feat-constraints
michdr Dec 10, 2025
32f9319
simplify
michdr Dec 10, 2025
c132a07
support language
michdr Dec 11, 2025
90e3f0b
wip
michdr Dec 11, 2025
d031d18
temp skip
michdr Dec 12, 2025
4c80a0d
fix unit test
michdr Dec 12, 2025
e54ced2
adapt docstring + parametrize unit tests
michdr Dec 15, 2025
ab284e2
misc improvements
michdr Dec 15, 2025
4d10171
address warnings
michdr Dec 15, 2025
d617cf7
merged_name standardization
michdr Dec 15, 2025
10a6923
wip
michdr Dec 15, 2025
399a5aa
reduce clutter
michdr Dec 15, 2025
710ba19
misc cleanup
michdr Dec 15, 2025
3184882
squeeze into one e2e test
michdr Dec 15, 2025
5418abb
do not rm any columns
michdr Dec 15, 2025
e8bfbe1
fix failing test
michdr Dec 16, 2025
3cdcdc6
consider model_type in from_generator_config
michdr Dec 16, 2025
6ca0a11
Merge branch 'main' into feat-constraints
michdr Dec 16, 2025
2117437
typo
michdr Dec 16, 2025
e875292
reduce scope for constraints v1
michdr Dec 16, 2025
073f09a
rm from domain.py
michdr Dec 16, 2025
a19f589
further reduce LOC
michdr Dec 16, 2025
accad5d
internal col naming convention
michdr Dec 16, 2025
e9d29d3
cleanup execution
michdr Dec 17, 2025
e814fc1
mv logic to step_generate_data
michdr Dec 17, 2025
f6e5b13
wip
michdr Dec 17, 2025
dd51ed9
wip
michdr Dec 17, 2025
11f8ce5
cleanup
michdr Dec 17, 2025
134fe6b
improve datetime delta + misc
michdr Dec 17, 2025
a53f2d1
adapt e2e test
michdr Dec 17, 2025
125cbb0
improve ineq datetime precision + adapt tests
michdr Dec 17, 2025
ac0ec2f
wip
michdr Dec 17, 2025
b451fd0
wip
michdr Dec 17, 2025
24baa30
wip
michdr Dec 17, 2025
04efcca
move config validators
michdr Dec 17, 2025
6682708
rewrite unit tests + misc
michdr Dec 17, 2025
ba2ce96
refine unit tests
michdr Dec 17, 2025
9997463
follow suggestions
michdr Dec 17, 2025
5fb7bad
misc fix
michdr Dec 17, 2025
90511e4
following pr review
michdr Dec 18, 2025
bfa340c
adapt inequality na handling
michdr Dec 18, 2025
0a1ca0e
wip
michdr Dec 18, 2025
f3b85bb
adapt inequality changes + test
michdr Dec 18, 2025
d90ef35
fix e2e test
michdr Dec 18, 2025
8e9dcc2
fix misc
michdr Dec 18, 2025
f43429b
wip
michdr Dec 29, 2025
9345df8
upd api part 1
michdr Dec 29, 2025
b3567e3
rm remaining validations (for now)
michdr Dec 29, 2025
5585ff4
wip
michdr Dec 30, 2025
21ced4e
adapt test
michdr Dec 30, 2025
45669d9
wip
michdr Dec 30, 2025
48f0ef8
accept both cases for constraint config
michdr Dec 30, 2025
9ffc4ff
Merge branch 'main' into feat-constraints
michdr Jan 7, 2026
f8643ef
improvements part 1
michdr Jan 7, 2026
672bf14
improvements part 2
michdr Jan 7, 2026
32711fc
refactor
michdr Jan 8, 2026
b470651
FixedCombinations
michdr Jan 8, 2026
4a2bf7a
ensure inequality constraint
michdr Jan 8, 2026
956624e
refine inequality reconstruction rules
michdr Jan 8, 2026
4777ff0
misc
michdr Jan 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions mostlyai/sdk/_data/constraints/transformations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright 2025 MOSTLY AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""constraint transformation utilities."""

from __future__ import annotations

import json
import logging
from pathlib import Path

import pandas as pd

from mostlyai.sdk._data.constraints.types import (
ConstraintHandler,
FixedCombinationsHandler,
InequalityHandler,
)
from mostlyai.sdk.client._constraint_types import (
FixedCombinations,
Inequality,
convert_constraint_config_to_typed,
)
from mostlyai.sdk.domain import Generator

_LOG = logging.getLogger(__name__)

# type alias for constraint types
ConstraintType = FixedCombinations | Inequality


def _create_constraint_handler(constraint: ConstraintType, table=None) -> ConstraintHandler:
"""factory function to create appropriate handler for a constraint."""
if isinstance(constraint, FixedCombinations):
return FixedCombinationsHandler(constraint)
elif isinstance(constraint, Inequality):
return InequalityHandler(constraint, table=table)
else:
raise ValueError(f"unknown constraint type: {type(constraint)}")


class ConstraintTranslator:
"""translates data between user schema and internal schema for constraints."""

def __init__(self, constraints: list[ConstraintType], table=None):
self.constraints = constraints
self.table = table
self.handlers = [_create_constraint_handler(c, table=table) for c in constraints]

def to_internal(self, df: pd.DataFrame) -> pd.DataFrame:
"""transform dataframe from user schema to internal schema."""
for handler in self.handlers:
df = handler.to_internal(df)
return df

def to_original(self, df: pd.DataFrame) -> pd.DataFrame:
"""transform dataframe from internal schema back to user schema."""
for handler in self.handlers:
df = handler.to_original(df)
return df

def get_all_column_names(self, original_column_names: list[str]) -> list[str]:
"""get list of all column names (original and internal constraint columns)."""
all_column_names = list(original_column_names)
for handler in self.handlers:
all_column_names.extend(handler.get_internal_column_names())
return all_column_names

def get_encoding_types(self) -> dict[str, str]:
"""get combined encoding types for all internal columns."""
encoding_types = {}
for handler in self.handlers:
encoding_types.update(handler.get_encoding_types())
return encoding_types

@staticmethod
def from_generator_config(
generator: Generator,
table_name: str,
) -> ConstraintTranslator | None:
"""create constraint translator from generator configuration for a specific table."""
if not generator.constraints:
return None

table = next((t for t in generator.tables if t.name == table_name), None)
if not table:
return None

# convert constraints to typed objects and filter by table_name
typed_constraints = []
for constraint in generator.constraints:
typed_constraint = convert_constraint_config_to_typed(constraint)
if typed_constraint.table_name == table_name:
typed_constraints.append(typed_constraint)

if not typed_constraints:
return None

# pass table to translator so handlers can check column types
constraint_translator = ConstraintTranslator(typed_constraints, table=table)
return constraint_translator


def preprocess_constraints_for_training(
*,
generator: Generator,
workspace_dir: Path,
target_table_name: str,
) -> list[str] | None:
"""preprocess constraint transformations for training data:
- transform constraints from user schema to internal schema (if any)
- update tgt-meta (encoding-types) and tgt-data with internal columns (if any)
- return list of all column names (original and internal constraint columns) for use in training
"""
target_table = next((t for t in generator.tables if t.name == target_table_name), None)
if not target_table:
_LOG.debug(f"table {target_table_name} not found in generator")
return None

if not generator.constraints:
return None

# convert constraints to typed objects and filter by table_name
typed_constraints = []
for constraint in generator.constraints:
typed_constraint = convert_constraint_config_to_typed(constraint)
if typed_constraint.table_name == target_table_name:
typed_constraints.append(typed_constraint)

if not typed_constraints:
return None

_LOG.info(f"preprocessing constraints for table {target_table_name}")
# pass table to translator so handlers can check column types
constraint_translator = ConstraintTranslator(typed_constraints, table=target_table)

tgt_data_dir = workspace_dir / "OriginalData" / "tgt-data"
if not tgt_data_dir.exists():
_LOG.warning(f"data directory not found: {tgt_data_dir}")
return None

parquet_files = sorted(list(tgt_data_dir.glob("part.*.parquet")))
for parquet_file in parquet_files:
df = pd.read_parquet(parquet_file)
df_transformed = constraint_translator.to_internal(df)
df_transformed.to_parquet(parquet_file, index=True)

original_columns = [c.name for c in target_table.columns] if target_table.columns else []
_update_meta_with_internal_columns(workspace_dir, target_table_name, constraint_translator, parquet_files)
all_column_names = constraint_translator.get_all_column_names(original_columns)
return all_column_names


def _update_meta_with_internal_columns(
workspace_dir: Path,
table_name: str,
constraint_translator: ConstraintTranslator,
parquet_files: list[Path],
) -> None:
"""update tgt-meta to reflect internal column structure after transformation."""
if not parquet_files:
return

meta_dir = workspace_dir / "OriginalData" / "tgt-meta"
meta_dir.mkdir(parents=True, exist_ok=True)
encoding_types_file = meta_dir / "encoding-types.json"

if encoding_types_file.exists():
with open(encoding_types_file) as f:
encoding_types = json.load(f)

encoding_types.update(constraint_translator.get_encoding_types())

with open(encoding_types_file, "w") as f:
json.dump(encoding_types, f, indent=2)

_LOG.debug(f"updated encoding-types.json with internal columns for {table_name}")
else:
_LOG.error(f"encoding-types.json not found to update internal columns for {table_name}")
25 changes: 25 additions & 0 deletions mostlyai/sdk/_data/constraints/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2025 MOSTLY AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""constraint type handlers."""

from mostlyai.sdk._data.constraints.types.base import ConstraintHandler
from mostlyai.sdk._data.constraints.types.fixed_combinations import FixedCombinationsHandler
from mostlyai.sdk._data.constraints.types.inequality import InequalityHandler

__all__ = [
"ConstraintHandler",
"FixedCombinationsHandler",
"InequalityHandler",
]
54 changes: 54 additions & 0 deletions mostlyai/sdk/_data/constraints/types/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2025 MOSTLY AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""base constraint handler class."""

from __future__ import annotations

from abc import ABC, abstractmethod

import pandas as pd


class ConstraintHandler(ABC):
"""abstract base class for constraint handlers."""

@abstractmethod
def get_internal_column_names(self) -> list[str]:
"""return list of internal column names created by this handler."""
pass

@abstractmethod
def to_internal(self, df: pd.DataFrame) -> pd.DataFrame:
"""transform dataframe (in-place) from user schema to internal schema."""
pass

@abstractmethod
def to_original(self, df: pd.DataFrame) -> pd.DataFrame:
"""transform dataframe (in-place) from internal schema back to user schema."""
pass

@abstractmethod
def get_encoding_types(self) -> dict[str, str]:
"""return encoding types for internal columns."""
pass

def _validate_columns(self, df: pd.DataFrame, columns: list[str]) -> None:
"""validate that all required columns exist in the dataframe."""
missing_cols = set(columns) - set(df.columns)
if missing_cols:
raise ValueError(
f"Columns {sorted(missing_cols)} required by {self.__class__.__name__} "
f"not found in dataframe. Available columns: {sorted(df.columns)}"
)
95 changes: 95 additions & 0 deletions mostlyai/sdk/_data/constraints/types/fixed_combinations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 MOSTLY AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""fixed combinations constraint handler."""

from __future__ import annotations

import hashlib
import json
import logging

import pandas as pd

from mostlyai.sdk._data.constraints.types.base import ConstraintHandler
from mostlyai.sdk.client._constraint_types import FixedCombinations

_LOG = logging.getLogger(__name__)


def _generate_internal_column_name(prefix: str, columns: list[str]) -> str:
"""generate a deterministic internal column name."""
key = "|".join(columns)
hash_suffix = hashlib.md5(key.encode()).hexdigest()[:8]
columns_str = "_".join(col.upper() for col in columns)
return f"__CONSTRAINT_{prefix}_{columns_str}_{hash_suffix}__"


class FixedCombinationsHandler(ConstraintHandler):
"""handler for FixedCombinations constraints."""

def __init__(self, constraint: FixedCombinations):
self.constraint = constraint
self.table_name = constraint.table_name
self.columns = constraint.columns
self.merged_name = _generate_internal_column_name("FC", self.columns)

def get_internal_column_names(self) -> list[str]:
return [self.merged_name]

def to_internal(self, df: pd.DataFrame) -> pd.DataFrame:
self._validate_columns(df, self.columns)

def merge_row(row):
values = [row[col] if pd.notna(row[col]) else None for col in self.columns]
# JSON serialization handles all escaping automatically
return json.dumps(values, ensure_ascii=False)

df[self.merged_name] = df.apply(merge_row, axis=1)
return df

def to_original(self, df: pd.DataFrame) -> pd.DataFrame:
if self.merged_name in df.columns:

def split_row(merged_value: str) -> list[str]:
if pd.isna(merged_value):
return [""] * len(self.columns)
elif merged_value == "_RARE_":
return ["_RARE_"] * len(self.columns)
try:
values = json.loads(merged_value)
return [str(v) if v is not None else "" for v in values]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NA values converted to empty strings in round-trip

Medium Severity

In to_original, the split_row function converts None values (which represent original NA/null values in the JSON-serialized data) to empty strings "" instead of preserving them as proper null values. The expression str(v) if v is not None else "" loses null information. When categorical columns contain NA values, they will become empty strings after the round-trip transformation, causing data loss. The same issue exists on line 67 where all columns get empty strings when merged_value is NA.

Fix in Cursor Fix in Web

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing element count validation causes crash on malformed data

Medium Severity

The split_row function in to_original parses JSON values but doesn't validate that the number of elements matches len(self.columns). If the synthetic model generates a malformed merged value with fewer elements than expected (e.g., '["a"]' when 3 columns are expected), the resulting split_df will have fewer columns. When the loop accesses split_df[i] where i exceeds the actual column count, a KeyError is raised. This could crash the generation pipeline when processing synthetic data.

Fix in Cursor Fix in Web

except json.JSONDecodeError:
_LOG.error(f"failed to decode JSON for {merged_value}; using empty values")
return [""] * len(self.columns)

split_values = df[self.merged_name].apply(split_row)
split_df = pd.DataFrame(split_values.tolist(), index=df.index)

# preserve original index
original_index = df.index
split_df.index = original_index

# assign to original columns
for i, col in enumerate(self.columns):
df[col] = split_df[i].values

# drop the merged column
df = df.drop(columns=[self.merged_name])
return df

def get_encoding_types(self) -> dict[str, str]:
# always use TABULAR encoding for constraints, regardless of model_type
# constraints merge columns which requires categorical encoding
return {self.merged_name: "TABULAR_CATEGORICAL"}
Loading