Skip to content

Commit

Permalink
update tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
allierc committed Jun 12, 2024
1 parent c6f0a07 commit 8713821
Show file tree
Hide file tree
Showing 36 changed files with 116 additions and 159 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ archive/*/*
*.ijm
*.svg
*.npz
*.csv

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion GNN_particles_Ntype.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# config_list = ['boids_64_256']
# config_list = ['boids_16_256_test_cell_division_c']
# config_list = ['boids_16_256_division_death_model_2']
config_list = ['arbitrary_3_no_tracking']
config_list = ['arbitrary_3_tracking']
# config_list = ['arbitrary_64']

seed_list = np.arange(10)
Expand Down
55 changes: 55 additions & 0 deletions config/arbitrary_3_tracking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
description: 'attraction-repulsion with 3 types particles'
dataset: 'arbitrary_3'

simulation:
params: [[1.6233, 1.0413, 1.6012, 1.5615], [1.7667, 1.8308, 1.0855, 1.9055], [1.7226, 1.7850, 1.0584, 1.8579]]
min_radius: 0
max_radius: 0.075
n_particles: 4800
n_particle_types: 3
n_interactions: 3
n_frames: 250
sigma: 0.005
delta_t: 0.1
dpos_init: 0
boundary: 'periodic'

graph_model:
particle_model_name: 'PDE_A'
mesh_model_name: ''
prediction: 'first_derivative'
input_size: 5
output_size: 3
hidden_dim: 128
n_mp_layers: 5
aggr_type: 'mean'
embedding_dim: 2
update_type: 'none'

plotting:
colormap: 'tab10'
arrow_length: 2

training:
n_epochs: 10
has_no_tracking: True
batch_size: 1
small_init_batch_size: True
n_runs: 2
sparsity: 'replace_embedding_function'
cluster_method: 'distance_plot'
data_augmentation: True
data_augmentation_loop: 10
fix_cluster_embedding: True
learning_rate_start: 1.0E-4
learning_rate_end: 5.0E-5
learning_rate_embedding_start: 1.0E-4
learning_rate_embedding_end: 1.0E-4
device: 'auto'
coeff_loss1: 1
coeff_loss2: 1





4 changes: 0 additions & 4 deletions hall_of_fame_2024-06-12_131801.203.csv

This file was deleted.

4 changes: 0 additions & 4 deletions hall_of_fame_2024-06-12_131801.203.csv.bkup

This file was deleted.

Binary file removed hall_of_fame_2024-06-12_131801.203.pkl
Binary file not shown.
6 changes: 0 additions & 6 deletions hall_of_fame_2024-06-12_132058.813.csv

This file was deleted.

6 changes: 0 additions & 6 deletions hall_of_fame_2024-06-12_132058.813.csv.bkup

This file was deleted.

Binary file removed hall_of_fame_2024-06-12_132058.813.pkl
Binary file not shown.
9 changes: 0 additions & 9 deletions hall_of_fame_2024-06-12_132813.045.csv

This file was deleted.

9 changes: 0 additions & 9 deletions hall_of_fame_2024-06-12_132813.045.csv.bkup

This file was deleted.

Binary file removed hall_of_fame_2024-06-12_132813.045.pkl
Binary file not shown.
7 changes: 0 additions & 7 deletions hall_of_fame_2024-06-12_133248.406.csv

This file was deleted.

7 changes: 0 additions & 7 deletions hall_of_fame_2024-06-12_133248.406.csv.bkup

This file was deleted.

Binary file removed hall_of_fame_2024-06-12_133248.406.pkl
Binary file not shown.
Empty file.
10 changes: 0 additions & 10 deletions hall_of_fame_2024-06-12_133530.033.csv.bkup

This file was deleted.

Binary file removed hall_of_fame_2024-06-12_133530.033.pkl
Binary file not shown.
Binary file removed hall_of_fame_2024-06-12_133828.466.pkl
Binary file not shown.
7 changes: 0 additions & 7 deletions hall_of_fame_2024-06-12_133901.030.csv

This file was deleted.

Empty file.
Binary file removed hall_of_fame_2024-06-12_133901.030.pkl
Binary file not shown.
8 changes: 0 additions & 8 deletions hall_of_fame_2024-06-12_134405.912.csv

This file was deleted.

8 changes: 0 additions & 8 deletions hall_of_fame_2024-06-12_134405.912.csv.bkup

This file was deleted.

Binary file removed hall_of_fame_2024-06-12_134405.912.pkl
Binary file not shown.
7 changes: 0 additions & 7 deletions hall_of_fame_2024-06-12_134611.221.csv

This file was deleted.

7 changes: 0 additions & 7 deletions hall_of_fame_2024-06-12_134611.221.csv.bkup

This file was deleted.

Binary file removed hall_of_fame_2024-06-12_134611.221.pkl
Binary file not shown.
Empty file.
7 changes: 0 additions & 7 deletions hall_of_fame_2024-06-12_134928.940.csv.bkup

This file was deleted.

Binary file removed hall_of_fame_2024-06-12_134928.940.pkl
Binary file not shown.
10 changes: 0 additions & 10 deletions hall_of_fame_2024-06-12_135115.739.csv

This file was deleted.

10 changes: 0 additions & 10 deletions hall_of_fame_2024-06-12_135115.739.csv.bkup

This file was deleted.

Binary file removed hall_of_fame_2024-06-12_135115.739.pkl
Binary file not shown.
87 changes: 57 additions & 30 deletions src/ParticleGraph/models/graph_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def data_train(config, config_file, device):

# matplotlib.use("Qt5Agg")
matplotlib.use("Qt5Agg")

seed = config.training.seed

Expand All @@ -28,7 +28,7 @@ def data_train(config, config_file, device):
elif has_cell_division:
data_train_cell(config, config_file, device)
elif has_no_tracking:
data_train_no_tracking(config, config_file, device)
data_train_tracking(config, config_file, device)
else:
data_train_particles(config, config_file, device)

Expand Down Expand Up @@ -417,7 +417,7 @@ def data_train_particles(config, config_file, device):
plt.close()


def data_train_no_tracking(config, config_file, device):
def data_train_tracking(config, config_file, device):
print('')

simulation_config = config.simulation
Expand Down Expand Up @@ -579,44 +579,73 @@ def data_train_no_tracking(config, config_file, device):
plt.savefig(f"./{log_dir}/tmp_training/embedding/before_particle_{dataset_name}_{epoch}_{N}.tif",dpi=87)
plt.close()

for k in range(n_frames-1):
x = x_list[k].clone().detach()
distance = torch.sum(bc_dpos(x[:, None, 1:3] - x[None, :, 1:3]) ** 2, dim=2)
adj_t = ((distance < max_radius ** 2) & (distance > min_radius ** 2)).float() * 1
edges = adj_t.nonzero().t().contiguous()
dataset = data.Data(x=x[:, :], edge_index=edges)
y = x_list[k + 1].clone().detach()
y = y[:, 1:3]
with torch.no_grad():
pred, logvar, sigma = model(dataset, training=False, vnorm=vnorm, phi=torch.zeros(1, device=device))
x_pred = bc_pos(x[:, 1:3] + pred * delta_t)
distance = torch.sum(bc_dpos(x[:, None, 1:3] - y[None, :, :]) ** 2, dim=2)
if epoch%2 == 0:

result = distance.min(dim=0)
min_value = result.values
min_index = result.indices
print('from cell to track training')
logger.info('from cell to track training')

if epoch%2 == 0:
x_list[k+1][min_index, 0:1] = x_list[k][:, 0:1].clone().detach()
else:
x_list[k]=index_l[k].clone().detach()
for k in trange(n_frames-1):
x = x_list[k].clone().detach()
distance = torch.sum(bc_dpos(x[:, None, 1:3] - x[None, :, 1:3]) ** 2, dim=2)
adj_t = ((distance < max_radius ** 2) & (distance > min_radius ** 2)).float() * 1
edges = adj_t.nonzero().t().contiguous()
dataset = data.Data(x=x[:, :], edge_index=edges)
y = x_list[k + 1].clone().detach()
y = y[:, 1:3]
with torch.no_grad():
pred, logvar, sigma = model(dataset, training=False, vnorm=vnorm, phi=torch.zeros(1, device=device))
x_pred = bc_pos(x[:, 1:3] + pred * delta_t)
distance = torch.sum(bc_dpos(x[:, None, 1:3] - y[None, :, :]) ** 2, dim=2)

with torch.no_grad():
model.a[(k+1)*n_particles:(k+2)*n_particles,:] = model.a[k*n_particles + min_index,:].clone().detach()
result = distance.min(dim=0)
min_value = result.values
min_index = result.indices

x_list[k+1][min_index, 0:1] = x_list[k][:, 0:1].clone().detach()

if epoch%2 == 0:
print('from cell to track training')
logger.info('from cell to track training')
x_ = torch.stack(x_list)
x_ = torch.reshape(x_, (x_.shape[0] * x_.shape[1], x_.shape[2]))
x_ = x_[0:n_frames*n_particles]
x_ = to_numpy(x_[:,0])
indexes = np.unique(x_)

for k in indexes:
pos = np.argwhere(x_ == k)
if len(pos>0):
pos=pos[:,0]
model_a = torch.median(model.a[pos,:])
model_a = model_a.clone().detach()
model_a = model_a.repeat(len(pos),1)
with torch.no_grad():
model.a[pos,:] = model_a

print(f'{len(np.unique(x_))} tracks, first track index: {np.min(x_)}, last track index: {np.max(x_)}')
logger.info(f'{len(np.unique(x_))} tracks, first track index: {np.min(x_)}, last track index: {np.max(x_)}')
x_ = []

else:

print('from track to cell training')
logger.info('from track to cell training')

x_ = torch.stack(x_list)
x_ = torch.reshape(x_, (x_.shape[0] * x_.shape[1], x_.shape[2]))
x_ = x_[0:n_frames * n_particles]
x_ = to_numpy(x_[:, 0])

for k in indexes:
pos = np.argwhere(x_ == k)
if len(pos > 0):
pos = pos[:, 0]
model_a = model.a[pos[0], :]
model_a = model_a.clone().detach()
model_a = model_a.repeat(len(pos), 1)
with torch.no_grad():
model.a[pos, :] = model_a

for k in range(n_frames):
x_list[k][:, 0] = index_l[k].clone().detach()

fig = plt.figure(figsize=(8, 8))
for k in range(0,n_frames-2,n_frames//10):
embedding = to_numpy(model.a[k*n_particles:(k+1)*n_particles,:].clone().detach())
Expand All @@ -625,12 +654,10 @@ def data_train_no_tracking(config, config_file, device):
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig(f"./{log_dir}/tmp_training/particle/embedding{dataset_name}_{epoch}_{N}.tif",dpi=87)
plt.savefig(f"./{log_dir}/tmp_training/embedding/after_particle_{dataset_name}_{epoch}_{N}.tif", dpi=87)
plt.close()




def data_train_cell(config, config_file, device):
print('')

Expand Down
4 changes: 2 additions & 2 deletions src/ParticleGraph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def create_log_dir(config, config_file):
os.makedirs(os.path.join(log_dir, 'tmp_training/particle'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'tmp_training/field'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'tmp_training/function'), exist_ok=True)
# os.makedirs(os.path.join(log_dir, 'tmp_training/embedding/siren'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'tmp_training/embedding'), exist_ok=True)
if config.training.n_ghosts > 0:
os.makedirs(os.path.join(log_dir, 'tmp_training/ghost'), exist_ok=True)
files = glob.glob(f"{log_dir}/results/*")
Expand All @@ -261,7 +261,7 @@ def create_log_dir(config, config_file):
files = glob.glob(f"{log_dir}/tmp_training/function/*")
for f in files:
os.remove(f)
files = glob.glob(f"{log_dir}/tmp_training/siren/*")
files = glob.glob(f"{log_dir}/tmp_training/embedding/*")
for f in files:
os.remove(f)
files = glob.glob(f"{log_dir}/tmp_training/ghost/*")
Expand Down

0 comments on commit 8713821

Please sign in to comment.