Skip to content

Commit

Permalink
update tracking
Browse files Browse the repository at this point in the history
first results
  • Loading branch information
allierc committed Jun 13, 2024
1 parent c9db579 commit c77f5b7
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 6 deletions.
3 changes: 2 additions & 1 deletion GNN_particles_Ntype.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
# config_list = ['boids_64_256']
# config_list = ['boids_16_256_test_cell_division_c']
# config_list = ['boids_16_256_division_death_model_2']
config_list = ['arbitrary_3_tracking']
# config_list = ['arbitrary_3_tracking']
config_list = ['boids_16_256_tracking']
# config_list = ['arbitrary_64']

seed_list = np.arange(10)
Expand Down
2 changes: 1 addition & 1 deletion config/arbitrary_3_tracking.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ training:
sparsity: 'replace_embedding_function'
cluster_method: 'distance_plot'
data_augmentation: True
data_augmentation_loop: 1
data_augmentation_loop: 200
fix_cluster_embedding: True
device: 'auto'
coeff_loss1: 1
Expand Down
46 changes: 46 additions & 0 deletions config/boids_16_256_tracking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
description: 'Boids 16 different types'
dataset: 'boids_16_256'

simulation:
params: [[27.6, 92.5, 48.2], [32.0, 51.8, 29.8], [23.6, 35.0, 13.5], [3.3, 76.4, 13.0], [94.0, 78.7, 30.8,], [81.4, 34.6, 2.2], [3.1, 40.5, 17.0], [88.8, 7.2, 29.8], [32.7, 76.8, 26.1], [14.5, 56.8, 27.6], [63.3, 99.9, 13.9], [32.5, 25.2, 24.4], [97.6, 56.5, 12.1], [62.2, 9.1, 28.3], [76.4, 52.9, 32.0], [48.1, 54.4, 30.7]]
min_radius: 0.001
max_radius: 0.04
n_particles: 1792
n_particle_types: 16
n_interactions: 16
has_cell_division: False
n_frames: 8000
sigma: 0.005
delta_t: 0.5
dpos_init: 5.0E-4
boundary: 'periodic'
start_frame: 0

graph_model:
particle_model_name: 'PDE_B'
mesh_model_name: ''
prediction: '2nd_derivative'
input_size: 9
output_size: 2
hidden_dim: 256
n_mp_layers: 5
aggr_type: 'mean'
embedding_dim: 2
update_type: 'none'

plotting:
colormap: 'tab20'
arrow_length: 5

training:
n_epochs: 20
has_no_tracking: True
batch_size: 1
small_init_batch_size: True
n_runs: 2
noise_level: 0
data_augmentation: True
data_augmentation_loop: 100
fix_cluster_embedding: True
device: 'auto'

8 changes: 4 additions & 4 deletions src/ParticleGraph/models/graph_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def data_train_particles(config, config_file, device):
loss = ((pred - y_batch) / (y_batch)).norm(2) / 1E9

visualize_embedding = True
if visualize_embedding & (((epoch < 3 ) & (N % 500 == 0)) | (N==0)):
if visualize_embedding & (((epoch < 3 ) & (N%(Niter//100) == 0)) | (N==0)):
plot_training(config=config, dataset_name=dataset_name, log_dir=log_dir,
epoch=epoch, N=N, x=x, model=model, n_nodes=0, n_node_types=0, index_nodes=0, dataset_num=1,
index_particles=index_particles, n_particles=n_particles,
Expand Down Expand Up @@ -609,7 +609,7 @@ def data_train_tracking(config, config_file, device):
# loss2 = 0 * pred.norm(2) / (vnorm**2) * config.training.coeff_loss2

visualize_embedding = True
if visualize_embedding & (((epoch < 3 ) & (N % 500 == 0)) | (N==0)):
if visualize_embedding & (((epoch < 3 ) & (N % (Niter//100) == 0)) | (N==0)):
print(N)
fig = plt.figure(figsize=(8, 8))
plt.scatter(to_numpy(x[:, 1]), to_numpy(x[:, 2]), s=10, c='k', alpha=0.05)
Expand Down Expand Up @@ -666,7 +666,7 @@ def data_train_tracking(config, config_file, device):
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig(f"./{log_dir}/tmp_training/embedding/before_particle_{dataset_name}_{epoch}_{N}.tif",dpi=87)
plt.savefig(f"./{log_dir}/tmp_training/before_particle_{dataset_name}_{epoch}_{N}.tif",dpi=87)
plt.close()

if epoch%2 == 0:
Expand Down Expand Up @@ -739,7 +739,7 @@ def data_train_tracking(config, config_file, device):
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig(f"./{log_dir}/tmp_training/embedding/after_particle_{dataset_name}_{epoch}_{N}.tif", dpi=87)
plt.savefig(f"./{log_dir}/tmp_training/after_particle_{dataset_name}_{epoch}_{N}.tif", dpi=87)
plt.close()


Expand Down

0 comments on commit c77f5b7

Please sign in to comment.