Skip to content

Commit

Permalink
import error fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
minkeymouse committed Jan 22, 2025
1 parent c2f301f commit 964ed13
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 251 deletions.
254 changes: 138 additions & 116 deletions src/synthcity/plugins/generic/plugin_syn_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,11 @@ class Syn_SeqPlugin(Plugin):
A plugin wrapping the 'Syn_Seq' aggregator in the synthcity Plugin interface.
Steps:
1) In .fit(), if the user passes a DataFrame, we wrap it in Syn_SeqDataLoader, then encode the data.
2) We keep separate:
- self._orig_schema => schema from the original data
- self._enc_schema => schema from the encoded data
1) In .fit(), if the user passes a DataFrame, we wrap it in Syn_SeqDataLoader, then call .encode().
2) We build or refine the domain in `_domain_rebuild`, using the original vs. converted dtype info,
turning them into constraints, then into distributions.
3) The aggregator trains column-by-column on the encoded data.
4) In .generate(), we re-check constraints (including user constraints) referencing
the original schema. Then we decode back to the original DataFrame structure.
Additional note: We add `_remap_special_value_rules` to handle user rules referencing
special values that belong in a 'cat' column, e.g. if user writes
("NUM_CIGAR", "<", 0) but that data is actually in "NUM_CIGAR_cat".
4) For .generate(), we re-check constraints (including user constraints) and decode back to the original format.
"""

@staticmethod
Expand All @@ -48,7 +42,7 @@ def type() -> str:

@staticmethod
def hyperparameter_space(**kwargs: Any) -> List:
# No tunable hyperparameters here
# No tunable hyperparameters for demonstration
return []

@validate_arguments(config=dict(arbitrary_types_allowed=True))
Expand All @@ -68,22 +62,18 @@ def __init__(
compress_dataset=compress_dataset,
sampling_strategy=sampling_strategy,
)
# Two separate schema references
self._orig_schema: Optional[Schema] = None
self._enc_schema: Optional[Schema] = None
self._data_info: Optional[Dict] = None
self._enc_data_info: Optional[Dict] = None

self._schema: Optional[Schema] = None
self._training_schema: Optional[Schema] = None
self._data_info: Optional[Dict] = None
self._training_data_info: Optional[Dict] = {}
self.model: Optional[Syn_Seq] = None

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def fit(self, X: Union[DataLoader, pd.DataFrame], *args: Any, **kwargs: Any) -> Any:
"""
Wrap a plain DataFrame into Syn_SeqDataLoader if needed, then encode the data.
Build up:
- self._orig_schema from the original data
- self._enc_schema from the encoded data
Then train the aggregator column-by-column on encoded data.
Build up the schema from original data vs. encoded data, and train the aggregator.
"""
# If plain DataFrame, wrap in Syn_SeqDataLoader
if isinstance(X, pd.DataFrame):
Expand All @@ -95,98 +85,90 @@ def fit(self, X: Union[DataLoader, pd.DataFrame], *args: Any, **kwargs: Any) ->
enable_reproducible_results(self.random_state)
self._data_info = X.info()

# Build schema for the original data
self._orig_schema = Schema(
# Build schema for the *original* data
self._schema = Schema(
data=X,
sampling_strategy=self.sampling_strategy,
random_state=self.random_state,
)

# Encode the data
X_encoded, self._enc_data_info = X.encode()
# Build a schema from the encoded data
self._enc_schema = Schema(
X_encoded, self._training_data_info = X.encode()

# Build an initial schema from the *encoded* data
base_schema = Schema(
data=X_encoded,
sampling_strategy=self.sampling_strategy,
random_state=self.random_state,
)

# Rebuild domain from original vs. converted dtype logic
self._training_schema = self._domain_rebuild(X_encoded, base_schema)

# aggregator training
output = self._fit(X_encoded, *args, **kwargs)
self.fitted = True
return output

def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "Syn_SeqPlugin":
"""
Train the aggregator column-by-column using the encoded DataLoader.
"""
self.model = Syn_Seq(
random_state=self.random_state,
strict=self.strict,
sampling_patience=self.sampling_patience,
)
self.model.fit_col(X_encoded, *args, **kwargs)
self.fitted = True
self.model.fit_col(X, *args, **kwargs)
return self

def training_schema(self) -> Schema:
"""Return the *encoded* schema used for aggregator training."""
if self._enc_schema is None:
raise RuntimeError("No encoded schema found. Fit the model first.")
return self._enc_schema

def schema(self) -> Schema:
"""Return the original data's schema."""
if self._orig_schema is None:
raise RuntimeError("No original schema found. Fit the model first.")
return self._orig_schema

def _remap_special_value_rules(
self,
rules_dict: Dict[str, List[Tuple[str, str, Any]]],
enc_schema: Schema
) -> Dict[str, List[Tuple[str, str, Any]]]:
def _domain_rebuild(self, X_encoded: DataLoader, base_schema: Schema) -> Schema:
"""
If user wrote rules referencing special values (like -8) on numeric columns,
we switch them to the corresponding _cat column. For example:
- "NUM_CIGAR" has special_value list [-8].
- user sets rule (NUM_CIGAR, "<", 0).
=> Actually these negative codes are stored in "NUM_CIGAR_cat".
So we redirect the rule to "NUM_CIGAR_cat".
Build new domain using feature_params & constraint_to_distribution.
For each column in the encoded data, gather basic "dtype" constraints,
transform them into Distribution objects, and create a new schema.
"""
if not rules_dict:
return rules_dict
enc_info = X_encoded.info()

# gather the special_value map from the *encoded* info
# (we assume `_enc_data_info["special_value"]` has your col->list).
special_map = {}
if self._enc_data_info and "special_value" in self._enc_data_info:
special_map = self._enc_data_info["special_value"]
syn_order = enc_info.get("syn_order", [])
orig_map = enc_info.get("original_dtype", {})
conv_map = enc_info.get("converted_type", {})

# build base->cat mapping
base_to_cat = {}
for col in enc_schema.domain:
if col.endswith("_cat"):
base_col = col[:-4]
base_to_cat[base_col] = col
domain: Dict[str, Any] = {}

new_rules = {}
for target_col, cond_list in rules_dict.items():
actual_target_col = target_col
# If the target_col is known to have special values => rename
if target_col in special_map and target_col in base_to_cat:
actual_target_col = base_to_cat[target_col]
# For each column in syn_order, figure out the dtype constraints
for col in syn_order:
col_rules = []

new_cond_list = []
for (feat_col, op, val) in cond_list:
new_feat = feat_col
if feat_col in special_map and feat_col in base_to_cat:
# if val is in special_map[feat_col], direct to cat
# or if user is comparing negative codes, etc.
if any(v == val for v in special_map[feat_col]):
new_feat = base_to_cat[feat_col]
new_cond_list.append((new_feat, op, val))
original_dt = orig_map.get(col, "").lower()
converted_dt = conv_map.get(col, "").lower()

new_rules[actual_target_col] = new_cond_list
# Example logic (you can adapt):
if col.endswith("_cat"):
# definitely treat as category
col_rules.append((col, "dtype", "category"))
elif ("int" in original_dt or "float" in original_dt) and ("category" in converted_dt):
col_rules.append((col, "dtype", "category"))
elif ("object" in original_dt or "category" in original_dt) and ("numeric" in converted_dt):
col_rules.append((col, "dtype", "float"))
elif "date" in original_dt:
col_rules.append((col, "dtype", "int"))
else:
col_rules.append((col, "dtype", "float"))

return new_rules
# Build a local Constraints for this single column
single_constraints = Constraints(rules=col_rules)
# Then transform into a Distribution
dist = constraint_to_distribution(single_constraints, col)
domain[col] = dist

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "Syn_SeqPlugin":
# we do not directly use this `_fit`, it’s overshadowed by .fit() above
raise NotImplementedError("Use .fit()")
# Now build new Schema with that domain
new_schema = Schema(domain=domain)
new_schema.sampling_strategy = base_schema.sampling_strategy
new_schema.random_state = base_schema.random_state

