Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add custom legends to plots #43

Merged
merged 7 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@ repos:
- id: codespell
name: codespell
description: Checks for common misspellings in text files.
entry: codespell --ignore-words=.codespell-ignore.txt
entry: codespell --ignore-words=.codespell-ignore.txt --skip=*.ipynb
language: python
types: [text]
33 changes: 23 additions & 10 deletions examples/quickstart.ipynb
OmaymaMahjoub marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions examples/simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
# Read in and process data
##############################
METRICS_TO_NORMALIZE = ["return"]
LEGEND_MAP = {
"algo_1": "Algorithm 1",
"algo_2": "Algorithm 2",
"algo_3": "Algorithm 3",
"algo_4": "Algorithm 4",
"algo_5": "Algorithm 5",
}

with open("examples/example_results.json") as f:
raw_data = json.load(f)
Expand Down Expand Up @@ -63,6 +70,7 @@
task_name=task,
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)

fig.figure.savefig(
Expand All @@ -75,6 +83,7 @@
environment_comparison_matrix,
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_performance_profile.png", bbox_inches="tight"
Expand All @@ -85,6 +94,7 @@
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_aggregate_scores.png", bbox_inches="tight"
Expand All @@ -99,6 +109,7 @@
["algo_1", "algo_3"],
["algo_2", "algo_4"],
],
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_prob_of_improvement.png", bbox_inches="tight"
Expand All @@ -108,6 +119,7 @@
sample_effeciency_matrix,
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_sample_effeciency_curve.png", bbox_inches="tight"
Expand All @@ -126,6 +138,7 @@
task_name=task,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)

fig.figure.savefig(f"examples/plots/env_1_{task}_agg_return.png", bbox_inches="tight")
Expand All @@ -136,6 +149,7 @@
environment_comparison_matrix,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig("examples/plots/return_performance_profile.png", bbox_inches="tight")

Expand All @@ -144,6 +158,7 @@
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=LEGEND_MAP,
)
fig.figure.savefig("examples/plots/return_aggregate_scores.png", bbox_inches="tight")

Expand All @@ -156,13 +171,15 @@
["algo_1", "algo_3"],
["algo_2", "algo_4"],
],
legend_map=LEGEND_MAP,
)
fig.figure.savefig("examples/plots/return_prob_of_improvement.png", bbox_inches="tight")

