Skip to content

Question about the discrete mask part in sample_x_t fucntion #2

@hyf3513OneGO

Description

@hyf3513OneGO

Issue: Inconsistent Atom Type Masking with Time Step

Describe the Question
File path: TFG-Flow/diffusion/aaflow.py
In class AAFlow
In the sample_x_t function, the masking of atom types appears to behave inversely to the expected diffusion process. As the time_atom_types value increases (representing later timesteps in the forward diffusion), the probability of an atom type being masked (replaced with ATOMNAME_TO_INDEX['MASK']) seems to decrease, rather than increase.

The relevant code snippet is:

mask_atom = torch.rand(atom_types.shape, device=self.device) > repeat(time_atom_types / self.T, 'b -> b n', n=atom_types.shape[1])
atom_types_t = atom_types * ~mask_atom + ATOMNAME_TO_INDEX['MASK'] * mask_atom

Here, time_atom_types / self.T creates a threshold that increases with the timestep. The comparison > means that mask_atom will be True (and the atom type will be masked) only when a random number is greater than this increasing threshold. This leads to fewer True values in mask_atom as time_atom_types grows.

Thank you for your attention in advance

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions