From 91a0d1ca83a95945c857b29a728816360e7bb646 Mon Sep 17 00:00:00 2001 From: Nina Montana Brown Date: Fri, 13 Oct 2023 11:26:39 +0100 Subject: [PATCH] Issue #2: added get_device function, with additional mps support --- vec2text/models/model_utils.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/vec2text/models/model_utils.py b/vec2text/models/model_utils.py index 408752f2..5c8fe583 100644 --- a/vec2text/models/model_utils.py +++ b/vec2text/models/model_utils.py @@ -33,7 +33,23 @@ EMBEDDING_TRANSFORM_STRATEGIES = ["repeat"] -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def get_device(): + """ + Function that checks + for GPU availability and returns + the appropriate device. + :return: torch.device + """ + if torch.cuda.is_available(): + dev = "cuda" + elif torch.backends.mps.is_available(): + dev = "mps" + else: + dev = "cpu" + device = torch.device(dev) + return device + +device = get_device() def disable_dropout(model: nn.Module):