@@ -91,28 +91,25 @@ def test_script(self):
9191 test_input = torch .ones (2 , 1 , 8 , 8 )
9292 test_script_save (loss , test_input , test_input )
9393
94- @parameterized .expand ([
95- ("sum_None_0.5_0.25" , "sum" , None , 0.5 , 0.25 ),
96- ("sum_weight_0.5_0.25" , "sum" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
97- ("sum_weight_tuple_0.5_0.25" , "sum" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
98- ("mean_None_0.5_0.25" , "mean" , None , 0.5 , 0.25 ),
99- ("mean_weight_0.5_0.25" , "mean" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
100- ("mean_weight_tuple_0.5_0.25" , "mean" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
101- ("none_None_0.5_0.25" , "none" , None , 0.5 , 0.25 ),
102- ("none_weight_0.5_0.25" , "none" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
103- ("none_weight_tuple_0.5_0.25" , "none" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
104- ])
94+ @parameterized .expand (
95+ [
96+ ("sum_None_0.5_0.25" , "sum" , None , 0.5 , 0.25 ),
97+ ("sum_weight_0.5_0.25" , "sum" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
98+ ("sum_weight_tuple_0.5_0.25" , "sum" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
99+ ("mean_None_0.5_0.25" , "mean" , None , 0.5 , 0.25 ),
100+ ("mean_weight_0.5_0.25" , "mean" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
101+ ("mean_weight_tuple_0.5_0.25" , "mean" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
102+ ("none_None_0.5_0.25" , "none" , None , 0.5 , 0.25 ),
103+ ("none_weight_0.5_0.25" , "none" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
104+ ("none_weight_tuple_0.5_0.25" , "none" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
105+ ]
106+ )
105107 def test_with_alpha (self , name , reduction , weight , lambda_focal , alpha ):
106108 size = [3 , 3 , 5 , 5 ]
107109 label = torch .randint (low = 0 , high = 2 , size = size )
108110 pred = torch .randn (size )
109111
110- common_params = {
111- "include_background" : True ,
112- "to_onehot_y" : False ,
113- "reduction" : reduction ,
114- "weight" : weight ,
115- }
112+ common_params = {"include_background" : True , "to_onehot_y" : False , "reduction" : reduction , "weight" : weight }
116113
117114 dice_focal = DiceFocalLoss (gamma = 1.0 , lambda_focal = lambda_focal , alpha = alpha , ** common_params )
118115 dice = DiceLoss (** common_params )
@@ -123,5 +120,6 @@ def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha):
123120
124121 np .testing .assert_allclose (result , expected_val , err_msg = f"Failed on case: { name } " )
125122
123+
126124if __name__ == "__main__" :
127125 unittest .main ()
0 commit comments