Skip to content

Commit

Permalink
add no_sym
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 14, 2025
1 parent 9dbcb8f commit f44e5e1
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 36 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
n_update_has_a_first_sum: bool = False,
auto_batchsize: int = 0,
optim_update: bool = True,
no_sym: bool = False,
) -> None:
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
self.unet_norm = unet_norm
self.auto_batchsize = auto_batchsize
self.optim_update = optim_update
self.no_sym = no_sym

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def init_subclass_params(sub_data, sub_class):
auto_batchsize=self.repflow_args.auto_batchsize,
optim_update=self.repflow_args.optim_update,
skip_stat=self.repflow_args.skip_stat,
no_sym=self.repflow_args.no_sym,
exclude_types=exclude_types,
env_protection=env_protection,
precision=precision,
Expand Down
79 changes: 43 additions & 36 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
only_e_bn: bool = False,
bn_moment: float = 0.1,
optim_update: bool = True,
no_sym: bool = False,
activation_function: str = "silu",
update_style: str = "res_residual",
update_residual: float = 0.1,
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
self.n_update_has_a = n_update_has_a
self.n_update_has_a_first_sum = n_update_has_a_first_sum
self.optim_update = optim_update
self.no_sym = no_sym

assert update_residual_init in [
"norm",
Expand Down Expand Up @@ -259,23 +261,26 @@ def __init__(
)

# node sym (grrg + drrd)
self.n_sym_dim = n_dim * self.axis_neuron + e_dim * self.axis_neuron
self.node_sym_linear = MLPLayer(
self.n_sym_dim,
n_dim,
precision=precision,
seed=child_seed(seed, 2),
)
if self.update_style == "res_residual":
self.n_residual.append(
get_residual(
n_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(seed, 3),
)
if not self.no_sym:
self.n_sym_dim = n_dim * self.axis_neuron + e_dim * self.axis_neuron
self.node_sym_linear = MLPLayer(
self.n_sym_dim,
n_dim,
precision=precision,
seed=child_seed(seed, 2),
)
if self.update_style == "res_residual":
self.n_residual.append(
get_residual(
n_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(seed, 3),
)
)
else:
self.node_sym_linear = None

# node edge message
self.node_edge_linear = MLPLayer(
Expand Down Expand Up @@ -818,28 +823,30 @@ def forward(

nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist)

# node sym (grrg + drrd)
node_sym_list: list[torch.Tensor] = []
node_sym_list.append(
self.symmetrization_op(
edge_ebd,
h2,
nlist_mask,
sw,
self.axis_neuron,
if not self.no_sym:
assert self.node_sym_linear is not None
# node sym (grrg + drrd)
node_sym_list: list[torch.Tensor] = []
node_sym_list.append(
self.symmetrization_op(
edge_ebd,
h2,
nlist_mask,
sw,
self.axis_neuron,
)
)
)
node_sym_list.append(
self.symmetrization_op(
nei_node_ebd,
h2,
nlist_mask,
sw,
self.axis_neuron,
node_sym_list.append(
self.symmetrization_op(
nei_node_ebd,
h2,
nlist_mask,
sw,
self.axis_neuron,
)
)
)
node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1)))
n_update_list.append(node_sym)
node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1)))
n_update_list.append(node_sym)

if not self.optim_update:
# nb x nloc x nnei x (n_dim * 2 + e_dim)
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
env_protection: float = 0.0,
precision: str = "float64",
skip_stat: bool = True,
no_sym: bool = False,
pre_ln: bool = False,
only_e_ln: bool = False,
pre_bn: bool = False,
Expand Down Expand Up @@ -237,6 +238,7 @@ def __init__(
self.n_attn_head = n_attn_head
self.auto_batchsize = auto_batchsize
self.optim_update = optim_update
self.no_sym = no_sym

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -449,6 +451,7 @@ def __init__(
only_e_bn=self.only_e_bn,
bn_moment=self.bn_moment,
optim_update=self.optim_update,
no_sym=self.no_sym,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,12 @@ def dpa3_repflow_args():
optional=True,
default="None",
),
Argument(
"no_sym",
bool,
optional=True,
default=False,
),
]


Expand Down

0 comments on commit f44e5e1

Please sign in to comment.