@@ -36,21 +36,21 @@ def test_multi_trainer():
3636 x = gluon .Parameter ('x' , shape = (10 ,), stype = 'row_sparse' )
3737 x .initialize ()
3838 # test set trainer
39- trainer0 = gluon .Trainer ([ x ] , 'sgd' )
39+ trainer0 = gluon .Trainer ({ 'x' : x } , 'sgd' )
4040 assert (x ._trainer () is trainer0 )
4141 # test unset trainer
4242 x ._set_trainer (None )
4343 assert (x ._trainer is None )
4444 x ._set_trainer (trainer0 )
4545 with pytest .raises (RuntimeError ):
4646 # multiple trainers for a sparse Parameter is not allowed
47- trainer1 = gluon .Trainer ([ x ] , 'sgd' )
47+ trainer1 = gluon .Trainer ({ 'x' : x } , 'sgd' )
4848
4949@with_seed ()
5050def test_trainer_with_sparse_grad_on_single_context ():
5151 x = gluon .Parameter ('x' , shape = (10 ,), grad_stype = 'row_sparse' )
5252 x .initialize (ctx = [mx .cpu (0 )], init = 'zeros' )
53- trainer = gluon .Trainer ([ x ] , 'sgd' , {'learning_rate' : 1.0 , 'momentum' : 0.5 })
53+ trainer = gluon .Trainer ({ 'x' : x } , 'sgd' , {'learning_rate' : 1.0 , 'momentum' : 0.5 })
5454 with mx .autograd .record ():
5555 for w in x .list_data ():
5656 y = w + 1
@@ -66,7 +66,7 @@ def test_trainer_with_teststore():
6666 x = gluon .Parameter ('x' , shape = (10 ,))
6767 x .initialize (ctx = [mx .cpu (0 ), mx .cpu (1 )], init = 'zeros' )
6868 kv = mx .kv .create ('teststore' )
69- trainer = gluon .Trainer ([ x ] , 'sgd' , {'learning_rate' : 1.0 , 'momentum' : 0.5 }, kvstore = kv )
69+ trainer = gluon .Trainer ({ 'x' : x } , 'sgd' , {'learning_rate' : 1.0 , 'momentum' : 0.5 }, kvstore = kv )
7070 with mx .autograd .record ():
7171 for w in x .list_data ():
7272 y = w + 1
@@ -77,14 +77,14 @@ def test_trainer_with_teststore():
7777 assert (x .data (mx .cpu (1 )).asnumpy () == - 2 ).all ()
7878 # Expect exceptions if update_on_kvstore is set to True,
7979 # because TestStore does not support that
80- invalid_trainer = gluon .Trainer ([ x ] , 'sgd' , kvstore = kv , update_on_kvstore = True )
80+ invalid_trainer = gluon .Trainer ({ 'x' : x } , 'sgd' , kvstore = kv , update_on_kvstore = True )
8181 pytest .raises (ValueError , invalid_trainer ._init_kvstore )
8282
8383@with_seed ()
8484def test_trainer ():
8585 x = gluon .Parameter ('x' , shape = (10 ,))
8686 x .initialize (ctx = [mx .cpu (0 ), mx .cpu (1 )], init = 'zeros' )
87- trainer = gluon .Trainer ([ x ] , 'sgd' , {'learning_rate' : 1.0 , 'momentum' : 0.5 })
87+ trainer = gluon .Trainer ({ 'x' : x } , 'sgd' , {'learning_rate' : 1.0 , 'momentum' : 0.5 })
8888 with mx .autograd .record ():
8989 for w in x .list_data ():
9090 y = w + 1
@@ -119,7 +119,7 @@ def test_trainer():
119119
120120 x = gluon .Parameter ('x' , shape = (10 ,))
121121 x .initialize (ctx = [mx .cpu (0 ), mx .cpu (1 )], init = 'zeros' )
122- trainer2 = gluon .Trainer ([ x ] , 'sgd' , {'learning_rate' : 1.0 , 'momentum' : 0.5 },
122+ trainer2 = gluon .Trainer ({ 'x' : x } , 'sgd' , {'learning_rate' : 1.0 , 'momentum' : 0.5 },
123123 update_on_kvstore = False )
124124 with mx .autograd .record ():
125125 for i , w in enumerate (x .list_data ()):
@@ -139,7 +139,7 @@ def test_trainer_save_load():
139139
140140 x = gluon .Parameter ('x' , shape = (10 ,), lr_mult = 1.0 )
141141 x .initialize (ctx = [mx .cpu (0 ), mx .cpu (1 )], init = 'zeros' )
142- trainer = gluon .Trainer ([ x ] , 'sgd' , {'learning_rate' : 0.1 })
142+ trainer = gluon .Trainer ({ 'x' : x } , 'sgd' , {'learning_rate' : 0.1 })
143143 with mx .autograd .record ():
144144 for w in x .list_data ():
145145 y = w + 1
@@ -158,7 +158,7 @@ def test_trainer_sparse_save_load():
158158 x = gluon .Parameter ('x' , shape = (10 , 1 ), lr_mult = 1.0 ,
159159 stype = 'row_sparse' , grad_stype = 'row_sparse' )
160160 x .initialize (ctx = [mx .cpu (0 )], init = 'zeros' )
161- trainer = gluon .Trainer ([ x ] , 'sgd' , {'learning_rate' : 0.1 })
161+ trainer = gluon .Trainer ({ 'x' : x } , 'sgd' , {'learning_rate' : 0.1 })
162162 all_rows = mx .nd .arange (0 , 10 , ctx = mx .cpu (0 ))
163163 with mx .autograd .record ():
164164 for w in x .list_row_sparse_data (all_rows ):
@@ -257,7 +257,7 @@ def test_trainer_sparse_kv():
257257 def check_trainer_sparse_kv (kv , stype , grad_stype , update_on_kv , expected ):
258258 x = mx .gluon .Parameter ('x' , shape = (10 ,1 ), lr_mult = 1.0 , stype = stype , grad_stype = grad_stype )
259259 x .initialize (ctx = [mx .cpu (0 ), mx .cpu (1 )], init = 'zeros' )
260- trainer = gluon .Trainer ([ x ] , 'sgd' , {'learning_rate' : 0.1 },
260+ trainer = gluon .Trainer ({ 'x' : x } , 'sgd' , {'learning_rate' : 0.1 },
261261 kvstore = kv , update_on_kvstore = update_on_kv )
262262 all_rows = mx .nd .arange (0 , 10 , ctx = mx .cpu (0 ))
263263 try :
@@ -297,7 +297,7 @@ def test_trainer_lr_sched():
297297 factor = 0.1
298298 lr = 1
299299 lr_sched = mx .lr_scheduler .FactorScheduler (freq , factor = factor , base_lr = lr )
300- trainer = gluon .Trainer ([ x ] , 'sgd' , {'learning_rate' : lr , 'lr_scheduler' : lr_sched })
300+ trainer = gluon .Trainer ({ 'x' : x } , 'sgd' , {'learning_rate' : lr , 'lr_scheduler' : lr_sched })
301301 for i in range (10 ):
302302 with mx .autograd .record ():
303303 for w in x .list_data ():
@@ -316,7 +316,7 @@ def test_trainer_lr_sched():
316316 factor = 0.1
317317 lr = 1
318318 lr_sched = mx .lr_scheduler .FactorScheduler (freq , factor = factor , base_lr = lr )
319- trainer = gluon .Trainer ([ x ] , 'sgd' , {'learning_rate' : lr , 'lr_scheduler' : lr_sched },
319+ trainer = gluon .Trainer ({ 'x' : x } , 'sgd' , {'learning_rate' : lr , 'lr_scheduler' : lr_sched },
320320 update_on_kvstore = False )
321321 for i in range (10 ):
322322 with mx .autograd .record ():
0 commit comments