Skip to content

Commit

Permalink
update visualization methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ASSANDHOLE committed Nov 1, 2022
1 parent aa9a756 commit 1c252b0
Showing 2 changed files with 68 additions and 5 deletions.
11 changes: 8 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -151,19 +151,24 @@ def main_NSGA():
n_evals_moea = np.insert(n_evals_moea, 0, 0)
igd_moea = np.insert(igd_moea, 0, igd[0])

visualize_pf(pf=pf, label='Sorrogate PF', color='green', scale=[0.5]*3, pf_true=pf_true)
visualize_pf(pf=moea_pf, label='NSGA-II PF', color='blue', scale=[0.5]*3, pf_true=pf_true)
visualize_pf(pf=pf, label='Sorrogate PF', color='green', scale=[0.5]*3, pf_true=pf_true,
show=SHOW_PLOT, save_path=f'{PREFIX}sor_pf.png', save_alternative=True)
visualize_pf(pf=moea_pf, label='NSGA-II PF', color='blue', scale=[0.5]*3, pf_true=pf_true,
show=SHOW_PLOT, save_path=f'{PREFIX}moea_pf.png', save_alternative=True)

func_evals = [max_pts_num*np.arange(len(igd)), n_evals_moea]
igds = [igd, igd_moea]
colors = ['black', 'blue']
labels = ["Our Surrogate Model", "NSGA-II"]
visualize_igd(func_evals, igds, colors, labels)
visualize_igd(func_evals, igds, colors, labels,
show=SHOW_PLOT, save_path=f'{PREFIX}igd.png', save_alternative=True)
plt.show()


if __name__ == '__main__':
# main()
# main_sinewave()
PREFIX = 'imgs/'
SHOW_PLOT = False
main_NSGA()

62 changes: 60 additions & 2 deletions visualization/visualization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from typing import List, Literal, Any

import numpy as np
@@ -54,7 +55,54 @@ def plot_loss(loss: List[Any] | np.ndarray, label: str) -> int:
plt.show()


def visualize_pf(pf, label, color, scale=None, pf_true=None):
def file_exist_or_alternate(file_path: str, save_alternative=True) -> str | None:
"""
Check if the file exists, if it does and `save_alternative`, return the next available file name.
Parameters
----------
file_path : str
The file path
save_alternative : bool
Whether to return the next available file name if the file exists
Returns
-------
str | None
The file path if the file does not exist;
the next available file name if the file exists and `save_alternative` is True.
None if the file exists and `save_alternative` is False.
"""
if os.path.exists(file_path):
if save_alternative:
file_dir, file_name = os.path.split(file_path)
file_name, suffix = os.path.splitext(file_name)
i = 1
while os.path.exists(os.path.join(file_dir, f'{file_name}_{i}{suffix}')):
i += 1
return os.path.join(file_dir, f'{file_name}_{i}{suffix}')
else:
return None
else:
return file_path


def show_and_save(show: bool, save_path: str | None, save_alternative: bool) -> str | None:
if show:
plt.show()
if save_path is not None:
save_path = file_exist_or_alternate(save_path, save_alternative)
if save_path is not None:
plt.savefig(save_path)

return save_path


def visualize_pf(pf, label, color, scale=None, pf_true=None, show=True, save_path=None, save_alternative=True):
# if `save_path` is None, the figure will not be saved.

# if `save_alternative_name` is True, the figure will be saved with the next available name
# in the format of `save_path_dir/save_path_name_{i}.suffix` where `i` is the next available number.
plt.figure(figsize=(8, 6))
ax = plt.axes(projection='3d')
ax.scatter3D(pf[:, 0], pf[:, 1], pf[:, 2], color=color, label=label)
@@ -66,9 +114,14 @@ def visualize_pf(pf, label, color, scale=None, pf_true=None):
ax.set_zlim(0, scale[2])
ax.legend(loc='best')
ax.set(xlabel="F_1", ylabel="F_2", zlabel="F_3")
res = show_and_save(show, save_path, save_alternative)
if res is not None:
print(f'Figure saved to {res}')
if save_path is not None and res is None:
print(f'Figure already exists at {save_path}')


def visualize_igd(func_evals, igds, colors, labels):
def visualize_igd(func_evals, igds, colors, labels, show=True, save_path=None, save_alternative=True):
plt.figure(figsize=(8, 6))
for i in range(len(igds)):
plt.plot(func_evals[i], igds[i], color=colors[i], lw=0.7, label=labels[i])
@@ -79,3 +132,8 @@ def visualize_igd(func_evals, igds, colors, labels):
plt.ylabel("IGD")
plt.yscale("log")
plt.legend()
res = show_and_save(show, save_path, save_alternative)
if res is not None:
print(f'Figure saved to {res}')
if save_path is not None and res is None:
print(f'Figure already exists at {save_path}')

0 comments on commit 1c252b0

Please sign in to comment.