diff --git a/tests/test_connector.py b/tests/test_connector.py index 0cc06375..c24fc398 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch from litserve.connector import _Connector, check_cuda_with_nvidia_smi import pytest @@ -22,6 +23,17 @@ def test_check_cuda_with_nvidia_smi(): assert check_cuda_with_nvidia_smi() == torch.cuda.device_count() +@pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Non Nvidia GPU only") +@patch( + "litserve.connector.subprocess.check_output", + return_value=b"GPU 0: NVIDIA GeForce RTX 4090 (UUID: GPU-rb438fre-0ar-9702-de35-ref4rjn34omk3 )", +) +def test_check_cuda_with_nvidia_smi_mock_gpu(mock_subprocess): + check_cuda_with_nvidia_smi.cache_clear() + assert check_cuda_with_nvidia_smi() == 1 + check_cuda_with_nvidia_smi.cache_clear() + + @pytest.mark.parametrize( ("input_accelerator", "expected_accelerator", "expected_devices"), [ @@ -32,6 +44,12 @@ def test_check_cuda_with_nvidia_smi(): torch.cuda.device_count(), marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"), ), + pytest.param( + "gpu", + "cuda", + torch.cuda.device_count(), + marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"), + ), pytest.param( None, "cuda", @@ -50,6 +68,12 @@ def test_check_cuda_with_nvidia_smi(): 1, marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"), ), + pytest.param( + "gpu", + "mps", + 1, + marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"), + ), pytest.param( "mps", "mps", @@ -65,14 +89,15 @@ def test_check_cuda_with_nvidia_smi(): ], ) def test_connector(input_accelerator, expected_accelerator, expected_devices): + check_cuda_with_nvidia_smi.cache_clear() connector = _Connector(accelerator=input_accelerator) assert ( connector.accelerator == expected_accelerator - ), f"accelerator was supposed to be {expected_accelerator} but was {connector.accelerator}" + ), f"accelerator mismatch - expected: {expected_accelerator}, actual: {connector.accelerator}" assert ( connector.devices == expected_devices - ), f"devices was supposed to be {expected_devices} but was {connector.devices}" + ), f"devices mismatch - expected {expected_devices}, actual: {connector.devices}" with pytest.raises(ValueError, match="accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'"): _Connector(accelerator="SUPER_CHIP")