Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError (type mismatch) when double-precision GMM training #59

Open
tenk-9 opened this issue Jun 17, 2024 · 0 comments
Open

RuntimeError (type mismatch) when double-precision GMM training #59

tenk-9 opened this issue Jun 17, 2024 · 0 comments

Comments

@tenk-9
Copy link

tenk-9 commented Jun 17, 2024

Hi, thank you for this wonderful repo.

Abst

I'm going to train GMM with double-precision by passing precision=64 to trainer_params, and I got a RuntimeError on initialization.
An error reports type-mismatch, and I found I could resolve this error by adding a single line.
It seems to be a bug, so I'm reporting it.

Environment

My environments are below.

  • pycave: 3.2.1
  • pytorch-lightning: 1.9.5
  • torch: 1.11.0+cu113
  • torchmetrics: 0.11.4

Error message

Traceback (most recent call last):
...
  File "****.py", line 97, in train_gmm
    gmm = gmm.fit(X)
  File "/usr/local/lib/python3.8/dist-packages/pycave/bayes/gmm/estimator.py", line 153, in fit
    estimator = KMeans(
  File "/usr/local/lib/python3.8/dist-packages/pycave/clustering/kmeans/estimator.py", line 129, in fit
    self.trainer(max_epochs=num_epochs).fit(module, loader)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
...
    self.distance_sampler.update(batch, shortest_distances)
  File "/usr/local/lib/python3.8/dist-packages/torchmetrics/metric.py", line 400, in wrapped_func
    raise err
  File "/usr/local/lib/python3.8/dist-packages/torchmetrics/metric.py", line 390, in wrapped_func
    update(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pycave/clustering/kmeans/metrics.py", line 151, in update
    self.choices.masked_scatter_(
RuntimeError: masked_scatter: expected self and source to have same dtypes but gotFloat and Double

In short, it says mismatch of type, between self.choices(maybe float?) and data(double).
As I saw an implemention, DistanceSampler.choices on pycave/clustering/kmeans/metrics.py is just a torch.Tensor.
So the type of elements needs to be converted to data.dtype.

Resolving the error

I added self.choices = self.choices.to(data.dtype) on line 151, and it became working!

# Then, we sample from the data `num_choices` times and replace if needed
        choices = (squared_distances + eps).multinomial(self.num_choices, replacement=True)
+       self.choices = self.choices.to(data.dtype)
        self.choices.masked_scatter_(
            use_choice_from_data.unsqueeze(1), data[choices[use_choice_from_data]]
        )

I would appreciate if you could deal with this issue.

Best regards,
tenk-9

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant