From 758a01c987062d44503b71e3c503aeb4fd8c3ba7 Mon Sep 17 00:00:00 2001 From: minkeymouse Date: Tue, 18 Feb 2025 22:35:15 +0900 Subject: [PATCH] Incomplete base code 2 --- prompt.txt | 240 ++++++++++-------- .../plugins/core/models/syn_seq/syn_seq.py | 119 +++++---- .../core/models/syn_seq/syn_seq_preprocess.py | 134 +++++----- 3 files changed, 279 insertions(+), 214 deletions(-) diff --git a/prompt.txt b/prompt.txt index 16a2b6c6..f489332e 100644 --- a/prompt.txt +++ b/prompt.txt @@ -103,12 +103,22 @@ from typing import Optional, Dict, List, Any, Tuple class SynSeqPreprocessor: """ - 전처리(preprocess) & 후처리(postprocess) 클래스를 함수화하여 단계별로 깔끔하게 정리. - - - max_categories 로직을 넣어 user_dtypes에 없는 컬럼은 auto로 category/numeric 판단 - - 날짜(col_type == "date")이면 to_datetime - - 범주형(col_type == "category")이면 astype('category') - - numeric + special value -> (base_col, base_col_cat) 분리 + A class to perform preprocessing and postprocessing for syn_seq. + + Preprocessing: + - Records the original dtypes. + - Automatically assigns dtypes (date/category/numeric) when not provided. + - Converts date columns to datetime and category columns to 'category' dtype. + - For numeric columns with special values (user_special_values), creates a new + categorical column (named base_col_cat) that marks special values: + * If the value is in the special list, the cell is mapped to the special value. + * Otherwise, a numeric marker (set to len(specials)) is used. + + Postprocessing: + - Merges back the split (base_col, base_col_cat) columns: + If the base column is NaN and the corresponding _cat value is one of the special values, + then the base column is replaced with that special value. + - Optionally applies user-provided rules sequentially to filter rows. """ def __init__( @@ -119,15 +129,15 @@ class SynSeqPreprocessor: ): """ Args: - user_dtypes: {col: "date"/"category"/"numeric"} 등. (없으면 auto 결정) - user_special_values: {col: [특수값1, 특수값2, ...]} - max_categories: auto 판단 시, nunique <= max_categories 이면 category, else numeric + user_dtypes: {col: "date"/"category"/"numeric"}, if not provided, auto-detected. + user_special_values: {col: [special_value1, special_value2, ...]} + max_categories: When auto-detecting dtypes, if nunique <= max_categories, assign 'category', else 'numeric'. """ self.user_dtypes = user_dtypes or {} self.user_special_values = user_special_values or {} self.max_categories = max_categories - # 내부 저장용 + # Internal storage self.original_dtypes: Dict[str, str] = {} # {col: original_dtype} self.split_map: Dict[str, str] = {} # {base_col -> cat_col} self.detected_specials: Dict[str, List[Any]] = {} # user special values @@ -137,22 +147,24 @@ class SynSeqPreprocessor: # ========================================================================= def preprocess(self, df: pd.DataFrame) -> pd.DataFrame: """ - 1) 원본 dtype 기록 - 2) user_dtypes or auto 판단 -> date/category/numeric 세팅 - 3) numeric + special_value -> split + Preprocesses the DataFrame. + 1) Record original dtypes. + 2) Auto-assign or apply user-specified dtypes. + 3) Convert date and category columns appropriately. + 4) For numeric columns with special values, create a _cat column. """ df = df.copy() - # (a) 원본 dtype 저장 + # (a) Record original dtypes. self._record_original_dtypes(df) - # (b) user_dtypes 없는 컬럼은 auto -> category/numeric + # (b) Auto-assign dtypes for columns not specified in user_dtypes. self._auto_assign_dtypes(df) - # (c) user_dtypes 적용: date->datetime, category->astype('category'), numeric->그대로 + # (c) Apply the specified dtypes. self._apply_user_dtypes(df) - # (d) numeric + special_value split + # (d) Split numeric columns that have special values into (base_col, base_col_cat). self._split_numeric_columns(df) return df @@ -163,10 +175,10 @@ class SynSeqPreprocessor: def _auto_assign_dtypes(self, df: pd.DataFrame): """ - user_dtypes에 명시가 없으면, - - nuniq <= max_categories -> 'category' - - else 'numeric' - - 만약 datetime64 타입이면 'date'로 지정 + For columns not specified in user_dtypes, assign: + - 'date' if the column is a datetime type. + - 'category' if nunique <= max_categories. + - Otherwise, 'numeric'. """ for col in df.columns: if col in self.user_dtypes: @@ -187,9 +199,10 @@ class SynSeqPreprocessor: def _apply_user_dtypes(self, df: pd.DataFrame): """ - 1) date -> pd.to_datetime - 2) category -> astype('category') - 3) numeric -> 그대로 + Apply the user-specified or auto-assigned dtypes: + - Convert 'date' columns with pd.to_datetime. + - Convert 'category' columns with astype('category'). + - Leave 'numeric' columns unchanged. """ for col, dtype_str in self.user_dtypes.items(): if col not in df.columns: @@ -199,10 +212,17 @@ class SynSeqPreprocessor: df[col] = pd.to_datetime(df[col], errors="coerce") elif dtype_str == "category": df[col] = df[col].astype("category") - else: - pass + # numeric: no conversion def _split_numeric_columns(self, df: pd.DataFrame): + """ + For each column in user_special_values: + - Create a new categorical column (base_col_cat) that reflects special values. + - For each value in the base column: + If NaN -> return NaN. + If in specials -> return the special value. + Otherwise -> return len(specials) (a marker indicating "normal"). + """ for col, specials in self.user_special_values.items(): if col not in df.columns: continue @@ -211,53 +231,54 @@ class SynSeqPreprocessor: self.split_map[col] = cat_col self.detected_specials[col] = specials - # Need to use user-defined mapping to create _cat columns and assign categories. - # Base column stays intact, only _cat columns created in front of base columns - + # Remove existing cat_col if exists. if cat_col in df.columns: df.drop(columns=[cat_col], inplace=True) base_idx = df.columns.get_loc(col) df.insert(base_idx, cat_col, None) - df[cat_col] = df[col].apply(cat_mapper).astype("category") + + def cat_mapper(x, specials, normal_marker=None, missing_marker="NAN"): + if normal_marker is None: + normal_marker = "NUMERIC" + if pd.isna(x): + return missing_marker + elif x in specials: + return str(x) + else: + return normal_marker + + + df[cat_col] = df[col].apply(lambda x: cat_mapper(x, specials)).astype("category") # ========================================================================= # POSTPROCESS # ========================================================================= def postprocess(self, df: pd.DataFrame, rules: Optional[Dict[str, List[Tuple[str, str, Any]]]] = None) -> pd.DataFrame: """ - 합성 결과 후처리: - 1) split된 (base_col, cat_col) 복원 - 2) (Enhanced) rules를 순서대로 적용하여, 규칙에 맞지 않는 행들을 제거한다. - - 만약 if-then 조건이 있다면 그 순서대로 평가한다. - (날짜 offset 복원은 없음) + Postprocesses the synthetic DataFrame: + 1) Merge back split columns (base_col, base_col_cat) by replacing NaNs in the base column + with the corresponding special value from the _cat column. + 2) Apply user-provided rules sequentially to filter rows. + (Note: Date offset restoration is not performed.) """ df = df.copy() - # Merge split columns + # Merge split columns. df = self._merge_splitted_cols(df) - # If rules are provided, apply them in the given order. + # Apply rules if provided. if rules is not None: df = self.apply_rules(df, rules) return df def _merge_splitted_cols(self, df: pd.DataFrame) -> pd.DataFrame: - for base_col, cat_col in self.split_map.items(): - if base_col not in df.columns or cat_col not in df.columns: - continue + """ + For each (base_col, cat_col) pair in split_map, if a base column cell is special values, + check the corresponding cell in the cat_col. + If cat_col has "NUMERIC", leave the base_col as it is. If cat_col has "NAN", + Then drop the cat_col. + """ - specials = self.detected_specials.get(base_col, []) + # Need a logic here - for i in range(len(df)): - if pd.isna(df.at[i, base_col]): - cat_val = df.at[i, cat_col] - try: - possible_val = float(cat_val) - except: - possible_val = cat_val - if possible_val in specials: - df.at[i, base_col] = possible_val - else: - pass - df.drop(columns=[cat_col], inplace=True) return df def apply_rules(self, df: pd.DataFrame, rules: Dict[str, List[Tuple[str, str, Any]]]) -> pd.DataFrame: @@ -273,14 +294,10 @@ class SynSeqPreprocessor: Returns: A new DataFrame with rows not satisfying the rules dropped. """ - # Process each target column in the order of insertion (Python 3.7+ preserves insertion order) for target_col, rule_list in rules.items(): - # For each rule in the list, filter out rows that do not satisfy the rule. for (col_feat, operator, rule_val) in rule_list: - # If the target column is not in df, skip. if col_feat not in df.columns: continue - # Build a condition based on the operator. if operator in ["=", "=="]: cond = (df[col_feat] == rule_val) | df[col_feat].isna() elif operator == ">": @@ -293,7 +310,6 @@ class SynSeqPreprocessor: cond = (df[col_feat] <= rule_val) | df[col_feat].isna() else: cond = pd.Series(True, index=df.index) - # Drop rows that do not satisfy the condition. df = df.loc[cond].copy() return df @@ -331,6 +347,7 @@ from synthcity.logger import info, warning from synthcity.plugins.core.dataloader import DataLoader from synthcity.plugins.core.models.syn_seq.syn_seq_encoder import Syn_SeqEncoder + # Import the column-fitting and column-generating functions. from synthcity.plugins.core.models.syn_seq.methods import ( syn_cart, generate_cart, @@ -345,9 +362,7 @@ from synthcity.plugins.core.models.syn_seq.methods import ( syn_swr, generate_swr, ) -# ------------------------------------------------------------------ # Map method names to (training function, generation function) -# ------------------------------------------------------------------ METHOD_MAP: Dict[str, Tuple[Any, Any]] = { "cart": (syn_cart, generate_cart), "ctree": (syn_ctree, generate_ctree), @@ -361,9 +376,7 @@ METHOD_MAP: Dict[str, Tuple[Any, Any]] = { "swr": (syn_swr, generate_swr), } -# ------------------------------------------------------------------ -# Syn_Seq: Column-by-column aggregator for sequential synthesis. -# ------------------------------------------------------------------ + class Syn_Seq: def __init__( self, @@ -374,41 +387,38 @@ class Syn_Seq: """ Args: random_state: Random seed. - strict: (Unused now; rule-checking is handled later.) + strict: (Unused now; rule‐checking is handled later.) sampling_patience: (Unused now.) """ self.random_state = random_state self.strict = strict self.sampling_patience = sampling_patience - self.special_values = Dict[str, List[Any]] = {} + self.special_values: Dict[str, List[Any]] = {} # mapping: col -> list of special values self._model_trained = False self._syn_order: List[str] = [] self._method_map: Dict[str, str] = {} self._varsel: Dict[str, List[str]] = {} self._col_models: Dict[str, Dict[str, Any]] = {} - # Store the real distribution for the first column and columns with special values. - self._stored_col_data: = None + # Store the real distribution for the first column and for columns with special values. + self._stored_col_data: Dict[str, np.ndarray] = {} - def fit_col(self, loader: DataLoader, *args: Any, **kwargs: Any) -> "Syn_Seq": + def fit_col(self, loader: Any, *args: Any, **kwargs: Any) -> "Syn_Seq": """ - Fit column-by-column using metadata from the loader. - 1) Retrieve info (syn_order, method, variable_selection). - 2) For columns ending with "_cat", force aggregator "cart". - 3) For the first column, store its real distribution. - 4) For each subsequent column, train its aggregator using preceding columns. + Fit column‐by‐column using metadata from the loader. """ info_dict = loader.info() training_data = loader.dataframe().copy() if training_data.empty: raise ValueError("No data => cannot fit Syn_Seq aggregator") + # Set syn_order, method mapping, variable selection, and special values. self._syn_order = info_dict.get("syn_order", list(training_data.columns)) self._method_map = info_dict.get("method", {}) - self.special_values = info_dict.get("special_vales", {}) + self.special_values = info_dict.get("special_values", {}) self._varsel = info_dict.get("variable_selection", {}) - # For auto-injected _cat columns, force method "cart" and mirror variable selection. + # For auto-injected _cat columns, force aggregator "cart" and mirror variable selection. for col in self._syn_order: if col.endswith("_cat"): self._method_map[col] = "cart" @@ -423,7 +433,15 @@ class Syn_Seq: # (3) Store the real distribution from the first column. first_col = self._syn_order[0] - self._stored_col_data = training_data[first_col].dropna().values + self._stored_col_data[first_col] = training_data[first_col].dropna().values + + # For columns with special values, store all non-null values that are NOT special. + for col, specials in self.special_values.items(): + # Filter rows where the column's value is not in specials. + filtered = training_data[~training_data[col].isin(specials)] + # Drop any NaNs and store the underlying values. + self._stored_col_data[col] = filtered[col].dropna().values + print(f"Fitting '{first_col}' => stored distribution from real data. Done.") # (4) For each subsequent column, train its aggregator. @@ -433,66 +451,84 @@ class Syn_Seq: preds_list = self._varsel.get(col, self._syn_order[:i]) y = training_data[col].values X = training_data[preds_list].values - mask = (~pd.isna(y)) + # If the column has special values, drop rows where y is one of those special values. + if col in self.special_values: + specials = self.special_values[col] + mask = mask & (~np.isin(y, specials)) X_ = X[mask] y_ = y[mask] - print(f"Fitting '{col}' with '{method_name}' ... ", end="", flush=True) - self._col_models[col] = self._fit_single_col(method_name, X_, y_) + try: + self._col_models[col] = self._fit_single_col(method_name, X_, y_) + except Exception as e: + print(f"Error fitting column {col}: {e}. Falling back to swr.", end=" ") + try: + self._col_models[col] = self._fit_single_col("swr", X, y) + except Exception as e2: + print(f"Fallback swr also failed for {col}: {e2}. Storing None.", end=" ") + self._col_models[col] = None print("Done!") self._model_trained = True return self def _fit_single_col(self, method_name: str, X: np.ndarray, y: np.ndarray) -> Dict[str, Any]: + """ + Fit a single column using the specified method. + """ fit_func, _ = METHOD_MAP[method_name] model = fit_func(y, X, random_state=self.random_state) return {"name": method_name, "fitted_model": model} - def generate_col(self, nrows: int) -> pd.DataFrame: + def generate_col(self, count: int) -> pd.DataFrame: """ - Generate `nrows` rows, column by column. - (No rule checking is performed here.) + Generate `count` rows, column-by-column. """ if not self._model_trained: raise RuntimeError("Syn_Seq aggregator not yet fitted") - if nrows <= 0: + if count <= 0: return pd.DataFrame(columns=self._syn_order) - gen_df = pd.DataFrame({col: [np.nan] * nrows for col in self._syn_order}) + # Initialize a DataFrame with NaN values. + gen_df = pd.DataFrame({col: [np.nan] * count for col in self._syn_order}) - # (1) Generate the first column. + # (1) Generate the first column using the stored real distribution. first_col = self._syn_order[0] - if self._stored_col_data is not None and len(self._stored_col_data[first_col]) > 0: - gen_df[first_col] = np.random.choice(self._stored_col_data[first_col], size=nrows, replace=True) + if self._stored_col_data.get(first_col) is not None and len(self._stored_col_data[first_col]) > 0: + gen_df[first_col] = np.random.choice(self._stored_col_data[first_col], size=count, replace=True) else: gen_df[first_col] = 0 print(f"Generating '{first_col}' => done.") - + # (2) Generate subsequent columns. for col in self._syn_order[1:]: method_name = self._method_map.get(col, "cart") idx = self._syn_order.index(col) preds_list = self._varsel.get(col, self._syn_order[:idx]) - - Xsyn = gen_df[preds_list].values - ysyn = self._generate_single_col(method_name, Xsyn, col) - gen_df[col] = ysyn + if col in self.special_values: + Xsyn_num = gen_df[gen_df["f{col}_col"] == "NUMERIC"] + ysyn_num = self._generate_single_col(method_name, Xsyn_num, col) + Xsyn_special = gen_df[~(gen_df["f{col}_col"] == "NUMERIC")] + ysyn_special = self._generate_single_col(method_name, Xsyn_special, col) + Xsyn = pd.concat(Xsyn_num, Xsyn_special) + ysyn = pd.concat(ysyn_num, ysyn_special) + else: + Xsyn = gen_df[preds_list].values + ysyn = self._generate_single_col(method_name, Xsyn, col) + gen_df[col] = ysyn print(f"Generating '{col}' => done.") return gen_df def _generate_single_col(self, method_name: str, Xsyn: np.ndarray, col: str) -> np.ndarray: - - # Need the logic here which if generating the data throws an error due to issues like highly imbalanced data so previous variable has only one value, etc, we need to use the stored data to sample the values of that variable. - # This only occurrs for columns with special values so I tried to store the data of columns of special values. - - # Following code existed before but I don't understand the usage. - # if col not in self._col_models: - # if self._stored_col_data is not None and len(self._stored_col_data) > 0: - # return np.random.choice(self._stored_col_data, size=len(Xsyn), replace=True) - # else: - # return np.zeros(len(Xsyn)) + """ + Generate synthetic values for a single column using the fitted model. + """ + if col not in self._col_models or self._col_models[col] is None: + if col in self._stored_col_data and len(self._stored_col_data[col]) > 0: + return np.random.choice(self._stored_col_data[col], size=len(Xsyn), replace=True) + else: + return np.zeros(len(Xsyn)) fit_info = self._col_models[col] _, generate_func = METHOD_MAP[fit_info["name"]] diff --git a/src/synthcity/plugins/core/models/syn_seq/syn_seq.py b/src/synthcity/plugins/core/models/syn_seq/syn_seq.py index 1c1c53bc..dfdfdec3 100644 --- a/src/synthcity/plugins/core/models/syn_seq/syn_seq.py +++ b/src/synthcity/plugins/core/models/syn_seq/syn_seq.py @@ -6,6 +6,7 @@ from synthcity.plugins.core.dataloader import DataLoader from synthcity.plugins.core.models.syn_seq.syn_seq_encoder import Syn_SeqEncoder + # Import the column-fitting and column-generating functions. from synthcity.plugins.core.models.syn_seq.methods import ( syn_cart, generate_cart, @@ -20,9 +21,7 @@ syn_swr, generate_swr, ) -# ------------------------------------------------------------------ # Map method names to (training function, generation function) -# ------------------------------------------------------------------ METHOD_MAP: Dict[str, Tuple[Any, Any]] = { "cart": (syn_cart, generate_cart), "ctree": (syn_ctree, generate_ctree), @@ -36,9 +35,7 @@ "swr": (syn_swr, generate_swr), } -# ------------------------------------------------------------------ -# Syn_Seq: Column-by-column aggregator for sequential synthesis. -# ------------------------------------------------------------------ + class Syn_Seq: def __init__( self, @@ -49,41 +46,38 @@ def __init__( """ Args: random_state: Random seed. - strict: (Unused now; rule-checking is handled later.) + strict: (Unused now; rule‐checking is handled later.) sampling_patience: (Unused now.) """ self.random_state = random_state self.strict = strict self.sampling_patience = sampling_patience - self.special_values = Dict[str, List[Any]] = {} + self.special_values: Dict[str, List[Any]] = {} # mapping: col -> list of special values self._model_trained = False self._syn_order: List[str] = [] self._method_map: Dict[str, str] = {} self._varsel: Dict[str, List[str]] = {} self._col_models: Dict[str, Dict[str, Any]] = {} - # Store the real distribution for the first column and columns with special values. - self._stored_col_data: = None + # Store the real distribution for the first column and for columns with special values. + self._stored_col_data: Dict[str, np.ndarray] = {} - def fit_col(self, loader: DataLoader, *args: Any, **kwargs: Any) -> "Syn_Seq": + def fit_col(self, loader: Any, *args: Any, **kwargs: Any) -> "Syn_Seq": """ - Fit column-by-column using metadata from the loader. - 1) Retrieve info (syn_order, method, variable_selection). - 2) For columns ending with "_cat", force aggregator "cart". - 3) For the first column, store its real distribution. - 4) For each subsequent column, train its aggregator using preceding columns. + Fit column‐by‐column using metadata from the loader. """ info_dict = loader.info() training_data = loader.dataframe().copy() if training_data.empty: raise ValueError("No data => cannot fit Syn_Seq aggregator") + # Set syn_order, method mapping, variable selection, and special values. self._syn_order = info_dict.get("syn_order", list(training_data.columns)) self._method_map = info_dict.get("method", {}) - self.special_values = info_dict.get("special_vales", {}) + self.special_values = info_dict.get("special_values", {}) self._varsel = info_dict.get("variable_selection", {}) - # For auto-injected _cat columns, force method "cart" and mirror variable selection. + # For auto-injected _cat columns, force aggregator "cart" and mirror variable selection. for col in self._syn_order: if col.endswith("_cat"): self._method_map[col] = "cart" @@ -98,7 +92,15 @@ def fit_col(self, loader: DataLoader, *args: Any, **kwargs: Any) -> "Syn_Seq": # (3) Store the real distribution from the first column. first_col = self._syn_order[0] - self._stored_col_data = training_data[first_col].dropna().values + self._stored_col_data[first_col] = training_data[first_col].dropna().values + + # For columns with special values, store all non-null values that are NOT special. + for col, specials in self.special_values.items(): + # Filter rows where the column's value is not in specials. + filtered = training_data[~training_data[col].isin(specials)] + # Drop any NaNs and store the underlying values. + self._stored_col_data[col] = filtered[col].dropna().values + print(f"Fitting '{first_col}' => stored distribution from real data. Done.") # (4) For each subsequent column, train its aggregator. @@ -108,74 +110,85 @@ def fit_col(self, loader: DataLoader, *args: Any, **kwargs: Any) -> "Syn_Seq": preds_list = self._varsel.get(col, self._syn_order[:i]) y = training_data[col].values X = training_data[preds_list].values - mask = (~pd.isna(y)) + # If the column has special values, drop rows where y is one of those special values. + if col in self.special_values: + specials = self.special_values[col] + mask = mask & (~np.isin(y, specials)) X_ = X[mask] y_ = y[mask] - print(f"Fitting '{col}' with '{method_name}' ... ", end="", flush=True) - self._col_models[col] = self._fit_single_col(method_name, X_, y_) + try: + self._col_models[col] = self._fit_single_col(method_name, X_, y_) + except Exception as e: + print(f"Error fitting column {col}: {e}. Falling back to swr.", end=" ") + try: + self._col_models[col] = self._fit_single_col("swr", X, y) + except Exception as e2: + print(f"Fallback swr also failed for {col}: {e2}. Storing None.", end=" ") + self._col_models[col] = None print("Done!") self._model_trained = True return self def _fit_single_col(self, method_name: str, X: np.ndarray, y: np.ndarray) -> Dict[str, Any]: + """ + Fit a single column using the specified method. + """ fit_func, _ = METHOD_MAP[method_name] - try: - model = fit_func(y, X, random_state=self.random_state) - except: - # Need the logic here which if generating the data throws an error due to issues like highly imbalanced data so previous variable has only one value, etc, we need to use the stored data to sample the values of that variable. - # This only occurrs for columns with special values so I tried to store the data of columns of special values. - # This means we need to add the fallback here in case of model fitting issue - + model = fit_func(y, X, random_state=self.random_state) return {"name": method_name, "fitted_model": model} - def generate_col(self, nrows: int) -> pd.DataFrame: + def generate_col(self, count: int) -> pd.DataFrame: """ - Generate `nrows` rows, column by column. - (No rule checking is performed here.) + Generate `count` rows, column-by-column. """ if not self._model_trained: raise RuntimeError("Syn_Seq aggregator not yet fitted") - if nrows <= 0: + if count <= 0: return pd.DataFrame(columns=self._syn_order) - gen_df = pd.DataFrame({col: [np.nan] * nrows for col in self._syn_order}) + # Initialize a DataFrame with NaN values. + gen_df = pd.DataFrame({col: [np.nan] * count for col in self._syn_order}) - # (1) Generate the first column. + # (1) Generate the first column using the stored real distribution. first_col = self._syn_order[0] - if self._stored_col_data is not None and len(self._stored_col_data[first_col]) > 0: - gen_df[first_col] = np.random.choice(self._stored_col_data[first_col], size=nrows, replace=True) + if self._stored_col_data.get(first_col) is not None and len(self._stored_col_data[first_col]) > 0: + gen_df[first_col] = np.random.choice(self._stored_col_data[first_col], size=count, replace=True) else: gen_df[first_col] = 0 print(f"Generating '{first_col}' => done.") - + # (2) Generate subsequent columns. for col in self._syn_order[1:]: method_name = self._method_map.get(col, "cart") idx = self._syn_order.index(col) preds_list = self._varsel.get(col, self._syn_order[:idx]) - - Xsyn = gen_df[preds_list].values - ysyn = self._generate_single_col(method_name, Xsyn, col) - gen_df[col] = ysyn + if col in self.special_values: + Xsyn_num = gen_df[gen_df["f{col}_col"] == "NUMERIC"] + ysyn_num = self._generate_single_col(method_name, Xsyn_num, col) + Xsyn_special = gen_df[~(gen_df["f{col}_col"] == "NUMERIC")] + ysyn_special = self._generate_single_col(method_name, Xsyn_special, col) + Xsyn = pd.concat(Xsyn_num, Xsyn_special) + ysyn = pd.concat(ysyn_num, ysyn_special) + else: + Xsyn = gen_df[preds_list].values + ysyn = self._generate_single_col(method_name, Xsyn, col) + gen_df[col] = ysyn print(f"Generating '{col}' => done.") return gen_df def _generate_single_col(self, method_name: str, Xsyn: np.ndarray, col: str) -> np.ndarray: - - # Need the logic here which if generating the data throws an error due to issues like highly imbalanced data so previous variable has only one value, etc, we need to use the stored data to sample the values of that variable. - # This only occurrs for columns with special values so I tried to store the data of columns of special values. - # This means we need to add the fallback here in case of model fitting issue - - # Following code existed before but I don't understand the usage. - # if col not in self._col_models: - # if self._stored_col_data is not None and len(self._stored_col_data) > 0: - # return np.random.choice(self._stored_col_data, size=len(Xsyn), replace=True) - # else: - # return np.zeros(len(Xsyn)) - + """ + Generate synthetic values for a single column using the fitted model. + """ + if col not in self._col_models or self._col_models[col] is None: + if col in self._stored_col_data and len(self._stored_col_data[col]) > 0: + return np.random.choice(self._stored_col_data[col], size=len(Xsyn), replace=True) + else: + return np.zeros(len(Xsyn)) + fit_info = self._col_models[col] _, generate_func = METHOD_MAP[fit_info["name"]] return generate_func(fit_info["fitted_model"], Xsyn) diff --git a/src/synthcity/plugins/core/models/syn_seq/syn_seq_preprocess.py b/src/synthcity/plugins/core/models/syn_seq/syn_seq_preprocess.py index 20b95279..98ab8464 100644 --- a/src/synthcity/plugins/core/models/syn_seq/syn_seq_preprocess.py +++ b/src/synthcity/plugins/core/models/syn_seq/syn_seq_preprocess.py @@ -7,12 +7,22 @@ class SynSeqPreprocessor: """ - 전처리(preprocess) & 후처리(postprocess) 클래스를 함수화하여 단계별로 깔끔하게 정리. - - - max_categories 로직을 넣어 user_dtypes에 없는 컬럼은 auto로 category/numeric 판단 - - 날짜(col_type == "date")이면 to_datetime - - 범주형(col_type == "category")이면 astype('category') - - numeric + special value -> (base_col, base_col_cat) 분리 + A class to perform preprocessing and postprocessing for syn_seq. + + Preprocessing: + - Records the original dtypes. + - Automatically assigns dtypes (date/category/numeric) when not provided. + - Converts date columns to datetime and category columns to 'category' dtype. + - For numeric columns with special values (user_special_values), creates a new + categorical column (named base_col_cat) that marks special values: + * If the value is in the special list, the cell is mapped to the special value. + * Otherwise, a numeric marker (set to len(specials)) is used. + + Postprocessing: + - Merges back the split (base_col, base_col_cat) columns: + If the base column is NaN and the corresponding _cat value is one of the special values, + then the base column is replaced with that special value. + - Optionally applies user-provided rules sequentially to filter rows. """ def __init__( @@ -23,15 +33,15 @@ def __init__( ): """ Args: - user_dtypes: {col: "date"/"category"/"numeric"} 등. (없으면 auto 결정) - user_special_values: {col: [특수값1, 특수값2, ...]} - max_categories: auto 판단 시, nunique <= max_categories 이면 category, else numeric + user_dtypes: {col: "date"/"category"/"numeric"}, if not provided, auto-detected. + user_special_values: {col: [special_value1, special_value2, ...]} + max_categories: When auto-detecting dtypes, if nunique <= max_categories, assign 'category', else 'numeric'. """ self.user_dtypes = user_dtypes or {} self.user_special_values = user_special_values or {} self.max_categories = max_categories - # 내부 저장용 + # Internal storage self.original_dtypes: Dict[str, str] = {} # {col: original_dtype} self.split_map: Dict[str, str] = {} # {base_col -> cat_col} self.detected_specials: Dict[str, List[Any]] = {} # user special values @@ -41,22 +51,24 @@ def __init__( # ========================================================================= def preprocess(self, df: pd.DataFrame) -> pd.DataFrame: """ - 1) 원본 dtype 기록 - 2) user_dtypes or auto 판단 -> date/category/numeric 세팅 - 3) numeric + special_value -> split + Preprocesses the DataFrame. + 1) Record original dtypes. + 2) Auto-assign or apply user-specified dtypes. + 3) Convert date and category columns appropriately. + 4) For numeric columns with special values, create a _cat column. """ df = df.copy() - # (a) 원본 dtype 저장 + # (a) Record original dtypes. self._record_original_dtypes(df) - # (b) user_dtypes 없는 컬럼은 auto -> category/numeric + # (b) Auto-assign dtypes for columns not specified in user_dtypes. self._auto_assign_dtypes(df) - # (c) user_dtypes 적용: date->datetime, category->astype('category'), numeric->그대로 + # (c) Apply the specified dtypes. self._apply_user_dtypes(df) - # (d) numeric + special_value split + # (d) Split numeric columns that have special values into (base_col, base_col_cat). self._split_numeric_columns(df) return df @@ -67,10 +79,10 @@ def _record_original_dtypes(self, df: pd.DataFrame): def _auto_assign_dtypes(self, df: pd.DataFrame): """ - user_dtypes에 명시가 없으면, - - nuniq <= max_categories -> 'category' - - else 'numeric' - - 만약 datetime64 타입이면 'date'로 지정 + For columns not specified in user_dtypes, assign: + - 'date' if the column is a datetime type. + - 'category' if nunique <= max_categories. + - Otherwise, 'numeric'. """ for col in df.columns: if col in self.user_dtypes: @@ -91,9 +103,10 @@ def _auto_assign_dtypes(self, df: pd.DataFrame): def _apply_user_dtypes(self, df: pd.DataFrame): """ - 1) date -> pd.to_datetime - 2) category -> astype('category') - 3) numeric -> 그대로 + Apply the user-specified or auto-assigned dtypes: + - Convert 'date' columns with pd.to_datetime. + - Convert 'category' columns with astype('category'). + - Leave 'numeric' columns unchanged. """ for col, dtype_str in self.user_dtypes.items(): if col not in df.columns: @@ -103,10 +116,17 @@ def _apply_user_dtypes(self, df: pd.DataFrame): df[col] = pd.to_datetime(df[col], errors="coerce") elif dtype_str == "category": df[col] = df[col].astype("category") - else: - pass + # numeric: no conversion def _split_numeric_columns(self, df: pd.DataFrame): + """ + For each column in user_special_values: + - Create a new categorical column (base_col_cat) that reflects special values. + - For each value in the base column: + If NaN -> return NaN. + If in specials -> return the special value. + Otherwise -> return len(specials) (a marker indicating "normal"). + """ for col, specials in self.user_special_values.items(): if col not in df.columns: continue @@ -115,53 +135,54 @@ def _split_numeric_columns(self, df: pd.DataFrame): self.split_map[col] = cat_col self.detected_specials[col] = specials - # Need to use user-defined mapping to create _cat columns and assign categories. - # Base column stays intact, only _cat columns created in front of base columns - + # Remove existing cat_col if exists. if cat_col in df.columns: df.drop(columns=[cat_col], inplace=True) base_idx = df.columns.get_loc(col) df.insert(base_idx, cat_col, None) - df[cat_col] = df[col].apply(cat_mapper).astype("category") + + def cat_mapper(x, specials, normal_marker=None, missing_marker="NAN"): + if normal_marker is None: + normal_marker = "NUMERIC" + if pd.isna(x): + return missing_marker + elif x in specials: + return str(x) + else: + return normal_marker + + + df[cat_col] = df[col].apply(lambda x: cat_mapper(x, specials)).astype("category") # ========================================================================= # POSTPROCESS # ========================================================================= def postprocess(self, df: pd.DataFrame, rules: Optional[Dict[str, List[Tuple[str, str, Any]]]] = None) -> pd.DataFrame: """ - 합성 결과 후처리: - 1) split된 (base_col, cat_col) 복원 - 2) (Enhanced) rules를 순서대로 적용하여, 규칙에 맞지 않는 행들을 제거한다. - - 만약 if-then 조건이 있다면 그 순서대로 평가한다. - (날짜 offset 복원은 없음) + Postprocesses the synthetic DataFrame: + 1) Merge back split columns (base_col, base_col_cat) by replacing NaNs in the base column + with the corresponding special value from the _cat column. + 2) Apply user-provided rules sequentially to filter rows. + (Note: Date offset restoration is not performed.) """ df = df.copy() - # Merge split columns + # Merge split columns. df = self._merge_splitted_cols(df) - # If rules are provided, apply them in the given order. + # Apply rules if provided. if rules is not None: df = self.apply_rules(df, rules) return df def _merge_splitted_cols(self, df: pd.DataFrame) -> pd.DataFrame: - for base_col, cat_col in self.split_map.items(): - if base_col not in df.columns or cat_col not in df.columns: - continue + """ + For each (base_col, cat_col) pair in split_map, if a base column cell is special values, + check the corresponding cell in the cat_col. + If cat_col has "NUMERIC", leave the base_col as it is. If cat_col has "NAN", + Then drop the cat_col. + """ + + # Need a logic here - specials = self.detected_specials.get(base_col, []) - - for i in range(len(df)): - if pd.isna(df.at[i, base_col]): - cat_val = df.at[i, cat_col] - try: - possible_val = float(cat_val) - except: - possible_val = cat_val - if possible_val in specials: - df.at[i, base_col] = possible_val - else: - pass - df.drop(columns=[cat_col], inplace=True) return df def apply_rules(self, df: pd.DataFrame, rules: Dict[str, List[Tuple[str, str, Any]]]) -> pd.DataFrame: @@ -177,14 +198,10 @@ def apply_rules(self, df: pd.DataFrame, rules: Dict[str, List[Tuple[str, str, An Returns: A new DataFrame with rows not satisfying the rules dropped. """ - # Process each target column in the order of insertion (Python 3.7+ preserves insertion order) for target_col, rule_list in rules.items(): - # For each rule in the list, filter out rows that do not satisfy the rule. for (col_feat, operator, rule_val) in rule_list: - # If the target column is not in df, skip. if col_feat not in df.columns: continue - # Build a condition based on the operator. if operator in ["=", "=="]: cond = (df[col_feat] == rule_val) | df[col_feat].isna() elif operator == ">": @@ -197,6 +214,5 @@ def apply_rules(self, df: pd.DataFrame, rules: Dict[str, List[Tuple[str, str, An cond = (df[col_feat] <= rule_val) | df[col_feat].isna() else: cond = pd.Series(True, index=df.index) - # Drop rows that do not satisfy the condition. df = df.loc[cond].copy() return df