From 964ed1386cd26cb93b5e51b5058526c995cd81bc Mon Sep 17 00:00:00 2001 From: minkeymouse Date: Wed, 22 Jan 2025 13:01:57 +0900 Subject: [PATCH] import error fixed --- .../plugins/generic/plugin_syn_seq.py | 254 ++++++++------- .../plugins/generic/plugin_syn_seq.ipynb | 291 ++++++++++-------- 2 files changed, 294 insertions(+), 251 deletions(-) diff --git a/src/synthcity/plugins/generic/plugin_syn_seq.py b/src/synthcity/plugins/generic/plugin_syn_seq.py index 56285c47..bf0c3680 100644 --- a/src/synthcity/plugins/generic/plugin_syn_seq.py +++ b/src/synthcity/plugins/generic/plugin_syn_seq.py @@ -25,17 +25,11 @@ class Syn_SeqPlugin(Plugin): A plugin wrapping the 'Syn_Seq' aggregator in the synthcity Plugin interface. Steps: - 1) In .fit(), if the user passes a DataFrame, we wrap it in Syn_SeqDataLoader, then encode the data. - 2) We keep separate: - - self._orig_schema => schema from the original data - - self._enc_schema => schema from the encoded data + 1) In .fit(), if the user passes a DataFrame, we wrap it in Syn_SeqDataLoader, then call .encode(). + 2) We build or refine the domain in `_domain_rebuild`, using the original vs. converted dtype info, + turning them into constraints, then into distributions. 3) The aggregator trains column-by-column on the encoded data. - 4) In .generate(), we re-check constraints (including user constraints) referencing - the original schema. Then we decode back to the original DataFrame structure. - - Additional note: We add `_remap_special_value_rules` to handle user rules referencing - special values that belong in a 'cat' column, e.g. if user writes - ("NUM_CIGAR", "<", 0) but that data is actually in "NUM_CIGAR_cat". + 4) For .generate(), we re-check constraints (including user constraints) and decode back to the original format. """ @staticmethod @@ -48,7 +42,7 @@ def type() -> str: @staticmethod def hyperparameter_space(**kwargs: Any) -> List: - # No tunable hyperparameters here + # No tunable hyperparameters for demonstration return [] @validate_arguments(config=dict(arbitrary_types_allowed=True)) @@ -68,22 +62,18 @@ def __init__( compress_dataset=compress_dataset, sampling_strategy=sampling_strategy, ) - # Two separate schema references - self._orig_schema: Optional[Schema] = None - self._enc_schema: Optional[Schema] = None - self._data_info: Optional[Dict] = None - self._enc_data_info: Optional[Dict] = None + self._schema: Optional[Schema] = None + self._training_schema: Optional[Schema] = None + self._data_info: Optional[Dict] = None + self._training_data_info: Optional[Dict] = {} self.model: Optional[Syn_Seq] = None @validate_arguments(config=dict(arbitrary_types_allowed=True)) def fit(self, X: Union[DataLoader, pd.DataFrame], *args: Any, **kwargs: Any) -> Any: """ Wrap a plain DataFrame into Syn_SeqDataLoader if needed, then encode the data. - Build up: - - self._orig_schema from the original data - - self._enc_schema from the encoded data - Then train the aggregator column-by-column on encoded data. + Build up the schema from original data vs. encoded data, and train the aggregator. """ # If plain DataFrame, wrap in Syn_SeqDataLoader if isinstance(X, pd.DataFrame): @@ -95,98 +85,90 @@ def fit(self, X: Union[DataLoader, pd.DataFrame], *args: Any, **kwargs: Any) -> enable_reproducible_results(self.random_state) self._data_info = X.info() - # Build schema for the original data - self._orig_schema = Schema( + # Build schema for the *original* data + self._schema = Schema( data=X, sampling_strategy=self.sampling_strategy, random_state=self.random_state, ) # Encode the data - X_encoded, self._enc_data_info = X.encode() - # Build a schema from the encoded data - self._enc_schema = Schema( + X_encoded, self._training_data_info = X.encode() + + # Build an initial schema from the *encoded* data + base_schema = Schema( data=X_encoded, sampling_strategy=self.sampling_strategy, random_state=self.random_state, ) + # Rebuild domain from original vs. converted dtype logic + self._training_schema = self._domain_rebuild(X_encoded, base_schema) + # aggregator training + output = self._fit(X_encoded, *args, **kwargs) + self.fitted = True + return output + + def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "Syn_SeqPlugin": + """ + Train the aggregator column-by-column using the encoded DataLoader. + """ self.model = Syn_Seq( random_state=self.random_state, strict=self.strict, sampling_patience=self.sampling_patience, ) - self.model.fit_col(X_encoded, *args, **kwargs) - self.fitted = True + self.model.fit_col(X, *args, **kwargs) return self - def training_schema(self) -> Schema: - """Return the *encoded* schema used for aggregator training.""" - if self._enc_schema is None: - raise RuntimeError("No encoded schema found. Fit the model first.") - return self._enc_schema - - def schema(self) -> Schema: - """Return the original data's schema.""" - if self._orig_schema is None: - raise RuntimeError("No original schema found. Fit the model first.") - return self._orig_schema - - def _remap_special_value_rules( - self, - rules_dict: Dict[str, List[Tuple[str, str, Any]]], - enc_schema: Schema - ) -> Dict[str, List[Tuple[str, str, Any]]]: + def _domain_rebuild(self, X_encoded: DataLoader, base_schema: Schema) -> Schema: """ - If user wrote rules referencing special values (like -8) on numeric columns, - we switch them to the corresponding _cat column. For example: - - "NUM_CIGAR" has special_value list [-8]. - - user sets rule (NUM_CIGAR, "<", 0). - => Actually these negative codes are stored in "NUM_CIGAR_cat". - So we redirect the rule to "NUM_CIGAR_cat". + Build new domain using feature_params & constraint_to_distribution. + + For each column in the encoded data, gather basic "dtype" constraints, + transform them into Distribution objects, and create a new schema. """ - if not rules_dict: - return rules_dict + enc_info = X_encoded.info() - # gather the special_value map from the *encoded* info - # (we assume `_enc_data_info["special_value"]` has your col->list). - special_map = {} - if self._enc_data_info and "special_value" in self._enc_data_info: - special_map = self._enc_data_info["special_value"] + syn_order = enc_info.get("syn_order", []) + orig_map = enc_info.get("original_dtype", {}) + conv_map = enc_info.get("converted_type", {}) - # build base->cat mapping - base_to_cat = {} - for col in enc_schema.domain: - if col.endswith("_cat"): - base_col = col[:-4] - base_to_cat[base_col] = col + domain: Dict[str, Any] = {} - new_rules = {} - for target_col, cond_list in rules_dict.items(): - actual_target_col = target_col - # If the target_col is known to have special values => rename - if target_col in special_map and target_col in base_to_cat: - actual_target_col = base_to_cat[target_col] + # For each column in syn_order, figure out the dtype constraints + for col in syn_order: + col_rules = [] - new_cond_list = [] - for (feat_col, op, val) in cond_list: - new_feat = feat_col - if feat_col in special_map and feat_col in base_to_cat: - # if val is in special_map[feat_col], direct to cat - # or if user is comparing negative codes, etc. - if any(v == val for v in special_map[feat_col]): - new_feat = base_to_cat[feat_col] - new_cond_list.append((new_feat, op, val)) + original_dt = orig_map.get(col, "").lower() + converted_dt = conv_map.get(col, "").lower() - new_rules[actual_target_col] = new_cond_list + # Example logic (you can adapt): + if col.endswith("_cat"): + # definitely treat as category + col_rules.append((col, "dtype", "category")) + elif ("int" in original_dt or "float" in original_dt) and ("category" in converted_dt): + col_rules.append((col, "dtype", "category")) + elif ("object" in original_dt or "category" in original_dt) and ("numeric" in converted_dt): + col_rules.append((col, "dtype", "float")) + elif "date" in original_dt: + col_rules.append((col, "dtype", "int")) + else: + col_rules.append((col, "dtype", "float")) - return new_rules + # Build a local Constraints for this single column + single_constraints = Constraints(rules=col_rules) + # Then transform into a Distribution + dist = constraint_to_distribution(single_constraints, col) + domain[col] = dist - @validate_arguments(config=dict(arbitrary_types_allowed=True)) - def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "Syn_SeqPlugin": - # we do not directly use this `_fit`, it’s overshadowed by .fit() above - raise NotImplementedError("Use .fit()") + # Now build new Schema with that domain + new_schema = Schema(domain=domain) + new_schema.sampling_strategy = base_schema.sampling_strategy + new_schema.random_state = base_schema.random_state + + return new_schema @validate_arguments(config=dict(arbitrary_types_allowed=True)) def generate( @@ -198,47 +180,42 @@ def generate( **kwargs: Any, ) -> DataLoader: """ - Generate synthetic data from aggregator. - - 1) Combine constraints from original schema with user constraints - 2) Possibly remap user rules -> cat columns for special_value - 3) aggregator -> encoded DataFrame - 4) decode back to original - 5) final constraints check using original schema + Generate synthetic data by sampling from the aggregator, applying constraints, + and decoding back to the original schema. """ if not self.fitted: raise RuntimeError("Must .fit() plugin before calling .generate()") - if self._orig_schema is None or self._enc_schema is None: + if self._schema is None: raise RuntimeError("No schema found. Fit the model first.") if random_state is not None: enable_reproducible_results(random_state) if count is None: - if self._data_info is not None: - count = self._data_info["len"] - else: - raise ValueError("Cannot determine 'count' for generation") + count = self._data_info["len"] has_gen_cond = ("cond" in kwargs) and (kwargs["cond"] is not None) if has_gen_cond and not self.expecting_conditional: raise RuntimeError( - "Got generation conditional, but aggregator wasn't trained conditionally" + "Got inference conditional, but aggregator wasn't trained with a conditional" ) - # Combine constraints from the original schema - gen_constraints = self._orig_schema.as_constraints() + # Combine constraints from training schema with user constraints + gen_constraints = self.training_schema().as_constraints() if constraints is not None: gen_constraints = gen_constraints.extend(constraints) - # aggregator generation on encoded schema - data_syn = self._generate(count, gen_constraints, rules=rules, **kwargs) + # Build a schema from these constraints + syn_schema = Schema.from_constraints(gen_constraints) + + # aggregator call + data_syn = self._generate(count, syn_schema, rules=rules, **kwargs) # decode from the encoded data back to original data_syn = data_syn.decode() - # final constraints check using the *original* schema - final_constraints = self._orig_schema.as_constraints() + # final constraints check + final_constraints = self.schema().as_constraints() if constraints is not None: final_constraints = final_constraints.extend(constraints) @@ -251,28 +228,73 @@ def generate( def _generate( self, count: int, - gen_constraints: Constraints, + syn_schema: Schema, rules: Optional[Dict[str, List[Tuple[str, str, Any]]]] = None, **kwargs: Any, ) -> DataLoader: """ - 1) Possibly remap user rules for special_value -> cat columns - 2) aggregator => produce an *encoded* DataFrame - 3) Force columns' dtypes (encoded schema) + Internal aggregator generation logic: + - Possibly remap rules to reference _cat columns if they specify special values + - Let aggregator do column-by-column generation + - Force the columns' dtypes according to syn_schema """ if not self.model: - raise RuntimeError("No aggregator model found") + raise RuntimeError("Aggregator not found for syn_seq plugin") - # Remap user rules to handle special values in _cat + # Remap user rules to handle special values in _cat columns if rules is not None: - rules = self._remap_special_value_rules(rules, self._enc_schema) + rules = self._remap_special_value_rules(rules, syn_schema) - # aggregator generate + # Generate the data df_syn = self.model.generate_col(count, rules=rules, max_iter_rules=10) - # Ensure correct dtypes (encoded) - df_syn = self._enc_schema.adapt_dtypes(df_syn) + # Ensure correct dtypes + df_syn = syn_schema.adapt_dtypes(df_syn) + + return create_from_info(df_syn, self._data_info) + + def _remap_special_value_rules( + self, + rules_dict: Dict[str, List[Tuple[str, str, Any]]], + syn_schema: Schema + ) -> Dict[str, List[Tuple[str, str, Any]]]: + """ + If user wrote rules referencing special values (like -0.04) on numeric columns, + we switch them to the corresponding _cat column. This is a simple version: + If 'val' is in 'special_value[feat_col]', rename feat_col -> feat_col_cat. + """ + if not rules_dict: + return rules_dict + + special_map = self._training_data_info.get("special_value", {}) + + # build a base->cat map + # e.g. if 'bp' => 'bp_cat' is in your domain + base_to_cat = {} + for col in syn_schema.domain: + if col.endswith("_cat"): + base_col = col[:-4] + base_to_cat[base_col] = col + + new_rules = {} + for target_col, cond_list in rules_dict.items(): + actual_target_col = target_col + # If target_col is known to have special values => rename + if target_col in special_map and target_col in base_to_cat: + actual_target_col = base_to_cat[target_col] + + new_cond_list = [] + for (feat_col, op, val) in cond_list: + new_feat = feat_col + # If feat_col references special values => rename + if feat_col in special_map and feat_col in base_to_cat: + if val in special_map[feat_col]: + new_feat = base_to_cat[feat_col] + new_cond_list.append((new_feat, op, val)) + + new_rules[actual_target_col] = new_cond_list + + return new_rules - # Return as DataLoader with the *encoded* info - return create_from_info(df_syn, self._enc_data_info) -plugin = Syn_SeqPlugin +# Register plugin for the library +plugin = Syn_SeqPlugin \ No newline at end of file diff --git a/tutorials/plugins/generic/plugin_syn_seq.ipynb b/tutorials/plugins/generic/plugin_syn_seq.ipynb index 2827dee0..2d9ab4fa 100644 --- a/tutorials/plugins/generic/plugin_syn_seq.ipynb +++ b/tutorials/plugins/generic/plugin_syn_seq.ipynb @@ -2,20 +2,16 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[KeOps] Warning : \n", - " The default C++ compiler could not be found on your system.\n", - " You need to either define the CXX environment variable or a symlink to the g++ command.\n", - " For example if g++-8 is the command you can do\n", - " import os\n", - " os.environ['CXX'] = 'g++-8'\n", - " \n", + "[KeOps] Warning : omp.h header is not in the path, disabling OpenMP. To fix this, you can set the environment\n", + " variable OMP_PATH to the location of the header before importing keopscore or pykeops,\n", + " e.g. using os.environ: import os; os.environ['OMP_PATH'] = '/path/to/omp/header'\n", "[KeOps] Warning : Cuda libraries were not detected on the system or could not be loaded ; using cpu only mode\n" ] } @@ -37,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -58,17 +54,17 @@ "441 -0.045472 -0.044642 -0.073030 -0.081413 0.083740 0.027809 0.173816 \n", "\n", " s4 s5 s6 target test \n", - "0 -0.002592 0.019907 -0.017646 151.0 8 \n", - "1 -0.039493 -0.068332 -0.092204 75.0 11 \n", - "2 -0.002592 0.002861 -0.025930 141.0 9 \n", - "3 0.034309 0.022688 -0.009362 206.0 22 \n", - "4 -0.002592 -0.031988 -0.046641 135.0 27 \n", + "0 -0.002592 0.019907 -0.017646 151.0 5 \n", + "1 -0.039493 -0.068332 -0.092204 75.0 14 \n", + "2 -0.002592 0.002861 -0.025930 141.0 2 \n", + "3 0.034309 0.022688 -0.009362 206.0 1 \n", + "4 -0.002592 -0.031988 -0.046641 135.0 17 \n", ".. ... ... ... ... ... \n", - "437 -0.002592 0.031193 0.007207 178.0 14 \n", - "438 0.034309 -0.018114 0.044485 104.0 4 \n", - "439 -0.011080 -0.046883 0.015491 132.0 24 \n", - "440 0.026560 0.044529 -0.025930 220.0 21 \n", - "441 -0.039493 -0.004222 0.003064 57.0 0 \n", + "437 -0.002592 0.031193 0.007207 178.0 11 \n", + "438 0.034309 -0.018114 0.044485 104.0 2 \n", + "439 -0.011080 -0.046883 0.015491 132.0 21 \n", + "440 0.026560 0.044529 -0.025930 220.0 20 \n", + "441 -0.039493 -0.004222 0.003064 57.0 2 \n", "\n", "[442 rows x 12 columns]\n" ] @@ -82,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -100,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -152,7 +148,7 @@ " 0.019907\n", " -0.017646\n", " 151.0\n", - " 8\n", + " 5\n", " \n", " \n", " 1\n", @@ -167,7 +163,7 @@ " -0.068332\n", " -0.092204\n", " 75.0\n", - " 11\n", + " 14\n", " \n", " \n", " 2\n", @@ -182,7 +178,7 @@ " 0.002861\n", " -0.025930\n", " 141.0\n", - " 9\n", + " 2\n", " \n", " \n", " 3\n", @@ -197,7 +193,7 @@ " 0.022688\n", " -0.009362\n", " 206.0\n", - " 22\n", + " 1\n", " \n", " \n", " 4\n", @@ -212,7 +208,7 @@ " -0.031988\n", " -0.046641\n", " 135.0\n", - " 27\n", + " 17\n", " \n", " \n", " ...\n", @@ -242,7 +238,7 @@ " 0.031193\n", " 0.007207\n", " 178.0\n", - " 14\n", + " 11\n", " \n", " \n", " 438\n", @@ -257,7 +253,7 @@ " -0.018114\n", " 0.044485\n", " 104.0\n", - " 4\n", + " 2\n", " \n", " \n", " 439\n", @@ -272,7 +268,7 @@ " -0.046883\n", " 0.015491\n", " 132.0\n", - " 24\n", + " 21\n", " \n", " \n", " 440\n", @@ -287,7 +283,7 @@ " 0.044529\n", " -0.025930\n", " 220.0\n", - " 21\n", + " 20\n", " \n", " \n", " 441\n", @@ -302,7 +298,7 @@ " -0.004222\n", " 0.003064\n", " 57.0\n", - " 0\n", + " 2\n", " \n", " \n", "\n", @@ -324,22 +320,22 @@ "441 -0.045472 -0.044642 -0.073030 -0.081413 0.083740 0.027809 0.173816 \n", "\n", " s4 s5 s6 target test \n", - "0 -0.002592 0.019907 -0.017646 151.0 8 \n", - "1 -0.039493 -0.068332 -0.092204 75.0 11 \n", - "2 -0.002592 0.002861 -0.025930 141.0 9 \n", - "3 0.034309 0.022688 -0.009362 206.0 22 \n", - "4 -0.002592 -0.031988 -0.046641 135.0 27 \n", + "0 -0.002592 0.019907 -0.017646 151.0 5 \n", + "1 -0.039493 -0.068332 -0.092204 75.0 14 \n", + "2 -0.002592 0.002861 -0.025930 141.0 2 \n", + "3 0.034309 0.022688 -0.009362 206.0 1 \n", + "4 -0.002592 -0.031988 -0.046641 135.0 17 \n", ".. ... ... ... ... ... \n", - "437 -0.002592 0.031193 0.007207 178.0 14 \n", - "438 0.034309 -0.018114 0.044485 104.0 4 \n", - "439 -0.011080 -0.046883 0.015491 132.0 24 \n", - "440 0.026560 0.044529 -0.025930 220.0 21 \n", - "441 -0.039493 -0.004222 0.003064 57.0 0 \n", + "437 -0.002592 0.031193 0.007207 178.0 11 \n", + "438 0.034309 -0.018114 0.044485 104.0 2 \n", + "439 -0.011080 -0.046883 0.015491 132.0 21 \n", + "440 0.026560 0.044529 -0.025930 220.0 20 \n", + "441 -0.039493 -0.004222 0.003064 57.0 2 \n", "\n", "[442 rows x 12 columns]" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -350,7 +346,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -411,50 +407,50 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[2025-01-21T16:00:03.626862+0900][17088][CRITICAL] module disabled: C:\\Users\\hsrhe\\Desktop\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n" + "[2025-01-22T13:01:45.197036+0900][2347][CRITICAL] module disabled: /Users/minkeychang/synthcity/src/synthcity/plugins/generic/plugin_goggle.py\n" ] }, { "data": { "text/plain": [ - "['arf',\n", - " 'privbayes',\n", - " 'survival_gan',\n", - " 'timevae',\n", - " 'image_cgan',\n", - " 'bayesian_network',\n", - " 'ddpm',\n", - " 'dpgan',\n", - " 'aim',\n", + "['ddpm',\n", " 'image_adsgan',\n", - " 'rtvae',\n", " 'timegan',\n", - " 'tvae',\n", - " 'survival_nflow',\n", - " 'survival_ctgan',\n", - " 'syn_seq',\n", - " 'nflow',\n", + " 'image_cgan',\n", " 'decaf',\n", - " 'great',\n", - " 'dummy_sampler',\n", - " 'marginal_distributions',\n", - " 'survae',\n", - " 'uniform_sampler',\n", + " 'survival_gan',\n", + " 'nflow',\n", " 'pategan',\n", - " 'adsgan',\n", + " 'survival_ctgan',\n", + " 'survival_nflow',\n", + " 'bayesian_network',\n", + " 'aim',\n", + " 'dpgan',\n", " 'radialgan',\n", + " 'marginal_distributions',\n", + " 'ctgan',\n", + " 'arf',\n", + " 'uniform_sampler',\n", + " 'tvae',\n", + " 'privbayes',\n", " 'fflows',\n", - " 'ctgan']" + " 'dummy_sampler',\n", + " 'great',\n", + " 'timevae',\n", + " 'syn_seq',\n", + " 'rtvae',\n", + " 'survae',\n", + " 'adsgan']" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -465,14 +461,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[2025-01-21T16:00:07.055404+0900][17088][CRITICAL] module disabled: C:\\Users\\hsrhe\\Desktop\\synthcity\\src\\synthcity\\plugins\\generic\\plugin_goggle.py\n" + "[2025-01-22T13:01:46.248079+0900][2347][CRITICAL] module disabled: /Users/minkeychang/synthcity/src/synthcity/plugins/generic/plugin_goggle.py\n" ] } ], @@ -482,16 +478,34 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 'bmi' with 'cart' ... Done!\n", + "Fitting 'age' with 'cart' ... Done!\n", + "Fitting 'test_cat' with 'cart' ... Done!\n", + "Fitting 'test' with 'cart' ... Done!\n", + "Fitting 'bp' with 'norm' ... Done!\n", + "Fitting 's1' with 'cart' ... Done!\n", + "Fitting 's2' with 'cart' ... Done!\n", + "Fitting 's3' with 'cart' ... Done!\n", + "Fitting 's4' with 'cart' ... Done!\n", + "Fitting 's5' with 'cart' ... Done!\n", + "Fitting 's6' with 'cart' ... Done!\n", + "Fitting 'target' with 'cart' ... Done!\n" + ] + }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -502,9 +516,28 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating 'bmi' => done.\n", + "Generating 'age' => done.\n", + "Generating 'test_cat' => done.\n", + "Generating 'test' => done.\n", + "Generating 'bp' => done.\n", + "Generating 's1' => done.\n", + "Generating 's2' => done.\n", + "Generating 's3' => done.\n", + "Generating 's4' => done.\n", + "Generating 's5' => done.\n", + "Generating 's6' => done.\n", + "Generating 'target' => done.\n" + ] + } + ], "source": [ "syn_df = syn_model.generate(nrows = len(X)).dataframe()" ] @@ -535,10 +568,9 @@ " \n", " \n", " \n", + " age\n", " sex\n", " bmi\n", - " age\n", - " test\n", " bp\n", " s1\n", " s2\n", @@ -547,16 +579,15 @@ " s5\n", " s6\n", " target\n", - " test_cat\n", + " test\n", " \n", " \n", " \n", " \n", " 0\n", + " 0.038076\n", " 0.050680\n", " 0.061696\n", - " 0.038076\n", - " 25.0\n", " 0.021872\n", " -0.044223\n", " -0.034821\n", @@ -565,14 +596,13 @@ " 0.019907\n", " -0.017646\n", " 151.0\n", - " -777\n", + " 5\n", " \n", " \n", " 1\n", + " -0.001882\n", " -0.044642\n", " -0.051474\n", - " -0.001882\n", - " 23.0\n", " -0.026328\n", " -0.008449\n", " -0.019163\n", @@ -581,14 +611,13 @@ " -0.068332\n", " -0.092204\n", " 75.0\n", - " -777\n", + " 14\n", " \n", " \n", " 2\n", + " 0.085299\n", " 0.050680\n", " 0.044451\n", - " 0.085299\n", - " 15.0\n", " -0.005670\n", " -0.045599\n", " -0.034194\n", @@ -597,14 +626,13 @@ " 0.002861\n", " -0.025930\n", " 141.0\n", - " -777\n", + " 2\n", " \n", " \n", " 3\n", + " -0.089063\n", " -0.044642\n", " -0.011595\n", - " -0.089063\n", - " 3.0\n", " -0.036656\n", " 0.012191\n", " 0.024991\n", @@ -613,14 +641,13 @@ " 0.022688\n", " -0.009362\n", " 206.0\n", - " -777\n", + " 1\n", " \n", " \n", " 4\n", + " 0.005383\n", " -0.044642\n", " -0.036385\n", - " 0.005383\n", - " 20.0\n", " 0.021872\n", " 0.003935\n", " 0.015596\n", @@ -629,7 +656,7 @@ " -0.031988\n", " -0.046641\n", " 135.0\n", - " -777\n", + " 17\n", " \n", " \n", " ...\n", @@ -645,14 +672,12 @@ " ...\n", " ...\n", " ...\n", - " ...\n", " \n", " \n", " 437\n", + " 0.041708\n", " 0.050680\n", " 0.019662\n", - " 0.041708\n", - " 27.0\n", " 0.059744\n", " -0.005697\n", " -0.002566\n", @@ -661,14 +686,13 @@ " 0.031193\n", " 0.007207\n", " 178.0\n", - " -777\n", + " 11\n", " \n", " \n", " 438\n", + " -0.005515\n", " 0.050680\n", " -0.015906\n", - " -0.005515\n", - " 25.0\n", " -0.067642\n", " 0.049341\n", " 0.079165\n", @@ -677,14 +701,13 @@ " -0.018114\n", " 0.044485\n", " 104.0\n", - " -777\n", + " 2\n", " \n", " \n", " 439\n", + " 0.041708\n", " 0.050680\n", " -0.015906\n", - " 0.041708\n", - " 27.0\n", " 0.017293\n", " -0.037344\n", " -0.013840\n", @@ -693,14 +716,13 @@ " -0.046883\n", " 0.015491\n", " 132.0\n", - " -777\n", + " 21\n", " \n", " \n", " 440\n", + " -0.045472\n", " -0.044642\n", " 0.039062\n", - " -0.045472\n", - " 17.0\n", " 0.001215\n", " 0.016318\n", " 0.015283\n", @@ -709,14 +731,13 @@ " 0.044529\n", " -0.025930\n", " 220.0\n", - " -777\n", + " 20\n", " \n", " \n", " 441\n", + " -0.045472\n", " -0.044642\n", " -0.073030\n", - " -0.045472\n", - " 26.0\n", " -0.081413\n", " 0.083740\n", " 0.027809\n", @@ -725,41 +746,41 @@ " -0.004222\n", " 0.003064\n", " 57.0\n", - " -777\n", + " 2\n", " \n", " \n", "\n", - "

