Skip to content

Commit

Permalink
Fixed PyTorch version check in sparse module (#1136)
Browse files Browse the repository at this point in the history
* Fixed version check in sparse module

* Updated test_factories.py

* Reverted to the string split method of version check

---------

Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com>
  • Loading branch information
Mystic-Slice and ClaudiaComito authored Apr 17, 2023
1 parent 1e1bd1e commit cb46e2c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion heat/sparse/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def sparse_csr_matrix(
(indptr: tensor([0, 1, 3, 4]), indices: tensor([2, 0, 2, 2]), data: tensor([1, 1, 2, 3]), dtype=ht.int64, device=cpu:0, split=None)
"""
# version check
if int(torch.__version__.split(".")[1]) < 10:
if int(torch.__version__.split(".")[0]) <= 1 and int(torch.__version__.split(".")[1]) < 10:
raise RuntimeError(f"ht.sparse requires torch >= 1.10. Found version {torch.__version__}.")

# sanitize the data type
Expand Down

0 comments on commit cb46e2c

Please sign in to comment.