diff --git a/examples/index_fit_example.py b/examples/index_fit_example.py new file mode 100644 index 00000000..a88f0c3b --- /dev/null +++ b/examples/index_fit_example.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python + +from sklearn.decomposition import PCA +import umap +from sklearn.datasets import fetch_openml +import matplotlib.pyplot as plt +import seaborn as sns + +# install hnswlib with pip install hnswlib +# or update to use any other index (e.g., nndescent) +import hnswlib + +sns.set(context="paper", style="white") + +mnist = fetch_openml("mnist_784", version=1) + +sample_size = 70000 + +# create a knn index +knn_index = hnswlib.Index(space='l2', dim=mnist.data.shape[1]) +knn_index.init_index(max_elements=sample_size, ef_construction=100, M=16) +knn_index.add_items(mnist.data.values[:sample_size]) + +knn_indices, knn_dists = knn_index.knn_query(mnist.data.values[:sample_size], k=15) + + +pca_init = PCA(n_components=2).fit_transform(mnist.data.values[:sample_size]) + +reducer = umap.UMAP(random_state=42, n_neighbors=15, + precomputed_knn=(knn_indices, knn_dists), + verbose=True, n_jobs=-1, init=pca_init, metric="euclidean") +embedding = reducer.fit_transform_index() + +fig, ax = plt.subplots(figsize=(12, 10)) +color = mnist.target[:sample_size].astype(int) +plt.scatter(embedding[:, 0], embedding[:, 1], c=color, cmap="Spectral", s=0.1) +plt.setp(ax, xticks=[], yticks=[]) +plt.title("MNIST data embedded into two dimensions by UMAP", fontsize=18) + +plt.show() \ No newline at end of file diff --git a/umap/umap_.py b/umap/umap_.py index 997bfd60..5c92ed7d 100644 --- a/umap/umap_.py +++ b/umap/umap_.py @@ -440,13 +440,12 @@ def compute_membership_strengths( def fuzzy_simplicial_set( - X, + knn_indices, + knn_dists, n_neighbors, random_state, metric, metric_kwds={}, - knn_indices=None, - knn_dists=None, angular=False, set_op_mix_ratio=1.0, local_connectivity=1.0, @@ -463,8 +462,18 @@ def fuzzy_simplicial_set( Parameters ---------- - X: array of shape (n_samples, n_features) - The data to be modelled as a fuzzy simplicial set. + + knn_indices: array of shape (n_samples, n_neighbors) + If the k-nearest neighbors of each point has already been calculated + you can pass them in here to save computation time. This should be + an array with the indices of the k-nearest neighbors as a row for + each data point. + + knn_dists: array of shape (n_samples, n_neighbors) + If the k-nearest neighbors of each point has already been calculated + you can pass them in here to save computation time. This should be + an array with the distances of the k-nearest neighbors as a row for + each data point. n_neighbors: int The number of neighbors to use to approximate geodesic distance. @@ -517,18 +526,6 @@ def fuzzy_simplicial_set( Arguments to pass on to the metric, such as the ``p`` value for Minkowski distance. - knn_indices: array of shape (n_samples, n_neighbors) (optional) - If the k-nearest neighbors of each point has already been calculated - you can pass them in here to save computation time. This should be - an array with the indices of the k-nearest neighbors as a row for - each data point. - - knn_dists: array of shape (n_samples, n_neighbors) (optional) - If the k-nearest neighbors of each point has already been calculated - you can pass them in here to save computation time. This should be - an array with the distances of the k-nearest neighbors as a row for - each data point. - angular: bool (optional, default False) Whether to use angular/cosine distance for the random projection forest for seeding NN-descent to determine approximate nearest @@ -563,16 +560,9 @@ def fuzzy_simplicial_set( 1-simplex between the ith and jth sample points. """ if knn_indices is None or knn_dists is None: - knn_indices, knn_dists, _ = nearest_neighbors( - X, - n_neighbors, - metric, - metric_kwds, - angular, - random_state, - verbose=verbose, - ) - + raise ValueError("knn_indices and knn_dists must be set") + + n_samples = knn_indices.shape[0] knn_dists = knn_dists.astype(np.float32) sigmas, rhos = smooth_knn_dist( @@ -586,7 +576,7 @@ def fuzzy_simplicial_set( ) result = scipy.sparse.coo_matrix( - (vals, (rows, cols)), shape=(X.shape[0], X.shape[0]) + (vals, (rows, cols)), shape=(n_samples, n_samples) ) result.eliminate_zeros() @@ -607,7 +597,7 @@ def fuzzy_simplicial_set( else: if return_dists: dmat = scipy.sparse.coo_matrix( - (dists, (rows, cols)), shape=(X.shape[0], X.shape[0]) + (dists, (rows, cols)), shape=(n_samples, n_samples) ) dists = dmat.maximum(dmat.transpose()).todok() @@ -1259,13 +1249,12 @@ def simplicial_set_embedding( ) emb_graph, emb_sigmas, emb_rhos, emb_dists = fuzzy_simplicial_set( - embedding, + knn_indices, + knn_dists, densmap_kwds["n_neighbors"], random_state, "euclidean", {}, - knn_indices, - knn_dists, verbose=verbose, return_dists=True, ) @@ -1833,7 +1822,7 @@ def _validate_parameters(self): self._target_metric_kwds = self.target_metric_kwds # check sparsity of data upfront to set proper _input_distance_func & # save repeated checks later on - if scipy.sparse.isspmatrix_csr(self._raw_data): + if self._raw_data is not None and scipy.sparse.isspmatrix_csr(self._raw_data): self._sparse_data = True else: self._sparse_data = False @@ -2032,7 +2021,7 @@ def _dist_only(x, y, *kwds): self.knn_indices = None self.knn_dists = None self.knn_search_index = None - elif self.knn_dists.shape[0] != self._raw_data.shape[0]: + elif self._raw_data is not None and self.knn_dists.shape[0] != self._raw_data.shape[0]: warn( "precomputed_knn has a different number of samples than the" " data you are fitting. precomputed_knn will be ignored and" @@ -2335,6 +2324,176 @@ def __sub__(self, other): result.rad_emb_ = aux_data["rad_emb"] return result + + def fit_transform_index(self, ensure_all_finite=True, **kwargs): + """Generate a UMAP embedding from a precomputed index + + Parameters + ---------- + ensure_all_finite : Whether to raise an error on np.inf, np.nan, pd.NA in array. + The possibilities are: - True: Force all values of array to be finite. + - False: accepts np.inf, np.nan, pd.NA in array. + - 'allow-nan': accepts only np.nan and pd.NA values in array. + Values cannot be infinite. + + **kwargs : optional + Any additional keyword arguments are passed to _fit_embed_data. + """ + self._raw_data = None + if self.precomputed_knn[0] is None or self.precomputed_knn[1] is None: + raise ValueError("precomputed_knn must be set") + + if self.init is None: + raise ValueError("init must be set") + + if isinstance(self.init, np.ndarray): + init = check_array( + self.init, + dtype=np.float32, + accept_sparse=False, + ensure_all_finite=ensure_all_finite, + ) + else: + raise ValueError("init must be a numpy array") + + if self.transform_mode != "embedding": + raise ValueError("fit_transform_index is only supported for embedding transform mode") + + # Handle all the optional arguments, setting default + if self.a is None or self.b is None: + self._a, self._b = find_ab_params(self.spread, self.min_dist) + else: + self._a = self.a + self._b = self.b + + self._initial_alpha = self.learning_rate + + self.knn_indices = self.precomputed_knn[0] + self.knn_dists = self.precomputed_knn[1] + # #848: allow precomputed knn to not have a search index + if len(self.precomputed_knn) == 2: + self.knn_search_index = None + else: + self.knn_search_index = self.precomputed_knn[2] + + self._validate_parameters() + + if self.n_neighbors > self.knn_indices.shape[1]: + raise ValueError("n_neighbors is larger than the dataset size") + + if self.verbose: + print(str(self)) + + self._original_n_threads = numba.get_num_threads() + if self.n_jobs > 0 and self.n_jobs is not None: + numba.set_num_threads(self.n_jobs) + + inverse = list(range(self.precomputed_knn[0].shape[0])) + random_state = check_random_state(self.random_state) + + if self.verbose: + print(ts(), "Construct fuzzy simplicial set") + + self._knn_indices = self.knn_indices + self._knn_dists = self.knn_dists + self._knn_search_index = self.knn_search_index + + # Disconnect any vertices farther apart than _disconnection_distance + disconnected_index = self._knn_dists >= self._disconnection_distance + self._knn_indices[disconnected_index] = -1 + self._knn_dists[disconnected_index] = np.inf + edges_removed = disconnected_index.sum() + + ( + self.graph_, + self._sigmas, + self._rhos, + self.graph_dists_, + ) = fuzzy_simplicial_set( + self._knn_indices, + self._knn_dists, + self.n_neighbors, + random_state, + self.metric, + self._metric_kwds, + self.angular_rp_forest, + self.set_op_mix_ratio, + self.local_connectivity, + True, + self.verbose, + self.densmap or self.output_dens, + ) + # Report the number of vertices with degree 0 in our umap.graph_ + # This ensures that they were properly disconnected. + vertices_disconnected = np.sum( + np.array(self.graph_.sum(axis=1)).flatten() == 0 + ) + raise_disconnected_warning( + edges_removed, + vertices_disconnected, + self._disconnection_distance, + self.knn_indices.shape[0], + verbose=self.verbose, + ) + self._input_distance_func = self.metric + + self._supervised = False + + if self.densmap or self.output_dens: + self._densmap_kwds["graph_dists"] = self.graph_dists_ + + if self.verbose: + print(ts(), "Construct embedding") + + epochs = ( + self.n_epochs_list if self.n_epochs_list is not None else self.n_epochs + ) + self.embedding_, aux_data = self._fit_embed_data( + X = None, + n_epochs=epochs, + init=init, + random_state = random_state, + **kwargs, + ) + + if self.n_epochs_list is not None: + if "embedding_list" not in aux_data: + raise KeyError( + "No list of embedding were found in 'aux_data'. " + "It is likely the layout optimization function " + "doesn't support the list of int for 'n_epochs'." + ) + else: + self.embedding_list_ = [ + e[inverse] for e in aux_data["embedding_list"] + ] + + # Assign any points that are fully disconnected from our manifold(s) to have embedding + # coordinates of np.nan. These will be filtered by our plotting functions automatically. + # They also prevent users from being deceived a distance query to one of these points. + # Might be worth moving this into simplicial_set_embedding or _fit_embed_data + disconnected_vertices = np.array(self.graph_.sum(axis=1)).flatten() == 0 + if len(disconnected_vertices) > 0: + self.embedding_[disconnected_vertices] = np.full( + self.n_components, np.nan + ) + + self.embedding_ = self.embedding_[inverse] + if self.output_dens: + self.rad_orig_ = aux_data["rad_orig"][inverse] + self.rad_emb_ = aux_data["rad_emb"][inverse] + + if self.verbose: + print(ts() + " Finished embedding") + + numba.set_num_threads(self._original_n_threads) + self._input_hash = joblib.hash(self._knn_indices) + + # Set number of features out for sklearn API + self._n_features_out = self.embedding_.shape[1] + + return self.embedding_ + def fit(self, X, y=None, ensure_all_finite=True, **kwargs): """Fit X into an embedded space. @@ -2520,13 +2679,12 @@ def fit(self, X, y=None, ensure_all_finite=True, **kwargs): self._rhos, self.graph_dists_, ) = fuzzy_simplicial_set( - X[index], + self._knn_indices, + self._knn_dists, self.n_neighbors, random_state, "precomputed", self._metric_kwds, - self._knn_indices, - self._knn_dists, self.angular_rp_forest, self.set_op_mix_ratio, self.local_connectivity, @@ -2585,19 +2743,29 @@ def fit(self, X, y=None, ensure_all_finite=True, **kwargs): # This will have no effect when _disconnection_distance is not set since it defaults to np.inf. edges_removed = np.sum(dmat >= self._disconnection_distance) dmat[dmat >= self._disconnection_distance] = np.inf + + knn_indices, knn_dists, _ = nearest_neighbors( + dmat, + self._n_neighbors, + "precomputed", + self._metric_kwds, + self.angular_rp_forest, + random_state, + verbose=False, + ) + ( self.graph_, self._sigmas, self._rhos, self.graph_dists_, ) = fuzzy_simplicial_set( - dmat, + knn_indices, + knn_dists, self._n_neighbors, random_state, "precomputed", self._metric_kwds, - None, - None, self.angular_rp_forest, self.set_op_mix_ratio, self.local_connectivity, @@ -2660,13 +2828,12 @@ def fit(self, X, y=None, ensure_all_finite=True, **kwargs): self._rhos, self.graph_dists_, ) = fuzzy_simplicial_set( - X[index], + self._knn_indices, + self._knn_dists, self.n_neighbors, random_state, nn_metric, self._metric_kwds, - self._knn_indices, - self._knn_dists, self.angular_rp_forest, self.set_op_mix_ratio, self.local_connectivity, @@ -2755,37 +2922,55 @@ def fit(self, X, y=None, ensure_all_finite=True, **kwargs): ensure_all_finite=ensure_all_finite, ) + knn_indices, knn_dists, _ = nearest_neighbors ( + ydmat, + target_n_neighbors, + "precomputed", + self._target_metric_kwds, + False, + random_state, + verbose=False, + ) + ( target_graph, target_sigmas, target_rhos, ) = fuzzy_simplicial_set( - ydmat, + knn_indices, + knn_dists, target_n_neighbors, random_state, "precomputed", self._target_metric_kwds, - None, - None, False, 1.0, 1.0, False, ) else: + + knn_indices, knn_dists, _ = nearest_neighbors( + y_, + target_n_neighbors, + self.target_metric, + self._target_metric_kwds, + False, + random_state, + verbose=False, + ) # Standard case ( target_graph, target_sigmas, target_rhos, ) = fuzzy_simplicial_set( - y_, + knn_indices, + knn_dists, target_n_neighbors, random_state, self.target_metric, self._target_metric_kwds, - None, - None, False, 1.0, 1.0, @@ -3420,14 +3605,23 @@ def update(self, X, ensure_all_finite=True): kwds=self._metric_kwds, ensure_all_finite=ensure_all_finite, ) - self.graph_, self._sigmas, self._rhos = fuzzy_simplicial_set( + + knn_indices, knn_dists, _ = nearest_neighbors( dmat, self._n_neighbors, + "precomputed", + self._metric_kwds, + self.angular_rp_forest, + random_state, + verbose=False, + ) + self.graph_, self._sigmas, self._rhos = fuzzy_simplicial_set( + knn_indices, + knn_dists, + self._n_neighbors, random_state, "precomputed", self._metric_kwds, - None, - None, self.angular_rp_forest, self.set_op_mix_ratio, self.local_connectivity, @@ -3463,13 +3657,12 @@ def update(self, X, ensure_all_finite=True): ) self.graph_, self._sigmas, self._rhos = fuzzy_simplicial_set( - self._raw_data, + self._knn_indices, + self._knn_dists, self.n_neighbors, random_state, nn_metric, self._metric_kwds, - self._knn_indices, - self._knn_dists, self.angular_rp_forest, self.set_op_mix_ratio, self.local_connectivity, @@ -3531,13 +3724,12 @@ def update(self, X, ensure_all_finite=True): nn_metric = self._input_distance_func self.graph_, self._sigmas, self._rhos = fuzzy_simplicial_set( - self._raw_data, + self._knn_indices, + self._knn_dists, self.n_neighbors, random_state, nn_metric, self._metric_kwds, - self._knn_indices, - self._knn_dists, self.angular_rp_forest, self.set_op_mix_ratio, self.local_connectivity,