Skip to content

Commit

Permalink
plot figure tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
allierc committed Jun 16, 2024
1 parent c96f671 commit b8b82d9
Showing 1 changed file with 15 additions and 32 deletions.
47 changes: 15 additions & 32 deletions GNN_particles_PlotFigure.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,6 @@ def plot_embedding_func_cluster(model, config, config_file, embedding_cluster, c
embedding = to_numpy(model.a[0:n_particles])
else:
embedding = get_embedding(model.a, 1)
csv_ = embedding
np.save(f"./{log_dir}/results/embedding_{config_file}_{epoch}.npy", csv_)
np.savetxt(f"./{log_dir}/results/embedding_{config_file}_{epoch}.txt", csv_)
if n_particle_types > 1000:
plt.scatter(embedding[:, 0], embedding[:, 1], c=to_numpy(x[:, 5]) / n_particles, s=10,
cmap=cc)
Expand Down Expand Up @@ -422,9 +419,6 @@ def plot_embedding_func_cluster(model, config, config_file, embedding_cluster, c
embedding = to_numpy(model.a)
else:
embedding = get_embedding(model.a, 1)
csv_ = embedding
np.save(f"./{log_dir}/results/embedding_{config_file}_{epoch}.npy", csv_)
np.savetxt(f"./{log_dir}/results/embedding_{config_file}_{epoch}.txt", csv_)
if n_particle_types > 1000:
plt.scatter(embedding[:, 0], embedding[:, 1], c=to_numpy(x[:, 5]) / n_particles, s=10,
cmap=cc)
Expand Down Expand Up @@ -851,8 +845,7 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo
x_ = torch.stack(x_list[1])
x_ = torch.reshape(x_, (x_.shape[0] * x_.shape[1], x_.shape[2]))
x_ = x_[0:(n_frames - 1) * n_particles]
x_ = to_numpy(x_[:, 0])
indexes = np.unique(x_)
indexes = np.unique(to_numpy(x_[:, 0]))

plt.xlabel(r'True particle index', fontsize=32)
plt.ylabel(r'Particle index in next frame', fontsize=32)
Expand Down Expand Up @@ -880,16 +873,15 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo
print(f'tracking errors: {np.sum(tracking_index_list)}')
logger.info(f'tracking errors: {np.sum(tracking_index_list)}')

embedding = to_numpy(model.a.clone().detach())
fig, ax = fig_init()
for k in trange(0,n_frames-2):
embedding = to_numpy(model.a[k*n_particles:(k+1)*n_particles,:].clone().detach())
for n in range(n_particle_types):
plt.scatter(embedding[index_particles[n], 0], embedding[index_particles[n], 1], s=1, c=cmap.color(n), alpha=0.025)
for k in indexes:
plt.scatter(embedding[int(k), 0], embedding[int(k), 1], s=1, c=cmap.color(int(to_numpy(x_[int(k),5]))), alpha=0.25)
plt.xlabel(r'$\ensuremath{\mathbf{a}}_{i0}$', fontsize=64)
plt.ylabel(r'$\ensuremath{\mathbf{a}}_{i1}$', fontsize=64)
plt.tight_layout()
plt.xlim([-40, 40])
plt.ylim([-40, 40])
plt.tight_layout()
plt.savefig(f"./{log_dir}/results/all_embedding_{config_file}_{epoch}.tif", dpi=170.7)
plt.close()

Expand All @@ -908,19 +900,19 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo
p = torch.load(f'graphs_data/graphs_{dataset_name}/model_p.pt', map_location=device)
rr = torch.tensor(np.linspace(0, max_radius, 1000)).to(device)
rmserr_list = []
for n in range(int(n_particles * (1 - config.training.particle_dropout))):
embedding_ = model_a_first[n, :] * torch.ones((1000, config.graph_model.embedding_dim), device=device)
for n in indexes:
embedding_ = model_a_first[int(n), :] * torch.ones((1000, config.graph_model.embedding_dim), device=device)
in_features = torch.cat((rr[:, None] / max_radius, 0 * rr[:, None],
rr[:, None] / max_radius, embedding_), dim=1)
with torch.no_grad():
func = model.lin_edge(in_features.float())
func = func[:, 0]
true_func = model.psi(rr, p[to_numpy(type_list[n]).astype(int)].squeeze(),
p[to_numpy(type_list[n]).astype(int)].squeeze())
true_func = model.psi(rr, p[int(to_numpy(x_[int(n),5]))].squeeze(),
p[int(to_numpy(x_[int(n),5]))].squeeze())
rmserr_list.append(torch.sqrt(torch.mean((func - true_func.squeeze()) ** 2)))
plt.plot(to_numpy(rr),
to_numpy(func),
color=cmap.color(to_numpy(type_list[n]).astype(int)), linewidth=8, alpha=0.1)
color=cmap.color(int(to_numpy(x_[int(n),5]))), linewidth=8, alpha=0.1)
plt.xlabel(r'$d_{ij}$', fontsize=64)
plt.ylabel(r'$f(\ensuremath{\mathbf{a}}_i, d_{ij})$', fontsize=64)
plt.xlim([0, max_radius])
Expand Down Expand Up @@ -962,7 +954,6 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo
plt.xlabel(r'$d_{ij}$', fontsize=64)
plt.ylabel(r'$f(\ensuremath{\mathbf{a}}_i, d_{ij})$', fontsize=64)
plt.tight_layout()
torch.save(plots, f"./{log_dir}/results/plots_true_{config_file}_{epoch}.pt")
plt.close()

for k in indexes:
Expand Down Expand Up @@ -3996,10 +3987,10 @@ def data_video_training(config_file, epoch_list, log_dir, logger, device):


def data_plot(config_file, epoch_list, device):
plt.rcParams['text.usetex'] = True
rc('font', **{'family': 'serif', 'serif': ['Palatino']})
# plt.rcParams['text.usetex'] = True
# rc('font', **{'family': 'serif', 'serif': ['Palatino']})
matplotlib.rcParams['savefig.pad_inches'] = 0
matplotlib.use("Qt5Agg")
# matplotlib.use("Qt5Agg")

l_dir = os.path.join('.', 'log')
log_dir = os.path.join(l_dir, 'try_{}'.format(config_file))
Expand Down Expand Up @@ -4088,21 +4079,13 @@ def data_plot(config_file, epoch_list, device):
# config_list = ['boids_16_256','boids_32_256','boids_64_256']
config_list = ['arbitrary_3_tracking']

# epoch_list = ['0_500','0_1000','0_2000','0_5000','0_10000','0_20000','0_49000','0','0_corrected','1_500','1_1000','1_2000','1_5000','1_10000','1_20000','1_49000','1','1_corrected','2_500','2_1000','2_2000','2_5000','2_10000','2_20000','2_49000','2','2_corrected']
# epoch_list =

epoch_list = ['2']
epoch_list = ['0_500','0_1000','0_2000','0_5000','0_10000','0_20000','0_49000','0','0_corrected','1_500','1_1000','1_2000','1_5000','1_10000','1_20000','1_49000','1','1_corrected','2_500','2_1000','2_2000','2_5000','2_10000','2_20000','2_49000','2','2_corrected','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18','19','20']

for config_file in config_list:
config = ParticleGraphConfig.from_yaml(f'./config/{config_file}.yaml')
data_plot(config_file, epoch_list, device)

# data_plot_attraction_repulsion_short(config_file, device=device)
# data_plot_boids(config_file)
# data_plot_gravity(config_file)
# data_plot_RD(config_file,cc='viridis')
# data_plot_particle_field(config_file, mode='figures', cc='grey', device=device)
# data_plot_wave(config_file,cc='viridis')
# data_plot_signal(config_file,cc='viridis')

# data_video_validation(config_file,device=device)
# data_video_training(config_file,device=device)

0 comments on commit b8b82d9

Please sign in to comment.