fig, _, _ = sample_efficiency_curves( # type: ignore
sample_effeciency_matrix,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/return_sample_effeciency_curve.png", bbox_inches="tight"
Expand Down
25 changes: 22 additions & 3 deletions marl_eval/plotting_tools/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def plot_single_task_curve(
ax: Optional[Axes] = None,
labelsize: str = "xx-large",
ticklabelsize: str = "xx-large",
legend_map: Optional[Dict] = None,
run_times: Optional[Dict] = None,
**kwargs: Any,
) -> Axes:
"""Plots an aggregate metric with CIs as a function of environment frames.
Expand All @@ -53,6 +55,10 @@ def plot_single_task_curve(
ax: `matplotlib.axes` object.
labelsize: Font size of the x-axis label.
ticklabelsize: Font size of the ticks.
legend_map: Dictionary that maps each algorithm to a label in the legend.
If None, then this mapping is created based on `algorithms`.
run_times: Dictionary that maps each algorithm to the number of seconds it
took to run. If None, then environment steps will be displayed.
**kwargs: Arbitrary keyword arguments.

Returns:
Expand All @@ -68,24 +74,37 @@ def plot_single_task_curve(
color_palette = sns.color_palette(color_palette, n_colors=len(algorithms))
colors = dict(zip(algorithms, color_palette))

marker = kwargs.pop("marker", "o")
linewidth = kwargs.pop("linewidth", 2)

for algorithm in algorithms:
x_axis_len = len(aggregated_data[algorithm]["mean"])

# Set x-axis values to match evaluation interval steps.
x_axis_values = np.arange(x_axis_len) * extra_info["evaluation_interval"]

if run_times is not None:
x_axis_values = np.linspace(0, run_times[algorithm] / 60, x_axis_len)

metric_values = np.array(aggregated_data[algorithm]["mean"])
confidence_interval = np.array(aggregated_data[algorithm]["ci"])
lower, upper = (
metric_values - confidence_interval,
metric_values + confidence_interval,
)

if legend_map is not None:
algorithm_name = legend_map[algorithm]
else:
algorithm_name = algorithm

ax.plot(
x_axis_values,
metric_values,
color=colors[algorithm],
marker=kwargs.pop("marker", "o"),
linewidth=kwargs.pop("linewidth", 2),
label=algorithm,
marker=marker,
linewidth=linewidth,
label=algorithm_name,
)
ax.fill_between(
x_axis_values, y1=lower, y2=upper, color=colors[algorithm], alpha=0.2
Expand Down
61 changes: 60 additions & 1 deletion marl_eval/plotting_tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def performance_profiles(
dictionary: Dict[str, Dict[str, Any]],
metric_name: str,
metrics_to_normalize: List[str],
legend_map: Optional[Dict[str, str]] = None,
) -> Figure:
"""Produces performance profile plots.

Expand All @@ -45,6 +46,7 @@ def performance_profiles(
for metric algorithm pairs.
metric_name: Name of metric to produce plots for.
metrics_to_normalize: List of metrics that are normalised.
legend_map: Dictionary that maps each algorithm to a custom legend label.

Returns:
fig: Matplotlib figure for storing.
Expand All @@ -64,6 +66,14 @@ def performance_profiles(
data_dictionary = upper_algo_dict
algorithms = list(data_dictionary.keys())

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
# Replace keys in data dict with corresponding key in legend map
data_dictionary = {
legend_map[algo]: value for algo, value in data_dictionary.items()
}
algorithms = list(data_dictionary.keys())

if metric_name in metrics_to_normalize:
xlabel = "Normalized " + " ".join(metric_name.split("_"))

Expand Down Expand Up @@ -95,6 +105,7 @@ def aggregate_scores(
rounding_decimals: Optional[int] = 2,
tabular_results_file_path: str = "./aggregated_score",
save_tabular_as_latex: Optional[bool] = False,
legend_map: Optional[Dict[str, str]] = None,
) -> Tuple[Figure, Dict[str, Dict[str, int]], Dict[str, Dict[str, float]]]:
"""Produces aggregated score plots.

Expand All @@ -106,6 +117,7 @@ def aggregate_scores(
rounding_decimals:number up to which the results values are rounded.
tabular_results_file_path: location to store the tabular results.
save_tabular_as_latex: store tabular results in latex format in a .txt file.
legend_map: Dictionary that maps each algorithm to a custom legend label.

Returns:
fig: Matplotlib figure for storing.
Expand All @@ -127,9 +139,16 @@ def aggregate_scores(
# Upper case all algorithm names
upper_algo_dict = {algo.upper(): value for algo, value in data_dictionary.items()}
data_dictionary = upper_algo_dict

algorithms = list(data_dictionary.keys())

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
# Replace keys in data dict with corresponding key in legend map
data_dictionary = {
legend_map[algo]: value for algo, value in data_dictionary.items()
}
algorithms = list(data_dictionary.keys())

aggregate_func = lambda x: np.array( # noqa: E731
[
metrics.aggregate_median(x),
Expand Down Expand Up @@ -226,14 +245,17 @@ def probability_of_improvement(
metric_name: str,
metrics_to_normalize: List[str],
algorithms_to_compare: List[List],
legend_map: Optional[Dict[str, str]] = None,
) -> Figure:
"""Produces probability of improvement plots.

Args:
dictionary: Dictionary containing 2D arrays of normalised absolute metric scores
for metric algorithm pairs.
metric_name: Name of metric to produce plots for.
metrics_to_normalize: List of metrics that are normalised.
algorithms_to_compare: 2D list containing pairs of algorithms to be compared.
legend_map: Dictionary that maps each algorithm to a custom legend label.

Returns:
fig: Matplotlib figure for storing.
Expand All @@ -257,6 +279,17 @@ def probability_of_improvement(
[pair[0].upper(), pair[1].upper()] for pair in algorithms_to_compare
]

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
# Replace keys in data dict with corresponding key in legend map
data_dictionary = {
legend_map[algo]: value for algo, value in data_dictionary.items()
}
# Make sure that the algorithms to compare are also in the legend map
algorithms_to_compare = [
[legend_map[pair[0]], legend_map[pair[1]]] for pair in algorithms_to_compare
]

algorithm_pairs = {}
for pair in algorithms_to_compare:
algorithm_pairs[",".join(pair)] = (
Expand All @@ -276,6 +309,7 @@ def sample_efficiency_curves(
dictionary: Dict[str, Dict[str, Any]],
metric_name: str,
metrics_to_normalize: List[str],
legend_map: Optional[Dict[str, str]] = None,
xlabel: str = "Timesteps",
) -> Tuple[Figure, Dict[str, np.ndarray], Dict[str, np.ndarray]]:
"""Produces sample efficiency curve plots.
Expand All @@ -285,6 +319,7 @@ def sample_efficiency_curves(
metric scores for metric algorithm pairs.
metric_name: Name of metric to produce plots for.
metrics_to_normalize: List of metrics that are normalised.
legend_map: Dictionary that maps each algorithm to a custom legend label.
xlabel: Label for x-axis.

Returns:
Expand Down Expand Up @@ -312,6 +347,14 @@ def sample_efficiency_curves(
data_dictionary = upper_algo_dict
algorithms = list(data_dictionary.keys())

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
# Replace keys in data dict with corresponding key in legend map
data_dictionary = {
OmaymaMahjoub marked this conversation as resolved.
Show resolved Hide resolved
legend_map[algo]: value for algo, value in data_dictionary.items()
}
algorithms = list(data_dictionary.keys())

# Find lowest values from amount of runs that have completed
# across all algorithms
run_lengths = [data_dictionary[algo].shape[2] for algo in data_dictionary]
Expand Down Expand Up @@ -356,6 +399,8 @@ def plot_single_task(
metric_name: str,
metrics_to_normalize: List[str],
xlabel: str = "Timesteps",
legend_map: Optional[Dict[str, str]] = None,
run_times: Optional[Dict[str, float]] = None,
) -> Figure:
"""Produces aggregated plot for a single task in an environment.

Expand All @@ -366,6 +411,10 @@ def plot_single_task(
metric_name: Name of metric to produce plots for.
metrics_to_normalize: List of metrics that are normalised.
xlabel: Label for x-axis.
legend_map: Dictionary that maps each algorithm to a label in the legend.
If None, then this mapping is created based on `algorithms`.
run_times: Dictionary that maps each algorithm to the number of seconds it
took to run. If None, then environment steps will be displayed.
"""
metric_name, task_name, environment_name, metrics_to_normalize = lower_case_inputs(
metric_name, task_name, environment_name, metrics_to_normalize
Expand Down Expand Up @@ -393,6 +442,13 @@ def plot_single_task(
algorithms = list(task_mean_ci_data.keys())
algorithms.remove("extra")

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}

if run_times is not None:
run_times = {algo.upper(): value for algo, value in run_times.items()}
xlabel = "Time (Minutes)"

fig = plot_single_task_curve(
task_mean_ci_data,
algorithms=algorithms,
Expand All @@ -401,6 +457,9 @@ def plot_single_task(
legend=algorithms,
figsize=(15, 8),
color_palette=cc.glasbey_category10,
legend_map=legend_map,
run_times=run_times,
marker="",
OmaymaMahjoub marked this conversation as resolved.
Show resolved Hide resolved
)

return fig
4 changes: 3 additions & 1 deletion marl_eval/utils/data_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ def _compare_values(
]
normed_metric_array = (
metric_array - metric_global_min
) / (metric_global_max - metric_global_min)
) / (
metric_global_max - metric_global_min + 1e-6
)
processed_data[env][task][algorithm][run][step][
f"norm_{metric}"
] = normed_metric_array.tolist()
Expand Down
Loading
Loading