diff --git a/copulas/sdmetrics.py b/copulas/sdmetrics.py new file mode 100644 index 00000000..d9090f71 --- /dev/null +++ b/copulas/sdmetrics.py @@ -0,0 +1,479 @@ +"""Visualization methods for SDMetrics.""" + +import pandas as pd +import plotly.express as px +import plotly.figure_factory as ff +from pandas.api.types import is_datetime64_dtype + +from copulas.utils import get_missing_percentage, is_datetime +from copulas.utils2 import PlotConfig + + +def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): + """Generate a bar plot of the real and synthetic data. + + Args: + real_column (pandas.Series): + The real data for the desired column. + synthetic_column (pandas.Series): + The synthetic data for the desired column. + plot_kwargs (dict, optional): + Dictionary of keyword arguments to pass to px.histogram. Keyword arguments + provided this way will overwrite defaults. + + Returns: + plotly.graph_objects._figure.Figure + """ + all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) + histogram_kwargs = { + 'x': 'values', + 'barmode': 'group', + 'color_discrete_sequence': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN], + 'pattern_shape': 'Data', + 'pattern_shape_sequence': ['', '/'], + 'histnorm': 'probability density', + } + histogram_kwargs.update(plot_kwargs) + fig = px.histogram( + all_data, + **histogram_kwargs + ) + + return fig + + +def _generate_heatmap_plot(all_data, columns): + """Generate heatmap plot for discrete data. + + Args: + all_data (pandas.DataFrame): + The real and synthetic data for the desired column pair containing a + ``Data`` column that indicates whether is real or synthetic. + columns (list): + A list of the columns being plotted. + + Returns: + plotly.graph_objects._figure.Figure + """ + fig = px.density_heatmap( + all_data, + x=columns[0], + y=columns[1], + facet_col='Data', + histnorm='probability' + ) + + fig.update_layout( + title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]}, + font={'size': PlotConfig.FONT_SIZE}, + ) + + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1] + ' Data')) + + return fig + + +def _generate_box_plot(all_data, columns): + """Generate a box plot for mixed discrete and continuous column data. + + Args: + all_data (pandas.DataFrame): + The real and synthetic data for the desired column pair containing a + ``Data`` column that indicates whether is real or synthetic. + columns (list): + A list of the columns being plotted. + + Returns: + plotly.graph_objects._figure.Figure + """ + fig = px.box( + all_data, + x=columns[0], + y=columns[1], + color='Data', + color_discrete_map={ + 'Real': PlotConfig.DATACEBO_DARK, + 'Synthetic': PlotConfig.DATACEBO_GREEN + }, + ) + + fig.update_layout( + title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + font={'size': PlotConfig.FONT_SIZE}, + ) + + return fig + + +def _generate_scatter_plot(all_data, columns): + """Generate a scatter plot for column pair plot. + + Args: + all_data (pandas.DataFrame): + The real and synthetic data for the desired column pair containing a + ``Data`` column that indicates whether is real or synthetic. + columns (list): + A list of the columns being plotted. + + Returns: + plotly.graph_objects._figure.Figure + """ + fig = px.scatter( + all_data, + x=columns[0], + y=columns[1], + color='Data', + color_discrete_map={ + 'Real': PlotConfig.DATACEBO_DARK, + 'Synthetic': PlotConfig.DATACEBO_GREEN + }, + symbol='Data' + ) + + fig.update_layout( + title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + font={'size': PlotConfig.FONT_SIZE}, + ) + + return fig + + +def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}): + """Plot the real and synthetic data as a distplot. + + Args: + real_data (pandas.DataFrame): + The real data for the desired column. + synthetic_data (pandas.DataFrame): + The synthetic data for the desired column. + plot_kwargs (dict, optional): + Dictionary of keyword arguments to pass to px.histogram. Keyword arguments + provided this way will overwrite defaults. + + Returns: + plotly.graph_objects._figure.Figure + """ + default_distplot_kwargs = { + 'show_hist': False, + 'show_rug': False, + 'colors': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN] + } + + fig = ff.create_distplot( + [real_data['values'], synthetic_data['values']], + ['Real', 'Synthetic'], + **{**default_distplot_kwargs, **plot_kwargs} + ) + + return fig + + +def _generate_column_plot(real_column, + synthetic_column, + plot_type, + plot_kwargs={}, + plot_title=None, + x_label=None): + """Generate a plot of the real and synthetic data. + + Args: + real_column (pandas.Series): + The real data for the desired column. + synthetic_column (pandas.Series): + The synthetic data for the desired column. + plot_type (str): + The type of plot to use. Must be one of 'bar' or 'distplot'. + hist_kwargs (dict, optional): + Dictionary of keyword arguments to pass to px.histogram. Keyword arguments + provided this way will overwrite defaults. + plot_title (str, optional): + Title to use for the plot. Defaults to 'Real vs. Synthetic Data for column {column}' + x_label (str, optional): + Label to use for x-axis. Defaults to 'Category'. + + Returns: + plotly.graph_objects._figure.Figure + """ + if plot_type not in ['bar', 'distplot']: + raise ValueError( + "Unrecognized plot_type '{plot_type}'. Pleas use one of 'bar' or 'distplot'" + ) + + column_name = real_column.name if hasattr(real_column, 'name') else '' + + missing_data_real = get_missing_percentage(real_column) + missing_data_synthetic = get_missing_percentage(synthetic_column) + + real_data = pd.DataFrame({'values': real_column.copy().dropna()}) + real_data['Data'] = 'Real' + synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()}) + synthetic_data['Data'] = 'Synthetic' + + is_datetime_sdtype = False + if is_datetime64_dtype(real_column.dtype): + is_datetime_sdtype = True + real_data['values'] = real_data['values'].astype('int64') + synthetic_data['values'] = synthetic_data['values'].astype('int64') + + trace_args = {} + + if plot_type == 'bar': + fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs) + elif plot_type == 'distplot': + x_label = x_label or 'Value' + fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs) + trace_args = {'fill': 'tozeroy'} + + for i, name in enumerate(['Real', 'Synthetic']): + fig.update_traces( + x=pd.to_datetime(fig.data[i].x) if is_datetime_sdtype else fig.data[i].x, + hovertemplate=f'{name}
Frequency: %{{y}}', + selector={'name': name}, + **trace_args + ) + + show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0 + annotations = [] if not show_missing_values else [ + { + 'xref': 'paper', + 'yref': 'paper', + 'x': 1.0, + 'y': 1.05, + 'showarrow': False, + 'text': ( + f'*Missing Values: Real Data ({missing_data_real}%), ' + f'Synthetic Data ({missing_data_synthetic}%)' + ), + }, + ] + + if not plot_title: + plot_title = f"Real vs. Synthetic Data for column '{column_name}'" + + if not x_label: + x_label = 'Category' + + fig.update_layout( + title=plot_title, + xaxis_title=x_label, + yaxis_title='Frequency', + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + annotations=annotations, + font={'size': PlotConfig.FONT_SIZE}, + ) + return fig + + +def _generate_cardinality_plot(real_data, + synthetic_data, + parent_primary_key, + child_foreign_key, + plot_type='bar'): + plot_title = ( + f"Relationship (child foreign key='{child_foreign_key}' and parent " + f"primary key='{parent_primary_key}')" + ) + x_label = '# of Children (per Parent)' + + plot_kwargs = {} + if plot_type == 'bar': + max_cardinality = max(max(real_data), max(synthetic_data)) + min_cardinality = min(min(real_data), min(synthetic_data)) + plot_kwargs = { + 'nbins': max_cardinality - min_cardinality + 1 + } + + return _generate_column_plot(real_data, synthetic_data, plot_type, + plot_kwargs, plot_title, x_label) + + +def _get_cardinality(parent_table, child_table, parent_primary_key, child_foreign_key): + """Return the cardinality of the parent-child relationship. + + Args: + parent_table (pandas.DataFrame): + The parent table. + child_table (pandas.DataFrame): + The child table. + parent_primary_key (string): + The name of the primary key column in the parent table. + child_foreign_key (string): + The name of the foreign key column in the child table. + + Returns: + pandas.DataFrame + """ + child_counts = child_table[child_foreign_key].value_counts().rename('# children') + cardinalities = child_counts.reindex(parent_table[parent_primary_key], fill_value=0).to_frame() + + return cardinalities.sort_values('# children')['# children'] + + +def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_table_name, + child_foreign_key, parent_primary_key, plot_type='bar'): + """Return a plot of the cardinality of the parent-child relationship. + + Args: + real_data (dict): + The real data. + synthetic_data (dict): + The synthetic data. + child_table_name (string): + The name of the child table. + parent_table_name (string): + The name of the parent table. + child_foreign_key (string): + The name of the foreign key column in the child table. + parent_primary_key (string): + The name of the primary key column in the parent table. + plot_type (string, optional): + The plot type to use to plot the cardinality. Must be either 'bar' or 'distplot'. + Defaults to 'bar'. + + Returns: + plotly.graph_objects._figure.Figure + """ + if plot_type not in ['bar', 'distplot']: + raise ValueError( + f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].") + + real_cardinality = _get_cardinality( + real_data[parent_table_name], real_data[child_table_name], + parent_primary_key, child_foreign_key + ) + synth_cardinality = _get_cardinality( + synthetic_data[parent_table_name], + synthetic_data[child_table_name], + parent_primary_key, child_foreign_key + ) + + fig = _generate_cardinality_plot( + real_cardinality, + synth_cardinality, + parent_primary_key, + child_foreign_key, + plot_type=plot_type + ) + + return fig + + +def get_column_plot(real_data, synthetic_data, column_name, plot_type=None): + """Return a plot of the real and synthetic data for a given column. + + Args: + real_data (pandas.DataFrame): + The real table data. + synthetic_data (pandas.DataFrame): + The synthetic table data. + column_name (str): + The name of the column. + plot_type (str or None): + The plot to be used. Can choose between ``distplot``, ``bar`` or ``None``. If ``None` + select between ``distplot`` or ``bar`` depending on the data that the column contains, + ``distplot`` for datetime and numerical values and ``bar`` for categorical. + Defaults to ``None``. + + Returns: + plotly.graph_objects._figure.Figure + """ + if plot_type not in ['bar', 'distplot', None]: + raise ValueError( + f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]." + ) + + if column_name not in real_data.columns: + raise ValueError(f"Column '{column_name}' not found in real table data.") + if column_name not in synthetic_data.columns: + raise ValueError(f"Column '{column_name}' not found in synthetic table data.") + + real_column = real_data[column_name] + if plot_type is None: + column_is_datetime = is_datetime(real_data[column_name]) + dtype = real_column.dropna().infer_objects().dtype.kind + if column_is_datetime or dtype in ('i', 'f'): + plot_type = 'distplot' + else: + plot_type = 'bar' + + real_column = real_data[column_name] + synthetic_column = synthetic_data[column_name] + + fig = _generate_column_plot(real_column, synthetic_column, plot_type) + + return fig + + +def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None): + """Return a plot of the real and synthetic data for a given column pair. + + Args: + real_data (pandas.DataFrame): + The real table data. + synthetic_column (pandas.Dataframe): + The synthetic table data. + column_names (list[string]): + The names of the two columns to plot. + plot_type (str or None): + The plot to be used. Can choose between ``box``, ``heatmap``, ``scatter`` or ``None``. + If ``None` select between ``box``, ``heatmap`` or ``scatter`` depending on the data + that the column contains, ``scatter`` used for datetime and numerical values, + ``heatmap`` for categorical and ``box`` for a mix of both. Defaults to ``None``. + + Returns: + plotly.graph_objects._figure.Figure + """ + if len(column_names) != 2: + raise ValueError('Must provide exactly two column names.') + + if not set(column_names).issubset(real_data.columns): + raise ValueError( + f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.' + ) + + if not set(column_names).issubset(synthetic_data.columns): + raise ValueError( + f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} ' + 'in synthetic data.' + ) + + if plot_type not in ['box', 'heatmap', 'scatter', None]: + raise ValueError( + f"Invalid plot_type '{plot_type}'. Please use one of " + "['box', 'heatmap', 'scatter', None]." + ) + + real_data = real_data[column_names] + synthetic_data = synthetic_data[column_names] + if plot_type is None: + plot_type = [] + for column_name in column_names: + column = real_data[column_name] + dtype = column.dropna().infer_objects().dtype.kind + if dtype in ('i', 'f') or is_datetime(column): + plot_type.append('scatter') + else: + plot_type.append('heatmap') + + if len(set(plot_type)) > 1: + plot_type = 'box' + else: + plot_type = plot_type.pop() + + # Merge the real and synthetic data and add a flag ``Data`` to indicate each one. + columns = list(real_data.columns) + real_data = real_data.copy() + real_data['Data'] = 'Real' + synthetic_data = synthetic_data.copy() + synthetic_data['Data'] = 'Synthetic' + all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) + + if plot_type == 'scatter': + return _generate_scatter_plot(all_data, columns) + elif plot_type == 'heatmap': + return _generate_heatmap_plot(all_data, columns) + + return _generate_box_plot(all_data, columns) diff --git a/copulas/utils.py b/copulas/utils.py deleted file mode 100644 index 30e56ebe..00000000 --- a/copulas/utils.py +++ /dev/null @@ -1,298 +0,0 @@ -"""SDMetrics utils to be used across all the project.""" - -from collections import Counter -from datetime import datetime - -import numpy as np -import pandas as pd -from sklearn.preprocessing import OneHotEncoder - - -def nested_attrs_meta(nested): - """Metaclass factory that defines a Metaclass with a dynamic attribute name.""" - - class Metaclass(type): - """Metaclass which pulls the attributes from a nested object using properties.""" - - def __getattr__(cls, attr): - """If cls does not have the attribute, try to get it from the nested object.""" - nested_obj = getattr(cls, nested) - if hasattr(nested_obj, attr): - return getattr(nested_obj, attr) - - raise AttributeError(f"type object '{cls.__name__}' has no attribute '{attr}'") - - @property - def name(cls): - return getattr(cls, nested).name - - @property - def goal(cls): - return getattr(cls, nested).goal - - @property - def max_value(cls): - return getattr(cls, nested).max_value - - @property - def min_value(cls): - return getattr(cls, nested).min_value - - return Metaclass - - -def get_frequencies(real, synthetic): - """Get percentual frequencies for each possible real categorical value. - - Given two iterators containing categorical data, this transforms it into - observed/expected frequencies which can be used for statistical tests. It - adds a regularization term to handle cases where the synthetic data contains - values that don't exist in the real data. - - Args: - real (list): - A list of hashable objects. - synthetic (list): - A list of hashable objects. - - Yields: - tuble[list, list]: - The observed and expected frequencies (as a percent). - """ - f_obs, f_exp = [], [] - real, synthetic = Counter(real), Counter(synthetic) - for value in synthetic: - if value not in real: - real[value] += 1e-6 # Regularization to prevent NaN. - - for value in real: - f_obs.append(synthetic[value] / sum(synthetic.values())) # noqa: PD011 - f_exp.append(real[value] / sum(real.values())) # noqa: PD011 - - return f_obs, f_exp - - -def get_missing_percentage(data_column): - """Compute the missing value percentage of a column. - - Args: - data_column (pandas.Series): - The data of the desired column. - - Returns: - pandas.Series: - Percentage of missing values inside the column. - """ - return round((data_column.isna().sum() / len(data_column)) * 100, 2) - - -def get_cardinality_distribution(parent_column, child_column): - """Compute the cardinality distribution of the (parent, child) pairing. - - Args: - parent_column (pandas.Series): - The parent column. - child_column (pandas.Series): - The child column. - - Returns: - pandas.Series: - The cardinality distribution. - """ - child_df = pd.DataFrame({'child_counts': child_column.value_counts()}) - cardinality_df = pd.DataFrame({'parent': parent_column}).join( - child_df, on='parent').fillna(0) - - return cardinality_df['child_counts'] - - -def is_datetime(data): - """Determine if the input is a datetime type or not. - - Args: - data (pandas.DataFrame, int or datetime): - Input to evaluate. - - Returns: - bool: - True if the input is a datetime type, False if not. - """ - return ( - pd.api.types.is_datetime64_any_dtype(data) - or isinstance(data, pd.Timestamp) - or isinstance(data, datetime) - ) - - -class HyperTransformer(): - """HyperTransformer class. - - The ``HyperTransformer`` class contains a set of transforms to transform one or - more columns based on each column's data type. - """ - - column_transforms = {} - column_kind = {} - - def fit(self, data): - """Fit the HyperTransformer to the given data. - - Args: - data (pandas.DataFrame): - The data to transform. - """ - if not isinstance(data, pd.DataFrame): - data = pd.DataFrame(data) - - for field in data: - kind = data[field].dropna().infer_objects().dtype.kind - self.column_kind[field] = kind - - if kind == 'i' or kind == 'f': - # Numerical column. - self.column_transforms[field] = {'mean': data[field].mean()} - elif kind == 'b': - # Boolean column. - numeric = pd.to_numeric(data[field], errors='coerce').astype(float) - self.column_transforms[field] = {'mode': numeric.mode().iloc[0]} - elif kind == 'O': - # Categorical column. - col_data = pd.DataFrame({'field': data[field]}) - enc = OneHotEncoder() - enc.fit(col_data) - self.column_transforms[field] = {'one_hot_encoder': enc} - elif kind == 'M': - # Datetime column. - nulls = data[field].isna() - integers = pd.to_numeric( - data[field], errors='coerce').to_numpy().astype(np.float64) - integers[nulls] = np.nan - self.column_transforms[field] = {'mean': pd.Series(integers).mean()} - - def transform(self, data): - """Transform the given data based on the data type of each column. - - Args: - data (pandas.DataFrame): - The data to transform. - - Returns: - pandas.DataFrame: - The transformed data. - """ - if not isinstance(data, pd.DataFrame): - data = pd.DataFrame(data) - - for field in data: - transform_info = self.column_transforms[field] - - kind = self.column_kind[field] - if kind == 'i' or kind == 'f': - # Numerical column. - data[field] = data[field].fillna(transform_info['mean']) - elif kind == 'b': - # Boolean column. - data[field] = pd.to_numeric(data[field], errors='coerce').astype(float) - data[field] = data[field].fillna(transform_info['mode']) - elif kind == 'O': - # Categorical column. - col_data = pd.DataFrame({'field': data[field]}) - out = transform_info['one_hot_encoder'].transform(col_data).toarray() - transformed = pd.DataFrame( - out, columns=[f'value{i}' for i in range(np.shape(out)[1])]) - data = data.drop(columns=[field]) - data = pd.concat([data, transformed.set_index(data.index)], axis=1) - elif kind == 'M': - # Datetime column. - nulls = data[field].isna() - integers = pd.to_numeric( - data[field], errors='coerce').to_numpy().astype(np.float64) - integers[nulls] = np.nan - data[field] = pd.Series(integers) - data[field] = data[field].fillna(transform_info['mean']) - - return data - - def fit_transform(self, data): - """Fit and transform the given data based on the data type of each column. - - Args: - data (pandas.DataFrame): - The data to transform. - - Returns: - pandas.DataFrame: - The transformed data. - """ - self.fit(data) - return self.transform(data) - - -def get_columns_from_metadata(metadata): - """Get the column info from a metadata dict. - - Args: - metadata (dict): - The metadata dict. - - Returns: - dict: - The columns metadata. - """ - return metadata.get('columns', {}) - - -def get_type_from_column_meta(column_metadata): - """Get the type of a given column from the column metadata. - - Args: - column_metadata (dict): - The column metadata. - - Returns: - string: - The column type. - """ - return column_metadata.get('sdtype', '') - - -def get_alternate_keys(metadata): - """Get the alternate keys from a metadata dict. - - Args: - metadata (dict): - The metadata dict. - - Returns: - list: - The list of alternate keys. - """ - alternate_keys = [] - for alternate_key in metadata.get('alternate_keys', []): - if isinstance(alternate_key, list): - alternate_keys.extend(alternate_key) - else: - alternate_keys.append(alternate_key) - - return alternate_keys - - -def strip_characters(list_character, a_string): - """Strip characters from a column name. - - Args: - list_character (list): - The list of characters to strip. - a_string (string): - The string to be stripped. - - Returns: - string: - The string with the characters stripped. - """ - result = a_string - for character in list_character: - if character in result: - result = result.replace(character, '') - - return result diff --git a/copulas/utils2.py b/copulas/utils2.py deleted file mode 100644 index 560afcf6..00000000 --- a/copulas/utils2.py +++ /dev/null @@ -1,233 +0,0 @@ -"""Report utility methods.""" - -import copy -import itertools -import warnings - -import numpy as np -import pandas as pd -from pandas.core.tools.datetimes import _guess_datetime_format_for_array - -from copulas.utils import ( - get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta, is_datetime) - -CONTINUOUS_SDTYPES = ['numerical', 'datetime'] -DISCRETE_SDTYPES = ['categorical', 'boolean'] - - -class PlotConfig: - """Custom plot settings for visualizations.""" - - GREEN = '#36B37E' - RED = '#FF0000' - ORANGE = '#F16141' - DATACEBO_DARK = '#000036' - DATACEBO_GREEN = '#01E0C9' - DATACEBO_BLUE = '#03AFF1' - BACKGROUND_COLOR = '#F5F5F8' - FONT_SIZE = 18 - - -def convert_to_datetime(column_data, datetime_format=None): - """Convert a column data to pandas datetime. - - Args: - column_data (pandas.Series): - The column data - format (str): - Optional string format of datetime. If ``None``, will attempt to infer the datetime - format from the column data. Defaults to ``None``. - - Returns: - pandas.Series: - The converted column data. - """ - if is_datetime(column_data): - return column_data - - if datetime_format is None: - datetime_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy()) - - return pd.to_datetime(column_data, format=datetime_format) - - -def convert_datetime_columns(real_column, synthetic_column, col_metadata): - """Convert a real and a synthetic column to pandas datetime. - - Args: - real_data (pandas.Series): - The real column data - synthetic_column (pandas.Series): - The synthetic column data - col_metadata: - The metadata associated with the column - - Returns: - (pandas.Series, pandas.Series): - The converted real and synthetic column data. - """ - datetime_format = col_metadata.get('format') or col_metadata.get('datetime_format') - return (convert_to_datetime(real_column, datetime_format), - convert_to_datetime(synthetic_column, datetime_format)) - - -def discretize_table_data(real_data, synthetic_data, metadata): - """Create a copy of the real and synthetic data with discretized data. - - Convert numerical and datetime columns to discrete values, and label them - as categorical. - - Args: - real_data (pandas.DataFrame): - The real data. - synthetic_data (pandas.DataFrame): - The synthetic data. - metadata (dict) - The metadata. - - Returns: - (pandas.DataFrame, pandas.DataFrame, dict): - The binned real and synthetic data, and the updated metadata. - """ - binned_real = real_data.copy() - binned_synthetic = synthetic_data.copy() - binned_metadata = copy.deepcopy(metadata) - - for column_name, column_meta in get_columns_from_metadata(metadata).items(): - sdtype = get_type_from_column_meta(column_meta) - - if sdtype in ('numerical', 'datetime'): - real_col = real_data[column_name] - synthetic_col = synthetic_data[column_name] - if sdtype == 'datetime': - datetime_format = column_meta.get('format') or column_meta.get('datetime_format') - if real_col.dtype == 'O' and datetime_format: - real_col = pd.to_datetime(real_col, format=datetime_format) - synthetic_col = pd.to_datetime(synthetic_col, format=datetime_format) - - real_col = pd.to_numeric(real_col) - synthetic_col = pd.to_numeric(synthetic_col) - - bin_edges = np.histogram_bin_edges(real_col.dropna()) - binned_real_col = np.digitize(real_col, bins=bin_edges) - binned_synthetic_col = np.digitize(synthetic_col, bins=bin_edges) - - binned_real[column_name] = binned_real_col - binned_synthetic[column_name] = binned_synthetic_col - get_columns_from_metadata(binned_metadata)[column_name] = {'sdtype': 'categorical'} - - return binned_real, binned_synthetic, binned_metadata - - -def _get_non_id_columns(metadata, binned_metadata): - valid_sdtypes = ['numerical', 'categorical', 'boolean', 'datetime'] - alternate_keys = get_alternate_keys(metadata) - non_id_columns = [] - for column, column_meta in get_columns_from_metadata(binned_metadata).items(): - is_key = column == metadata.get('primary_key', '') or column in alternate_keys - if get_type_from_column_meta(column_meta) in valid_sdtypes and not is_key: - non_id_columns.append(column) - - return non_id_columns - - -def discretize_and_apply_metric(real_data, synthetic_data, metadata, metric, keys_to_skip=[]): - """Discretize the data and apply the given metric. - - Args: - real_data (pandas.DataFrame): - The real data. - synthetic_data (pandas.DataFrame): - The synthetic data. - metadata (dict) - The metadata. - metric (sdmetrics.single_table.MultiColumnPairMetric): - The column pair metric to apply. - keys_to_skip (list[tuple(str)] or None): - A list of keys for which to skip computing the metric. - - Returns: - dict: - The metric results. - """ - metric_results = {} - - binned_real, binned_synthetic, binned_metadata = discretize_table_data( - real_data, synthetic_data, metadata) - - non_id_cols = _get_non_id_columns(metadata, binned_metadata) - for columns in itertools.combinations(non_id_cols, r=2): - sorted_columns = tuple(sorted(columns)) - if ( - sorted_columns not in keys_to_skip and - (sorted_columns[1], sorted_columns[0]) not in keys_to_skip - ): - result = metric.column_pairs_metric.compute_breakdown( - binned_real[list(sorted_columns)], - binned_synthetic[list(sorted_columns)], - ) - metric_results[sorted_columns] = result - metric_results[sorted_columns] = result - - return metric_results - - -def aggregate_metric_results(metric_results): - """Aggregate the scores and errors in a metric results mapping. - - Args: - metric_results (dict): - The metric results to aggregate. - - Returns: - (float, int): - The average of the metric scores, and the number of errors. - """ - if len(metric_results) == 0: - return np.nan, 0 - - metric_scores = [] - num_errors = 0 - - for _, breakdown in metric_results.items(): - metric_score = breakdown.get('score', np.nan) - if not np.isnan(metric_score): - metric_scores.append(metric_score) - if 'error' in breakdown: - num_errors += 1 - - return np.mean(metric_scores), num_errors - - -def _validate_categorical_values(real_data, synthetic_data, metadata, table=None): - """Get categorical values found in synthetic data but not real data for all columns. - - Args: - real_data (pd.DataFrame): - The real data. - synthetic_data (pd.DataFrame): - The synthetic data. - metadata (dict): - The metadata. - table (str, optional): - The name of the current table, if one exists - """ - if table: - warning_format = ('Unexpected values ({values}) in column "{column}" ' - f'and table "{table}"') - else: - warning_format = 'Unexpected values ({values}) in column "{column}"' - - columns = get_columns_from_metadata(metadata) - for column, column_meta in columns.items(): - column_type = get_type_from_column_meta(column_meta) - if column_type == 'categorical': - extra_categories = [ - value for value in synthetic_data[column].unique() - if value not in real_data[column].unique() - ] - if extra_categories: - value_list = '", "'.join(str(value) for value in extra_categories[:5]) - values = f'"{value_list}" + more' if len( - extra_categories) > 5 else f'"{value_list}"' - warnings.warn(warning_format.format(values=values, column=column))