Skip to content

Commit

Permalink
angular_var angular_Bernouilli
Browse files Browse the repository at this point in the history
  • Loading branch information
allierc committed Jun 14, 2024
1 parent 715a05c commit 0a4209b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
24 changes: 21 additions & 3 deletions GNN_particles_PlotFigure.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo

fig = plt.figure(figsize=(8, 8))
tracking_index = 0
tracking_index_list = []
for k in trange(n_frames):
x = x_list[1][k].clone().detach()
distance = torch.sum(bc_dpos(x[:, None, 1:3] - x[None, :, 1:3]) ** 2, dim=2)
Expand All @@ -844,6 +845,7 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo
plt.scatter(true_index[index_particles[n]], reconstructed_index[index_particles[n]], s=1, c=cmap.color(n), alpha=0.05)

tracking_index += np.sum((to_numpy(min_index) - np.arange(len(min_index)) == 0)) / n_frames / n_particles * 100
tracking_index_list.append(np.sum((to_numpy(min_index) - np.arange(len(min_index)) == 0)))
x_list[1][k + 1][min_index, 0:1] = x_list[1][k][:, 0:1].clone().detach()

x_ = torch.stack(x_list[1])
Expand All @@ -864,7 +866,21 @@ def data_plot_attraction_repulsion_tracking(config_file, epoch_list, log_dir, lo
print(f'{len(indexes)} tracks')
logger.info(f'{len(indexes)} tracks')

fig,ax = fig_init()
tracking_index_list = np.array(tracking_index_list)
tracking_index_list = n_particles - tracking_index_list

fig,ax = fig_init(formatx='%.0f', formaty='%.0f')
plt.plot(np.arange(n_frames), tracking_index_list, color='k', linewidth=2)
plt.ylabel(r'tracking errors', fontsize=64)
plt.xlabel(r'frame', fontsize=64)
plt.tight_layout()
plt.savefig(f"./{log_dir}/results/tracking_error_{config_file}_{epoch}.tif", dpi=170.7)
plt.close()

print(f'tracking errors: {np.sum(tracking_index_list)}')
logger.info(f'tracking errors: {np.sum(tracking_index_list)}')

fig, ax = fig_init()
for k in trange(0,n_frames-2):
embedding = to_numpy(model.a[k*n_particles:(k+1)*n_particles,:].clone().detach())
for n in range(n_particle_types):
Expand Down Expand Up @@ -3983,7 +3999,7 @@ def data_plot(config_file, epoch_list, device):
plt.rcParams['text.usetex'] = True
rc('font', **{'family': 'serif', 'serif': ['Palatino']})
matplotlib.rcParams['savefig.pad_inches'] = 0
# matplotlib.use("Qt5Agg")
matplotlib.use("Qt5Agg")

l_dir = os.path.join('.', 'log')
log_dir = os.path.join(l_dir, 'try_{}'.format(config_file))
Expand Down Expand Up @@ -4072,7 +4088,9 @@ def data_plot(config_file, epoch_list, device):
# config_list = ['boids_16_256','boids_32_256','boids_64_256']
config_list = ['arbitrary_3_tracking']

epoch_list = ['0_500','0_1000','0_2000','0_5000','0_10000','0_20000','0_49000','0','0_corrected','1_500','1_1000','1_2000','1_5000','1_10000','1_20000','1_49000','1','1_corrected','2_500','2_1000','2_2000','2_5000','2_10000','2_20000','2_49000','2','2_corrected']
# epoch_list = ['0_500','0_1000','0_2000','0_5000','0_10000','0_20000','0_49000','0','0_corrected','1_500','1_1000','1_2000','1_5000','1_10000','1_20000','1_49000','1','1_corrected','2_500','2_1000','2_2000','2_5000','2_10000','2_20000','2_49000','2','2_corrected']

epoch_list = ['2']

for config_file in config_list:
config = ParticleGraphConfig.from_yaml(f'./config/{config_file}.yaml')
Expand Down
3 changes: 2 additions & 1 deletion src/ParticleGraph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class SimulationConfig(BaseModel):
cell_death_rate: list[float] = [-1]
min_radius: Annotated[float, Field(ge=0)] = 0
max_radius: Annotated[float, Field(gt=0)]
radius_pid: bool = False
angular_var: float = 0
angular_Bernouilli: float = 0
max_edges: float = 1.0E6
diffusion_coefficients: list[list[float]] = None
n_particles: int = 1000
Expand Down

0 comments on commit 0a4209b

Please sign in to comment.