Skip to content

Commit bfcfc1b

Browse files
committed
Detect if mps is available across python backends
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent a8f8073 commit bfcfc1b

File tree

7 files changed

+16
-28
lines changed

7 files changed

+16
-28
lines changed

backend/python/chatterbox/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def LoadModel(self, request, context):
4141
else:
4242
print("CUDA is not available", file=sys.stderr)
4343
device = "cpu"
44-
44+
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
45+
if mps_available:
46+
device = "mps"
4547
if not torch.cuda.is_available() and request.CUDA:
4648
return backend_pb2.Result(success=False, message="CUDA is not available")
4749

backend/python/coqui/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def LoadModel(self, request, context):
4040
else:
4141
print("CUDA is not available", file=sys.stderr)
4242
device = "cpu"
43-
43+
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
44+
if mps_available:
45+
device = "mps"
4446
if not torch.cuda.is_available() and request.CUDA:
4547
return backend_pb2.Result(success=False, message="CUDA is not available")
4648

backend/python/diffusers/backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,9 @@ def LoadModel(self, request, context):
368368
device = "cpu" if not request.CUDA else "cuda"
369369
if XPU:
370370
device = "xpu"
371+
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
372+
if mps_available:
373+
device = "mps"
371374
self.device = device
372375
if request.LoraAdapter:
373376
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )

backend/python/faster-whisper/backend.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
import backend_pb2
1212
import backend_pb2_grpc
13-
13+
import torch
1414
from faster_whisper import WhisperModel
1515

1616
import grpc
@@ -35,7 +35,9 @@ def LoadModel(self, request, context):
3535
# device = "cuda" if request.CUDA else "cpu"
3636
if request.CUDA:
3737
device = "cuda"
38-
38+
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
39+
if mps_available:
40+
device = "mps"
3941
try:
4042
print("Preparing models, please wait", file=sys.stderr)
4143
self.model = WhisperModel(request.Model, device=device, compute_type="float16")

backend/python/kitten-tts/backend.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,6 @@ def Health(self, request, context):
3333
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
3434
def LoadModel(self, request, context):
3535

36-
# Get device
37-
# device = "cuda" if request.CUDA else "cpu"
38-
if torch.cuda.is_available():
39-
print("CUDA is available", file=sys.stderr)
40-
device = "cuda"
41-
else:
42-
print("CUDA is not available", file=sys.stderr)
43-
device = "cpu"
44-
45-
if not torch.cuda.is_available() and request.CUDA:
46-
return backend_pb2.Result(success=False, message="CUDA is not available")
47-
4836
self.AudioPath = None
4937
# List available KittenTTS models
5038
print("Available KittenTTS voices: expr-voice-2-m, expr-voice-2-f, expr-voice-3-m, expr-voice-3-f, expr-voice-4-m, expr-voice-4-f, expr-voice-5-m, expr-voice-5-f")

backend/python/kokoro/backend.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,6 @@ def Health(self, request, context):
3333
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
3434

3535
def LoadModel(self, request, context):
36-
# Get device
37-
if torch.cuda.is_available():
38-
print("CUDA is available", file=sys.stderr)
39-
device = "cuda"
40-
else:
41-
print("CUDA is not available", file=sys.stderr)
42-
device = "cpu"
43-
44-
if not torch.cuda.is_available() and request.CUDA:
45-
return backend_pb2.Result(success=False, message="CUDA is not available")
46-
4736
try:
4837
print("Preparing Kokoro TTS pipeline, please wait", file=sys.stderr)
4938
# empty dict

backend/python/transformers/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def LoadModel(self, request, context):
9494
self.SentenceTransformer = False
9595

9696
device_map="cpu"
97-
97+
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
98+
if mps_available:
99+
device_map = "mps"
98100
quantization = None
99101
autoTokenizer = True
100102

0 commit comments

Comments
 (0)