Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions deepmd/model_format/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,18 @@ def call(
Returns
-------
descriptor
The descriptor. shape: nf x nloc x ng x axis_neuron
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
this descriptor returns None
h2
The rotationally equivariant pair-partical representation.
this descriptor returns None
sw
The smooth switch function.
"""
# nf x nloc x nnei x 4
rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd)
Expand All @@ -238,15 +249,17 @@ def call(
gg = self.cal_g(ss, tt)
# nf x nloc x ng x 4
gr += np.einsum("flni,flnj->flij", gg, tr)
# nf x nloc x ng x 4
gr /= self.nnei
gr1 = gr[:, :, : self.axis_neuron, :]
# nf x nloc x ng x ng1
grrg = np.einsum("flid,fljd->flij", gr, gr1)
# nf x nloc x (ng x ng1)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron)
return grrg
return grrg, gr[..., 1:], None, None, ww

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
return {
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
Expand All @@ -271,6 +284,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "DescrptSeA":
"""Deserialize from dict."""
data = copy.deepcopy(data)
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
Expand Down
35 changes: 33 additions & 2 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,42 @@ def forward(
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
):
"""Compute the descriptor.

Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.

Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
shape: nf x nloc x nnei x ng
h2
The rotationally equivariant pair-partical representation.
shape: nf x nloc x nnei x 3
sw
The smooth switch function. shape: nf x nloc x nnei

"""
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
g1, env_mat, diff, rot_mat, sw = self.se_atten(
g1, g2, h2, rot_mat, sw = self.se_atten(
nlist,
extended_coord,
extended_atype,
Expand All @@ -149,4 +179,5 @@ def forward(
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, env_mat, diff, rot_mat, sw

return g1, rot_mat, g2, h2, sw
32 changes: 31 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,36 @@ def forward(
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
):
"""Compute the descriptor.

Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, mapps extended region index to local region.

Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
shape: nf x nloc x nnei x ng
h2
The rotationally equivariant pair-partical representation.
shape: nf x nloc x nnei x 3
sw
The smooth switch function. shape: nf x nloc x nnei

"""
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
# nlists
Expand Down Expand Up @@ -372,4 +402,4 @@ def forward(
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, g2, h2, rot_mat, sw
return g1, rot_mat, g2, h2, sw
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def forward(
# (nb x nloc) x ng2 x 3
rot_mat = torch.permute(h2g2, (0, 1, 3, 2))

return g1, g2, h2, rot_mat.view(-1, self.dim_emb, 3), sw
return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw

def compute_input_stats(self, merged):
"""Update mean and stddev for descriptor elements."""
Expand Down
46 changes: 39 additions & 7 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,42 @@ def get_data_process_key(cls, config):

def forward(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
coord_ext: torch.Tensor,
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
):
return self.sea.forward(nlist, extended_coord, extended_atype, None, mapping)
"""Compute the descriptor.

Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.

Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
this descriptor returns None
h2
The rotationally equivariant pair-partical representation.
this descriptor returns None
sw
The smooth switch function.

"""
return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -389,7 +419,7 @@ def forward(
del extended_atype_embd, mapping
nloc = nlist.shape[1]
atype = extended_atype[:, :nloc]
dmatrix, diff, _ = prod_env_mat_se_a(
dmatrix, diff, sw = prod_env_mat_se_a(
extended_coord,
nlist,
atype,
Expand Down Expand Up @@ -438,12 +468,14 @@ def forward(
result = torch.matmul(
xyz_scatter_1, xyz_scatter_2
) # shape is [nframes*nall, self.filter_neuron[-1], self.axis_neuron]
result = result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([-1, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
return (
result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron),
None,
None,
result,
rot_mat,
None,
None,
sw,
)


Expand Down
7 changes: 3 additions & 4 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,8 @@ def forward(
self.rcut,
self.rcut_smth,
)
dmatrix = dmatrix.view(
-1, self.ndescrpt
) # shape is [nframes*nall, self.ndescrpt]
# [nfxnlocxnnei, self.ndescrpt]
dmatrix = dmatrix.view(-1, self.ndescrpt)
nlist_mask = nlist != -1
nlist[nlist == -1] = 0
sw = torch.squeeze(sw, -1)
Expand Down Expand Up @@ -328,7 +327,7 @@ def forward(
return (
result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron),
ret.view(-1, nloc, self.nnei, self.filter_neuron[-1]),
diff,
dmatrix.view(-1, nloc, self.nnei, 4)[..., 1:],
rot_mat.view(-1, self.filter_neuron[-1], 3),
sw,
)
Expand Down
69 changes: 15 additions & 54 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Descriptor,
)
from deepmd.pt.model.task import (
DenoiseNet,
Fitting,
)

Expand Down Expand Up @@ -93,40 +92,20 @@ def __init__(
sampled=sampled,
)

# Fitting
if fitting_net:
fitting_net["type"] = fitting_net.get("type", "ener")
if self.descriptor_type not in ["se_e2_a"]:
fitting_net["ntypes"] = 1
else:
fitting_net["ntypes"] = self.descriptor.get_ntype()
fitting_net["use_tebd"] = False
fitting_net["embedding_width"] = self.descriptor.dim_out

self.grad_force = "direct" not in fitting_net["type"]
if not self.grad_force:
fitting_net["out_dim"] = self.descriptor.dim_emb
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
self.fitting_net = Fitting(**fitting_net)
fitting_net["type"] = fitting_net.get("type", "ener")
if self.descriptor_type not in ["se_e2_a"]:
fitting_net["ntypes"] = 1
else:
self.fitting_net = None
self.grad_force = False
if not self.split_nlist:
self.coord_denoise_net = DenoiseNet(
self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb
)
elif self.combination:
self.coord_denoise_net = DenoiseNet(
self.descriptor.dim_out,
self.ntypes - 1,
self.descriptor.dim_emb_list,
self.prefactor,
)
else:
self.coord_denoise_net = DenoiseNet(
self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb
)
fitting_net["ntypes"] = self.descriptor.get_ntype()
fitting_net["use_tebd"] = False
fitting_net["embedding_width"] = self.descriptor.dim_out

self.grad_force = "direct" not in fitting_net["type"]
if not self.grad_force:
fitting_net["out_dim"] = self.descriptor.dim_emb
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
self.fitting_net = Fitting(**fitting_net)

def get_fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand Down Expand Up @@ -178,31 +157,13 @@ def forward_atomic(
atype = extended_atype[:, :nloc]
if self.do_grad():
extended_coord.requires_grad_(True)
descriptor, env_mat, diff, rot_mat, sw = self.descriptor(
descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
)
assert descriptor is not None
# energy, force
if self.fitting_net is not None:
fit_ret = self.fitting_net(
descriptor, atype, atype_tebd=None, rot_mat=rot_mat
)
# denoise
else:
nlist_list = [nlist]
if not self.split_nlist:
nnei_mask = nlist != -1
elif self.combination:
nnei_mask = []
for item in nlist_list:
nnei_mask_item = item != -1
nnei_mask.append(nnei_mask_item)
else:
env_mat = env_mat[-1]
diff = diff[-1]
nnei_mask = nlist_list[-1] != -1
fit_ret = self.coord_denoise_net(env_mat, diff, nnei_mask, descriptor, sw)
fit_ret = self.fitting_net(descriptor, atype, atype_tebd=None, rot_mat=rot_mat)
return fit_ret
3 changes: 3 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,11 @@ def forward(
inputs
) # Shape is [nframes, nloc, m1]
assert list(vec_out.size()) == [nframes, nloc, self.out_dim]
# (nf x nloc) x 1 x od
vec_out = vec_out.view(-1, 1, self.out_dim)
assert rot_mat is not None
# (nf x nloc) x od x 3
rot_mat = rot_mat.view(-1, self.out_dim, 3)
vec_out = (
torch.bmm(vec_out, rot_mat).squeeze(-2).view(nframes, nloc, 3)
) # Shape is [nframes, nloc, 3]
Expand Down
3 changes: 2 additions & 1 deletion source/tests/common/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,5 @@ def test_self_consistency(
em1 = DescrptSeA.deserialize(em0.serialize())
mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist)
mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist)
np.testing.assert_allclose(mm0, mm1)
for ii in [0, 1, 4]:
np.testing.assert_allclose(mm0[ii], mm1[ii])
2 changes: 2 additions & 0 deletions source/tests/pt/test_permutation_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test(
)


@unittest.skip("support of the denoise is temporally disabled")
class TestDenoiseModelDPA1(unittest.TestCase, PermutationDenoiseTest):
def setUp(self):
model_params = copy.deepcopy(model_dpa1)
Expand All @@ -74,6 +75,7 @@ def setUp(self):
self.model = get_model(model_params, sampled).to(env.DEVICE)


@unittest.skip("support of the denoise is temporally disabled")
class TestDenoiseModelDPA2(unittest.TestCase, PermutationDenoiseTest):
def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
Expand Down
2 changes: 2 additions & 0 deletions source/tests/pt/test_rot_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test(
)


@unittest.skip("support of the denoise is temporally disabled")
class TestDenoiseModelDPA1(unittest.TestCase, RotDenoiseTest):
def setUp(self):
model_params = copy.deepcopy(model_dpa1)
Expand All @@ -105,6 +106,7 @@ def setUp(self):
self.model = get_model(model_params, sampled).to(env.DEVICE)


@unittest.skip("support of the denoise is temporally disabled")
class TestDenoiseModelDPA2(unittest.TestCase, RotDenoiseTest):
def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
Expand Down
Loading