@@ -97,6 +97,10 @@ def _geom_mean(t):
97
97
return np .exp (np .mean (np .log (np_t ), axis = 0 ))
98
98
99
99
100
+ def _mean (y_true ):
101
+ return y_true .mean (dim = 0 ).numpy ()
102
+
103
+
100
104
def test_geom_average ():
101
105
102
106
with pytest .raises (NotComputableError ):
@@ -129,44 +133,30 @@ def test_geom_average():
129
133
np .testing .assert_almost_equal (m .numpy (), _geom_mean (y_true .reshape (- 1 , 10 )), decimal = 5 )
130
134
131
135
132
- def test_integration ():
133
- def _test (metric_cls , true_result_fn ):
134
-
135
- size = 100
136
- custom_variable = 10.0 + 5.0 * torch .rand (size , 12 )
137
-
138
- def update_fn (engine , batch ):
139
- return 0 , custom_variable [engine .state .iteration - 1 ]
140
-
141
- engine = Engine (update_fn )
142
-
143
- custom_var_mean = metric_cls (output_transform = lambda output : output [1 ])
144
- custom_var_mean .attach (engine , "agg_custom_var" )
136
+ @pytest .mark .parametrize ("metric_cls, true_result_fn" , [(Average , _mean ), (GeometricAverage , _geom_mean )])
137
+ @pytest .mark .parametrize ("shape" , [[100 , 12 ], [100 ]])
138
+ def test_integration (metric_cls , true_result_fn , shape ):
145
139
146
- state = engine .run ([0 ] * size )
147
- np .testing .assert_almost_equal (
148
- state .metrics ["agg_custom_var" ].numpy (), true_result_fn (custom_variable ), decimal = 5
149
- )
140
+ assert len (shape ) > 0 and len (shape ) < 3
150
141
151
- size = 100
152
- custom_variable = 10.0 + 5.0 * torch .rand (size )
142
+ custom_variable = 10.0 + 5.0 * torch .rand (shape )
153
143
154
- def update_fn (engine , batch ):
155
- return 0 , custom_variable [engine .state .iteration - 1 ].item ()
144
+ def update_fn (engine , batch ):
156
145
157
- engine = Engine (update_fn )
146
+ output = custom_variable [engine .state .iteration - 1 ]
147
+ output = output .item () if output .ndimension () < 1 else output
148
+ return 0 , output
158
149
159
- custom_var_mean = metric_cls (output_transform = lambda output : output [1 ])
160
- custom_var_mean .attach (engine , "agg_custom_var" )
150
+ engine = Engine (update_fn )
161
151
162
- state = engine . run ([ 0 ] * size )
163
- assert state . metrics [ "agg_custom_var" ] == pytest . approx ( true_result_fn ( custom_variable ) )
152
+ custom_var_mean = metric_cls ( output_transform = lambda output : output [ 1 ] )
153
+ custom_var_mean . attach ( engine , "agg_custom_var" )
164
154
165
- def _mean (y_true ):
166
- return y_true .mean (dim = 0 ).numpy ()
155
+ state = engine .run ([0 ] * shape [0 ])
167
156
168
- _test (Average , _mean )
169
- _test (GeometricAverage , _geom_mean )
157
+ np .testing .assert_almost_equal (
158
+ np .array (state .metrics ["agg_custom_var" ]), true_result_fn (custom_variable ), decimal = 5
159
+ )
170
160
171
161
172
162
def test_compute_mean_std ():
0 commit comments