@@ -89,66 +89,10 @@ def train(cfg: DictConfig):
8989 # evaluate after finished training
9090 solver .eval ()
9191
92- # visualize prediction for different functions u and corresponding G(u)
93- dtype = paddle .get_default_dtype ()
94-
95- def generate_y_u_G_ref (
96- u_func : Callable , G_u_func : Callable
97- ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
98- """Generate discretized data of given function u and corresponding G(u).
99-
100- Args:
101- u_func (Callable): Function u.
102- G_u_func (Callable): Function G(u).
92+ def predict_func (input_dict ):
93+ return solver .predict (input_dict , return_numpy = True )[cfg .MODEL .G_key ]
10394
104- Returns:
105- Tuple[np.ndarray, np.ndarray, np.ndarray]: Discretized data of u, y and G(u).
106- """
107- x = np .linspace (0 , 1 , cfg .MODEL .num_loc , dtype = dtype ).reshape (
108- [1 , cfg .MODEL .num_loc ]
109- )
110- u = u_func (x )
111- u = np .tile (u , [cfg .NUM_Y , 1 ])
112-
113- y = np .linspace (0 , 1 , cfg .NUM_Y , dtype = dtype ).reshape ([cfg .NUM_Y , 1 ])
114- G_ref = G_u_func (y )
115- return u , y , G_ref
116-
117- func_u_G_pair = [
118- # (title_string, func_u, func_G(u)), s.t. dG/dx == u and G(u)(0) = 0
119- (r"$u=\cos(x), G(u)=sin(x$)" , lambda x : np .cos (x ), lambda y : np .sin (y )), # 1
120- (
121- r"$u=sec^2(x), G(u)=tan(x$)" ,
122- lambda x : (1 / np .cos (x )) ** 2 ,
123- lambda y : np .tan (y ),
124- ), # 2
125- (
126- r"$u=sec(x)tan(x), G(u)=sec(x) - 1$" ,
127- lambda x : (1 / np .cos (x ) * np .tan (x )),
128- lambda y : 1 / np .cos (y ) - 1 ,
129- ), # 3
130- (
131- r"$u=1.5^x\ln{1.5}, G(u)=1.5^x-1$" ,
132- lambda x : 1.5 ** x * np .log (1.5 ),
133- lambda y : 1.5 ** y - 1 ,
134- ), # 4
135- (r"$u=3x^2, G(u)=x^3$" , lambda x : 3 * x ** 2 , lambda y : y ** 3 ), # 5
136- (r"$u=4x^3, G(u)=x^4$" , lambda x : 4 * x ** 3 , lambda y : y ** 4 ), # 6
137- (r"$u=5x^4, G(u)=x^5$" , lambda x : 5 * x ** 4 , lambda y : y ** 5 ), # 7
138- (r"$u=6x^5, G(u)=x^6$" , lambda x : 5 * x ** 4 , lambda y : y ** 5 ), # 8
139- (r"$u=e^x, G(u)=e^x-1$" , lambda x : np .exp (x ), lambda y : np .exp (y ) - 1 ), # 9
140- ]
141-
142- os .makedirs (os .path .join (cfg .output_dir , "visual" ), exist_ok = True )
143- for i , (title , u_func , G_func ) in enumerate (func_u_G_pair ):
144- u , y , G_ref = generate_y_u_G_ref (u_func , G_func )
145- G_pred = solver .predict ({"u" : u , "y" : y }, return_numpy = True )["G" ]
146- plt .plot (y , G_pred , label = r"$G(u)(y)_{ref}$" )
147- plt .plot (y , G_ref , label = r"$G(u)(y)_{pred}$" )
148- plt .legend ()
149- plt .title (title )
150- plt .savefig (os .path .join (cfg .output_dir , "visual" , f"func_{ i } _result.png" ))
151- plt .clf ()
95+ plot (cfg , predict_func )
15296
15397
15498def evaluate (cfg : DictConfig ):
@@ -189,6 +133,50 @@ def evaluate(cfg: DictConfig):
189133 )
190134 solver .eval ()
191135
136+ def predict_func (input_dict ):
137+ return solver .predict (input_dict , return_numpy = True )[cfg .MODEL .G_key ]
138+
139+ plot (cfg , predict_func )
140+
141+
142+ def export (cfg : DictConfig ):
143+ # set model
144+ model = ppsci .arch .DeepONet (** cfg .MODEL )
145+
146+ # initialize solver
147+ solver = ppsci .solver .Solver (
148+ model ,
149+ pretrained_model_path = cfg .INFER .pretrained_model_path ,
150+ )
151+
152+ # export model
153+ from paddle .static import InputSpec
154+
155+ input_spec = [
156+ {
157+ model .input_keys [0 ]: InputSpec (
158+ [None , 1000 ], "float32" , name = model .input_keys [0 ]
159+ ),
160+ model .input_keys [1 ]: InputSpec (
161+ [None , 1 ], "float32" , name = model .input_keys [1 ]
162+ ),
163+ }
164+ ]
165+ solver .export (input_spec , cfg .INFER .export_path )
166+
167+
168+ def inference (cfg : DictConfig ):
169+ from deploy import python_infer
170+
171+ predictor = python_infer .GeneralPredictor (cfg )
172+
173+ def predict_func (input_dict ):
174+ return next (iter (predictor .predict (input_dict ).values ()))
175+
176+ plot (cfg , predict_func )
177+
178+
179+ def plot (cfg : DictConfig , predict_func : Callable ):
192180 # visualize prediction for different functions u and corresponding G(u)
193181 dtype = paddle .get_default_dtype ()
194182
@@ -242,13 +230,17 @@ def generate_y_u_G_ref(
242230 os .makedirs (os .path .join (cfg .output_dir , "visual" ), exist_ok = True )
243231 for i , (title , u_func , G_func ) in enumerate (func_u_G_pair ):
244232 u , y , G_ref = generate_y_u_G_ref (u_func , G_func )
245- G_pred = solver . predict ({"u" : u , "y" : y }, return_numpy = True )[ "G" ]
233+ G_pred = predict_func ({"u" : u , "y" : y })
246234 plt .plot (y , G_pred , label = r"$G(u)(y)_{ref}$" )
247235 plt .plot (y , G_ref , label = r"$G(u)(y)_{pred}$" )
248236 plt .legend ()
249237 plt .title (title )
250238 plt .savefig (os .path .join (cfg .output_dir , "visual" , f"func_{ i } _result.png" ))
239+ logger .message (
240+ f"Saved result of function { i } to { cfg .output_dir } /visual/func_{ i } _result.png"
241+ )
251242 plt .clf ()
243+ plt .close ()
252244
253245
254246@hydra .main (version_base = None , config_path = "./conf" , config_name = "deeponet.yaml" )
@@ -257,8 +249,14 @@ def main(cfg: DictConfig):
257249 train (cfg )
258250 elif cfg .mode == "eval" :
259251 evaluate (cfg )
252+ elif cfg .mode == "export" :
253+ export (cfg )
254+ elif cfg .mode == "infer" :
255+ inference (cfg )
260256 else :
261- raise ValueError (f"cfg.mode should in ['train', 'eval'], but got '{ cfg .mode } '" )
257+ raise ValueError (
258+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{ cfg .mode } '"
259+ )
262260
263261
264262if __name__ == "__main__" :
0 commit comments