Skip to content

Commit

Permalink
Issue jxmorris12#2: added get_device function, with additional mps su…
Browse files Browse the repository at this point in the history
…pport
  • Loading branch information
NMontanaBrown committed Oct 13, 2023
1 parent b121c68 commit 91a0d1c
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion vec2text/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,23 @@
EMBEDDING_TRANSFORM_STRATEGIES = ["repeat"]


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_device():
"""
Function that checks
for GPU availability and returns
the appropriate device.
:return: torch.device
"""
if torch.cuda.is_available():
dev = "cuda"
elif torch.backends.mps.is_available():
dev = "mps"
else:
dev = "cpu"
device = torch.device(dev)
return device

device = get_device()


def disable_dropout(model: nn.Module):
Expand Down

0 comments on commit 91a0d1c

Please sign in to comment.