diff --git a/copulas/sdmetrics.py b/copulas/sdmetrics.py deleted file mode 100644 index bafe1d52..00000000 --- a/copulas/sdmetrics.py +++ /dev/null @@ -1,480 +0,0 @@ -"""Visualization methods for SDMetrics.""" - -import pandas as pd -import plotly.express as px -import plotly.figure_factory as ff -from pandas.api.types import is_datetime64_dtype - -from copulas.utils import get_missing_percentage, is_datetime -from copulas.utils2 import PlotConfig - - -def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): - """Generate a bar plot of the real and synthetic data. - - Args: - real_column (pandas.Series): - The real data for the desired column. - synthetic_column (pandas.Series): - The synthetic data for the desired column. - plot_kwargs (dict, optional): - Dictionary of keyword arguments to pass to px.histogram. Keyword arguments - provided this way will overwrite defaults. - - Returns: - plotly.graph_objects._figure.Figure - """ - all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) - histogram_kwargs = { - 'x': 'values', - # 'color': 'Data', - 'barmode': 'group', - 'color_discrete_sequence': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN], - 'pattern_shape': 'Data', - 'pattern_shape_sequence': ['', '/'], - 'histnorm': 'probability density', - } - histogram_kwargs.update(plot_kwargs) - fig = px.histogram( - all_data, - **histogram_kwargs - ) - - return fig - - -def _generate_heatmap_plot(all_data, columns): - """Generate heatmap plot for discrete data. - - Args: - all_data (pandas.DataFrame): - The real and synthetic data for the desired column pair containing a - ``Data`` column that indicates whether is real or synthetic. - columns (list): - A list of the columns being plotted. - - Returns: - plotly.graph_objects._figure.Figure - """ - fig = px.density_heatmap( - all_data, - x=columns[0], - y=columns[1], - facet_col='Data', - histnorm='probability' - ) - - fig.update_layout( - title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", - coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]}, - font={'size': PlotConfig.FONT_SIZE}, - ) - - fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1] + ' Data')) - - return fig - - -def _generate_box_plot(all_data, columns): - """Generate a box plot for mixed discrete and continuous column data. - - Args: - all_data (pandas.DataFrame): - The real and synthetic data for the desired column pair containing a - ``Data`` column that indicates whether is real or synthetic. - columns (list): - A list of the columns being plotted. - - Returns: - plotly.graph_objects._figure.Figure - """ - fig = px.box( - all_data, - x=columns[0], - y=columns[1], - color='Data', - color_discrete_map={ - 'Real': PlotConfig.DATACEBO_DARK, - 'Synthetic': PlotConfig.DATACEBO_GREEN - }, - ) - - fig.update_layout( - title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", - plot_bgcolor=PlotConfig.BACKGROUND_COLOR, - font={'size': PlotConfig.FONT_SIZE}, - ) - - return fig - - -def _generate_scatter_plot(all_data, columns): - """Generate a scatter plot for column pair plot. - - Args: - all_data (pandas.DataFrame): - The real and synthetic data for the desired column pair containing a - ``Data`` column that indicates whether is real or synthetic. - columns (list): - A list of the columns being plotted. - - Returns: - plotly.graph_objects._figure.Figure - """ - fig = px.scatter( - all_data, - x=columns[0], - y=columns[1], - color='Data', - color_discrete_map={ - 'Real': PlotConfig.DATACEBO_DARK, - 'Synthetic': PlotConfig.DATACEBO_GREEN - }, - symbol='Data' - ) - - fig.update_layout( - title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", - plot_bgcolor=PlotConfig.BACKGROUND_COLOR, - font={'size': PlotConfig.FONT_SIZE}, - ) - - return fig - - -def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}): - """Plot the real and synthetic data as a distplot. - - Args: - real_data (pandas.DataFrame): - The real data for the desired column. - synthetic_data (pandas.DataFrame): - The synthetic data for the desired column. - plot_kwargs (dict, optional): - Dictionary of keyword arguments to pass to px.histogram. Keyword arguments - provided this way will overwrite defaults. - - Returns: - plotly.graph_objects._figure.Figure - """ - default_distplot_kwargs = { - 'show_hist': False, - 'show_rug': False, - 'colors': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN] - } - - fig = ff.create_distplot( - [real_data['values'], synthetic_data['values']], - ['Real', 'Synthetic'], - **{**default_distplot_kwargs, **plot_kwargs} - ) - - return fig - - -def _generate_column_plot(real_column, - synthetic_column, - plot_type, - plot_kwargs={}, - plot_title=None, - x_label=None): - """Generate a plot of the real and synthetic data. - - Args: - real_column (pandas.Series): - The real data for the desired column. - synthetic_column (pandas.Series): - The synthetic data for the desired column. - plot_type (str): - The type of plot to use. Must be one of 'bar' or 'distplot'. - hist_kwargs (dict, optional): - Dictionary of keyword arguments to pass to px.histogram. Keyword arguments - provided this way will overwrite defaults. - plot_title (str, optional): - Title to use for the plot. Defaults to 'Real vs. Synthetic Data for column {column}' - x_label (str, optional): - Label to use for x-axis. Defaults to 'Category'. - - Returns: - plotly.graph_objects._figure.Figure - """ - if plot_type not in ['bar', 'distplot']: - raise ValueError( - "Unrecognized plot_type '{plot_type}'. Pleas use one of 'bar' or 'distplot'" - ) - - column_name = real_column.name if hasattr(real_column, 'name') else '' - - missing_data_real = get_missing_percentage(real_column) - missing_data_synthetic = get_missing_percentage(synthetic_column) - - real_data = pd.DataFrame({'values': real_column.copy().dropna()}) - real_data['Data'] = 'Real' - synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()}) - synthetic_data['Data'] = 'Synthetic' - - is_datetime_sdtype = False - if is_datetime64_dtype(real_column.dtype): - is_datetime_sdtype = True - real_data['values'] = real_data['values'].astype('int64') - synthetic_data['values'] = synthetic_data['values'].astype('int64') - - trace_args = {} - - if plot_type == 'bar': - fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs) - elif plot_type == 'distplot': - x_label = x_label or 'Value' - fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs) - trace_args = {'fill': 'tozeroy'} - - for i, name in enumerate(['Real', 'Synthetic']): - fig.update_traces( - x=pd.to_datetime(fig.data[i].x) if is_datetime_sdtype else fig.data[i].x, - hovertemplate=f'{name}
Frequency: %{{y}}', - selector={'name': name}, - **trace_args - ) - - show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0 - annotations = [] if not show_missing_values else [ - { - 'xref': 'paper', - 'yref': 'paper', - 'x': 1.0, - 'y': 1.05, - 'showarrow': False, - 'text': ( - f'*Missing Values: Real Data ({missing_data_real}%), ' - f'Synthetic Data ({missing_data_synthetic}%)' - ), - }, - ] - - if not plot_title: - plot_title = f"Real vs. Synthetic Data for column '{column_name}'" - - if not x_label: - x_label = 'Category' - - fig.update_layout( - title=plot_title, - xaxis_title=x_label, - yaxis_title='Frequency', - plot_bgcolor=PlotConfig.BACKGROUND_COLOR, - annotations=annotations, - font={'size': PlotConfig.FONT_SIZE}, - ) - return fig - - -def _generate_cardinality_plot(real_data, - synthetic_data, - parent_primary_key, - child_foreign_key, - plot_type='bar'): - plot_title = ( - f"Relationship (child foreign key='{child_foreign_key}' and parent " - f"primary key='{parent_primary_key}')" - ) - x_label = '# of Children (per Parent)' - - plot_kwargs = {} - if plot_type == 'bar': - max_cardinality = max(max(real_data), max(synthetic_data)) - min_cardinality = min(min(real_data), min(synthetic_data)) - plot_kwargs = { - 'nbins': max_cardinality - min_cardinality + 1 - } - - return _generate_column_plot(real_data, synthetic_data, plot_type, - plot_kwargs, plot_title, x_label) - - -def _get_cardinality(parent_table, child_table, parent_primary_key, child_foreign_key): - """Return the cardinality of the parent-child relationship. - - Args: - parent_table (pandas.DataFrame): - The parent table. - child_table (pandas.DataFrame): - The child table. - parent_primary_key (string): - The name of the primary key column in the parent table. - child_foreign_key (string): - The name of the foreign key column in the child table. - - Returns: - pandas.DataFrame - """ - child_counts = child_table[child_foreign_key].value_counts().rename('# children') - cardinalities = child_counts.reindex(parent_table[parent_primary_key], fill_value=0).to_frame() - - return cardinalities.sort_values('# children')['# children'] - - -def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_table_name, - child_foreign_key, parent_primary_key, plot_type='bar'): - """Return a plot of the cardinality of the parent-child relationship. - - Args: - real_data (dict): - The real data. - synthetic_data (dict): - The synthetic data. - child_table_name (string): - The name of the child table. - parent_table_name (string): - The name of the parent table. - child_foreign_key (string): - The name of the foreign key column in the child table. - parent_primary_key (string): - The name of the primary key column in the parent table. - plot_type (string, optional): - The plot type to use to plot the cardinality. Must be either 'bar' or 'distplot'. - Defaults to 'bar'. - - Returns: - plotly.graph_objects._figure.Figure - """ - if plot_type not in ['bar', 'distplot']: - raise ValueError( - f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].") - - real_cardinality = _get_cardinality( - real_data[parent_table_name], real_data[child_table_name], - parent_primary_key, child_foreign_key - ) - synth_cardinality = _get_cardinality( - synthetic_data[parent_table_name], - synthetic_data[child_table_name], - parent_primary_key, child_foreign_key - ) - - fig = _generate_cardinality_plot( - real_cardinality, - synth_cardinality, - parent_primary_key, - child_foreign_key, - plot_type=plot_type - ) - - return fig - - -def get_column_plot(real_data, synthetic_data, column_name, plot_type=None): - """Return a plot of the real and synthetic data for a given column. - - Args: - real_data (pandas.DataFrame): - The real table data. - synthetic_data (pandas.DataFrame): - The synthetic table data. - column_name (str): - The name of the column. - plot_type (str or None): - The plot to be used. Can choose between ``distplot``, ``bar`` or ``None``. If ``None` - select between ``distplot`` or ``bar`` depending on the data that the column contains, - ``distplot`` for datetime and numerical values and ``bar`` for categorical. - Defaults to ``None``. - - Returns: - plotly.graph_objects._figure.Figure - """ - if plot_type not in ['bar', 'distplot', None]: - raise ValueError( - f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]." - ) - - if column_name not in real_data.columns: - raise ValueError(f"Column '{column_name}' not found in real table data.") - if column_name not in synthetic_data.columns: - raise ValueError(f"Column '{column_name}' not found in synthetic table data.") - - real_column = real_data[column_name] - if plot_type is None: - column_is_datetime = is_datetime(real_data[column_name]) - dtype = real_column.dropna().infer_objects().dtype.kind - if column_is_datetime or dtype in ('i', 'f'): - plot_type = 'distplot' - else: - plot_type = 'bar' - - real_column = real_data[column_name] - synthetic_column = synthetic_data[column_name] - - fig = _generate_column_plot(real_column, synthetic_column, plot_type) - - return fig - - -def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None): - """Return a plot of the real and synthetic data for a given column pair. - - Args: - real_data (pandas.DataFrame): - The real table data. - synthetic_column (pandas.Dataframe): - The synthetic table data. - column_names (list[string]): - The names of the two columns to plot. - plot_type (str or None): - The plot to be used. Can choose between ``box``, ``heatmap``, ``scatter`` or ``None``. - If ``None` select between ``box``, ``heatmap`` or ``scatter`` depending on the data - that the column contains, ``scatter`` used for datetime and numerical values, - ``heatmap`` for categorical and ``box`` for a mix of both. Defaults to ``None``. - - Returns: - plotly.graph_objects._figure.Figure - """ - if len(column_names) != 2: - raise ValueError('Must provide exactly two column names.') - - if not set(column_names).issubset(real_data.columns): - raise ValueError( - f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.' - ) - - if not set(column_names).issubset(synthetic_data.columns): - raise ValueError( - f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} ' - 'in synthetic data.' - ) - - if plot_type not in ['box', 'heatmap', 'scatter', None]: - raise ValueError( - f"Invalid plot_type '{plot_type}'. Please use one of " - "['box', 'heatmap', 'scatter', None]." - ) - - real_data = real_data[column_names] - synthetic_data = synthetic_data[column_names] - if plot_type is None: - plot_type = [] - for column_name in column_names: - column = real_data[column_name] - dtype = column.dropna().infer_objects().dtype.kind - if dtype in ('i', 'f') or is_datetime(column): - plot_type.append('scatter') - else: - plot_type.append('heatmap') - - if len(set(plot_type)) > 1: - plot_type = 'box' - else: - plot_type = plot_type.pop() - - # Merge the real and synthetic data and add a flag ``Data`` to indicate each one. - columns = list(real_data.columns) - real_data = real_data.copy() - real_data['Data'] = 'Real' - synthetic_data = synthetic_data.copy() - synthetic_data['Data'] = 'Synthetic' - all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) - - if plot_type == 'scatter': - return _generate_scatter_plot(all_data, columns) - elif plot_type == 'heatmap': - return _generate_heatmap_plot(all_data, columns) - - return _generate_box_plot(all_data, columns)