442 rows × 13 columns

\n", + "

442 rows × 12 columns

\n", "" ], "text/plain": [ - " sex bmi age test bp s1 s2 \\\n", - "0 0.050680 0.061696 0.038076 25.0 0.021872 -0.044223 -0.034821 \n", - "1 -0.044642 -0.051474 -0.001882 23.0 -0.026328 -0.008449 -0.019163 \n", - "2 0.050680 0.044451 0.085299 15.0 -0.005670 -0.045599 -0.034194 \n", - "3 -0.044642 -0.011595 -0.089063 3.0 -0.036656 0.012191 0.024991 \n", - "4 -0.044642 -0.036385 0.005383 20.0 0.021872 0.003935 0.015596 \n", - ".. ... ... ... ... ... ... ... \n", - "437 0.050680 0.019662 0.041708 27.0 0.059744 -0.005697 -0.002566 \n", - "438 0.050680 -0.015906 -0.005515 25.0 -0.067642 0.049341 0.079165 \n", - "439 0.050680 -0.015906 0.041708 27.0 0.017293 -0.037344 -0.013840 \n", - "440 -0.044642 0.039062 -0.045472 17.0 0.001215 0.016318 0.015283 \n", - "441 -0.044642 -0.073030 -0.045472 26.0 -0.081413 0.083740 0.027809 \n", + " age sex bmi bp s1 s2 s3 \\\n", + "0 0.038076 0.050680 0.061696 0.021872 -0.044223 -0.034821 -0.043401 \n", + "1 -0.001882 -0.044642 -0.051474 -0.026328 -0.008449 -0.019163 0.074412 \n", + "2 0.085299 0.050680 0.044451 -0.005670 -0.045599 -0.034194 -0.032356 \n", + "3 -0.089063 -0.044642 -0.011595 -0.036656 0.012191 0.024991 -0.036038 \n", + "4 0.005383 -0.044642 -0.036385 0.021872 0.003935 0.015596 0.008142 \n", + ".. ... ... ... ... ... ... ... \n", + "437 0.041708 0.050680 0.019662 0.059744 -0.005697 -0.002566 -0.028674 \n", + "438 -0.005515 0.050680 -0.015906 -0.067642 0.049341 0.079165 -0.028674 \n", + "439 0.041708 0.050680 -0.015906 0.017293 -0.037344 -0.013840 -0.024993 \n", + "440 -0.045472 -0.044642 0.039062 0.001215 0.016318 0.015283 -0.028674 \n", + "441 -0.045472 -0.044642 -0.073030 -0.081413 0.083740 0.027809 0.173816 \n", "\n", - " s3 s4 s5 s6 target test_cat \n", - "0 -0.043401 -0.002592 0.019907 -0.017646 151.0 -777 \n", - "1 0.074412 -0.039493 -0.068332 -0.092204 75.0 -777 \n", - "2 -0.032356 -0.002592 0.002861 -0.025930 141.0 -777 \n", - "3 -0.036038 0.034309 0.022688 -0.009362 206.0 -777 \n", - "4 0.008142 -0.002592 -0.031988 -0.046641 135.0 -777 \n", - ".. ... ... ... ... ... ... \n", - "437 -0.028674 -0.002592 0.031193 0.007207 178.0 -777 \n", - "438 -0.028674 0.034309 -0.018114 0.044485 104.0 -777 \n", - "439 -0.024993 -0.011080 -0.046883 0.015491 132.0 -777 \n", - "440 -0.028674 0.026560 0.044529 -0.025930 220.0 -777 \n", - "441 0.173816 -0.039493 -0.004222 0.003064 57.0 -777 \n", + " s4 s5 s6 target test \n", + "0 -0.002592 0.019907 -0.017646 151.0 5 \n", + "1 -0.039493 -0.068332 -0.092204 75.0 14 \n", + "2 -0.002592 0.002861 -0.025930 141.0 2 \n", + "3 0.034309 0.022688 -0.009362 206.0 1 \n", + "4 -0.002592 -0.031988 -0.046641 135.0 17 \n", + ".. ... ... ... ... ... \n", + "437 -0.002592 0.031193 0.007207 178.0 11 \n", + "438 0.034309 -0.018114 0.044485 104.0 2 \n", + "439 -0.011080 -0.046883 0.015491 132.0 21 \n", + "440 0.026560 0.044529 -0.025930 220.0 20 \n", + "441 -0.039493 -0.004222 0.003064 57.0 2 \n", "\n", - "[442 rows x 13 columns]" + "[442 rows x 12 columns]" ] }, "execution_count": 10,