diff --git a/GNN_particles_Ntype.py b/GNN_particles_Ntype.py index 4f802426..ef82a799 100644 --- a/GNN_particles_Ntype.py +++ b/GNN_particles_Ntype.py @@ -45,7 +45,8 @@ # config_list = ['boids_64_256'] # config_list = ['boids_16_256_test_cell_division_c'] # config_list = ['boids_16_256_division_death_model_2'] - config_list = ['arbitrary_3_tracking'] + # config_list = ['arbitrary_3_tracking'] + config_list = ['boids_16_256_tracking'] # config_list = ['arbitrary_64'] seed_list = np.arange(10) diff --git a/config/arbitrary_3_tracking.yaml b/config/arbitrary_3_tracking.yaml index d6b73339..dd4583f5 100644 --- a/config/arbitrary_3_tracking.yaml +++ b/config/arbitrary_3_tracking.yaml @@ -40,7 +40,7 @@ training: sparsity: 'replace_embedding_function' cluster_method: 'distance_plot' data_augmentation: True - data_augmentation_loop: 1 + data_augmentation_loop: 200 fix_cluster_embedding: True device: 'auto' coeff_loss1: 1 diff --git a/config/boids_16_256_tracking.yaml b/config/boids_16_256_tracking.yaml new file mode 100644 index 00000000..dfd1b13e --- /dev/null +++ b/config/boids_16_256_tracking.yaml @@ -0,0 +1,46 @@ +description: 'Boids 16 different types' +dataset: 'boids_16_256' + +simulation: + params: [[27.6, 92.5, 48.2], [32.0, 51.8, 29.8], [23.6, 35.0, 13.5], [3.3, 76.4, 13.0], [94.0, 78.7, 30.8,], [81.4, 34.6, 2.2], [3.1, 40.5, 17.0], [88.8, 7.2, 29.8], [32.7, 76.8, 26.1], [14.5, 56.8, 27.6], [63.3, 99.9, 13.9], [32.5, 25.2, 24.4], [97.6, 56.5, 12.1], [62.2, 9.1, 28.3], [76.4, 52.9, 32.0], [48.1, 54.4, 30.7]] + min_radius: 0.001 + max_radius: 0.04 + n_particles: 1792 + n_particle_types: 16 + n_interactions: 16 + has_cell_division: False + n_frames: 8000 + sigma: 0.005 + delta_t: 0.5 + dpos_init: 5.0E-4 + boundary: 'periodic' + start_frame: 0 + +graph_model: + particle_model_name: 'PDE_B' + mesh_model_name: '' + prediction: '2nd_derivative' + input_size: 9 + output_size: 2 + hidden_dim: 256 + n_mp_layers: 5 + aggr_type: 'mean' + embedding_dim: 2 + update_type: 'none' + +plotting: + colormap: 'tab20' + arrow_length: 5 + +training: + n_epochs: 20 + has_no_tracking: True + batch_size: 1 + small_init_batch_size: True + n_runs: 2 + noise_level: 0 + data_augmentation: True + data_augmentation_loop: 100 + fix_cluster_embedding: True + device: 'auto' + diff --git a/src/ParticleGraph/models/graph_trainer.py b/src/ParticleGraph/models/graph_trainer.py index 48c01f45..b8ba8fe3 100644 --- a/src/ParticleGraph/models/graph_trainer.py +++ b/src/ParticleGraph/models/graph_trainer.py @@ -226,7 +226,7 @@ def data_train_particles(config, config_file, device): loss = ((pred - y_batch) / (y_batch)).norm(2) / 1E9 visualize_embedding = True - if visualize_embedding & (((epoch < 3 ) & (N % 500 == 0)) | (N==0)): + if visualize_embedding & (((epoch < 3 ) & (N%(Niter//100) == 0)) | (N==0)): plot_training(config=config, dataset_name=dataset_name, log_dir=log_dir, epoch=epoch, N=N, x=x, model=model, n_nodes=0, n_node_types=0, index_nodes=0, dataset_num=1, index_particles=index_particles, n_particles=n_particles, @@ -609,7 +609,7 @@ def data_train_tracking(config, config_file, device): # loss2 = 0 * pred.norm(2) / (vnorm**2) * config.training.coeff_loss2 visualize_embedding = True - if visualize_embedding & (((epoch < 3 ) & (N % 500 == 0)) | (N==0)): + if visualize_embedding & (((epoch < 3 ) & (N % (Niter//100) == 0)) | (N==0)): print(N) fig = plt.figure(figsize=(8, 8)) plt.scatter(to_numpy(x[:, 1]), to_numpy(x[:, 2]), s=10, c='k', alpha=0.05) @@ -666,7 +666,7 @@ def data_train_tracking(config, config_file, device): plt.xticks([]) plt.yticks([]) plt.tight_layout() - plt.savefig(f"./{log_dir}/tmp_training/embedding/before_particle_{dataset_name}_{epoch}_{N}.tif",dpi=87) + plt.savefig(f"./{log_dir}/tmp_training/before_particle_{dataset_name}_{epoch}_{N}.tif",dpi=87) plt.close() if epoch%2 == 0: @@ -739,7 +739,7 @@ def data_train_tracking(config, config_file, device): plt.xticks([]) plt.yticks([]) plt.tight_layout() - plt.savefig(f"./{log_dir}/tmp_training/embedding/after_particle_{dataset_name}_{epoch}_{N}.tif", dpi=87) + plt.savefig(f"./{log_dir}/tmp_training/after_particle_{dataset_name}_{epoch}_{N}.tif", dpi=87) plt.close()