Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transform when using cuML HDBSCAN #1960

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bertopic/cluster/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hdbscan
import numpy as np


def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
""" Function used to select the HDBSCAN-like model for generating
predictions and probabilities.
Expand Down Expand Up @@ -51,8 +51,8 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):

str_type_model = str(type(model)).lower()
if "cuml" in str_type_model and "hdbscan" in str_type_model:
from cuml.cluster.hdbscan.prediction import approximate_predict
probabilities = approximate_predict(model, embeddings)
from cuml.cluster import hdbscan as cuml_hdbscan
probabilities = cuml_hdbscan.membership_vector(model, embeddings)
return probabilities

return None
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,17 @@ def online_topic_model(documents, document_embeddings, embedding_model):
topics.extend(model.topics_)
model.topics_ = topics
return model

@pytest.fixture(scope="session")
def cuml_base_topic_model(documents, document_embeddings, embedding_model):
from cuml.cluster import HDBSCAN as cuml_hdbscan
from cuml.manifold import UMAP as cuml_umap

model = BERTopic(
embedding_model=embedding_model,
calculate_probabilities=True,
umap_model=cuml_umap(n_components=5, n_neighbors=5, random_state=42),
hdbscan_model=cuml_hdbscan(min_cluster_size=3, prediction_data=True),
)
model.fit(documents, document_embeddings)
return model
16 changes: 15 additions & 1 deletion tests/test_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
import pytest
from bertopic import BERTopic

def cuml_available():
try:
import cuml
return True
except ImportError:
return False

@pytest.mark.parametrize(
'model',
Expand All @@ -14,7 +20,10 @@
('online_topic_model'),
('supervised_topic_model'),
('representation_topic_model'),
('zeroshot_topic_model')
('zeroshot_topic_model'),
pytest.param(
"cuml_base_topic_model", marks=pytest.mark.skipif(not cuml_available(), reason="cuML not available")
),
])
def test_full_model(model, documents, request):
""" Tests the entire pipeline in one go. This serves as a sanity check to see if the default
Expand All @@ -26,6 +35,11 @@ def test_full_model(model, documents, request):
if model == "base_topic_model":
topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2")
topic_model = BERTopic.load("model_dir")

if model == "cuml_base_topic_model":
assert "cuml" in str(type(topic_model.umap_model)).lower()
assert "cuml" in str(type(topic_model.hdbscan_model)).lower()

topics = topic_model.topics_

for topic in set(topics):
Expand Down