Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS Support #10

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

MPS Support #10

wants to merge 2 commits into from

Conversation

m1guelpf
Copy link

@m1guelpf m1guelpf commented Dec 21, 2022

This PR introduces Metal GPU support, at the cost of slightly lowering accuracy on the gaussian_diffusion step (changing float64 to float32, only when running on mps).

@jameshennessytempus
Copy link

I think this works well

@henrycunh
Copy link

perhaps something like this to preserve float precision on cuda?

diff --git a/point_e/util/precision_compatibility.py b/point_e/util/precision_compatibility.py
new file mode 100644
--- /dev/null
+++ b/point_e/util/precision_compatibility.py
@@ -0,0 +1,5 @@
+import torch
+import numpy as np
+
+NP_FLOAT32_64 = np.float32 if torch.backends.mps.is_available() else np.float64
+TH_FLOAT32_64 = torch.float32 if torch.backends.mps.is_available() else torch.float64
\ No newline at end of filediff --git a/point_e/diffusion/gaussian_diffusion.py b/point_e/diffusion/gaussian_diffusion.py
--- point_e/diffusion/gaussian_diffusion.py
+++ point_e/diffusion/gaussian_diffusion.py
@@ -6,8 +6,9 @@
 from typing import Any, Dict, Iterable, Optional, Sequence, Union
 
 import numpy as np
 import torch as th
+from point_e.util.precision_compatibility import NP_FLOAT32_64, TH_FLOAT32_64
 
 
 def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
     """
@@ -15,9 +16,9 @@
 
     See get_named_beta_schedule() for the new library of schedules.
     """
     if beta_schedule == "linear":
-        betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float32)
+        betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=NP_FLOAT32_64)
     else:
         raise NotImplementedError(beta_schedule)
     assert betas.shape == (num_diffusion_timesteps,)
     return betas
@@ -159,9 +160,9 @@
         self.channel_scales = channel_scales
         self.channel_biases = channel_biases
 
         # originally uses float64 for accuracy, moving to float32 for mps compatibility
-        betas = np.array(betas, dtype=np.float32)
+        betas = np.array(betas, dtype=NP_FLOAT32_64)
         self.betas = betas
         assert len(betas.shape) == 1, "betas must be 1-D"
         assert (betas > 0).all() and (betas <= 1).all()
 
@@ -1012,9 +1013,9 @@
     :param broadcast_shape: a larger shape of K dimensions with the batch
                             dimension equal to the length of timesteps.
     :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
     """
-    res = th.from_numpy(arr).to(dtype=th.float32, device=timesteps.device)[timesteps].to(th.float32)
+    res = th.from_numpy(arr).to(dtype=TH_FLOAT32_64, device=timesteps.device)[timesteps].to(TH_FLOAT32_64)
     while len(res.shape) < len(broadcast_shape):
         res = res[..., None]
     return res + th.zeros(broadcast_shape, device=timesteps.device)
 

@jamesthesnake
Copy link

I love it!

Co-Authored-By: henrycunh <henrycunh@gmail.com>
@m1guelpf
Copy link
Author

@henrycunh Added!

@xmario3
Copy link

xmario3 commented Dec 22, 2022

Tried now on a macbook air M2.

It worked very well, for reference:

  • circa 30s on windows 11 with cuda on RTX2080
  • 3m 53s on macBook Air M2
  • more than 20 minute using CPU on intel i9 11gen

Only problem is the actual implementation of pytorch for MPS, that get this:
UserWarning: The operator 'aten::linalg_vector_norm' is not currently supported on the MPS backend and will fall back to run on the CPU.

@peruginiandrea
Copy link

I apologize for my question, but how noticeable is the change to float32?

@henrycunh
Copy link

I apologize for my question, but how noticeable is the change to float32?

I'm pretty confident that using higher precision, like float64, will almost always give us tighter, more accurate results when we're smoothing out noise with a Gaussian diffusion algorithm. It's true that using higher precision can be a bit more computationally intensive, but the benefits are usually worth it. Plus, it's always nice to have the extra accuracy and stability in our results!

@jameshennessytempus
Copy link

Could we set that as a parameter that defaults to 64 but write another paramter that is 32?

@jamesthesnake
Copy link

^ agree

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants