diff --git a/vec2text/models/model_utils.py b/vec2text/models/model_utils.py index 408752f2..5c8fe583 100644 --- a/vec2text/models/model_utils.py +++ b/vec2text/models/model_utils.py @@ -33,7 +33,23 @@ EMBEDDING_TRANSFORM_STRATEGIES = ["repeat"] -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def get_device(): + """ + Function that checks + for GPU availability and returns + the appropriate device. + :return: torch.device + """ + if torch.cuda.is_available(): + dev = "cuda" + elif torch.backends.mps.is_available(): + dev = "mps" + else: + dev = "cpu" + device = torch.device(dev) + return device + +device = get_device() def disable_dropout(model: nn.Module):