return new_schema

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def generate(
Expand All @@ -198,47 +180,42 @@ def generate(
**kwargs: Any,
) -> DataLoader:
"""
Generate synthetic data from aggregator.
1) Combine constraints from original schema with user constraints
2) Possibly remap user rules -> cat columns for special_value
3) aggregator -> encoded DataFrame
4) decode back to original
5) final constraints check using original schema
Generate synthetic data by sampling from the aggregator, applying constraints,
and decoding back to the original schema.
"""
if not self.fitted:
raise RuntimeError("Must .fit() plugin before calling .generate()")
if self._orig_schema is None or self._enc_schema is None:
if self._schema is None:
raise RuntimeError("No schema found. Fit the model first.")

if random_state is not None:
enable_reproducible_results(random_state)

if count is None:
if self._data_info is not None:
count = self._data_info["len"]
else:
raise ValueError("Cannot determine 'count' for generation")
count = self._data_info["len"]

has_gen_cond = ("cond" in kwargs) and (kwargs["cond"] is not None)
if has_gen_cond and not self.expecting_conditional:
raise RuntimeError(
"Got generation conditional, but aggregator wasn't trained conditionally"
"Got inference conditional, but aggregator wasn't trained with a conditional"
)

# Combine constraints from the original schema
gen_constraints = self._orig_schema.as_constraints()
# Combine constraints from training schema with user constraints
gen_constraints = self.training_schema().as_constraints()
if constraints is not None:
gen_constraints = gen_constraints.extend(constraints)

# aggregator generation on encoded schema
data_syn = self._generate(count, gen_constraints, rules=rules, **kwargs)
# Build a schema from these constraints
syn_schema = Schema.from_constraints(gen_constraints)

# aggregator call
data_syn = self._generate(count, syn_schema, rules=rules, **kwargs)

# decode from the encoded data back to original
data_syn = data_syn.decode()

# final constraints check using the *original* schema
final_constraints = self._orig_schema.as_constraints()
# final constraints check
final_constraints = self.schema().as_constraints()
if constraints is not None:
final_constraints = final_constraints.extend(constraints)

Expand All @@ -251,28 +228,73 @@ def generate(
def _generate(
self,
count: int,
gen_constraints: Constraints,
syn_schema: Schema,
rules: Optional[Dict[str, List[Tuple[str, str, Any]]]] = None,
**kwargs: Any,
) -> DataLoader:
"""
1) Possibly remap user rules for special_value -> cat columns
2) aggregator => produce an *encoded* DataFrame
3) Force columns' dtypes (encoded schema)
Internal aggregator generation logic:
- Possibly remap rules to reference _cat columns if they specify special values
- Let aggregator do column-by-column generation
- Force the columns' dtypes according to syn_schema
"""
if not self.model:
raise RuntimeError("No aggregator model found")
raise RuntimeError("Aggregator not found for syn_seq plugin")

# Remap user rules to handle special values in _cat
# Remap user rules to handle special values in _cat columns
if rules is not None:
rules = self._remap_special_value_rules(rules, self._enc_schema)
rules = self._remap_special_value_rules(rules, syn_schema)

# aggregator generate
# Generate the data
df_syn = self.model.generate_col(count, rules=rules, max_iter_rules=10)
# Ensure correct dtypes (encoded)
df_syn = self._enc_schema.adapt_dtypes(df_syn)
# Ensure correct dtypes
df_syn = syn_schema.adapt_dtypes(df_syn)

return create_from_info(df_syn, self._data_info)

def _remap_special_value_rules(
self,
rules_dict: Dict[str, List[Tuple[str, str, Any]]],
syn_schema: Schema
) -> Dict[str, List[Tuple[str, str, Any]]]:
"""
If user wrote rules referencing special values (like -0.04) on numeric columns,
we switch them to the corresponding _cat column. This is a simple version:
If 'val' is in 'special_value[feat_col]', rename feat_col -> feat_col_cat.
"""
if not rules_dict:
return rules_dict

special_map = self._training_data_info.get("special_value", {})

# build a base->cat map
# e.g. if 'bp' => 'bp_cat' is in your domain
base_to_cat = {}
for col in syn_schema.domain:
if col.endswith("_cat"):
base_col = col[:-4]
base_to_cat[base_col] = col

new_rules = {}
for target_col, cond_list in rules_dict.items():
actual_target_col = target_col
# If target_col is known to have special values => rename
if target_col in special_map and target_col in base_to_cat:
actual_target_col = base_to_cat[target_col]

new_cond_list = []
for (feat_col, op, val) in cond_list:
new_feat = feat_col
# If feat_col references special values => rename
if feat_col in special_map and feat_col in base_to_cat:
if val in special_map[feat_col]:
new_feat = base_to_cat[feat_col]
new_cond_list.append((new_feat, op, val))

new_rules[actual_target_col] = new_cond_list

return new_rules

# Return as DataLoader with the *encoded* info
return create_from_info(df_syn, self._enc_data_info)

plugin = Syn_SeqPlugin
# Register plugin for the library
plugin = Syn_SeqPlugin
Loading

0 comments on commit 964ed13

Please sign in to comment.