From 8abff4c16135bc986a9e1022eac8f6af40f812eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 1 Feb 2024 10:07:57 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/model_format/__init__.py | 8 +- deepmd/model_format/dpa1.py | 23 +- deepmd/model_format/network.py | 58 ++--- deepmd/pt/model/descriptor/dpa1.py | 32 ++- deepmd/pt/model/descriptor/se_atten.py | 328 ++++++++++++++---------- deepmd/pt/model/network/mlp.py | 106 ++++---- source/tests/pt/test_descriptor_dpa1.py | 4 +- source/tests/pt/test_dpa1.py | 91 +++++-- 8 files changed, 388 insertions(+), 262 deletions(-) diff --git a/deepmd/model_format/__init__.py b/deepmd/model_format/__init__.py index 3aa28ec192..b1814c6cb5 100644 --- a/deepmd/model_format/__init__.py +++ b/deepmd/model_format/__init__.py @@ -4,6 +4,9 @@ PRECISION_DICT, NativeOP, ) +from .dpa1 import ( + DescrptDPA1, +) from .env_mat import ( EnvMat, ) @@ -11,11 +14,11 @@ InvarFitting, ) from .network import ( + EmbdLayer, EmbeddingNet, FittingNet, - NativeLayer, - EmbdLayer, LayerNorm, + NativeLayer, NativeNet, NetworkCollection, load_dp_model, @@ -37,7 +40,6 @@ from .se_e2_a import ( DescrptSeA, ) -from .dpa1 import DescrptDPA1 __all__ = [ "InvarFitting", diff --git a/deepmd/model_format/dpa1.py b/deepmd/model_format/dpa1.py index 829339838f..cd367ab5a7 100644 --- a/deepmd/model_format/dpa1.py +++ b/deepmd/model_format/dpa1.py @@ -21,9 +21,9 @@ EnvMat, ) from .network import ( + EmbdLayer, EmbeddingNet, NetworkCollection, - EmbdLayer, ) @@ -147,6 +147,7 @@ class DescrptDPA1(NativeOP): DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation. arXiv preprint arXiv:2208.08236. """ + def __init__( self, rcut: float, @@ -183,7 +184,7 @@ def __init__( if spin is not None: raise NotImplementedError("spin is not implemented") # TODO - if tebd_input_mode != 'concat': + if tebd_input_mode != "concat": raise NotImplementedError("tebd_input_mode != 'concat' not implemented") if not smooth: raise NotImplementedError("smooth == False not implemented") @@ -215,8 +216,10 @@ def __init__( self.concat_output_tebd = concat_output_tebd self.spin = spin - self.type_embedding = EmbdLayer(ntypes, tebd_dim, padding=True, precision=precision) - in_dim = 1 + self.tebd_dim * 2 if self.tebd_input_mode in ['concat'] else 1 + self.type_embedding = EmbdLayer( + ntypes, tebd_dim, padding=True, precision=precision + ) + in_dim = 1 + self.tebd_dim * 2 if self.tebd_input_mode in ["concat"] else 1 self.embeddings = NetworkCollection( ndim=0, ntypes=self.ntypes, @@ -255,8 +258,11 @@ def __getitem__(self, key): @property def dim_out(self): """Returns the output dimension of this descriptor.""" - return self.neuron[-1] * self.axis_neuron + self.tebd_dim * 2 \ - if self.concat_output_tebd else self.neuron[-1] * self.axis_neuron + return ( + self.neuron[-1] * self.axis_neuron + self.tebd_dim * 2 + if self.concat_output_tebd + else self.neuron[-1] * self.axis_neuron + ) def cal_g( self, @@ -302,7 +308,6 @@ def call( sw The smooth switch function. """ - # nf x nloc x nnei x 4 rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd) nf, nloc, nnei, _ = rr.shape @@ -318,7 +323,9 @@ def call( nlist_masked[nlist_masked == -1] = 0 index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) # nf x nloc x nnei x tebd_dim - atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape(nf, nloc, nnei, self.tebd_dim) + atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( + nf, nloc, nnei, self.tebd_dim + ) ng = self.neuron[-1] ss = rr[..., 0:1] ss = np.concatenate([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) diff --git a/deepmd/model_format/network.py b/deepmd/model_format/network.py index 508eb07d56..2fd9727a59 100644 --- a/deepmd/model_format/network.py +++ b/deepmd/model_format/network.py @@ -342,16 +342,17 @@ def __init__( ) -> None: self.padding = padding self.num_channel = num_channel + 1 if self.padding else num_channel - super().__init__(num_in=self.num_channel, - num_out=num_out, - bias=False, - use_timestep=False, - activation_function=None, - resnet=False, - precision=precision, - ) + super().__init__( + num_in=self.num_channel, + num_out=num_out, + bias=False, + use_timestep=False, + activation_function=None, + resnet=False, + precision=precision, + ) if self.padding: - self.w[-1] = 0. + self.w[-1] = 0.0 def serialize(self) -> dict: """Serialize the layer to a dict. @@ -361,9 +362,7 @@ def serialize(self) -> dict: dict The serialized layer. """ - data = { - "w": self.w - } + data = {"w": self.w} return { "padding": self.padding, "precision": self.precision, @@ -390,9 +389,7 @@ def deserialize(cls, data: dict) -> "EmbdLayer": padding=False, **data, ) - obj.w, = ( - variables["w"], - ) + (obj.w,) = (variables["w"],) obj.padding = padding obj.check_shape_consistency() return obj @@ -464,18 +461,19 @@ def __init__( self.eps = eps self.uni_init = uni_init self.num_in = num_in - super().__init__(num_in=1, - num_out=num_in, - bias=True, - use_timestep=False, - activation_function=None, - resnet=False, - precision=precision, - ) + super().__init__( + num_in=1, + num_out=num_in, + bias=True, + use_timestep=False, + activation_function=None, + resnet=False, + precision=precision, + ) self.w = self.w.squeeze(0) # keep the weight shape to be [num_in] if self.uni_init: - self.w = 1. - self.b = 0. + self.w = 1.0 + self.b = 0.0 def serialize(self) -> dict: """Serialize the layer to a dict. @@ -510,17 +508,13 @@ def deserialize(cls, data: dict) -> "LayerNorm": assert len(variables["w"].shape) == 1 if variables["b"] is not None: assert len(variables["b"].shape) == 1 - num_in, = variables["w"].shape + (num_in,) = variables["w"].shape obj = cls( num_in, **data, ) - obj.w, = ( - variables["w"], - ) - obj.b, = ( - variables["b"], - ) + (obj.w,) = (variables["w"],) + (obj.b,) = (variables["b"],) obj._check_shape_consistency() return obj diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 611c9b1179..db54132576 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -6,19 +6,25 @@ import torch +from deepmd.model_format import EnvMat as DPEnvMat from deepmd.pt.model.descriptor import ( Descriptor, ) +from deepmd.pt.model.network.mlp import ( + EmbdLayer, + NetworkCollection, +) from deepmd.pt.model.network.network import ( TypeEmbedNet, ) +from deepmd.pt.utils import ( + env, +) -from .se_atten import DescrptBlockSeAtten, NeighborGatedAttention -from deepmd.pt.model.network.mlp import EmbdLayer, NetworkCollection -from deepmd.model_format import ( - EnvMat as DPEnvMat, +from .se_atten import ( + DescrptBlockSeAtten, + NeighborGatedAttention, ) -from deepmd.pt.utils import env @Descriptor.register("dpa1") @@ -74,7 +80,7 @@ def __init__( normalize=normalize, temperature=temperature, old_impl=old_impl, - **kwargs + **kwargs, ) self.type_embedding_old = None self.type_embedding = None @@ -82,7 +88,9 @@ def __init__( if self.old_impl: self.type_embedding_old = TypeEmbedNet(ntypes, tebd_dim) else: - self.type_embedding = EmbdLayer(ntypes, tebd_dim, padding=True, precision=precision) + self.type_embedding = EmbdLayer( + ntypes, tebd_dim, padding=True, precision=precision + ) self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd @@ -195,9 +203,9 @@ def forward( return g1, rot_mat, g2, h2, sw def set_stat_mean_and_stddev( - self, - mean: torch.Tensor, - stddev: torch.Tensor, + self, + mean: torch.Tensor, + stddev: torch.Tensor, ) -> None: self.se_atten.mean = mean self.se_atten.stddev = stddev @@ -253,5 +261,7 @@ def deserialize(cls, data: dict) -> "DescrptDPA1": obj.se_atten["davg"] = t_cvt(variables["davg"]) obj.se_atten["dstd"] = t_cvt(variables["dstd"]) obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings) - obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) + obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( + attention_layers + ) return obj diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index ba7c533ff7..82681ed345 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -8,11 +8,6 @@ import torch import torch.nn as nn import torch.nn.functional as torch_func -from deepmd.pt.utils.env import ( - PRECISION_DICT, - DEFAULT_PRECISION, -) -from deepmd.pt.utils.utils import ActivationFn from deepmd.pt.model.descriptor.descriptor import ( DescriptorBlock, @@ -21,18 +16,23 @@ from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat_se_a, ) +from deepmd.pt.model.network.mlp import ( + EmbeddingNet, + LayerNorm, + MLPLayer, + NetworkCollection, +) from deepmd.pt.model.network.network import ( NeighborWiseAttention, TypeFilter, ) -from deepmd.pt.model.network.mlp import EmbeddingNet, NetworkCollection, MLPLayer, LayerNorm - -from deepmd.model_format import ( - EnvMat as DPEnvMat, -) from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) @DescriptorBlock.register("se_atten") @@ -46,7 +46,7 @@ def __init__( neuron: list = [25, 50, 100], axis_neuron: int = 16, tebd_dim: int = 8, - tebd_input_mode: str = 'concat', + tebd_input_mode: str = "concat", # set_davg_zero: bool = False, set_davg_zero: bool = True, # TODO attn: int = 128, @@ -106,23 +106,30 @@ def __init__( self.nnei = sum(sel) self.ndescrpt = self.nnei * 4 if self.old_impl: - self.dpa1_attention = NeighborWiseAttention(self.attn_layer, self.nnei, self.filter_neuron[-1], - self.attn_dim, - dotr=self.attn_dotr, do_mask=self.attn_mask, - activation=self.activation_function, - scaling_factor=self.scaling_factor, - normalize=self.normalize, - temperature=self.temperature) + self.dpa1_attention = NeighborWiseAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + activation=self.activation_function, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature, + ) else: - self.dpa1_attention = NeighborGatedAttention(self.attn_layer, - self.nnei, - self.filter_neuron[-1], - self.attn_dim, - dotr=self.attn_dotr, - do_mask=self.attn_mask, - scaling_factor=self.scaling_factor, - normalize=self.normalize, - temperature=self.temperature) + self.dpa1_attention = NeighborGatedAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature, + ) wanted_shape = (self.ntypes, self.nnei, 4) mean = torch.zeros( @@ -133,18 +140,29 @@ def __init__( ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) - self.embd_input_dim = 1 + self.tebd_dim * 2 if self.tebd_input_mode in ['concat'] else 1 + self.embd_input_dim = ( + 1 + self.tebd_dim * 2 if self.tebd_input_mode in ["concat"] else 1 + ) self.filter_layers_old = None self.filter_layers = None if self.old_impl: filter_layers = [] - one = TypeFilter(0, self.nnei, self.filter_neuron, return_G=True, tebd_dim=self.tebd_dim, use_tebd=True, - tebd_mode=self.tebd_input_mode) + one = TypeFilter( + 0, + self.nnei, + self.filter_neuron, + return_G=True, + tebd_dim=self.tebd_dim, + use_tebd=True, + tebd_mode=self.tebd_input_mode, + ) filter_layers.append(one) self.filter_layers_old = torch.nn.ModuleList(filter_layers) else: - filter_layers = NetworkCollection(ndim=0, ntypes=len(sel), network_type="embedding_network") + filter_layers = NetworkCollection( + ndim=0, ntypes=len(sel), network_type="embedding_network" + ) filter_layers[0] = EmbeddingNet( self.embd_input_dim, self.filter_neuron, @@ -338,18 +356,26 @@ def forward( atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) if self.old_impl: assert self.filter_layers_old is not None - dmatrix = dmatrix.view(-1, self.ndescrpt) # shape is [nframes*nall, self.ndescrpt] + dmatrix = dmatrix.view( + -1, self.ndescrpt + ) # shape is [nframes*nall, self.ndescrpt] gg = self.filter_layers_old[0]( dmatrix, atype_tebd=atype_tebd_nnei, nlist_tebd=atype_tebd_nlist, ) # shape is [nframes*nall, self.neei, out_size] - input_r = torch.nn.functional.normalize(dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1) - gg = self.dpa1_attention(gg, nlist_mask, input_r=input_r, - sw=sw) # shape is [nframes*nloc, self.neei, out_size] - inputs_reshape = dmatrix.view(-1, self.nnei, 4).permute(0, 2, - 1) # shape is [nframes*natoms[0], 4, self.neei] - xyz_scatter = torch.matmul(inputs_reshape, gg) # shape is [nframes*natoms[0], 4, out_size] + input_r = torch.nn.functional.normalize( + dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + inputs_reshape = dmatrix.view(-1, self.nnei, 4).permute( + 0, 2, 1 + ) # shape is [nframes*natoms[0], 4, self.neei] + xyz_scatter = torch.matmul( + inputs_reshape, gg + ) # shape is [nframes*natoms[0], 4, out_size] else: assert self.filter_layers is not None dmatrix = dmatrix.view(-1, self.nnei, 4) @@ -357,16 +383,19 @@ def forward( # nfnl x nnei x 4 rr = dmatrix ss = rr[:, :, :1] - if self.tebd_input_mode in ['concat']: + if self.tebd_input_mode in ["concat"]: nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) # nfnl x nnei x (1 + tebd_dim * 2) ss = torch.concat([ss, nlist_tebd, atype_tebd], dim=2) # nfnl x nnei x ng gg = self.filter_layers._networks[0](ss) - input_r = torch.nn.functional.normalize(dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1) - gg = self.dpa1_attention(gg, nlist_mask, input_r=input_r, - sw=sw) # shape is [nframes*nloc, self.neei, out_size] + input_r = torch.nn.functional.normalize( + dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] # nfnl x 4 x ng xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) xyz_scatter = xyz_scatter / self.nnei @@ -386,20 +415,20 @@ def forward( class NeighborGatedAttention(nn.Module): - def __init__(self, - layer_num: int, - nnei: int, - embed_dim: int, - hidden_dim: int, - dotr: bool = False, - do_mask: bool = False, - scaling_factor: float = 1.0, - normalize: bool = True, - temperature: float = None, - precision: str = DEFAULT_PRECISION, - ): - """Construct a neighbor-wise attention net. - """ + def __init__( + self, + layer_num: int, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: float = None, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net.""" super(NeighborGatedAttention, self).__init__() self.layer_num = layer_num self.nnei = nnei @@ -414,31 +443,37 @@ def __init__(self, self.network_type = NeighborGatedAttentionLayer attention_layers = [] for i in range(self.layer_num): - attention_layers.append(NeighborGatedAttentionLayer(nnei, - embed_dim, - hidden_dim, - dotr=dotr, - do_mask=do_mask, - scaling_factor=scaling_factor, - normalize=normalize, - temperature=temperature, - precision=precision)) + attention_layers.append( + NeighborGatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + precision=precision, + ) + ) self.attention_layers = nn.ModuleList(attention_layers) def forward( - self, - input_G, - nei_mask, - input_r: Optional[torch.Tensor] = None, - sw: Optional[torch.Tensor] = None, + self, + input_G, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, ): """ Args: input_G: Input G, [nframes * nloc, nnei, embed_dim] nei_mask: neighbor mask, [nframes * nloc, nnei] input_r: normalized radial, [nframes, nloc, nei, 3] - Returns: - out: Output G, [nframes * nloc, nnei, embed_dim] + + Returns + ------- + out: Output G, [nframes * nloc, nnei, embed_dim] """ out = input_G # https://github.com/pytorch/pytorch/issues/39165#issuecomment-635472592 @@ -475,6 +510,7 @@ def __setitem__(self, key, value): def serialize(self) -> dict: """Serialize the networks to a dict. + Returns ------- dict @@ -493,12 +529,13 @@ def serialize(self) -> dict: "normalize": self.normalize, "temperature": self.temperature, "precision": self.precision, - "attention_layers": [layer.serialize() for layer in self.attention_layers] + "attention_layers": [layer.serialize() for layer in self.attention_layers], } @classmethod def deserialize(cls, data: dict) -> "NeighborGatedAttention": """Deserialize the networks from a dict. + Parameters ---------- data : dict @@ -512,19 +549,19 @@ def deserialize(cls, data: dict) -> "NeighborGatedAttention": class NeighborGatedAttentionLayer(nn.Module): - def __init__(self, - nnei: int, - embed_dim: int, - hidden_dim: int, - dotr: bool = False, - do_mask: bool = False, - scaling_factor: float = 1.0, - normalize: bool = True, - temperature: float = None, - precision: str = DEFAULT_PRECISION, - ): - """Construct a neighbor-wise attention layer. - """ + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: float = None, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention layer.""" super(NeighborGatedAttentionLayer, self).__init__() self.nnei = nnei self.embed_dim = embed_dim @@ -535,24 +572,25 @@ def __init__(self, self.normalize = normalize self.temperature = temperature self.precision = precision - self.attention_layer = GatedAttentionLayer(nnei, - embed_dim, - hidden_dim, - dotr=dotr, - do_mask=do_mask, - scaling_factor=scaling_factor, - normalize=normalize, - temperature=temperature, - precision=precision, - ) + self.attention_layer = GatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + precision=precision, + ) self.attn_layer_norm = LayerNorm(self.embed_dim, precision=precision) def forward( - self, - x, - nei_mask, - input_r: Optional[torch.Tensor] = None, - sw: Optional[torch.Tensor] = None, + self, + x, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, ): residual = x x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) @@ -562,6 +600,7 @@ def forward( def serialize(self) -> dict: """Serialize the networks to a dict. + Returns ------- dict @@ -578,12 +617,13 @@ def serialize(self) -> dict: "temperature": self.temperature, "precision": self.precision, "attention_layer": self.attention_layer.serialize(), - "attn_layer_norm": self.attn_layer_norm.serialize() + "attn_layer_norm": self.attn_layer_norm.serialize(), } @classmethod def deserialize(cls, data: dict) -> "NeighborGatedAttentionLayer": """Deserialize the networks from a dict. + Parameters ---------- data : dict @@ -598,21 +638,21 @@ def deserialize(cls, data: dict) -> "NeighborGatedAttentionLayer": class GatedAttentionLayer(nn.Module): - def __init__(self, - nnei: int, - embed_dim: int, - hidden_dim: int, - dotr: bool = False, - do_mask: bool = False, - scaling_factor: float = 1.0, - normalize: bool = True, - temperature: float = None, - bias: bool = True, - smooth: bool = True, - precision: str = DEFAULT_PRECISION, - ): - """Construct a neighbor-wise attention net. - """ + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: float = None, + bias: bool = True, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net.""" super(GatedAttentionLayer, self).__init__() self.nnei = nnei self.embed_dim = embed_dim @@ -629,26 +669,42 @@ def __init__(self, else: self.scaling = temperature self.normalize = normalize - self.in_proj = MLPLayer(embed_dim, hidden_dim * 3, bias=bias, use_timestep=False, bavg=0., stddev=1., - precision=precision) - self.out_proj = MLPLayer(hidden_dim, embed_dim, bias=bias, use_timestep=False, bavg=0., stddev=1., - precision=precision) + self.in_proj = MLPLayer( + embed_dim, + hidden_dim * 3, + bias=bias, + use_timestep=False, + bavg=0.0, + stddev=1.0, + precision=precision, + ) + self.out_proj = MLPLayer( + hidden_dim, + embed_dim, + bias=bias, + use_timestep=False, + bavg=0.0, + stddev=1.0, + precision=precision, + ) def forward( - self, - query, - nei_mask, - input_r: Optional[torch.Tensor] = None, - sw: Optional[torch.Tensor] = None, - attnw_shift: float = 20.0, + self, + query, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + attnw_shift: float = 20.0, ): """ Args: query: input G, [nframes * nloc, nnei, embed_dim] nei_mask: neighbor mask, [nframes * nloc, nnei] input_r: normalized radial, [nframes, nloc, nei, 3] - Returns: - type_embedding: + + Returns + ------- + type_embedding: """ q, k, v = self.in_proj(query).chunk(3, dim=-1) # [nframes * nloc, nnei, hidden_dim] @@ -669,11 +725,15 @@ def forward( # [nframes * nloc, nnei] assert sw is not None sw = sw.view([-1, self.nnei]) - attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[:, None, :] - attnw_shift + attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[ + :, None, : + ] - attnw_shift else: - attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(1), float("-inf")) + attn_weights = attn_weights.masked_fill( + ~nei_mask.unsqueeze(1), float("-inf") + ) attn_weights = torch_func.softmax(attn_weights, dim=-1) - attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), float(0.0)) + attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), 0.0) if self.smooth: assert sw is not None attn_weights = attn_weights * sw[:, :, None] * sw[:, None, :] @@ -687,6 +747,7 @@ def forward( def serialize(self) -> dict: """Serialize the networks to a dict. + Returns ------- dict @@ -707,12 +768,13 @@ def serialize(self) -> dict: "smooth": self.smooth, "precision": self.precision, "in_proj": self.in_proj.serialize(), - "out_proj": self.out_proj.serialize() + "out_proj": self.out_proj.serialize(), } @classmethod def deserialize(cls, data: dict) -> "GatedAttentionLayer": """Deserialize the networks from a dict. + Parameters ---------- data : dict diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 5bd9c1a23d..2d79fedba0 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -16,11 +16,11 @@ device = env.DEVICE +from deepmd.model_format import EmbdLayer as DPEmbdLayer +from deepmd.model_format import LayerNorm as DPLayerNorm from deepmd.model_format import ( NativeLayer, ) -from deepmd.model_format import EmbdLayer as DPEmbdLayer -from deepmd.model_format import LayerNorm as DPLayerNorm from deepmd.model_format import NetworkCollection as DPNetworkCollection from deepmd.model_format import ( make_embedding_network, @@ -193,24 +193,25 @@ def check_load_param(ss): class EmbdLayer(MLPLayer): def __init__( - self, - num_channel, - num_out, - padding: bool = True, - stddev: float = 1., - precision: str = DEFAULT_PRECISION, + self, + num_channel, + num_out, + padding: bool = True, + stddev: float = 1.0, + precision: str = DEFAULT_PRECISION, ): self.padding = padding self.num_channel = num_channel + 1 if self.padding else num_channel - super().__init__(num_in=self.num_channel, - num_out=num_out, - bias=False, - use_timestep=False, - activation_function=None, - resnet=False, - stddev=stddev, - precision=precision, - ) + super().__init__( + num_in=self.num_channel, + num_out=num_out, + bias=False, + use_timestep=False, + activation_function=None, + resnet=False, + stddev=stddev, + precision=precision, + ) if self.padding: nn.init.zeros_(self.matrix.data[-1]) @@ -218,14 +219,14 @@ def dim_channel(self) -> int: return self.matrix.shape[0] def forward( - self, - xx: torch.Tensor, + self, + xx: torch.Tensor, ) -> torch.Tensor: """One Embedding layer used by DP model. Parameters ---------- - xx: torch.Tensor + xx : torch.Tensor The input of index. Returns @@ -274,36 +275,41 @@ def deserialize(cls, data: dict) -> "EmbdLayer": ) obj.padding = padding prec = PRECISION_DICT[obj.precision] - check_load_param = \ - lambda ss: nn.Parameter(data=torch.tensor(nl[ss], dtype=prec, device=device)) \ - if nl[ss] is not None else None + check_load_param = ( + lambda ss: nn.Parameter( + data=torch.tensor(nl[ss], dtype=prec, device=device) + ) + if nl[ss] is not None + else None + ) obj.matrix = check_load_param("matrix") return obj class LayerNorm(MLPLayer): def __init__( - self, - num_in, - eps: float = 1e-5, - uni_init: bool = True, - bavg: float = 0., - stddev: float = 1., - precision: str = DEFAULT_PRECISION, + self, + num_in, + eps: float = 1e-5, + uni_init: bool = True, + bavg: float = 0.0, + stddev: float = 1.0, + precision: str = DEFAULT_PRECISION, ): self.eps = eps self.uni_init = uni_init self.num_in = num_in - super().__init__(num_in=1, - num_out=num_in, - bias=True, - use_timestep=False, - activation_function=None, - resnet=False, - bavg=bavg, - stddev=stddev, - precision=precision, - ) + super().__init__( + num_in=1, + num_out=num_in, + bias=True, + use_timestep=False, + activation_function=None, + resnet=False, + bavg=bavg, + stddev=stddev, + precision=precision, + ) self.matrix = torch.nn.Parameter(self.matrix.squeeze(0)) if self.uni_init: nn.init.ones_(self.matrix.data) @@ -313,14 +319,14 @@ def dim_out(self) -> int: return self.matrix.shape[0] def forward( - self, - xx: torch.Tensor, + self, + xx: torch.Tensor, ) -> torch.Tensor: """One Layer Norm used by DP model. Parameters ---------- - xx: torch.Tensor + xx : torch.Tensor The input of index. Returns @@ -328,7 +334,9 @@ def forward( yy: torch.Tensor The output. """ - yy = torch_func.layer_norm(xx, tuple((self.num_in,)), self.matrix, self.bias, self.eps) + yy = torch_func.layer_norm( + xx, tuple((self.num_in,)), self.matrix, self.bias, self.eps + ) return yy def serialize(self) -> dict: @@ -365,9 +373,13 @@ def deserialize(cls, data: dict) -> "LayerNorm": precision=nl["precision"], ) prec = PRECISION_DICT[obj.precision] - check_load_param = \ - lambda ss: nn.Parameter(data=torch.tensor(nl[ss], dtype=prec, device=device)) \ - if nl[ss] is not None else None + check_load_param = ( + lambda ss: nn.Parameter( + data=torch.tensor(nl[ss], dtype=prec, device=device) + ) + if nl[ss] is not None + else None + ) obj.matrix = check_load_param("matrix") obj.bias = check_load_param("bias") return obj diff --git a/source/tests/pt/test_descriptor_dpa1.py b/source/tests/pt/test_descriptor_dpa1.py index 2caeb5890e..a7a0f9a4d6 100644 --- a/source/tests/pt/test_descriptor_dpa1.py +++ b/source/tests/pt/test_descriptor_dpa1.py @@ -12,7 +12,9 @@ DescrptBlockSeAtten, DescrptDPA1, ) -from deepmd.pt.model.network.mlp import EmbdLayer +from deepmd.pt.model.network.mlp import ( + EmbdLayer, +) from deepmd.pt.utils import ( env, ) diff --git a/source/tests/pt/test_dpa1.py b/source/tests/pt/test_dpa1.py index 400a611c05..ffd5a25bd5 100644 --- a/source/tests/pt/test_dpa1.py +++ b/source/tests/pt/test_dpa1.py @@ -1,12 +1,13 @@ -import torch, copy -import unittest +# SPDX-License-Identifier: LGPL-3.0-or-later import itertools +import unittest + import numpy as np +import torch try: - from deepmd.model_format import ( - DescrptDPA1 as DPDescrptDPA1 - ) + from deepmd.model_format import DescrptDPA1 as DPDescrptDPA1 + support_se_atten = True except ModuleNotFoundError: support_se_atten = False @@ -14,7 +15,7 @@ support_se_atten = False from deepmd.pt.model.descriptor.dpa1 import ( - DescrptDPA1 + DescrptDPA1, ) from deepmd.pt.utils import ( env, @@ -32,13 +33,14 @@ dtype = env.GLOBAL_PT_FLOAT_PRECISION + @unittest.skipIf(not support_se_atten, "EnvMat not supported") class TestDescrptSeAtten(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): TestCaseSingleFrameWithNlist.setUp(self) def test_consistency( - self, + self, ): rng = np.random.default_rng(100) nf, nloc, nnei = self.nlist.shape @@ -47,15 +49,19 @@ def test_consistency( dstd = 0.1 + np.abs(dstd) for idt, prec in itertools.product( - [False, True], - ["float64", "float32"], + [False, True], + ["float64", "float32"], ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) err_msg = f"idt={idt} prec={prec}" # dpa1 new impl dd0 = DescrptDPA1( - self.rcut, self.rcut_smth, self.sel, self.nt, attn_layer=0, # TODO add support for non-zero layer + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + attn_layer=0, # TODO add support for non-zero layer # precision=prec, # resnet_dt=idt, old_impl=False, @@ -75,22 +81,34 @@ def test_consistency( torch.tensor(self.nlist, dtype=int, device=env.DEVICE), ) np.testing.assert_allclose( - rd0.detach().cpu().numpy(), rd1.detach().cpu().numpy(), - rtol=rtol, atol=atol, err_msg=err_msg, + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, ) # dp impl dd2 = DPDescrptDPA1.deserialize(dd0.serialize()) rd2, _, _, _, _ = dd2.call( - self.coord_ext, self.atype_ext, self.nlist, + self.coord_ext, + self.atype_ext, + self.nlist, ) np.testing.assert_allclose( - rd0.detach().cpu().numpy(), rd2, - rtol=rtol, atol=atol, err_msg=err_msg, + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, ) # old impl if idt is False and prec == "float64": dd3 = DescrptDPA1( - self.rcut, self.rcut_smth, self.sel, self.nt, attn_layer=0, # TODO add support for non-zero layer + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + attn_layer=0, # TODO add support for non-zero layer # precision=prec, # resnet_dt=idt, old_impl=True, @@ -101,16 +119,29 @@ def test_consistency( dd0_state_dict_attn = dd0.se_atten.dpa1_attention.state_dict() dd3_state_dict_attn = dd3.se_atten.dpa1_attention.state_dict() for i in dd3_state_dict: - dd3_state_dict[i] = dd0_state_dict[i.replace('.deep_layers.', '.layers.') - .replace('filter_layers_old.', 'filter_layers._networks.').replace('.attn_layer_norm.weight', '.attn_layer_norm.matrix')].detach().clone() - if '.bias' in i and 'attn_layer_norm' not in i: + dd3_state_dict[i] = ( + dd0_state_dict[ + i.replace(".deep_layers.", ".layers.") + .replace("filter_layers_old.", "filter_layers._networks.") + .replace( + ".attn_layer_norm.weight", ".attn_layer_norm.matrix" + ) + ] + .detach() + .clone() + ) + if ".bias" in i and "attn_layer_norm" not in i: dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0) dd3.se_atten.load_state_dict(dd3_state_dict) dd0_state_dict_tebd = dd0.type_embedding.state_dict() dd3_state_dict_tebd = dd3.type_embedding_old.state_dict() for i in dd3_state_dict_tebd: - dd3_state_dict_tebd[i] = dd0_state_dict_tebd[i.replace('embedding.weight', 'matrix')].detach().clone() + dd3_state_dict_tebd[i] = ( + dd0_state_dict_tebd[i.replace("embedding.weight", "matrix")] + .detach() + .clone() + ) dd3.type_embedding_old.load_state_dict(dd3_state_dict_tebd) rd3, _, _, _, _ = dd3( @@ -119,12 +150,15 @@ def test_consistency( torch.tensor(self.nlist, dtype=int, device=env.DEVICE), ) np.testing.assert_allclose( - rd0.detach().cpu().numpy(), rd3.detach().cpu().numpy(), - rtol=rtol, atol=atol, err_msg=err_msg, + rd0.detach().cpu().numpy(), + rd3.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, ) def test_jit( - self, + self, ): rng = np.random.default_rng() nf, nloc, nnei = self.nlist.shape @@ -133,15 +167,18 @@ def test_jit( dstd = 0.1 + np.abs(dstd) for idt, prec in itertools.product( - [False, True], - ["float64", "float32"], + [False, True], + ["float64", "float32"], ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) err_msg = f"idt={idt} prec={prec}" # sea new impl dd0 = DescrptDPA1( - self.rcut, self.rcut_smth, self.sel, self.nt, + self.rcut, + self.rcut_smth, + self.sel, + self.nt, # precision=prec, # resnet_dt=idt, old_impl=False, @@ -150,4 +187,4 @@ def test_jit( dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) # dd1 = DescrptDPA1.deserialize(dd0.serialize()) model = torch.jit.script(dd0) - # model = torch.jit.script(dd1) \ No newline at end of file + # model = torch.jit.script(dd1)