Skip to content

Commit

Permalink
add n_multi_edge_message
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 4, 2025
1 parent fec6462 commit 9d9dc8f
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 11 deletions.
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
a_compress_rate: int = 0,
a_compress_e_rate: int = 1,
a_compress_use_split: bool = False,
n_multi_edge_message: int = 1,
axis_neuron: int = 4,
update_angle: bool = True,
update_style: str = "res_residual",
Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(
a_compress_use_split : bool, optional
Whether to split first sub-vectors instead of linear mapping during angular message compression.
The default value is False.
n_multi_edge_message : int, optional
The head number of multiple edge messages to update node feature.
Default is 1, indicating one head edge message.
axis_neuron : int, optional
The number of dimension of submatrix in the symmetrization ops.
update_angle : bool, optional
Expand Down Expand Up @@ -87,6 +91,7 @@ def __init__(
self.a_rcut_smth = a_rcut_smth
self.a_sel = a_sel
self.a_compress_rate = a_compress_rate
self.n_multi_edge_message = n_multi_edge_message
self.axis_neuron = axis_neuron
self.update_angle = update_angle
self.update_style = update_style
Expand Down Expand Up @@ -117,6 +122,7 @@ def serialize(self) -> dict:
"a_compress_rate": self.a_compress_rate,
"a_compress_e_rate": self.a_compress_e_rate,
"a_compress_use_split": self.a_compress_use_split,
"n_multi_edge_message": self.n_multi_edge_message,
"axis_neuron": self.axis_neuron,
"update_angle": self.update_angle,
"update_style": self.update_style,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def init_subclass_params(sub_data, sub_class):
a_compress_rate=self.repflow_args.a_compress_rate,
a_compress_e_rate=self.repflow_args.a_compress_e_rate,
a_compress_use_split=self.repflow_args.a_compress_use_split,
n_multi_edge_message=self.repflow_args.n_multi_edge_message,
axis_neuron=self.repflow_args.axis_neuron,
update_angle=self.repflow_args.update_angle,
activation_function=self.activation_function,
Expand Down
35 changes: 24 additions & 11 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
a_compress_rate: int = 0,
a_compress_use_split: bool = False,
a_compress_e_rate: int = 1,
n_multi_edge_message: int = 1,
axis_neuron: int = 4,
update_angle: bool = True, # angle
activation_function: str = "silu",
Expand Down Expand Up @@ -79,6 +80,8 @@ def __init__(
f"For a_compress_rate of {a_compress_rate}, a_dim must be divisible by {2 * a_compress_rate}. "
f"Currently, a_dim={a_dim} is not valid."
)
self.n_multi_edge_message = n_multi_edge_message
assert self.n_multi_edge_message >= 1, "n_multi_edge_message must >= 1!"
self.axis_neuron = axis_neuron
self.update_angle = update_angle
self.activation_function = activation_function
Expand Down Expand Up @@ -144,20 +147,21 @@ def __init__(
# node edge message
self.node_edge_linear = MLPLayer(
self.edge_info_dim,
n_dim,
self.n_multi_edge_message * n_dim,
precision=precision,
seed=child_seed(seed, 4),
)
if self.update_style == "res_residual":
self.n_residual.append(
get_residual(
n_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(seed, 5),
for head_index in range(self.n_multi_edge_message):
self.n_residual.append(
get_residual(
n_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(child_seed(seed, 5), head_index),
)
)
)

# edge self message
self.edge_self_linear = MLPLayer(
Expand Down Expand Up @@ -479,10 +483,18 @@ def forward(
)

# node edge message
# nb x nloc x nnei x n_dim
# nb x nloc x nnei x (h * n_dim)
node_edge_update = self.act(self.node_edge_linear(edge_info)) * sw.unsqueeze(-1)
node_edge_update = torch.sum(node_edge_update, dim=-2) / self.nnei
n_update_list.append(node_edge_update)
if self.n_multi_edge_message > 1:
# nb x nloc x nnei x h x n_dim
node_edge_update_mul_head = node_edge_update.view(
nb, nloc, self.n_multi_edge_message, self.n_dim
)
for head_index in range(self.n_multi_edge_message):
n_update_list.append(node_edge_update_mul_head[:, :, head_index, :])
else:
n_update_list.append(node_edge_update)
# update node_ebd
n_updated = self.list_update(n_update_list, "node")

Expand Down Expand Up @@ -670,6 +682,7 @@ def serialize(self) -> dict:
"a_compress_rate": self.a_compress_rate,
"a_compress_e_rate": self.a_compress_e_rate,
"a_compress_use_split": self.a_compress_use_split,
"n_multi_edge_message": self.n_multi_edge_message,
"axis_neuron": self.axis_neuron,
"activation_function": self.activation_function,
"update_angle": self.update_angle,
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
a_compress_rate: int = 0,
a_compress_e_rate: int = 1,
a_compress_use_split: bool = False,
n_multi_edge_message: int = 1,
axis_neuron: int = 4,
update_angle: bool = True,
activation_function: str = "silu",
Expand Down Expand Up @@ -137,6 +138,9 @@ def __init__(
a_compress_use_split : bool, optional
Whether to split first sub-vectors instead of linear mapping during angular message compression.
The default value is False.
n_multi_edge_message : int, optional
The head number of multiple edge messages to update node feature.
Default is 1, indicating one head edge message.
axis_neuron : int, optional
The number of dimension of submatrix in the symmetrization ops.
update_angle : bool, optional
Expand Down Expand Up @@ -191,6 +195,7 @@ def __init__(
self.split_sel = self.sel
self.a_compress_rate = a_compress_rate
self.a_compress_e_rate = a_compress_e_rate
self.n_multi_edge_message = n_multi_edge_message
self.axis_neuron = axis_neuron
self.set_davg_zero = set_davg_zero
self.skip_stat = skip_stat
Expand Down Expand Up @@ -238,6 +243,7 @@ def __init__(
a_compress_rate=self.a_compress_rate,
a_compress_use_split=self.a_compress_use_split,
a_compress_e_rate=self.a_compress_e_rate,
n_multi_edge_message=self.n_multi_edge_message,
axis_neuron=self.axis_neuron,
update_angle=self.update_angle,
activation_function=self.activation_function,
Expand Down
11 changes: 11 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,10 @@ def dpa3_repflow_args():
"Whether to split first sub-vectors instead of linear mapping during angular message compression. "
"The default value is False."
)
doc_n_multi_edge_message = (
"The head number of multiple edge messages to update node feature. "
"Default is 1, indicating one head edge message."
)
doc_axis_neuron = "The number of dimension of submatrix in the symmetrization ops."
doc_update_angle = (
"Where to update the angle rep. If not, only node and edge rep will be used."
Expand Down Expand Up @@ -1506,6 +1510,13 @@ def dpa3_repflow_args():
default=False,
doc=doc_a_compress_use_split,
),
Argument(
"n_multi_edge_message",
int,
optional=True,
default=1,
doc=doc_n_multi_edge_message,
),
Argument(
"axis_neuron",
int,
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt/model/test_dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ def test_consistency(
rus,
ruri,
acr,
nme,
prec,
ect,
) in itertools.product(
[True, False], # update_angle
["res_residual"], # update_style
["norm", "const"], # update_residual_init
[0, 1], # a_compress_rate
[1, 2], # n_multi_edge_message
["float64"], # precision
[False], # use_econf_tebd
):
Expand All @@ -76,6 +78,7 @@ def test_consistency(
a_rcut_smth=self.rcut_smth,
a_sel=nnei - 1,
a_compress_rate=acr,
n_multi_edge_message=nme,
axis_neuron=4,
update_angle=ua,
update_style=rus,
Expand Down Expand Up @@ -131,13 +134,15 @@ def test_jit(
rus,
ruri,
acr,
nme,
prec,
ect,
) in itertools.product(
[True, False], # update_angle
["res_residual"], # update_style
["norm", "const"], # update_residual_init
[0, 1], # a_compress_rate
[1, 2], # n_multi_edge_message
["float64"], # precision
[False], # use_econf_tebd
):
Expand All @@ -156,6 +161,7 @@ def test_jit(
a_rcut_smth=self.rcut_smth,
a_sel=nnei - 1,
a_compress_rate=acr,
n_multi_edge_message=nme,
axis_neuron=4,
update_angle=ua,
update_style=rus,
Expand Down
3 changes: 3 additions & 0 deletions source/tests/universal/dpmodel/descriptor/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def DescriptorParamDPA3(
update_residual=0.1,
update_residual_init="const",
update_angle=True,
n_multi_edge_message=1,
a_compress_rate=0,
precision="float64",
):
Expand All @@ -493,6 +494,7 @@ def DescriptorParamDPA3(
"a_rcut_smth": rcut_smth / 2,
"a_sel": sum(sel) // 4,
"a_compress_rate": a_compress_rate,
"n_multi_edge_message": n_multi_edge_message,
"axis_neuron": 4,
"update_angle": update_angle,
"update_style": update_style,
Expand Down Expand Up @@ -523,6 +525,7 @@ def DescriptorParamDPA3(
"exclude_types": ([], [[0, 1]]),
"update_angle": (True, False),
"a_compress_rate": (0, 1),
"n_multi_edge_message": (1, 2),
"env_protection": (0.0, 1e-8),
"precision": ("float64",),
}
Expand Down

0 comments on commit 9d9dc8f

Please sign in to comment.