diff --git a/GNN_particles_PlotFigure.py b/GNN_particles_PlotFigure.py index 788d1973..28d52aee 100644 --- a/GNN_particles_PlotFigure.py +++ b/GNN_particles_PlotFigure.py @@ -351,9 +351,6 @@ def plot_embedding_func_cluster(model, config, config_file, embedding_cluster, c embedding = to_numpy(model.a[0:n_particles]) else: embedding = get_embedding(model.a, 1) - csv_ = embedding - np.save(f"./{log_dir}/results/embedding_{config_file}_{epoch}.npy", csv_) - np.savetxt(f"./{log_dir}/results/embedding_{config_file}_{epoch}.txt", csv_) if n_particle_types > 1000: plt.scatter(embedding[:, 0], embedding[:, 1], c=to_numpy(x[:, 5]) / n_particles, s=10, cmap=cc) @@ -422,9 +419,6 @@ def plot_embedding_func_cluster(model, config, config_file, embedding_cluster, c embedding = to_numpy(model.a) else: embedding = get_embedding(model.a, 1) - csv_ = embedding - np.save(f"./{log_dir}/results/embedding_{config_file}_{epoch}.npy", csv_) - np.savetxt(f"./{log_dir}/results/embedding_{config_file}_{epoch}.txt", csv_) if n_particle_types > 1000: plt.scatter(embedding[:, 0], embedding[:, 1], c=to_numpy(x[:, 5]) / n_particles, s=10, cmap=cc) @@ -851,8 +845,7 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo x_ = torch.stack(x_list[1]) x_ = torch.reshape(x_, (x_.shape[0] * x_.shape[1], x_.shape[2])) x_ = x_[0:(n_frames - 1) * n_particles] - x_ = to_numpy(x_[:, 0]) - indexes = np.unique(x_) + indexes = np.unique(to_numpy(x_[:, 0])) plt.xlabel(r'True particle index', fontsize=32) plt.ylabel(r'Particle index in next frame', fontsize=32) @@ -880,16 +873,15 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo print(f'tracking errors: {np.sum(tracking_index_list)}') logger.info(f'tracking errors: {np.sum(tracking_index_list)}') + embedding = to_numpy(model.a.clone().detach()) fig, ax = fig_init() - for k in trange(0,n_frames-2): - embedding = to_numpy(model.a[k*n_particles:(k+1)*n_particles,:].clone().detach()) - for n in range(n_particle_types): - plt.scatter(embedding[index_particles[n], 0], embedding[index_particles[n], 1], s=1, c=cmap.color(n), alpha=0.025) + for k in indexes: + plt.scatter(embedding[int(k), 0], embedding[int(k), 1], s=1, c=cmap.color(int(to_numpy(x_[int(k),5]))), alpha=0.25) plt.xlabel(r'$\ensuremath{\mathbf{a}}_{i0}$', fontsize=64) plt.ylabel(r'$\ensuremath{\mathbf{a}}_{i1}$', fontsize=64) - plt.tight_layout() plt.xlim([-40, 40]) plt.ylim([-40, 40]) + plt.tight_layout() plt.savefig(f"./{log_dir}/results/all_embedding_{config_file}_{epoch}.tif", dpi=170.7) plt.close() @@ -908,19 +900,19 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo p = torch.load(f'graphs_data/graphs_{dataset_name}/model_p.pt', map_location=device) rr = torch.tensor(np.linspace(0, max_radius, 1000)).to(device) rmserr_list = [] - for n in range(int(n_particles * (1 - config.training.particle_dropout))): - embedding_ = model_a_first[n, :] * torch.ones((1000, config.graph_model.embedding_dim), device=device) + for n in indexes: + embedding_ = model_a_first[int(n), :] * torch.ones((1000, config.graph_model.embedding_dim), device=device) in_features = torch.cat((rr[:, None] / max_radius, 0 * rr[:, None], rr[:, None] / max_radius, embedding_), dim=1) with torch.no_grad(): func = model.lin_edge(in_features.float()) func = func[:, 0] - true_func = model.psi(rr, p[to_numpy(type_list[n]).astype(int)].squeeze(), - p[to_numpy(type_list[n]).astype(int)].squeeze()) + true_func = model.psi(rr, p[int(to_numpy(x_[int(n),5]))].squeeze(), + p[int(to_numpy(x_[int(n),5]))].squeeze()) rmserr_list.append(torch.sqrt(torch.mean((func - true_func.squeeze()) ** 2))) plt.plot(to_numpy(rr), to_numpy(func), - color=cmap.color(to_numpy(type_list[n]).astype(int)), linewidth=8, alpha=0.1) + color=cmap.color(int(to_numpy(x_[int(n),5]))), linewidth=8, alpha=0.1) plt.xlabel(r'$d_{ij}$', fontsize=64) plt.ylabel(r'$f(\ensuremath{\mathbf{a}}_i, d_{ij})$', fontsize=64) plt.xlim([0, max_radius]) @@ -962,7 +954,6 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo plt.xlabel(r'$d_{ij}$', fontsize=64) plt.ylabel(r'$f(\ensuremath{\mathbf{a}}_i, d_{ij})$', fontsize=64) plt.tight_layout() - torch.save(plots, f"./{log_dir}/results/plots_true_{config_file}_{epoch}.pt") plt.close() for k in indexes: @@ -3996,10 +3987,10 @@ def data_video_training(config_file, epoch_list, log_dir, logger, device): def data_plot(config_file, epoch_list, device): - plt.rcParams['text.usetex'] = True - rc('font', **{'family': 'serif', 'serif': ['Palatino']}) + # plt.rcParams['text.usetex'] = True + # rc('font', **{'family': 'serif', 'serif': ['Palatino']}) matplotlib.rcParams['savefig.pad_inches'] = 0 - matplotlib.use("Qt5Agg") + # matplotlib.use("Qt5Agg") l_dir = os.path.join('.', 'log') log_dir = os.path.join(l_dir, 'try_{}'.format(config_file)) @@ -4088,21 +4079,13 @@ def data_plot(config_file, epoch_list, device): # config_list = ['boids_16_256','boids_32_256','boids_64_256'] config_list = ['arbitrary_3_tracking'] - # epoch_list = ['0_500','0_1000','0_2000','0_5000','0_10000','0_20000','0_49000','0','0_corrected','1_500','1_1000','1_2000','1_5000','1_10000','1_20000','1_49000','1','1_corrected','2_500','2_1000','2_2000','2_5000','2_10000','2_20000','2_49000','2','2_corrected'] + # epoch_list = - epoch_list = ['2'] + epoch_list = ['0_500','0_1000','0_2000','0_5000','0_10000','0_20000','0_49000','0','0_corrected','1_500','1_1000','1_2000','1_5000','1_10000','1_20000','1_49000','1','1_corrected','2_500','2_1000','2_2000','2_5000','2_10000','2_20000','2_49000','2','2_corrected','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18','19','20'] for config_file in config_list: config = ParticleGraphConfig.from_yaml(f'./config/{config_file}.yaml') data_plot(config_file, epoch_list, device) - # data_plot_attraction_repulsion_short(config_file, device=device) - # data_plot_boids(config_file) - # data_plot_gravity(config_file) - # data_plot_RD(config_file,cc='viridis') - # data_plot_particle_field(config_file, mode='figures', cc='grey', device=device) - # data_plot_wave(config_file,cc='viridis') - # data_plot_signal(config_file,cc='viridis') - # data_video_validation(config_file,device=device) # data_video_training(config_file,device=device)