From b6ff0bfe332eac1621ac5106d7fb0a8098dec5d8 Mon Sep 17 00:00:00 2001 From: wwdda <1155107718@link.cuhk.edu.hk> Date: Sat, 23 Sep 2023 04:13:24 +0800 Subject: [PATCH] bug fix for transform() when using hdbscan & calculate_probabilities=True --- bertopic/cluster/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 355a53f6..b2f979cf 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -51,8 +51,8 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: - from cuml.cluster.hdbscan.prediction import approximate_predict - probabilities = approximate_predict(model, embeddings) + from cuml.cluster import hdbscan as cuml_hdbscan + probabilities = cuml_hdbscan.membership_vector(model, embeddings) return probabilities return None