diff --git a/copulas/sdmetrics.py b/copulas/sdmetrics.py
new file mode 100644
index 00000000..d9090f71
--- /dev/null
+++ b/copulas/sdmetrics.py
@@ -0,0 +1,479 @@
+"""Visualization methods for SDMetrics."""
+
+import pandas as pd
+import plotly.express as px
+import plotly.figure_factory as ff
+from pandas.api.types import is_datetime64_dtype
+
+from copulas.utils import get_missing_percentage, is_datetime
+from copulas.utils2 import PlotConfig
+
+
+def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
+ """Generate a bar plot of the real and synthetic data.
+
+ Args:
+ real_column (pandas.Series):
+ The real data for the desired column.
+ synthetic_column (pandas.Series):
+ The synthetic data for the desired column.
+ plot_kwargs (dict, optional):
+ Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
+ provided this way will overwrite defaults.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
+ histogram_kwargs = {
+ 'x': 'values',
+ 'barmode': 'group',
+ 'color_discrete_sequence': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN],
+ 'pattern_shape': 'Data',
+ 'pattern_shape_sequence': ['', '/'],
+ 'histnorm': 'probability density',
+ }
+ histogram_kwargs.update(plot_kwargs)
+ fig = px.histogram(
+ all_data,
+ **histogram_kwargs
+ )
+
+ return fig
+
+
+def _generate_heatmap_plot(all_data, columns):
+ """Generate heatmap plot for discrete data.
+
+ Args:
+ all_data (pandas.DataFrame):
+ The real and synthetic data for the desired column pair containing a
+ ``Data`` column that indicates whether is real or synthetic.
+ columns (list):
+ A list of the columns being plotted.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ fig = px.density_heatmap(
+ all_data,
+ x=columns[0],
+ y=columns[1],
+ facet_col='Data',
+ histnorm='probability'
+ )
+
+ fig.update_layout(
+ title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
+ coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]},
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+
+ fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1] + ' Data'))
+
+ return fig
+
+
+def _generate_box_plot(all_data, columns):
+ """Generate a box plot for mixed discrete and continuous column data.
+
+ Args:
+ all_data (pandas.DataFrame):
+ The real and synthetic data for the desired column pair containing a
+ ``Data`` column that indicates whether is real or synthetic.
+ columns (list):
+ A list of the columns being plotted.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ fig = px.box(
+ all_data,
+ x=columns[0],
+ y=columns[1],
+ color='Data',
+ color_discrete_map={
+ 'Real': PlotConfig.DATACEBO_DARK,
+ 'Synthetic': PlotConfig.DATACEBO_GREEN
+ },
+ )
+
+ fig.update_layout(
+ title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+
+ return fig
+
+
+def _generate_scatter_plot(all_data, columns):
+ """Generate a scatter plot for column pair plot.
+
+ Args:
+ all_data (pandas.DataFrame):
+ The real and synthetic data for the desired column pair containing a
+ ``Data`` column that indicates whether is real or synthetic.
+ columns (list):
+ A list of the columns being plotted.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ fig = px.scatter(
+ all_data,
+ x=columns[0],
+ y=columns[1],
+ color='Data',
+ color_discrete_map={
+ 'Real': PlotConfig.DATACEBO_DARK,
+ 'Synthetic': PlotConfig.DATACEBO_GREEN
+ },
+ symbol='Data'
+ )
+
+ fig.update_layout(
+ title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+
+ return fig
+
+
+def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}):
+ """Plot the real and synthetic data as a distplot.
+
+ Args:
+ real_data (pandas.DataFrame):
+ The real data for the desired column.
+ synthetic_data (pandas.DataFrame):
+ The synthetic data for the desired column.
+ plot_kwargs (dict, optional):
+ Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
+ provided this way will overwrite defaults.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ default_distplot_kwargs = {
+ 'show_hist': False,
+ 'show_rug': False,
+ 'colors': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]
+ }
+
+ fig = ff.create_distplot(
+ [real_data['values'], synthetic_data['values']],
+ ['Real', 'Synthetic'],
+ **{**default_distplot_kwargs, **plot_kwargs}
+ )
+
+ return fig
+
+
+def _generate_column_plot(real_column,
+ synthetic_column,
+ plot_type,
+ plot_kwargs={},
+ plot_title=None,
+ x_label=None):
+ """Generate a plot of the real and synthetic data.
+
+ Args:
+ real_column (pandas.Series):
+ The real data for the desired column.
+ synthetic_column (pandas.Series):
+ The synthetic data for the desired column.
+ plot_type (str):
+ The type of plot to use. Must be one of 'bar' or 'distplot'.
+ hist_kwargs (dict, optional):
+ Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
+ provided this way will overwrite defaults.
+ plot_title (str, optional):
+ Title to use for the plot. Defaults to 'Real vs. Synthetic Data for column {column}'
+ x_label (str, optional):
+ Label to use for x-axis. Defaults to 'Category'.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ if plot_type not in ['bar', 'distplot']:
+ raise ValueError(
+ "Unrecognized plot_type '{plot_type}'. Pleas use one of 'bar' or 'distplot'"
+ )
+
+ column_name = real_column.name if hasattr(real_column, 'name') else ''
+
+ missing_data_real = get_missing_percentage(real_column)
+ missing_data_synthetic = get_missing_percentage(synthetic_column)
+
+ real_data = pd.DataFrame({'values': real_column.copy().dropna()})
+ real_data['Data'] = 'Real'
+ synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()})
+ synthetic_data['Data'] = 'Synthetic'
+
+ is_datetime_sdtype = False
+ if is_datetime64_dtype(real_column.dtype):
+ is_datetime_sdtype = True
+ real_data['values'] = real_data['values'].astype('int64')
+ synthetic_data['values'] = synthetic_data['values'].astype('int64')
+
+ trace_args = {}
+
+ if plot_type == 'bar':
+ fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs)
+ elif plot_type == 'distplot':
+ x_label = x_label or 'Value'
+ fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs)
+ trace_args = {'fill': 'tozeroy'}
+
+ for i, name in enumerate(['Real', 'Synthetic']):
+ fig.update_traces(
+ x=pd.to_datetime(fig.data[i].x) if is_datetime_sdtype else fig.data[i].x,
+ hovertemplate=f'{name}
Frequency: %{{y}}',
+ selector={'name': name},
+ **trace_args
+ )
+
+ show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0
+ annotations = [] if not show_missing_values else [
+ {
+ 'xref': 'paper',
+ 'yref': 'paper',
+ 'x': 1.0,
+ 'y': 1.05,
+ 'showarrow': False,
+ 'text': (
+ f'*Missing Values: Real Data ({missing_data_real}%), '
+ f'Synthetic Data ({missing_data_synthetic}%)'
+ ),
+ },
+ ]
+
+ if not plot_title:
+ plot_title = f"Real vs. Synthetic Data for column '{column_name}'"
+
+ if not x_label:
+ x_label = 'Category'
+
+ fig.update_layout(
+ title=plot_title,
+ xaxis_title=x_label,
+ yaxis_title='Frequency',
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ annotations=annotations,
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+ return fig
+
+
+def _generate_cardinality_plot(real_data,
+ synthetic_data,
+ parent_primary_key,
+ child_foreign_key,
+ plot_type='bar'):
+ plot_title = (
+ f"Relationship (child foreign key='{child_foreign_key}' and parent "
+ f"primary key='{parent_primary_key}')"
+ )
+ x_label = '# of Children (per Parent)'
+
+ plot_kwargs = {}
+ if plot_type == 'bar':
+ max_cardinality = max(max(real_data), max(synthetic_data))
+ min_cardinality = min(min(real_data), min(synthetic_data))
+ plot_kwargs = {
+ 'nbins': max_cardinality - min_cardinality + 1
+ }
+
+ return _generate_column_plot(real_data, synthetic_data, plot_type,
+ plot_kwargs, plot_title, x_label)
+
+
+def _get_cardinality(parent_table, child_table, parent_primary_key, child_foreign_key):
+ """Return the cardinality of the parent-child relationship.
+
+ Args:
+ parent_table (pandas.DataFrame):
+ The parent table.
+ child_table (pandas.DataFrame):
+ The child table.
+ parent_primary_key (string):
+ The name of the primary key column in the parent table.
+ child_foreign_key (string):
+ The name of the foreign key column in the child table.
+
+ Returns:
+ pandas.DataFrame
+ """
+ child_counts = child_table[child_foreign_key].value_counts().rename('# children')
+ cardinalities = child_counts.reindex(parent_table[parent_primary_key], fill_value=0).to_frame()
+
+ return cardinalities.sort_values('# children')['# children']
+
+
+def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_table_name,
+ child_foreign_key, parent_primary_key, plot_type='bar'):
+ """Return a plot of the cardinality of the parent-child relationship.
+
+ Args:
+ real_data (dict):
+ The real data.
+ synthetic_data (dict):
+ The synthetic data.
+ child_table_name (string):
+ The name of the child table.
+ parent_table_name (string):
+ The name of the parent table.
+ child_foreign_key (string):
+ The name of the foreign key column in the child table.
+ parent_primary_key (string):
+ The name of the primary key column in the parent table.
+ plot_type (string, optional):
+ The plot type to use to plot the cardinality. Must be either 'bar' or 'distplot'.
+ Defaults to 'bar'.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ if plot_type not in ['bar', 'distplot']:
+ raise ValueError(
+ f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].")
+
+ real_cardinality = _get_cardinality(
+ real_data[parent_table_name], real_data[child_table_name],
+ parent_primary_key, child_foreign_key
+ )
+ synth_cardinality = _get_cardinality(
+ synthetic_data[parent_table_name],
+ synthetic_data[child_table_name],
+ parent_primary_key, child_foreign_key
+ )
+
+ fig = _generate_cardinality_plot(
+ real_cardinality,
+ synth_cardinality,
+ parent_primary_key,
+ child_foreign_key,
+ plot_type=plot_type
+ )
+
+ return fig
+
+
+def get_column_plot(real_data, synthetic_data, column_name, plot_type=None):
+ """Return a plot of the real and synthetic data for a given column.
+
+ Args:
+ real_data (pandas.DataFrame):
+ The real table data.
+ synthetic_data (pandas.DataFrame):
+ The synthetic table data.
+ column_name (str):
+ The name of the column.
+ plot_type (str or None):
+ The plot to be used. Can choose between ``distplot``, ``bar`` or ``None``. If ``None`
+ select between ``distplot`` or ``bar`` depending on the data that the column contains,
+ ``distplot`` for datetime and numerical values and ``bar`` for categorical.
+ Defaults to ``None``.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ if plot_type not in ['bar', 'distplot', None]:
+ raise ValueError(
+ f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]."
+ )
+
+ if column_name not in real_data.columns:
+ raise ValueError(f"Column '{column_name}' not found in real table data.")
+ if column_name not in synthetic_data.columns:
+ raise ValueError(f"Column '{column_name}' not found in synthetic table data.")
+
+ real_column = real_data[column_name]
+ if plot_type is None:
+ column_is_datetime = is_datetime(real_data[column_name])
+ dtype = real_column.dropna().infer_objects().dtype.kind
+ if column_is_datetime or dtype in ('i', 'f'):
+ plot_type = 'distplot'
+ else:
+ plot_type = 'bar'
+
+ real_column = real_data[column_name]
+ synthetic_column = synthetic_data[column_name]
+
+ fig = _generate_column_plot(real_column, synthetic_column, plot_type)
+
+ return fig
+
+
+def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None):
+ """Return a plot of the real and synthetic data for a given column pair.
+
+ Args:
+ real_data (pandas.DataFrame):
+ The real table data.
+ synthetic_column (pandas.Dataframe):
+ The synthetic table data.
+ column_names (list[string]):
+ The names of the two columns to plot.
+ plot_type (str or None):
+ The plot to be used. Can choose between ``box``, ``heatmap``, ``scatter`` or ``None``.
+ If ``None` select between ``box``, ``heatmap`` or ``scatter`` depending on the data
+ that the column contains, ``scatter`` used for datetime and numerical values,
+ ``heatmap`` for categorical and ``box`` for a mix of both. Defaults to ``None``.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ if len(column_names) != 2:
+ raise ValueError('Must provide exactly two column names.')
+
+ if not set(column_names).issubset(real_data.columns):
+ raise ValueError(
+ f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.'
+ )
+
+ if not set(column_names).issubset(synthetic_data.columns):
+ raise ValueError(
+ f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} '
+ 'in synthetic data.'
+ )
+
+ if plot_type not in ['box', 'heatmap', 'scatter', None]:
+ raise ValueError(
+ f"Invalid plot_type '{plot_type}'. Please use one of "
+ "['box', 'heatmap', 'scatter', None]."
+ )
+
+ real_data = real_data[column_names]
+ synthetic_data = synthetic_data[column_names]
+ if plot_type is None:
+ plot_type = []
+ for column_name in column_names:
+ column = real_data[column_name]
+ dtype = column.dropna().infer_objects().dtype.kind
+ if dtype in ('i', 'f') or is_datetime(column):
+ plot_type.append('scatter')
+ else:
+ plot_type.append('heatmap')
+
+ if len(set(plot_type)) > 1:
+ plot_type = 'box'
+ else:
+ plot_type = plot_type.pop()
+
+ # Merge the real and synthetic data and add a flag ``Data`` to indicate each one.
+ columns = list(real_data.columns)
+ real_data = real_data.copy()
+ real_data['Data'] = 'Real'
+ synthetic_data = synthetic_data.copy()
+ synthetic_data['Data'] = 'Synthetic'
+ all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
+
+ if plot_type == 'scatter':
+ return _generate_scatter_plot(all_data, columns)
+ elif plot_type == 'heatmap':
+ return _generate_heatmap_plot(all_data, columns)
+
+ return _generate_box_plot(all_data, columns)
diff --git a/copulas/utils.py b/copulas/utils.py
deleted file mode 100644
index 30e56ebe..00000000
--- a/copulas/utils.py
+++ /dev/null
@@ -1,298 +0,0 @@
-"""SDMetrics utils to be used across all the project."""
-
-from collections import Counter
-from datetime import datetime
-
-import numpy as np
-import pandas as pd
-from sklearn.preprocessing import OneHotEncoder
-
-
-def nested_attrs_meta(nested):
- """Metaclass factory that defines a Metaclass with a dynamic attribute name."""
-
- class Metaclass(type):
- """Metaclass which pulls the attributes from a nested object using properties."""
-
- def __getattr__(cls, attr):
- """If cls does not have the attribute, try to get it from the nested object."""
- nested_obj = getattr(cls, nested)
- if hasattr(nested_obj, attr):
- return getattr(nested_obj, attr)
-
- raise AttributeError(f"type object '{cls.__name__}' has no attribute '{attr}'")
-
- @property
- def name(cls):
- return getattr(cls, nested).name
-
- @property
- def goal(cls):
- return getattr(cls, nested).goal
-
- @property
- def max_value(cls):
- return getattr(cls, nested).max_value
-
- @property
- def min_value(cls):
- return getattr(cls, nested).min_value
-
- return Metaclass
-
-
-def get_frequencies(real, synthetic):
- """Get percentual frequencies for each possible real categorical value.
-
- Given two iterators containing categorical data, this transforms it into
- observed/expected frequencies which can be used for statistical tests. It
- adds a regularization term to handle cases where the synthetic data contains
- values that don't exist in the real data.
-
- Args:
- real (list):
- A list of hashable objects.
- synthetic (list):
- A list of hashable objects.
-
- Yields:
- tuble[list, list]:
- The observed and expected frequencies (as a percent).
- """
- f_obs, f_exp = [], []
- real, synthetic = Counter(real), Counter(synthetic)
- for value in synthetic:
- if value not in real:
- real[value] += 1e-6 # Regularization to prevent NaN.
-
- for value in real:
- f_obs.append(synthetic[value] / sum(synthetic.values())) # noqa: PD011
- f_exp.append(real[value] / sum(real.values())) # noqa: PD011
-
- return f_obs, f_exp
-
-
-def get_missing_percentage(data_column):
- """Compute the missing value percentage of a column.
-
- Args:
- data_column (pandas.Series):
- The data of the desired column.
-
- Returns:
- pandas.Series:
- Percentage of missing values inside the column.
- """
- return round((data_column.isna().sum() / len(data_column)) * 100, 2)
-
-
-def get_cardinality_distribution(parent_column, child_column):
- """Compute the cardinality distribution of the (parent, child) pairing.
-
- Args:
- parent_column (pandas.Series):
- The parent column.
- child_column (pandas.Series):
- The child column.
-
- Returns:
- pandas.Series:
- The cardinality distribution.
- """
- child_df = pd.DataFrame({'child_counts': child_column.value_counts()})
- cardinality_df = pd.DataFrame({'parent': parent_column}).join(
- child_df, on='parent').fillna(0)
-
- return cardinality_df['child_counts']
-
-
-def is_datetime(data):
- """Determine if the input is a datetime type or not.
-
- Args:
- data (pandas.DataFrame, int or datetime):
- Input to evaluate.
-
- Returns:
- bool:
- True if the input is a datetime type, False if not.
- """
- return (
- pd.api.types.is_datetime64_any_dtype(data)
- or isinstance(data, pd.Timestamp)
- or isinstance(data, datetime)
- )
-
-
-class HyperTransformer():
- """HyperTransformer class.
-
- The ``HyperTransformer`` class contains a set of transforms to transform one or
- more columns based on each column's data type.
- """
-
- column_transforms = {}
- column_kind = {}
-
- def fit(self, data):
- """Fit the HyperTransformer to the given data.
-
- Args:
- data (pandas.DataFrame):
- The data to transform.
- """
- if not isinstance(data, pd.DataFrame):
- data = pd.DataFrame(data)
-
- for field in data:
- kind = data[field].dropna().infer_objects().dtype.kind
- self.column_kind[field] = kind
-
- if kind == 'i' or kind == 'f':
- # Numerical column.
- self.column_transforms[field] = {'mean': data[field].mean()}
- elif kind == 'b':
- # Boolean column.
- numeric = pd.to_numeric(data[field], errors='coerce').astype(float)
- self.column_transforms[field] = {'mode': numeric.mode().iloc[0]}
- elif kind == 'O':
- # Categorical column.
- col_data = pd.DataFrame({'field': data[field]})
- enc = OneHotEncoder()
- enc.fit(col_data)
- self.column_transforms[field] = {'one_hot_encoder': enc}
- elif kind == 'M':
- # Datetime column.
- nulls = data[field].isna()
- integers = pd.to_numeric(
- data[field], errors='coerce').to_numpy().astype(np.float64)
- integers[nulls] = np.nan
- self.column_transforms[field] = {'mean': pd.Series(integers).mean()}
-
- def transform(self, data):
- """Transform the given data based on the data type of each column.
-
- Args:
- data (pandas.DataFrame):
- The data to transform.
-
- Returns:
- pandas.DataFrame:
- The transformed data.
- """
- if not isinstance(data, pd.DataFrame):
- data = pd.DataFrame(data)
-
- for field in data:
- transform_info = self.column_transforms[field]
-
- kind = self.column_kind[field]
- if kind == 'i' or kind == 'f':
- # Numerical column.
- data[field] = data[field].fillna(transform_info['mean'])
- elif kind == 'b':
- # Boolean column.
- data[field] = pd.to_numeric(data[field], errors='coerce').astype(float)
- data[field] = data[field].fillna(transform_info['mode'])
- elif kind == 'O':
- # Categorical column.
- col_data = pd.DataFrame({'field': data[field]})
- out = transform_info['one_hot_encoder'].transform(col_data).toarray()
- transformed = pd.DataFrame(
- out, columns=[f'value{i}' for i in range(np.shape(out)[1])])
- data = data.drop(columns=[field])
- data = pd.concat([data, transformed.set_index(data.index)], axis=1)
- elif kind == 'M':
- # Datetime column.
- nulls = data[field].isna()
- integers = pd.to_numeric(
- data[field], errors='coerce').to_numpy().astype(np.float64)
- integers[nulls] = np.nan
- data[field] = pd.Series(integers)
- data[field] = data[field].fillna(transform_info['mean'])
-
- return data
-
- def fit_transform(self, data):
- """Fit and transform the given data based on the data type of each column.
-
- Args:
- data (pandas.DataFrame):
- The data to transform.
-
- Returns:
- pandas.DataFrame:
- The transformed data.
- """
- self.fit(data)
- return self.transform(data)
-
-
-def get_columns_from_metadata(metadata):
- """Get the column info from a metadata dict.
-
- Args:
- metadata (dict):
- The metadata dict.
-
- Returns:
- dict:
- The columns metadata.
- """
- return metadata.get('columns', {})
-
-
-def get_type_from_column_meta(column_metadata):
- """Get the type of a given column from the column metadata.
-
- Args:
- column_metadata (dict):
- The column metadata.
-
- Returns:
- string:
- The column type.
- """
- return column_metadata.get('sdtype', '')
-
-
-def get_alternate_keys(metadata):
- """Get the alternate keys from a metadata dict.
-
- Args:
- metadata (dict):
- The metadata dict.
-
- Returns:
- list:
- The list of alternate keys.
- """
- alternate_keys = []
- for alternate_key in metadata.get('alternate_keys', []):
- if isinstance(alternate_key, list):
- alternate_keys.extend(alternate_key)
- else:
- alternate_keys.append(alternate_key)
-
- return alternate_keys
-
-
-def strip_characters(list_character, a_string):
- """Strip characters from a column name.
-
- Args:
- list_character (list):
- The list of characters to strip.
- a_string (string):
- The string to be stripped.
-
- Returns:
- string:
- The string with the characters stripped.
- """
- result = a_string
- for character in list_character:
- if character in result:
- result = result.replace(character, '')
-
- return result
diff --git a/copulas/utils2.py b/copulas/utils2.py
deleted file mode 100644
index 560afcf6..00000000
--- a/copulas/utils2.py
+++ /dev/null
@@ -1,233 +0,0 @@
-"""Report utility methods."""
-
-import copy
-import itertools
-import warnings
-
-import numpy as np
-import pandas as pd
-from pandas.core.tools.datetimes import _guess_datetime_format_for_array
-
-from copulas.utils import (
- get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta, is_datetime)
-
-CONTINUOUS_SDTYPES = ['numerical', 'datetime']
-DISCRETE_SDTYPES = ['categorical', 'boolean']
-
-
-class PlotConfig:
- """Custom plot settings for visualizations."""
-
- GREEN = '#36B37E'
- RED = '#FF0000'
- ORANGE = '#F16141'
- DATACEBO_DARK = '#000036'
- DATACEBO_GREEN = '#01E0C9'
- DATACEBO_BLUE = '#03AFF1'
- BACKGROUND_COLOR = '#F5F5F8'
- FONT_SIZE = 18
-
-
-def convert_to_datetime(column_data, datetime_format=None):
- """Convert a column data to pandas datetime.
-
- Args:
- column_data (pandas.Series):
- The column data
- format (str):
- Optional string format of datetime. If ``None``, will attempt to infer the datetime
- format from the column data. Defaults to ``None``.
-
- Returns:
- pandas.Series:
- The converted column data.
- """
- if is_datetime(column_data):
- return column_data
-
- if datetime_format is None:
- datetime_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy())
-
- return pd.to_datetime(column_data, format=datetime_format)
-
-
-def convert_datetime_columns(real_column, synthetic_column, col_metadata):
- """Convert a real and a synthetic column to pandas datetime.
-
- Args:
- real_data (pandas.Series):
- The real column data
- synthetic_column (pandas.Series):
- The synthetic column data
- col_metadata:
- The metadata associated with the column
-
- Returns:
- (pandas.Series, pandas.Series):
- The converted real and synthetic column data.
- """
- datetime_format = col_metadata.get('format') or col_metadata.get('datetime_format')
- return (convert_to_datetime(real_column, datetime_format),
- convert_to_datetime(synthetic_column, datetime_format))
-
-
-def discretize_table_data(real_data, synthetic_data, metadata):
- """Create a copy of the real and synthetic data with discretized data.
-
- Convert numerical and datetime columns to discrete values, and label them
- as categorical.
-
- Args:
- real_data (pandas.DataFrame):
- The real data.
- synthetic_data (pandas.DataFrame):
- The synthetic data.
- metadata (dict)
- The metadata.
-
- Returns:
- (pandas.DataFrame, pandas.DataFrame, dict):
- The binned real and synthetic data, and the updated metadata.
- """
- binned_real = real_data.copy()
- binned_synthetic = synthetic_data.copy()
- binned_metadata = copy.deepcopy(metadata)
-
- for column_name, column_meta in get_columns_from_metadata(metadata).items():
- sdtype = get_type_from_column_meta(column_meta)
-
- if sdtype in ('numerical', 'datetime'):
- real_col = real_data[column_name]
- synthetic_col = synthetic_data[column_name]
- if sdtype == 'datetime':
- datetime_format = column_meta.get('format') or column_meta.get('datetime_format')
- if real_col.dtype == 'O' and datetime_format:
- real_col = pd.to_datetime(real_col, format=datetime_format)
- synthetic_col = pd.to_datetime(synthetic_col, format=datetime_format)
-
- real_col = pd.to_numeric(real_col)
- synthetic_col = pd.to_numeric(synthetic_col)
-
- bin_edges = np.histogram_bin_edges(real_col.dropna())
- binned_real_col = np.digitize(real_col, bins=bin_edges)
- binned_synthetic_col = np.digitize(synthetic_col, bins=bin_edges)
-
- binned_real[column_name] = binned_real_col
- binned_synthetic[column_name] = binned_synthetic_col
- get_columns_from_metadata(binned_metadata)[column_name] = {'sdtype': 'categorical'}
-
- return binned_real, binned_synthetic, binned_metadata
-
-
-def _get_non_id_columns(metadata, binned_metadata):
- valid_sdtypes = ['numerical', 'categorical', 'boolean', 'datetime']
- alternate_keys = get_alternate_keys(metadata)
- non_id_columns = []
- for column, column_meta in get_columns_from_metadata(binned_metadata).items():
- is_key = column == metadata.get('primary_key', '') or column in alternate_keys
- if get_type_from_column_meta(column_meta) in valid_sdtypes and not is_key:
- non_id_columns.append(column)
-
- return non_id_columns
-
-
-def discretize_and_apply_metric(real_data, synthetic_data, metadata, metric, keys_to_skip=[]):
- """Discretize the data and apply the given metric.
-
- Args:
- real_data (pandas.DataFrame):
- The real data.
- synthetic_data (pandas.DataFrame):
- The synthetic data.
- metadata (dict)
- The metadata.
- metric (sdmetrics.single_table.MultiColumnPairMetric):
- The column pair metric to apply.
- keys_to_skip (list[tuple(str)] or None):
- A list of keys for which to skip computing the metric.
-
- Returns:
- dict:
- The metric results.
- """
- metric_results = {}
-
- binned_real, binned_synthetic, binned_metadata = discretize_table_data(
- real_data, synthetic_data, metadata)
-
- non_id_cols = _get_non_id_columns(metadata, binned_metadata)
- for columns in itertools.combinations(non_id_cols, r=2):
- sorted_columns = tuple(sorted(columns))
- if (
- sorted_columns not in keys_to_skip and
- (sorted_columns[1], sorted_columns[0]) not in keys_to_skip
- ):
- result = metric.column_pairs_metric.compute_breakdown(
- binned_real[list(sorted_columns)],
- binned_synthetic[list(sorted_columns)],
- )
- metric_results[sorted_columns] = result
- metric_results[sorted_columns] = result
-
- return metric_results
-
-
-def aggregate_metric_results(metric_results):
- """Aggregate the scores and errors in a metric results mapping.
-
- Args:
- metric_results (dict):
- The metric results to aggregate.
-
- Returns:
- (float, int):
- The average of the metric scores, and the number of errors.
- """
- if len(metric_results) == 0:
- return np.nan, 0
-
- metric_scores = []
- num_errors = 0
-
- for _, breakdown in metric_results.items():
- metric_score = breakdown.get('score', np.nan)
- if not np.isnan(metric_score):
- metric_scores.append(metric_score)
- if 'error' in breakdown:
- num_errors += 1
-
- return np.mean(metric_scores), num_errors
-
-
-def _validate_categorical_values(real_data, synthetic_data, metadata, table=None):
- """Get categorical values found in synthetic data but not real data for all columns.
-
- Args:
- real_data (pd.DataFrame):
- The real data.
- synthetic_data (pd.DataFrame):
- The synthetic data.
- metadata (dict):
- The metadata.
- table (str, optional):
- The name of the current table, if one exists
- """
- if table:
- warning_format = ('Unexpected values ({values}) in column "{column}" '
- f'and table "{table}"')
- else:
- warning_format = 'Unexpected values ({values}) in column "{column}"'
-
- columns = get_columns_from_metadata(metadata)
- for column, column_meta in columns.items():
- column_type = get_type_from_column_meta(column_meta)
- if column_type == 'categorical':
- extra_categories = [
- value for value in synthetic_data[column].unique()
- if value not in real_data[column].unique()
- ]
- if extra_categories:
- value_list = '", "'.join(str(value) for value in extra_categories[:5])
- values = f'"{value_list}" + more' if len(
- extra_categories) > 5 else f'"{value_list}"'
- warnings.warn(warning_format.format(values=values, column=column))