Skip to content

Commit

Permalink
fix training
Browse files Browse the repository at this point in the history
k = np.random.randint(n_frames - 1)
  • Loading branch information
allierc committed Jun 15, 2024
1 parent f92a36f commit dac8de1
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions src/ParticleGraph/models/graph_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def data_train_particles(config, config_file, device):
dataset_batch = []
for batch in range(batch_size):

k = 1 + np.random.randint(n_frames - 2)
k = np.random.randint(n_frames - 2)

x = x_list[run][k].clone().detach()

Expand Down Expand Up @@ -674,6 +674,7 @@ def data_train_tracking(config, config_file, device):
logger.info('from cell to track training')

tracking_index = 0
tracking_index_list=[]
for k in trange(n_frames):
x = x_list[1][k].clone().detach()
distance = torch.sum(bc_dpos(x[:, None, 1:3] - x[None, :, 1:3]) ** 2, dim=2)
Expand All @@ -695,6 +696,8 @@ def data_train_tracking(config, config_file, device):
tracking_index += np.sum((to_numpy(min_index) - np.arange(len(min_index))==0)) / n_frames / n_particles *100
x_list[1][k+1][min_index, 0:1] = x_list[1][k][:, 0:1].clone().detach()

tracking_index_list.append(len(x_pred) - np.sum((to_numpy(min_index) - np.arange(len(min_index)) == 0)))

plt.xticks([])
plt.yticks([])
plt.tight_layout()
Expand All @@ -704,6 +707,17 @@ def data_train_tracking(config, config_file, device):
print(f'tracking index: {tracking_index}')
logger.info(f'tracking index: {tracking_index}')

print(f'tracking errors: {np.sum(tracking_index_list)}')
logger.info(f'tracking errors: {np.sum(tracking_index_list)}')

fig, ax = fig_init(formatx='%.0f', formaty='%.0f')
plt.plot(np.arange(n_frames), tracking_index_list, color='k', linewidth=2)
plt.ylabel(r'tracking errors', fontsize=64)
plt.xlabel(r'frame', fontsize=64)
plt.tight_layout()
plt.savefig(f"./{log_dir}/tmp_training/tracking_error_{config_file}_{epoch}_{N}.tif", dpi=170.7)
plt.close()

x_ = torch.stack(x_list[1])
x_ = torch.reshape(x_, (x_.shape[0] * x_.shape[1], x_.shape[2]))
x_ = x_[0:(n_frames-1)*n_particles]
Expand Down Expand Up @@ -995,14 +1009,12 @@ def data_train_cell_tracking(config, config_file, device):
plt.savefig(f"./{log_dir}/tmp_training/particle/{dataset_name}_{epoch}_{N}.tif")
plt.close()

x_ = x_list[1][n_frames - 1].clone().detach()
x_ = x_list[1][0].clone().detach()
index_particles = get_index_particles(x_, n_particle_types, dimension)
plot_training_cell(config=config, dataset_name=dataset_name, log_dir=log_dir,
epoch=epoch, N=N, model=model, index_particles=index_particles, n_particle_types=n_particle_types, type_list=type_list, ynorm=ynorm, cmap=cmap, device=device)
torch.save({'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()}, os.path.join(log_dir, 'models', f'best_model_with_{n_runs - 1}_graphs_{epoch}_{N}.pt'))
t, r, a = get_gpu_memory_map(device)
logger.info(f"GPU memory: total {t} reserved {r} allocated {a}")

loss.backward()
optimizer.step()
Expand All @@ -1024,6 +1036,9 @@ def data_train_cell_tracking(config, config_file, device):
list_loss.append(total_loss / (N + 1) / n_particles)
torch.save(list_loss, os.path.join(log_dir, 'loss.pt'))

t, r, a = get_gpu_memory_map(device)
logger.info(f"GPU memory: total {t} reserved {r} allocated {a}")

if has_ghost:
torch.save({'model_state_dict': ghosts_particles.state_dict(),
'optimizer_state_dict': optimizer_ghost_particles.state_dict()}, os.path.join(log_dir, 'models', f'best_ghost_particles_with_{n_runs - 1}_graphs_{epoch}.pt'))
Expand All @@ -1045,6 +1060,7 @@ def data_train_cell_tracking(config, config_file, device):
logger.info('from cell to track training')

tracking_index = 0
tracking_index_list = []
for k in trange(n_frames):
x = x_list[1][k].clone().detach()
distance = torch.sum(bc_dpos(x[:, None, 1:3] - x[None, :, 1:3]) ** 2, dim=2)
Expand All @@ -1066,12 +1082,25 @@ def data_train_cell_tracking(config, config_file, device):
tracking_index += np.sum((to_numpy(min_index) - np.arange(len(min_index))==0)) / n_frames / n_particles *100
x_list[1][k+1][min_index, 0:1] = x_list[1][k][:, 0:1].clone().detach()

tracking_index_list.append(len(x_pred) - np.sum((to_numpy(min_index) - np.arange(len(min_index)) == 0)))

plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig(f"./{log_dir}/tmp_training/proxy_tracking_{dataset_name}_{epoch}_{N}.tif", dpi=87)
plt.close()

fig, ax = fig_init(formatx='%.0f', formaty='%.0f')
plt.plot(np.arange(n_frames), tracking_index_list, color='k', linewidth=2)
plt.ylabel(r'tracking errors', fontsize=64)
plt.xlabel(r'frame', fontsize=64)
plt.tight_layout()
plt.savefig(f"./{log_dir}/tmp_training/tracking_error_{config_file}_{epoch}_{N}.tif", dpi=170.7)
plt.close()

print(f'tracking errors: {np.sum(tracking_index_list)}')
logger.info(f'tracking errors: {np.sum(tracking_index_list)}')

print(f'tracking index: {tracking_index}')
logger.info(f'tracking index: {tracking_index}')

Expand All @@ -1080,7 +1109,6 @@ def data_train_cell_tracking(config, config_file, device):
x_ = x_list[1][k].clone().detach()
else:
x_ = torch.cat((x_,x_list[1][k]),0)
x_=torch.cat((x_,x_list[1][k]),1)
x_ = to_numpy(x_[:,0])
indexes = np.unique(x_)

Expand Down Expand Up @@ -1300,7 +1328,7 @@ def data_train_cell(config, config_file, device):

for batch in range(batch_size):

k = 1 + np.random.randint(n_frames - 2)
k = np.random.randint(n_frames - 2)

x = x_list[run][k].clone().detach()

Expand Down Expand Up @@ -1985,7 +2013,7 @@ def data_train_particle_field(config, config_file, device):

for batch in range(batch_size):

k = 1 + np.random.randint(n_frames - 2)
k = np.random.randint(n_frames - 2)

x = x_list[run][k].clone().detach()
x_mesh = x_mesh_list[run][k].clone().detach()
Expand Down

0 comments on commit dac8de1

Please sign in to comment.