Skip to content

Commit

Permalink
update tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
allierc committed Jun 13, 2024
1 parent c77f5b7 commit 07555e3
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 15 deletions.
4 changes: 2 additions & 2 deletions GNN_particles_Ntype.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
# config_list = ['boids_64_256']
# config_list = ['boids_16_256_test_cell_division_c']
# config_list = ['boids_16_256_division_death_model_2']
# config_list = ['arbitrary_3_tracking']
config_list = ['boids_16_256_tracking']
config_list = ['arbitrary_3_tracking']
# config_list = ['boids_16_256_tracking']
# config_list = ['arbitrary_64']

seed_list = np.arange(10)
Expand Down
2 changes: 1 addition & 1 deletion config/arbitrary_3_tracking.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ plotting:
ylim: [-5E-5, 5E-5]

training:
n_epochs: 1
n_epochs: 20
has_no_tracking: True
batch_size: 1
small_init_batch_size: True
Expand Down
29 changes: 22 additions & 7 deletions src/ParticleGraph/models/graph_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def data_train_tracking(config, config_file, device):

print('Create models ...')
model, bc_pos, bc_dpos = choose_training_model(config, device)
# print('Loading existing model ...')
print('Loading existing model ...')
# net = f"./log/try_{config_file}/models/best_model_with_1_graphs_0.pt"
# state_dict = torch.load(net,map_location=device)
# model.load_state_dict(state_dict['model_state_dict'])
Expand Down Expand Up @@ -548,8 +548,7 @@ def data_train_tracking(config, config_file, device):
sin_phi = torch.sin(phi)

run = 1 + np.random.randint(n_runs - 1)
k = np.random.randint(n_frames - 2)

k = np.random.randint(n_frames)
x = x_list[run][k].clone().detach()

if has_ghost:
Expand Down Expand Up @@ -620,7 +619,7 @@ def data_train_tracking(config, config_file, device):
plt.tight_layout()
plt.savefig(f"./{log_dir}/tmp_training/particle/{dataset_name}_{epoch}_{N}.tif")
plt.close()
model.a_current = model.a

