Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transform when using cuML HDBSCAN #1960

Merged

Conversation

beckernick
Copy link
Contributor

@beckernick beckernick commented May 1, 2024

This PR:

  • Uses cuML's membership_vector functionality in the transform pipeline rather than approximate predict. This now matches the default CPU-based HDBSCAN behavior
  • Updates the test suite to include an optional cuML-based test. Because pytest fixtures are materialized when first requested by a test runtime, we can use cuML functionality within the fixture and mark the pytest to only run in environments/systems in which cuML is available.

Closes #1764 and closes #1317

For completeness, I've included the output of a local run of this pytest using RAPIDS/cuML 24.04:

rapids-24.04-bertopic) nicholasb@nicholasb-HP-Z8-G4-Workstation:~/NVIDIA/BERTopic/tests$ pytest test_bertopic.py -v
========================================== test session starts ==========================================
platform linux -- Python 3.10.14, pytest-8.2.0, pluggy-1.5.0 -- /home/nicholasb/miniconda3/envs/rapids-24.04-bertopic/bin/python3.10
cachedir: .pytest_cache
rootdir: /home/nicholasb/NVIDIA/BERTopic
plugins: anyio-4.3.0
collected 10 items                                                                                      

test_bertopic.py::test_full_model[base_topic_model] PASSED                                        [ 10%]
test_bertopic.py::test_full_model[kmeans_pca_topic_model] PASSED                                  [ 20%]
test_bertopic.py::test_full_model[custom_topic_model] PASSED                                      [ 30%]
test_bertopic.py::test_full_model[merged_topic_model] PASSED                                      [ 40%]
test_bertopic.py::test_full_model[reduced_topic_model] PASSED                                     [ 50%]
test_bertopic.py::test_full_model[online_topic_model] PASSED                                      [ 60%]
test_bertopic.py::test_full_model[supervised_topic_model] PASSED                                  [ 70%]
test_bertopic.py::test_full_model[representation_topic_model] PASSED                              [ 80%]
test_bertopic.py::test_full_model[zeroshot_topic_model] PASSED                                    [ 90%]
test_bertopic.py::test_full_model[cuml_base_topic_model] PASSED                                   [100%]

=========================================== warnings summary ============================================
...
...
============================== 10 passed, 27 warnings in 108.42s (0:01:48) ==============================

And local run of this pytest in an environment without cuML:

(bertopic) nicholasb@nicholasb-HP-Z8-G4-Workstation:~/NVIDIA/BERTopic/tests$ pytest test_bertopic.py -v
========================================== test session starts ==========================================
platform linux -- Python 3.10.14, pytest-8.2.0, pluggy-1.5.0 -- /home/nicholasb/miniconda3/envs/bertopic/bin/python3.10
cachedir: .pytest_cache
rootdir: /home/nicholasb/NVIDIA/BERTopic
plugins: anyio-4.3.0
collected 10 items                                                                                      

test_bertopic.py::test_full_model[base_topic_model] PASSED                                        [ 10%]
test_bertopic.py::test_full_model[kmeans_pca_topic_model] PASSED                                  [ 20%]
test_bertopic.py::test_full_model[custom_topic_model] PASSED                                      [ 30%]
test_bertopic.py::test_full_model[merged_topic_model] PASSED                                      [ 40%]
test_bertopic.py::test_full_model[reduced_topic_model] PASSED                                     [ 50%]
test_bertopic.py::test_full_model[online_topic_model] PASSED                                      [ 60%]
test_bertopic.py::test_full_model[supervised_topic_model] PASSED                                  [ 70%]
test_bertopic.py::test_full_model[representation_topic_model] PASSED                              [ 80%]
test_bertopic.py::test_full_model[zeroshot_topic_model] PASSED                                    [ 90%]
test_bertopic.py::test_full_model[cuml_base_topic_model] SKIPPED (cuML not available)             [100%]

=========================================== warnings summary ============================================
...
...
========================== 9 passed, 1 skipped, 8 warnings in 97.41s (0:01:37) ==========================

@MaartenGr
Copy link
Owner

Hi @beckernick!

Thanks for the PR and your work on this. Definitely helps to have somebody from the RAPIDS team on this! I just went through your code and it all looks good to me. When the pipeline runs without any problems (which I assume it will), I'll go ahead and merge it.

@MaartenGr MaartenGr merged commit 30d6fcf into MaartenGr:master May 7, 2024
2 checks passed
@beckernick beckernick deleted the bug/cuml-membership-vector-transform branch May 11, 2024 21:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants