Skip to content

Commit

Permalink
Fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Oct 13, 2023
1 parent 37207f3 commit b1a9998
Showing 1 changed file with 11 additions and 27 deletions.
38 changes: 11 additions & 27 deletions copulas/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,14 @@ def _generate_scatter_plot(all_data, columns):
return fig


def _generate_column_plot(real_column,
synthetic_column,
plot_kwargs={},
plot_title=None,
x_label=None):
def _generate_column_plot(real_column, synthetic_column):
"""Generate a plot of the real and synthetic data.
Args:
real_column (pandas.Series):
The real data for the desired column.
synthetic_column (pandas.Series):
The synthetic data for the desired column.
hist_kwargs (dict, optional):
Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
provided this way will overwrite defaults.
plot_title (str, optional):
Title to use for the plot. Defaults to 'Real vs. Synthetic Data for column {column}'
x_label (str, optional):
Label to use for x-axis. Defaults to 'Category'.
Returns:
plotly.graph_objects._figure.Figure
Expand All @@ -112,7 +101,7 @@ def _generate_column_plot(real_column,

trace_args = {}

fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs)
fig = _generate_column_bar_plot(real_data, synthetic_data)

for i, name in enumerate(['Real', 'Synthetic']):
fig.update_traces(
Expand All @@ -122,15 +111,10 @@ def _generate_column_plot(real_column,
**trace_args
)

if not plot_title:
plot_title = f"Real vs. Synthetic Data for column '{column_name}'"

if not x_label:
x_label = 'Category'

plot_title = f"Real vs. Synthetic Data for column '{column_name}'"
fig.update_layout(
title=plot_title,
xaxis_title=x_label,
xaxis_title='Category',
yaxis_title='Frequency',
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
annotations=[],
Expand Down Expand Up @@ -195,7 +179,7 @@ def compare_1d(real, synth):
if not isinstance(synth, pd.Series):
synth = pd.Series(synth)

return _generate_column_plot(real, synth, plot_type='bar')
return _generate_column_plot(real, synth)


def scatter_2d(data, columns=None):
Expand All @@ -217,7 +201,7 @@ def scatter_2d(data, columns=None):
return _generate_scatter_plot(data, columns)


def compare_2d_(real, synth, columns=None):
def compare_2d(real, synth, columns=None):
"""Return a plot of the real and synthetic data for a given column pair.
Args:
Expand All @@ -231,8 +215,8 @@ def compare_2d_(real, synth, columns=None):
Returns:
plotly.graph_objects._figure.Figure
"""
real_data = real_data[columns]
synthetic_data = synthetic_data[columns]
real_data = real[columns]
synthetic_data = synth[columns]
columns = list(real_data.columns)
real_data['Data'] = 'Real'
synthetic_data['Data'] = 'Synthetic'
Expand All @@ -241,7 +225,7 @@ def compare_2d_(real, synth, columns=None):
return _generate_scatter_plot(all_data, columns)


def scatter_3d_plotly(data, columns=None):
def scatter_3d(data, columns=None):
"""Return a 3D scatter plot of the data.
Args:
Expand Down Expand Up @@ -289,7 +273,7 @@ def compare_3d(real, synth, columns=None):
"""
columns = columns or real.columns

fig = scatter_3d_plotly(real[columns])
fig = scatter_3d_plotly(synth[columns], fig=fig)
fig = scatter_3d(real[columns])
fig = scatter_3d(synth[columns], fig=fig)

return fig

0 comments on commit b1a9998

Please sign in to comment.