Skip to content

Commit

Permalink
add litserve.connector tests (#256)
Browse files Browse the repository at this point in the history
* add gpu

* update

* clear cache

* fix tests

* fix
  • Loading branch information
aniketmaurya authored Aug 30, 2024
1 parent 8ec911d commit 042cd08
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch

from litserve.connector import _Connector, check_cuda_with_nvidia_smi
import pytest
Expand All @@ -22,6 +23,17 @@ def test_check_cuda_with_nvidia_smi():
assert check_cuda_with_nvidia_smi() == torch.cuda.device_count()


@pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Non Nvidia GPU only")
@patch(
"litserve.connector.subprocess.check_output",
return_value=b"GPU 0: NVIDIA GeForce RTX 4090 (UUID: GPU-rb438fre-0ar-9702-de35-ref4rjn34omk3 )",
)
def test_check_cuda_with_nvidia_smi_mock_gpu(mock_subprocess):
check_cuda_with_nvidia_smi.cache_clear()
assert check_cuda_with_nvidia_smi() == 1
check_cuda_with_nvidia_smi.cache_clear()


@pytest.mark.parametrize(
("input_accelerator", "expected_accelerator", "expected_devices"),
[
Expand All @@ -32,6 +44,12 @@ def test_check_cuda_with_nvidia_smi():
torch.cuda.device_count(),
marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"),
),
pytest.param(
"gpu",
"cuda",
torch.cuda.device_count(),
marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"),
),
pytest.param(
None,
"cuda",
Expand All @@ -50,6 +68,12 @@ def test_check_cuda_with_nvidia_smi():
1,
marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"),
),
pytest.param(
"gpu",
"mps",
1,
marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"),
),
pytest.param(
"mps",
"mps",
Expand All @@ -65,14 +89,15 @@ def test_check_cuda_with_nvidia_smi():
],
)
def test_connector(input_accelerator, expected_accelerator, expected_devices):
check_cuda_with_nvidia_smi.cache_clear()
connector = _Connector(accelerator=input_accelerator)
assert (
connector.accelerator == expected_accelerator
), f"accelerator was supposed to be {expected_accelerator} but was {connector.accelerator}"
), f"accelerator mismatch - expected: {expected_accelerator}, actual: {connector.accelerator}"

assert (
connector.devices == expected_devices
), f"devices was supposed to be {expected_devices} but was {connector.devices}"
), f"devices mismatch - expected {expected_devices}, actual: {connector.devices}"

with pytest.raises(ValueError, match="accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'"):
_Connector(accelerator="SUPER_CHIP")
Expand Down

0 comments on commit 042cd08

Please sign in to comment.