From cb46e2c933a570a41ac0faf43cdcef5892a1dac2 Mon Sep 17 00:00:00 2001 From: Ashwath V A <73862377+Mystic-Slice@users.noreply.github.com> Date: Mon, 17 Apr 2023 14:25:40 +0530 Subject: [PATCH] Fixed PyTorch version check in `sparse` module (#1136) * Fixed version check in sparse module * Updated test_factories.py * Reverted to the string split method of version check --------- Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> --- heat/sparse/factories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/sparse/factories.py b/heat/sparse/factories.py index 865feb324f..6d04a2964c 100644 --- a/heat/sparse/factories.py +++ b/heat/sparse/factories.py @@ -94,7 +94,7 @@ def sparse_csr_matrix( (indptr: tensor([0, 1, 3, 4]), indices: tensor([2, 0, 2, 2]), data: tensor([1, 1, 2, 3]), dtype=ht.int64, device=cpu:0, split=None) """ # version check - if int(torch.__version__.split(".")[1]) < 10: + if int(torch.__version__.split(".")[0]) <= 1 and int(torch.__version__.split(".")[1]) < 10: raise RuntimeError(f"ht.sparse requires torch >= 1.10. Found version {torch.__version__}.") # sanitize the data type