Skip to content

Conversation

@Dev-Sudarshan
Copy link
Contributor

This PR adds optional GPU support to L-C2ST by introducing a PyTorch-based
MLP classifier implemented via skorch. This addresses issue #1160 .

Changes:

  • Add a PyTorch MLP with skorch to support GPU training
  • Preserve sklearn-like defaults and training dynamics
  • Add device handling (cpu / cuda)
  • Support user overrides via classifier_kwargs while preserving sbi defaults
  • Extend tests to cover:
    • GPU/CPU device placement
    • Default parameter behavior
    • User override merging

@janfb
Copy link
Contributor

janfb commented Jan 7, 2026

Hi @Dev-Sudarshan,

Thank you for this contribution adding GPU support to LC2ST! Apologies for the long silence on this PR.

Note: This PR contains some unrelated PyMC version pinning changes. Please remove those by rebasing on main so we can focus on the GPU support implementation.

I've done a review and found the following issues:

Issues Found

  1. Missing dependency: skorch not added to pyproject.toml. please add it as core dependency in alphabetical order pyproject.toml.

  2. Incorrect input_dim for PyTorchMLP: At line 191-195, the module__input_dim is set to thetas.shape[-1], but the classifier receives concatenated [theta, x] data in train_lc2st() (line 769). This causes a dimension mismatch error when training on GPU.

Current:

elif self.clf_class == NeuralNetClassifier:                                                                                                                                           
      ndim = thetas.shape[-1]                                                                                                                                                           
      self.clf_kwargs = {                                                                                                                                                               
          "module": PyTorchMLP,                                                                                                                                                         
          "module__input_dim": ndim,  # Bug: only theta dim

Fix:

elif self.clf_class == NeuralNetClassifier:                                                                                                                                           
      ndim = thetas.shape[-1]                                                                                                                                                           
      input_dim = thetas.shape[-1] + xs.shape[-1]  # theta + x concatenated                                                                                                             
      self.clf_kwargs = {                                                                                                                                                               
          "module": PyTorchMLP,                                                                                                                                                         
          "module__input_dim": input_dim,                                                                                                                                               

(Keep ndim for hidden_layer_sizes calculation as that scaling is intentional.)

  1. PyTorchMLP should live in sbi/utils/metrics.py

The PyTorchMLP class is currently defined in lc2st.py, but we'd like it moved to sbi/utils/metrics.py where the existing C2ST code lives. This will allow reuse when we add GPU support to C2ST as well. We could also rename it to a more generic MLPClassifierModule, matching sklearn's MLPClassifier and skorch's module kwarg.

  1. Add MPS support and use generic GPU detection

The current implementation only checks for CUDA. Please add MPS support (for Apple Silicon) and refactor to use a generic use_gpu pattern:
For example. something like:

  if classifier.lower() == 'mlp':                                                                                                                                                       
      use_gpu = (                                                                                                                                                                       
          (self.device.lower() == 'cuda' and torch.cuda.is_available()) or                                                                                                              
          (self.device.lower() == 'mps' and torch.backends.mps.is_available())                                                                                                          
      ) 
      if use_gpu:                                                                                                                                                                       
          classifier = NeuralNetClassifier                                                                                                                                              
      else:                                                                                                                                                                             
          classifier = MLPClassifier                                                                                                                                                    

Similarly, update the RandomForest warning to check for device.lower() in ('cuda', 'mps').

  1. Warn when requested GPU is unavailable

When a user explicitly requests device="cuda" (or "mps") but that backend isn't available, the code silently falls back to sklearn's MLPClassifier. Please add a warning (similar to the RandomForest case) so users know their GPU request wasn't honored.

  1. Test organization

Instead of running the lc2st device test on both cpu and gpu, restrict it to GPU by adding a marker decorator:

@pytest.mark.gpu
  def test_lc2st_runs_on_requested_device(calibration_data):
      """Test that LC2ST runs on cpu and cuda/mps (if available)."""
      ...

and testing only cuda and mps (if available).

Thanks again for working on this feature – the overall approach using skorch looks good!

Let me know if you have any questions about the feedback above.

@janfb janfb self-requested a review January 7, 2026 14:04
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See PR comment above.

@Dev-Sudarshan
Copy link
Contributor Author

Thank you for the feedback. I will review everything thoroughly and make the required changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants