Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions exp1_decision_boundary/generate_data_rogue_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,16 @@ def plot_data(X, y, rogue_point_idx=None):
framealpha=0.95,
facecolor='white',
edgecolor='lightgray',
loc='best',
fontsize=10
loc='upper center',
bbox_to_anchor=(0.5, -0.10),
ncol=4,
fontsize=14
)

# Add labels and title
plt.xlabel('Feature 1', fontsize=12)
plt.ylabel('Feature 2', fontsize=12)
plt.title('Cluster Distribution', fontsize=14, fontweight='bold')

plt.xlabel('Feature 1', fontsize=16)
plt.ylabel('Feature 2', fontsize=16)

# Improve ticks
plt.tick_params(direction='out', length=6, width=1)

Expand Down
13 changes: 3 additions & 10 deletions exp1_decision_boundary/generate_data_rogue_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,12 @@ def plot_data(X, y, rogue_point_idx=None):
facecolor='white',
edgecolor='lightgray',
loc='best',
fontsize=10
fontsize=14
)

# Add labels and title
plt.xlabel('Feature 1', fontsize=12)
plt.ylabel('Feature 2', fontsize=12)
plt.title('Cluster Distribution', fontsize=14, fontweight='bold')
plt.xlabel('Feature 1', fontsize=14)
plt.ylabel('Feature 2', fontsize=14)

# Improve ticks
plt.tick_params(direction='out', length=6, width=1)
Expand Down Expand Up @@ -156,7 +155,6 @@ def save_fig(plot, filename, folder):
# 0. No rogue point
X, y, rogue_point_idx = generate_data(centroids, stds, sizes)
plot = plot_data(X,y, rogue_point_idx)
plot.title('Simple synthetic dataset (no rogue points)', fontsize=14, fontweight='bold')
save_data(X, y, rogue_point_idx, os.path.join(data_folder, 'data_0.npz'))
save_fig(plot, 'data_0', data_plots_folder)
plt.close()
Expand All @@ -166,7 +164,6 @@ def save_fig(plot, filename, folder):
rogue_point = (np.mean(centroids, axis=0), np.array([1]))
X, y, rogue_point_idx = generate_data(centroids, stds, sizes, rogue_point=rogue_point)
plot = plot_data(X,y, rogue_point_idx)
plot.title('Simple synthetic dataset (rogue point with same distance to all centroids)', fontsize=14, fontweight='bold')
save_data(X, y, rogue_point_idx, os.path.join(data_folder, 'data_1.npz'))
save_fig(plot, 'data_1', data_plots_folder)
# clear plot
Expand All @@ -177,7 +174,6 @@ def save_fig(plot, filename, folder):
rogue_point = (centroids[1], np.array([1]))
X, y, rogue_point_idx = generate_data(centroids, stds, sizes, rogue_point=rogue_point)
plot = plot_data(X,y, rogue_point_idx)
plot.title('Simple synthetic dataset (rogue point with same centroid as the class with a same label)', fontsize=14, fontweight='bold')
save_data(X, y, rogue_point_idx, os.path.join(data_folder, 'data_2.npz'))
save_fig(plot, 'data_2', data_plots_folder)
plt.close()
Expand All @@ -188,7 +184,6 @@ def save_fig(plot, filename, folder):
rogue_point = (centroids[2], np.array([1]))
X, y, rogue_point_idx = generate_data(centroids, stds, sizes, rogue_point=rogue_point)
plot = plot_data(X,y, rogue_point_idx)
plot.title('Simple synthetic dataset (rogue point with same centroid as the class with a different label)', fontsize=14, fontweight='bold')
save_data(X, y, rogue_point_idx, os.path.join(data_folder, 'data_3.npz'))
save_fig(plot, 'data_3', data_plots_folder)
plt.close()
Expand All @@ -198,7 +193,6 @@ def save_fig(plot, filename, folder):
rogue_point = (np.array([-8, 8]), np.array([1]))
X, y, rogue_point_idx = generate_data(centroids, stds, sizes, rogue_point=rogue_point)
plot = plot_data(X,y, rogue_point_idx)
plot.title('Simple synthetic dataset (rogue point far away from its centroid, but probably in the same decision boundary)', fontsize=14, fontweight='bold')
save_data(X, y, rogue_point_idx, os.path.join(data_folder, 'data_4.npz'))
save_fig(plot, 'data_4', data_plots_folder)
plt.close()
Expand All @@ -209,7 +203,6 @@ def save_fig(plot, filename, folder):
rogue_point = (np.array([1, -8]), np.array([1]))
X, y, rogue_point_idx = generate_data(centroids, stds, sizes, rogue_point=rogue_point)
plot = plot_data(X,y, rogue_point_idx)
plot.title('Simple synthetic dataset (rogue point far away from its centroid, but probably in the same decision boundary)', fontsize=14, fontweight='bold')
save_data(X, y, rogue_point_idx, os.path.join(data_folder, 'data_5.npz'))
save_fig(plot, 'data_5', data_plots_folder)
plt.close()
Expand Down
Loading