Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add custom legends to plots #43

Merged
merged 7 commits into from
Jan 30, 2024
Prev Previous commit
Next Next commit
feat: add support for custom legend maps to all plotting scripts
  • Loading branch information
RuanJohn committed Jan 17, 2024
commit 64690dcd5d5ee6d72bb7cc5cf6367bae67f786c6
17 changes: 17 additions & 0 deletions examples/simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
# Read in and process data
##############################
METRICS_TO_NORMALIZE = ["return"]
LEGEND_MAP = {
"algo_1": "Algorithm 1",
"algo_2": "Algorithm 2",
"algo_3": "Algorithm 3",
"algo_4": "Algorithm 4",
"algo_5": "Algorithm 5",
}

with open("examples/example_results.json") as f:
raw_data = json.load(f)
Expand Down Expand Up @@ -63,6 +70,7 @@
task_name=task,
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)

fig.figure.savefig(
Expand All @@ -75,6 +83,7 @@
environment_comparison_matrix,
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_performance_profile.png", bbox_inches="tight"
Expand All @@ -85,6 +94,7 @@
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_aggregate_scores.png", bbox_inches="tight"
Expand All @@ -99,6 +109,7 @@
["algo_1", "algo_3"],
["algo_2", "algo_4"],
],
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_prob_of_improvement.png", bbox_inches="tight"
Expand All @@ -108,6 +119,7 @@
sample_effeciency_matrix,
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_sample_effeciency_curve.png", bbox_inches="tight"
Expand All @@ -126,6 +138,7 @@
task_name=task,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)

fig.figure.savefig(f"examples/plots/env_1_{task}_agg_return.png", bbox_inches="tight")
Expand All @@ -136,6 +149,7 @@
environment_comparison_matrix,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig("examples/plots/return_performance_profile.png", bbox_inches="tight")

Expand All @@ -144,6 +158,7 @@
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=LEGEND_MAP,
)
fig.figure.savefig("examples/plots/return_aggregate_scores.png", bbox_inches="tight")

Expand All @@ -156,13 +171,15 @@
["algo_1", "algo_3"],
["algo_2", "algo_4"],
],
legend_map=LEGEND_MAP,
)
fig.figure.savefig("examples/plots/return_prob_of_improvement.png", bbox_inches="tight")

fig, _, _ = sample_efficiency_curves( # type: ignore
sample_effeciency_matrix,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/return_sample_effeciency_curve.png", bbox_inches="tight"
Expand Down
45 changes: 44 additions & 1 deletion marl_eval/plotting_tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def performance_profiles(
dictionary: Dict[str, Dict[str, Any]],
metric_name: str,
metrics_to_normalize: List[str],
legend_map: Optional[Dict[str, str]] = None,
) -> Figure:
"""Produces performance profile plots.

Expand All @@ -45,6 +46,7 @@ def performance_profiles(
for metric algorithm pairs.
metric_name: Name of metric to produce plots for.
metrics_to_normalize: List of metrics that are normalised.
legend_map: Dictionary that maps each algorithm to a custom legend label.

Returns:
fig: Matplotlib figure for storing.
Expand All @@ -64,6 +66,14 @@ def performance_profiles(
data_dictionary = upper_algo_dict
algorithms = list(data_dictionary.keys())

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
# Replace keys in data dict with corresponding key in legend map
data_dictionary = {
legend_map[algo]: value for algo, value in data_dictionary.items()
}
algorithms = list(data_dictionary.keys())

if metric_name in metrics_to_normalize:
xlabel = "Normalized " + " ".join(metric_name.split("_"))

Expand Down Expand Up @@ -95,6 +105,7 @@ def aggregate_scores(
rounding_decimals: Optional[int] = 2,
tabular_results_file_path: str = "./aggregated_score",
save_tabular_as_latex: Optional[bool] = False,
legend_map: Optional[Dict[str, str]] = None,
) -> Tuple[Figure, Dict[str, Dict[str, int]], Dict[str, Dict[str, float]]]:
"""Produces aggregated score plots.

Expand All @@ -106,6 +117,7 @@ def aggregate_scores(
rounding_decimals:number up to which the results values are rounded.
tabular_results_file_path: location to store the tabular results.
save_tabular_as_latex: store tabular results in latex format in a .txt file.
legend_map: Dictionary that maps each algorithm to a custom legend label.

Returns:
fig: Matplotlib figure for storing.
Expand All @@ -127,9 +139,16 @@ def aggregate_scores(
# Upper case all algorithm names
upper_algo_dict = {algo.upper(): value for algo, value in data_dictionary.items()}
data_dictionary = upper_algo_dict

algorithms = list(data_dictionary.keys())

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
# Replace keys in data dict with corresponding key in legend map
data_dictionary = {
legend_map[algo]: value for algo, value in data_dictionary.items()
}
algorithms = list(data_dictionary.keys())

aggregate_func = lambda x: np.array( # noqa: E731
[
metrics.aggregate_median(x),
Expand Down Expand Up @@ -226,14 +245,17 @@ def probability_of_improvement(
metric_name: str,
metrics_to_normalize: List[str],
algorithms_to_compare: List[List],
legend_map: Optional[Dict[str, str]] = None,
) -> Figure:
"""Produces probability of improvement plots.

Args:
dictionary: Dictionary containing 2D arrays of normalised absolute metric scores
for metric algorithm pairs.
metric_name: Name of metric to produce plots for.
metrics_to_normalize: List of metrics that are normalised.
algorithms_to_compare: 2D list containing pairs of algorithms to be compared.
legend_map: Dictionary that maps each algorithm to a custom legend label.

Returns:
fig: Matplotlib figure for storing.
Expand All @@ -257,6 +279,17 @@ def probability_of_improvement(
[pair[0].upper(), pair[1].upper()] for pair in algorithms_to_compare
]

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
# Replace keys in data dict with corresponding key in legend map
data_dictionary = {
legend_map[algo]: value for algo, value in data_dictionary.items()
}
# Make sure that the algorithms to compare are also in the legend map
algorithms_to_compare = [
[legend_map[pair[0]], legend_map[pair[1]]] for pair in algorithms_to_compare
]

algorithm_pairs = {}
for pair in algorithms_to_compare:
algorithm_pairs[",".join(pair)] = (
Expand All @@ -276,6 +309,7 @@ def sample_efficiency_curves(
dictionary: Dict[str, Dict[str, Any]],
metric_name: str,
metrics_to_normalize: List[str],
legend_map: Optional[Dict[str, str]] = None,
xlabel: str = "Timesteps",
) -> Tuple[Figure, Dict[str, np.ndarray], Dict[str, np.ndarray]]:
"""Produces sample efficiency curve plots.
Expand All @@ -285,6 +319,7 @@ def sample_efficiency_curves(
metric scores for metric algorithm pairs.
metric_name: Name of metric to produce plots for.
metrics_to_normalize: List of metrics that are normalised.
legend_map: Dictionary that maps each algorithm to a custom legend label.
xlabel: Label for x-axis.

Returns:
Expand Down Expand Up @@ -312,6 +347,14 @@ def sample_efficiency_curves(
data_dictionary = upper_algo_dict
algorithms = list(data_dictionary.keys())

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
# Replace keys in data dict with corresponding key in legend map
data_dictionary = {
OmaymaMahjoub marked this conversation as resolved.
Show resolved Hide resolved
legend_map[algo]: value for algo, value in data_dictionary.items()
}
algorithms = list(data_dictionary.keys())

# Find lowest values from amount of runs that have completed
# across all algorithms
run_lengths = [data_dictionary[algo].shape[2] for algo in data_dictionary]
Expand Down
4 changes: 3 additions & 1 deletion marl_eval/utils/data_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ def _compare_values(
]
normed_metric_array = (
metric_array - metric_global_min
) / (metric_global_max - metric_global_min)
) / (
metric_global_max - metric_global_min + 1e-6
)
processed_data[env][task][algorithm][run][step][
f"norm_{metric}"
] = normed_metric_array.tolist()
Expand Down
Loading