1- # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1515from __future__ import annotations
1616
1717from typing import Dict
18+ from typing import List
1819from typing import Optional
1920from typing import Tuple
2021from typing import Union
2122
23+ import paddle
2224import paddle .nn as nn
2325
2426from ppsci .arch import base
2729
2830
2931class SPINN (base .Arch ):
30- """
31- SPINN: Sparse Interaction Neural Network
32+ """Separable Physics-Informed Neural Networks.
33+
3234 Args:
3335 input_keys (Tuple[str, ...]): Keys of input variables.
3436 output_keys (Tuple[str, ...]): Keys of output variables.
35- r (int): Number of features for each output.
37+ r (int): Number of features for each output dimension .
3638 num_layers (int): Number of layers.
3739 hidden_size (Union[int, Tuple[int, ...]]): Size of hidden layer.
3840 activation (str, optional): Name of activation function.
@@ -46,17 +48,19 @@ class SPINN(base.Arch):
4648 >>> from ppsci.arch import SPINN
4749 >>> model = SPINN(
4850 ... input_keys=('x', 'y', 'z'),
49- ... output_keys=('u',),
50- ... r=316 ,
51+ ... output_keys=('u', 'v' ),
52+ ... r=32 ,
5153 ... num_layers=4,
5254 ... hidden_size=32,
5355 ... )
54- >>> input_dict = {"x": paddle.rand([10 ]),
55- ... "y": paddle.rand([10 ]),
56- ... "z": paddle.rand([10 ])}
56+ >>> input_dict = {"x": paddle.rand([3, 1 ]),
57+ ... "y": paddle.rand([4, 1 ]),
58+ ... "z": paddle.rand([5, 1 ])}
5759 >>> output_dict = model(input_dict)
5860 >>> print(output_dict["u"].shape)
59- [10, 10, 10, 1]
61+ [3, 4, 5, 1]
62+ >>> print(output_dict["v"].shape)
63+ [3, 4, 5, 1]
6064 """
6165
6266 def __init__ (
@@ -106,18 +110,45 @@ def _init_weights(self):
106110 initializer .glorot_normal_ (m .weight )
107111 initializer .zeros_ (m .bias )
108112
109- def forward_tensor (self , x , y , z ):
113+ def _tensor_contraction (self , x : paddle .Tensor , y : paddle .Tensor ) -> paddle .Tensor :
114+ """Tensor contraction between two tensors along the last channel.
115+
116+ Args:
117+ x (Tensor): Input tensor with shape [*N, C].
118+ y (Tensor): Input tensor with shape [*M, C]
119+
120+ Returns:
121+ Tensor: Output tensor with shape [*N, *M, C].
122+ """
123+ x_ndim = x .ndim
124+ y_ndim = y .ndim
125+ out_dim = x_ndim + y_ndim - 1
126+
127+ # Align the dimensions of x and y to out_dim
128+ if x_ndim < out_dim :
129+ # Add singleton dimensions to x at the end of dimensions
130+ x = x .unsqueeze ([- 2 ] * (out_dim - x_ndim ))
131+ if y_ndim < out_dim :
132+ # Add singleton dimensions to y at the begin of dimensions
133+ y = y .unsqueeze ([0 ] * (out_dim - y_ndim ))
134+
135+ # Multiply x and y with implicit broadcasting
136+ out = x * y
137+
138+ return out
139+
140+ def forward_tensor (self , x , y , z ) -> List [paddle .Tensor ]:
110141 # forward each dim branch
111142 feature_f = []
112143 for i , input_var in enumerate ((x , y , z )):
113- input_i = {self .input_keys [i ]: input_var . unsqueeze ( 1 ) }
144+ input_i = {self .input_keys [i ]: input_var }
114145 output_f_i = self .branch_nets [i ](input_i )
115146 feature_f .append (output_f_i ["f" ]) # [B, r*output_dim]
116147
117- # dot product and sum over all branch outputs and
118148 output = []
119149 for i , key in enumerate (self .output_keys ):
120150 st , ed = i * self .r , (i + 1 ) * self .r
151+ # do tensor contraction and sum over all branch outputs
121152 if ed - st == self .r :
122153 output_i = feature_f [0 ]
123154 else :
@@ -128,20 +159,18 @@ def forward_tensor(self, x, y, z):
128159 output_ii = feature_f [j ]
129160 else :
130161 output_ii = feature_f [j ][:, st :ed ]
131- if j != len (self .input_keys ) - 1 :
132- output_i = output_i .unsqueeze (1 ) * output_ii .unsqueeze (0 )
133- else :
134- output_i = (
135- output_i .unsqueeze (2 ) * output_ii .unsqueeze (0 ).unsqueeze (0 )
136- ).sum (axis = - 1 , keepdim = True )
162+ output_i = self ._tensor_contraction (output_i , output_ii )
163+
164+ output_i = output_i .sum (- 1 , keepdim = True )
137165 output .append (output_i )
138166
139- return output [ - 1 ]
167+ return output
140168
141169 def forward (self , x ):
142170 if self ._input_transform is not None :
143171 x = self ._input_transform (x )
144- output = [self .forward_tensor (x ["x" ], x ["y" ], x ["z" ])]
172+
173+ output = self .forward_tensor (* [x [key ] for key in self .input_keys ])
145174
146175 output = {key : output [i ] for i , key in enumerate (self .output_keys )}
147176
0 commit comments