@@ -75,7 +75,6 @@ def __init__(
7575 use_raw_recompute = False ,
7676 recompute_kwargs = {},
7777 raise_value_error = False ,
78- recompute_use_kwargs_as_inputs = False ,
7978 ):
8079 super ().__init__ ()
8180 self .recompute_blocks = recompute_blocks
@@ -116,7 +115,6 @@ def __init__(
116115 self .runfunc2 , self .runfunc3 , self .runfunc4
117116 ),
118117 ]
119- self .recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs
120118
121119 def forward (self , inputs ):
122120 if self .use_fleet_sq and not self .use_raw_recompute :
@@ -137,14 +135,9 @@ def forward(self, inputs):
137135 )
138136 for i in range (len (self .layers )):
139137 if i in self .recompute_blocks :
140- if self .recompute_use_kwargs_as_inputs :
141- inputs = recompute (
142- self .layers [i ], pos = pos , x = inputs , ** recompute_kwargs
143- )
144- else :
145- inputs = recompute (
146- self .layers [i ], inputs , pos , ** recompute_kwargs
147- )
138+ inputs = recompute (
139+ self .layers [i ], inputs , pos , ** recompute_kwargs
140+ )
148141 else :
149142 inputs = self .layers [i ](inputs , pos )
150143
@@ -160,7 +153,6 @@ def run_model(
160153 segments = 1 ,
161154 enable_autocast = False ,
162155 pure_fp16 = False ,
163- recompute_use_kwargs_as_inputs = False ,
164156):
165157 gen = paddle .seed (10 )
166158 gen .manual_seed (10 )
@@ -176,7 +168,6 @@ def run_model(
176168 segments = segments ,
177169 recompute_kwargs = recompute_kwargs ,
178170 raise_value_error = raise_value_error ,
179- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
180171 )
181172
182173 if pure_fp16 :
@@ -217,12 +208,7 @@ def run_model(
217208
218209
219210class TestRecompute (unittest .TestCase ):
220- def test_base_case (
221- self ,
222- enable_autocast = False ,
223- pure_fp16 = False ,
224- recompute_use_kwargs_as_inputs = False ,
225- ):
211+ def test_base_case (self , enable_autocast = False , pure_fp16 = False ):
226212 def check_identical (loss_ref , param_ref , grad_ref , loss , param , grad ):
227213 self .assertEqual (loss_ref , loss )
228214 self .assertEqual (param_ref , param )
@@ -245,7 +231,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
245231 enable_autocast = enable_autocast ,
246232 pure_fp16 = pure_fp16 ,
247233 recompute_kwargs = {"use_reentrant" : flag },
248- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
249234 )
250235 check_identical (loss_ref , param_ref , grad_ref , loss , param , grad )
251236
@@ -255,7 +240,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
255240 enable_autocast = enable_autocast ,
256241 pure_fp16 = pure_fp16 ,
257242 recompute_kwargs = {"use_reentrant" : flag },
258- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
259243 )
260244 check_identical (loss_ref , param_ref , grad_ref , loss , param , grad )
261245
@@ -265,7 +249,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
265249 enable_autocast = enable_autocast ,
266250 pure_fp16 = pure_fp16 ,
267251 recompute_kwargs = {"use_reentrant" : flag },
268- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
269252 )
270253 check_identical (loss_ref , param_ref , grad_ref , loss , param , grad )
271254
@@ -275,7 +258,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
275258 enable_autocast = enable_autocast ,
276259 pure_fp16 = pure_fp16 ,
277260 recompute_kwargs = {"use_reentrant" : flag },
278- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
279261 )
280262 check_identical (loss_ref , param_ref , grad_ref , loss , param , grad )
281263
@@ -286,7 +268,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
286268 enable_autocast = enable_autocast ,
287269 pure_fp16 = pure_fp16 ,
288270 recompute_kwargs = {"use_reentrant" : flag },
289- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
290271 )
291272 check_identical (loss_ref , param_ref , grad_ref , loss , param , grad )
292273
@@ -310,42 +291,31 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
310291
311292 def test_fc_net_with_dropout (self ):
312293 self .test_base_case ()
313- self .test_base_case (recompute_use_kwargs_as_inputs = True )
314294
315295 def test_fc_net_without_restore_rng (self ):
316296 for flag in [True , False ]:
317- for recompute_use_kwargs_as_inputs in [True , False ]:
318- loss_ref , param_ref , grad_ref = run_model (
319- recompute_block = [2 ],
320- recompute_kwargs = {
321- "preserve_rng_state" : False ,
322- "use_reentrant" : flag ,
323- },
324- enable_autocast = True ,
325- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
326- )
297+ loss_ref , param_ref , grad_ref = run_model (
298+ recompute_block = [2 ],
299+ recompute_kwargs = {
300+ "preserve_rng_state" : False ,
301+ "use_reentrant" : flag ,
302+ },
303+ enable_autocast = True ,
304+ )
327305
328306 def test_fc_net_with_amp (self ):
329307 self .test_base_case (enable_autocast = True )
330- self .test_base_case (
331- enable_autocast = True , recompute_use_kwargs_as_inputs = True
332- )
333308
334309 def test_fc_net_with_fp16 (self ):
335310 self .test_base_case (enable_autocast = True , pure_fp16 = True )
336- self .test_base_case (
337- enable_autocast = True ,
338- pure_fp16 = True ,
339- recompute_use_kwargs_as_inputs = True ,
340- )
341311
342312 def test_recompute_kwargs (self ):
343313 paddle .set_device ("gpu" )
344314 pos = paddle .randn (shape = [10 , 10 ], dtype = "float32" )
345315 pos .stop_gradient = False
346316
347317 kwargs = {"pos" : pos , "use_reentrant" : True }
348- with self .assertRaises (TypeError ):
318+ with self .assertRaises (ValueError ):
349319 loss_ref , param_ref , grad_ref = run_model (
350320 recompute_block = [2 ],
351321 recompute_kwargs = kwargs ,
@@ -358,57 +328,46 @@ def test_recompute_kwargs(self):
358328 )
359329
360330 def test_recompute_inputs_with_param (self ):
361- for flag in [True , False ]:
362- for recompute_use_kwargs_as_inputs in [True , False ]:
363- pos = paddle .randn (shape = [10 , 10 ], dtype = "float32" )
364- new_pos = EagerParamBase (
365- shape = pos .shape , dtype = pos .dtype , name = pos .name
366- )
367- pos ._share_buffer_to (new_pos )
368- new_pos .stop_gradient = False
331+ pos = paddle .randn (shape = [10 , 10 ], dtype = "float32" )
332+ new_pos = EagerParamBase (
333+ shape = pos .shape , dtype = pos .dtype , name = pos .name
334+ )
335+ pos ._share_buffer_to (new_pos )
336+ new_pos .stop_gradient = False
369337
370- loss , param , grad = run_model (
371- recompute_block = [2 , 4 ],
372- recompute_kwargs = {"pos" : new_pos , "use_reentrant" : flag },
373- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
374- )
338+ loss , param , grad = run_model (
339+ recompute_block = [], recompute_kwargs = {"pos" : new_pos }
340+ )
375341
376- loss_ref , param_ref , grad_ref = run_model (
377- recompute_block = [1 , 2 , 3 ],
378- recompute_kwargs = {"pos" : new_pos , "use_reentrant" : flag },
379- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
380- )
342+ loss_ref , param_ref , grad_ref = run_model (
343+ recompute_block = [1 , 2 , 3 ], recompute_kwargs = {"pos" : new_pos }
344+ )
381345
382- self .assertEqual (loss_ref , loss )
383- self .assertEqual (param_ref , param )
384- self .assertEqual (grad_ref , grad )
346+ self .assertEqual (loss_ref , loss )
347+ self .assertEqual (param_ref , param )
348+ self .assertEqual (grad_ref , grad )
385349
386350 def test_recompute_inputs_with_tuple (self ):
387- for flag in [True , False ]:
388- for recompute_use_kwargs_as_inputs in [True , False ]:
389- pos = paddle .randn (shape = [10 , 10 ], dtype = "float32" )
390- new_pos = EagerParamBase (
391- shape = pos .shape , dtype = pos .dtype , name = pos .name
392- )
393- pos ._share_buffer_to (new_pos )
394- pos .stop_gradient = False
395- new_pos .stop_gradient = False
396-
397- loss , param , grad = run_model (
398- recompute_block = [2 , 4 ],
399- recompute_kwargs = {"pos" : (pos ,), "use_reentrant" : flag },
400- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
401- )
351+ pos = paddle .randn (shape = [10 , 10 ], dtype = "float32" )
352+ new_pos = EagerParamBase (
353+ shape = pos .shape , dtype = pos .dtype , name = pos .name
354+ )
355+ pos ._share_buffer_to (new_pos )
356+ pos .stop_gradient = False
357+ new_pos .stop_gradient = False
402358
403- loss_ref , param_ref , grad_ref = run_model (
404- recompute_block = [1 , 2 , 3 ],
405- recompute_kwargs = {"pos" : (new_pos ,), "use_reentrant" : flag },
406- recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs ,
407- )
359+ loss , param , grad = run_model (
360+ recompute_block = [2 , 4 ], recompute_kwargs = {"pos" : (pos ,)}
361+ )
362+
363+ loss_ref , param_ref , grad_ref = run_model (
364+ recompute_block = [1 , 2 , 3 ],
365+ recompute_kwargs = {"pos" : (new_pos ,)},
366+ )
408367
409- self .assertEqual (loss_ref , loss )
410- self .assertEqual (param_ref , param )
411- self .assertEqual (grad_ref , grad )
368+ self .assertEqual (loss_ref , loss )
369+ self .assertEqual (param_ref , param )
370+ self .assertEqual (grad_ref , grad )
412371
413372
414373if __name__ == '__main__' :
0 commit comments