plot_training(config=config, dataset_name=dataset_name, log_dir=log_dir,
epoch=epoch, N=N, x=x, model=model, n_nodes=0, n_node_types=0, index_nodes=0, dataset_num=1,
index_particles=index_particles, n_particles=n_particles,
Expand Down Expand Up @@ -669,12 +668,14 @@ def data_train_tracking(config, config_file, device):
plt.savefig(f"./{log_dir}/tmp_training/before_particle_{dataset_name}_{epoch}_{N}.tif",dpi=87)
plt.close()

fig = plt.figure(figsize=(8, 8))
if epoch%2 == 0:

print('from cell to track training')
logger.info('from cell to track training')

for k in trange(n_frames-1):
tracking_index = 0
for k in trange(n_frames):
x = x_list[1][k].clone().detach()
distance = torch.sum(bc_dpos(x[:, None, 1:3] - x[None, :, 1:3]) ** 2, dim=2)
adj_t = ((distance < max_radius ** 2) & (distance > min_radius ** 2)).float() * 1
Expand All @@ -691,12 +692,22 @@ def data_train_tracking(config, config_file, device):
result = distance.min(dim=1)
min_value = result.values
min_index = result.indices

plt.scatter(np.arange(len(min_index)), to_numpy(min_index), s=10, c='k', alpha=0.05)
tracking_index += np.sum((to_numpy(min_index) - np.arange(len(min_index))==0)) / n_frames / n_particles *100
x_list[1][k+1][min_index, 0:1] = x_list[1][k][:, 0:1].clone().detach()

plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig(f"./{log_dir}/tmp_training/proxy_tracking_{dataset_name}_{epoch}_{N}.tif", dpi=87)
plt.close()

print(f'tracking index: {tracking_index}')
logger.info(f'tracking index: {tracking_index}')

x_ = torch.stack(x_list[1])
x_ = torch.reshape(x_, (x_.shape[0] * x_.shape[1], x_.shape[2]))
x_ = x_[0:n_frames*n_particles]
x_ = x_[0:(n_frames-1)*n_particles]
x_ = to_numpy(x_[:,0])
indexes = np.unique(x_)

Expand Down Expand Up @@ -742,6 +753,10 @@ def data_train_tracking(config, config_file, device):
plt.savefig(f"./{log_dir}/tmp_training/after_particle_{dataset_name}_{epoch}_{N}.tif", dpi=87)
plt.close()

lr_embedding = train_config.learning_rate_embedding_start
lr = train_config.learning_rate_start
optimizer, n_total_params = set_trainable_parameters(model, lr_embedding, lr)
logger.info(f'Learning rates: {lr}, {lr_embedding}')



Expand Down
13 changes: 8 additions & 5 deletions src/ParticleGraph/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def plot_training (config, dataset_name, log_dir, epoch, N, x, index_particles,
else:
fig = plt.figure(figsize=(8, 8))
if has_no_tracking:
embedding = to_numpy(model.a_current)
embedding = to_numpy(model.a)
else:
embedding = get_embedding(model.a, 1)
for n in range(n_particle_types):
Expand All @@ -236,7 +236,10 @@ def plot_training (config, dataset_name, log_dir, epoch, N, x, index_particles,
ax = fig.add_subplot(1, 2, 1)
rr = torch.tensor(np.logspace(7, 9, 1000)).to(device)
for n in range(n_particles):
embedding_ = model.a[1, n, :] * torch.ones((1000, model_config.embedding_dim), device=device)
if has_no_tracking:
embedding_ = model.a[n, :] * torch.ones((1000, model_config.embedding_dim), device=device)
else:
embedding_ = model.a[1, n, :] * torch.ones((1000, model_config.embedding_dim), device=device)
in_features = torch.cat((rr[:, None] / simulation_config.max_radius, 0 * rr[:, None],
rr[:, None] / simulation_config.max_radius, 10 ** embedding_), dim=1)
with torch.no_grad():
Expand Down Expand Up @@ -270,7 +273,7 @@ def plot_training (config, dataset_name, log_dir, epoch, N, x, index_particles,
func_list = []
for n in range(n_particles):
if has_no_tracking:
embedding_ = model.a_current[n, :] * torch.ones((1000, model_config.embedding_dim), device=device)
embedding_ = model.a[n, :] * torch.ones((1000, model_config.embedding_dim), device=device)
else:
embedding_ = model.a[1, n, :] * torch.ones((1000, model_config.embedding_dim), device=device)
in_features = torch.cat((rr[:, None] / max_radius, 0 * rr[:, None],
Expand All @@ -283,7 +286,7 @@ def plot_training (config, dataset_name, log_dir, epoch, N, x, index_particles,
if n % 5 == 0:
plt.plot(to_numpy(rr), to_numpy(func) * to_numpy(ynorm),
color=cmap.color(int(n // (n_particles / n_particle_types))), linewidth=2)
plt.ylim([-1E-4, 1E-4])
# plt.ylim([-1E-4, 1E-4])
plt.xlim([-max_radius, max_radius])
# plt.xlabel(r'$x_j-x_i$', fontsize=64)
# plt.ylabel(r'$f_{ij}$', fontsize=64)
Expand Down Expand Up @@ -352,7 +355,7 @@ def plot_training (config, dataset_name, log_dir, epoch, N, x, index_particles,
rr = torch.tensor(np.linspace(0, simulation_config.max_radius, 200)).to(device)
for n in range(n_particles):
if has_no_tracking:
embedding_ = model.a_current[n, :] * torch.ones((200, model_config.embedding_dim), device=device)
embedding_ = model.a[n, :] * torch.ones((200, model_config.embedding_dim), device=device)
else:
embedding_ = model.a[1, n, :] * torch.ones((200, model_config.embedding_dim), device=device)
if (model_config.particle_model_name == 'PDE_A'):
Expand Down

0 comments on commit 07555e3

Please sign in to comment.