-
Notifications
You must be signed in to change notification settings - Fork 760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPS Support #10
base: main
Are you sure you want to change the base?
MPS Support #10
Conversation
I think this works well |
perhaps something like this to preserve float precision on cuda? diff --git a/point_e/util/precision_compatibility.py b/point_e/util/precision_compatibility.py
new file mode 100644
--- /dev/null
+++ b/point_e/util/precision_compatibility.py
@@ -0,0 +1,5 @@
+import torch
+import numpy as np
+
+NP_FLOAT32_64 = np.float32 if torch.backends.mps.is_available() else np.float64
+TH_FLOAT32_64 = torch.float32 if torch.backends.mps.is_available() else torch.float64
\ No newline at end of filediff --git a/point_e/diffusion/gaussian_diffusion.py b/point_e/diffusion/gaussian_diffusion.py
--- point_e/diffusion/gaussian_diffusion.py
+++ point_e/diffusion/gaussian_diffusion.py
@@ -6,8 +6,9 @@
from typing import Any, Dict, Iterable, Optional, Sequence, Union
import numpy as np
import torch as th
+from point_e.util.precision_compatibility import NP_FLOAT32_64, TH_FLOAT32_64
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
@@ -15,9 +16,9 @@
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "linear":
- betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float32)
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=NP_FLOAT32_64)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
@@ -159,9 +160,9 @@
self.channel_scales = channel_scales
self.channel_biases = channel_biases
# originally uses float64 for accuracy, moving to float32 for mps compatibility
- betas = np.array(betas, dtype=np.float32)
+ betas = np.array(betas, dtype=NP_FLOAT32_64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
@@ -1012,9 +1013,9 @@
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
- res = th.from_numpy(arr).to(dtype=th.float32, device=timesteps.device)[timesteps].to(th.float32)
+ res = th.from_numpy(arr).to(dtype=TH_FLOAT32_64, device=timesteps.device)[timesteps].to(TH_FLOAT32_64)
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
I love it! |
Co-Authored-By: henrycunh <henrycunh@gmail.com>
@henrycunh Added! |
Tried now on a macbook air M2. It worked very well, for reference:
Only problem is the actual implementation of pytorch for MPS, that get this: |
I apologize for my question, but how noticeable is the change to float32? |
I'm pretty confident that using higher precision, like |
Could we set that as a parameter that defaults to 64 but write another paramter that is 32? |
^ agree |
This PR introduces Metal GPU support, at the cost of slightly lowering accuracy on the gaussian_diffusion step (changing
float64
tofloat32
, only when running on mps).