Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Oct 31, 2023
1 parent c614376 commit ddca46c
Showing 1 changed file with 47 additions and 1 deletion.
48 changes: 47 additions & 1 deletion copulas/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def _generate_1d_plot(data, title, labels, colors):
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
font={'size': PlotConfig.FONT_SIZE},
showlegend=True if labels[0] else False,
xaxis_title='value',
yaxis_title='frequency'
)

return fig
Expand All @@ -70,6 +72,13 @@ def dist_1d(data, title=None, label=None):
Returns:
plotly.graph_objects._figure.Figure
"""
if not title:
title = 'Data'
if isinstance(data, pd.DataFrame):
title += f" for column '{data.columns[0]}'"
elif isinstance(data, pd.Series) and data.name:
title += f" for column '{data.name}'"

return _generate_1d_plot(
data=[data],
title=title,
Expand All @@ -92,6 +101,13 @@ def compare_1d(real, synth, title=None):
Returns:
plotly.graph_objects._figure.Figure
"""
if not title:
title = 'Real vs. Synthetic Data'
if isinstance(real, pd.DataFrame):
title += f" for column '{real.columns[0]}'"
elif isinstance(real, pd.Series) and real.name:
title += f" for column '{real.name}'"

return _generate_1d_plot(
data=[real, synth],
title=title,
Expand Down Expand Up @@ -162,6 +178,13 @@ def scatter_2d(data, columns=None, title=None):
data = data.copy()
data['Data'] = 'Real'

if not title:
title = 'Data'
if columns:
title += f" for columns '{columns[0]}' and '{columns[1]}'"
elif isinstance(data, pd.DataFrame):
title += f" for columns '{data.columns[0]}' and '{data.columns[1]}'"

return _generate_scatter_2d_plot(
data=data,
columns=columns,
Expand Down Expand Up @@ -191,6 +214,13 @@ def compare_2d(real, synth, columns=None, title=None):
synth['Data'] = 'Synthetic'
data = pd.concat([real, synth], axis=0, ignore_index=True)

if not title:
title = 'Real vs. Synthetic Data'
if columns:
title += f" for columns '{columns[0]}' and '{columns[1]}'"
elif isinstance(data, pd.DataFrame):
title += f" for columns '{data.columns[0]}' and '{data.columns[1]}'"

return _generate_scatter_2d_plot(
data=data,
columns=columns,
Expand Down Expand Up @@ -256,7 +286,7 @@ def scatter_3d(data, columns=None, title=None):
Args:
data (pandas.DataFrame):
The table data. Must have at least 3 columns.
column_names (list[string]):
columns (list[string]):
The names of the three columns to plot.
title (str):
The title of the plot.
Expand All @@ -267,6 +297,14 @@ def scatter_3d(data, columns=None, title=None):
data = data.copy()
data['Data'] = 'Real'

if not title:
title = 'Data'
if columns:
title += f" for columns '{columns[0]}', '{columns[1]}' and '{columns[2]}'"
elif isinstance(data, pd.DataFrame):
title += \
f" for columns '{data.columns[0]}', '{data.columns[1]}' and '{data.columns[2]}'"

return _generate_scatter_3d_plot(
data=data,
columns=columns,
Expand All @@ -293,6 +331,14 @@ def compare_3d(real, synth, columns=None, title=None):
synth['Data'] = 'Synthetic'
data = pd.concat([real, synth], axis=0, ignore_index=True)

if not title:
title = 'Real vs. Synthetic Data'
if columns:
title += f" for columns '{columns[0]}', '{columns[1]}' and '{columns[2]}'"
elif isinstance(data, pd.DataFrame):
title += \
f" for columns '{data.columns[0]}', '{data.columns[1]}' and '{data.columns[2]}'"

return _generate_scatter_3d_plot(
data=data,
columns=columns,
Expand Down

0 comments on commit ddca46c

Please sign in to comment.