Skip to content

Commit d26bb81

Browse files
committed
Determinism use fix
Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 25c2219 commit d26bb81

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

tests/integration/test_seg_loss_integration.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from monai.losses import DiceLoss, FocalLoss, GeneralizedDiceLoss, TverskyLoss
2323
from monai.networks import one_hot
24+
from monai.utils import set_determinism
2425

2526
TEST_CASES = [
2627
[DiceLoss, {"to_onehot_y": True, "squared_pred": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, {}],
@@ -49,14 +50,11 @@
4950
class TestSegLossIntegration(unittest.TestCase):
5051

5152
def setUp(self):
52-
torch.backends.cudnn.deterministic = True
53-
torch.backends.cudnn.benchmark = False
54-
torch.manual_seed(0)
53+
set_determinism(0)
5554
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
5655

5756
def tearDown(self):
58-
torch.backends.cudnn.deterministic = False
59-
torch.backends.cudnn.benchmark = True
57+
set_determinism(None)
6058

6159
@parameterized.expand(TEST_CASES)
6260
def test_convergence(self, loss_type, loss_args, forward_args):

0 commit comments

Comments
 (0)