1+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
115"""
216Sympy to python function conversion module
317"""
418
19+ from __future__ import annotations
20+
521import functools
622from typing import TYPE_CHECKING
723from typing import Dict
1733
1834from ppsci .autodiff import hessian
1935from ppsci .autodiff import jacobian
20- from ppsci .utils import logger
2136
2237if TYPE_CHECKING :
2338 from ppsci import arch
@@ -235,7 +250,7 @@ def __init__(self, expr: Union[sp.Number, sp.NumberSymbol]):
235250 self .expr = float (self .expr )
236251 else :
237252 raise TypeError (
238- f"expr({ expr } ) should be float/int/bool , but got { type (self .expr )} "
253+ f"expr({ expr } ) should be Float/Integer/Boolean/Rational , but got { type (self .expr )} "
239254 )
240255 self .expr = paddle .to_tensor (self .expr )
241256
@@ -253,10 +268,9 @@ class ComposedNode(nn.Layer):
253268 Compose list of several callable objects together.
254269 """
255270
256- def __init__ (self , target : str , funcs : List [Node ]):
271+ def __init__ (self , funcs : List [Node ]):
257272 super ().__init__ ()
258273 self .funcs = funcs
259- self .target = target
260274
261275 def forward (self , data_dict : Dict ):
262276 # call all funcs in order
@@ -299,19 +313,57 @@ def _post_traverse(cur_node: sp.Basic, nodes: List[sp.Basic]) -> List[sp.Basic]:
299313
300314
301315def sympy_to_function (
302- target : str ,
303316 expr : sp .Expr ,
304317 models : Optional [Union [arch .Arch , Tuple [arch .Arch , ...]]] = None ,
305318) -> ComposedNode :
306319 """Convert sympy expression to callable function.
307320
308321 Args:
309- target (str): Alias of `expr`, such as "z" for expression: "z = a + b * c".
310322 expr (sp.Expr): Sympy expression to be converted.
311323 models (Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]]): Model(s) for computing forward result in `LayerNode`.
312324
313325 Returns:
314326 ComposedNode: Callable object for computing expr with necessary input(s) data in dict given.
327+
328+ Examples:
329+ >>> import paddle
330+ >>> import sympy as sp
331+ >>> from ppsci import arch
332+ >>> from ppsci.utils import sym_to_func
333+
334+ >>> a, b, c, x, y = sp.symbols("a b c x y")
335+ >>> u = sp.Function("u")(x, y)
336+ >>> v = sp.Function("v")(x, y)
337+ >>> z = -a + b * (c ** 2) + u * v + 2.3
338+
339+ >>> model = arch.MLP(("x", "y"), ("u", "v"), 4, 16)
340+
341+ >>> batch_size = 13
342+ >>> a_tensor = paddle.randn([batch_size, 1])
343+ >>> b_tensor = paddle.randn([batch_size, 1])
344+ >>> c_tensor = paddle.randn([batch_size, 1])
345+ >>> x_tensor = paddle.randn([batch_size, 1])
346+ >>> y_tensor = paddle.randn([batch_size, 1])
347+
348+ >>> model_output_dict = model({"x": x_tensor, "y": y_tensor})
349+ >>> u_tensor, v_tensor = model_output_dict["u"], model_output_dict["v"]
350+
351+ >>> z_tensor_manually = (
352+ ... -a_tensor + b_tensor * (c_tensor ** 2)
353+ ... + u_tensor * v_tensor + 2.3
354+ ... )
355+ >>> z_tensor_sympy = sym_to_func.sympy_to_function(z, model)(
356+ ... {
357+ ... "a": a_tensor,
358+ ... "b": b_tensor,
359+ ... "c": c_tensor,
360+ ... "x": x_tensor,
361+ ... "y": y_tensor,
362+ ... }
363+ ... )
364+
365+ >>> paddle.allclose(z_tensor_manually, z_tensor_sympy).item()
366+ True
315367 """
316368
317369 # simplify expression to reduce nodes in tree
@@ -330,9 +382,10 @@ def sympy_to_function(
330382 sympy_nodes = list (dict .fromkeys (sympy_nodes ))
331383
332384 # convert sympy node to callable node
385+ if not isinstance (models , (tuple , list )):
386+ models = (models ,)
333387 callable_nodes = []
334388 for i , node in enumerate (sympy_nodes ):
335- logger .debug (f"tree node [{ i + 1 } /{ len (sympy_nodes )} ]: { node } " )
336389 if isinstance (node .func , sp .core .function .UndefinedFunction ):
337390 match = False
338391 for model in models :
@@ -359,4 +412,4 @@ def sympy_to_function(
359412 )
360413
361414 # Compose callable nodes into one callable object
362- return ComposedNode (target , callable_nodes )
415+ return ComposedNode (callable_nodes )
0 commit comments