-
Notifications
You must be signed in to change notification settings - Fork 1
Description
Issue: Inconsistent Atom Type Masking with Time Step
Describe the Question
File path: TFG-Flow/diffusion/aaflow.py
In class AAFlow
In the sample_x_t
function, the masking of atom types appears to behave inversely to the expected diffusion process. As the time_atom_types
value increases (representing later timesteps in the forward diffusion), the probability of an atom type being masked (replaced with ATOMNAME_TO_INDEX['MASK']
) seems to decrease, rather than increase.
The relevant code snippet is:
mask_atom = torch.rand(atom_types.shape, device=self.device) > repeat(time_atom_types / self.T, 'b -> b n', n=atom_types.shape[1])
atom_types_t = atom_types * ~mask_atom + ATOMNAME_TO_INDEX['MASK'] * mask_atom
Here, time_atom_types / self.T
creates a threshold that increases with the timestep. The comparison > means that mask_atom will be True (and the atom type will be masked) only when a random number is greater than this increasing threshold. This leads to fewer True values in mask_atom as time_atom_types grows.
Thank you for your attention in advance