diff --git a/copulas/visualization.py b/copulas/visualization.py index be3f64cf..60883c34 100644 --- a/copulas/visualization.py +++ b/copulas/visualization.py @@ -51,6 +51,8 @@ def _generate_1d_plot(data, title, labels, colors): plot_bgcolor=PlotConfig.BACKGROUND_COLOR, font={'size': PlotConfig.FONT_SIZE}, showlegend=True if labels[0] else False, + xaxis_title='value', + yaxis_title='frequency' ) return fig @@ -70,6 +72,13 @@ def dist_1d(data, title=None, label=None): Returns: plotly.graph_objects._figure.Figure """ + if not title: + title = 'Data' + if isinstance(data, pd.DataFrame): + title += f" for column '{data.columns[0]}'" + elif isinstance(data, pd.Series) and data.name: + title += f" for column '{data.name}'" + return _generate_1d_plot( data=[data], title=title, @@ -92,6 +101,13 @@ def compare_1d(real, synth, title=None): Returns: plotly.graph_objects._figure.Figure """ + if not title: + title = 'Real vs. Synthetic Data' + if isinstance(real, pd.DataFrame): + title += f" for column '{real.columns[0]}'" + elif isinstance(real, pd.Series) and real.name: + title += f" for column '{real.name}'" + return _generate_1d_plot( data=[real, synth], title=title, @@ -162,6 +178,13 @@ def scatter_2d(data, columns=None, title=None): data = data.copy() data['Data'] = 'Real' + if not title: + title = 'Data' + if columns: + title += f" for columns '{columns[0]}' and '{columns[1]}'" + elif isinstance(data, pd.DataFrame): + title += f" for columns '{data.columns[0]}' and '{data.columns[1]}'" + return _generate_scatter_2d_plot( data=data, columns=columns, @@ -191,6 +214,13 @@ def compare_2d(real, synth, columns=None, title=None): synth['Data'] = 'Synthetic' data = pd.concat([real, synth], axis=0, ignore_index=True) + if not title: + title = 'Real vs. Synthetic Data' + if columns: + title += f" for columns '{columns[0]}' and '{columns[1]}'" + elif isinstance(data, pd.DataFrame): + title += f" for columns '{data.columns[0]}' and '{data.columns[1]}'" + return _generate_scatter_2d_plot( data=data, columns=columns, @@ -256,7 +286,7 @@ def scatter_3d(data, columns=None, title=None): Args: data (pandas.DataFrame): The table data. Must have at least 3 columns. - column_names (list[string]): + columns (list[string]): The names of the three columns to plot. title (str): The title of the plot. @@ -267,6 +297,14 @@ def scatter_3d(data, columns=None, title=None): data = data.copy() data['Data'] = 'Real' + if not title: + title = 'Data' + if columns: + title += f" for columns '{columns[0]}', '{columns[1]}' and '{columns[2]}'" + elif isinstance(data, pd.DataFrame): + title += \ + f" for columns '{data.columns[0]}', '{data.columns[1]}' and '{data.columns[2]}'" + return _generate_scatter_3d_plot( data=data, columns=columns, @@ -293,6 +331,14 @@ def compare_3d(real, synth, columns=None, title=None): synth['Data'] = 'Synthetic' data = pd.concat([real, synth], axis=0, ignore_index=True) + if not title: + title = 'Real vs. Synthetic Data' + if columns: + title += f" for columns '{columns[0]}', '{columns[1]}' and '{columns[2]}'" + elif isinstance(data, pd.DataFrame): + title += \ + f" for columns '{data.columns[0]}', '{data.columns[1]}' and '{data.columns[2]}'" + return _generate_scatter_3d_plot( data=data, columns=columns,