From 098f0d770985079471ad97137da6be13c31a9b3c Mon Sep 17 00:00:00 2001 From: Robrecht Cannoodt Date: Mon, 25 Aug 2025 14:54:01 +0200 Subject: [PATCH] WIP add concat function --- NAMESPACE | 3 + R/concat.R | 416 ++++++++++++++++++++++++++ R/concat_helpers.R | 554 +++++++++++++++++++++++++++++++++++ R/concat_strategies.R | 184 ++++++++++++ man/concat.Rd | 107 +++++++ tests/testthat/test-concat.R | 69 +++++ 6 files changed, 1333 insertions(+) create mode 100644 R/concat.R create mode 100644 R/concat_helpers.R create mode 100644 R/concat_strategies.R create mode 100644 man/concat.Rd create mode 100644 tests/testthat/test-concat.R diff --git a/NAMESPACE b/NAMESPACE index 74d5bb21..9d3a1fd1 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,8 +2,11 @@ S3method(as_AnnData,Seurat) S3method(as_AnnData,SingleCellExperiment) +S3method(cbind,AbstractAnnData) +S3method(rbind,AbstractAnnData) export(AnnData) export(as_AnnData) +export(concat) export(generate_dataset) export(read_h5ad) export(write_h5ad) diff --git a/R/concat.R b/R/concat.R new file mode 100644 index 00000000..acf65a87 --- /dev/null +++ b/R/concat.R @@ -0,0 +1,416 @@ +#' Concatenate AnnData objects +#' +#' Combine multiple AnnData objects along specified axis with flexible +#' strategies for handling mismatched dimensions and metadata. +#' +#' Related files: +#' - concat_helpers.R: Helper functions for matrix and annotation concatenation +#' - concat_strategies.R: Merge strategies for handling non-aligned elements +#' +#' @param adatas List of AnnData objects to concatenate, or named list for batch labels +#' @param axis Which axis to concatenate along. Either "obs"/0 for observations (rows) +#' or "var"/1 for variables (columns) +#' @param join How to align the other axis. "outer" takes the union of indices, +#' "inner" takes the intersection +#' @param merge Strategy for elements not aligned to concatenation axis. One of: +#' \itemize{ +#' \item "same": Keep elements that are identical across all objects +#' \item "unique": Keep elements that appear in only one object (or are identical) +#' \item "first": Keep the first occurrence of each element +#' \item "only": Keep elements that appear in exactly one object +#' } +#' @param uns_merge Strategy for merging .uns metadata (same options as merge) +#' @param label Column name to add batch information to obs/var. NULL for no label +#' @param keys Batch labels. Defaults to names(adatas) if named list, otherwise sequential integers +#' @param index_unique Separator to make row/column names unique. NULL keeps original names +#' @param fill_value Value for missing data when join="outer". Default: 0 for sparse, NA for dense +#' @param backend Backend to use for result ("memory" or "hdf5") +#' +#' @return A new AnnData object containing the concatenated data +#' +#' @details +#' This function provides flexible concatenation of AnnData objects similar to +#' Python's anndata.concat(). It handles: +#' +#' - Mismatched observation/variable names through join strategies +#' - Complex metadata merging through merge strategies +#' - Batch tracking through labeling +#' - Memory-efficient operations for large datasets +#' +#' @examples +#' \dontrun{ +#' # Create example datasets +#' ad1 <- generate_dataset(n_obs = 100, n_vars = 50, format = "AnnData") +#' ad2 <- generate_dataset(n_obs = 80, n_vars = 50, format = "AnnData") +#' ad3 <- generate_dataset(n_obs = 120, n_vars = 50, format = "AnnData") +#' +#' # Basic row concatenation +#' combined <- concat(list(ad1, ad2, ad3), axis = "obs") +#' +#' # With batch tracking +#' combined <- concat(list(ctrl = ad1, treat = ad2), axis = "obs", label = "condition") +#' +#' # Inner join (intersection of variables) +#' combined <- concat(list(ad1, ad2), axis = "obs", join = "inner") +#' } +#' +#' @export +concat <- function( + adatas, + axis = c("obs", "var", 0, 1), + join = c("outer", "inner"), + merge = c("unique", "same", "first", "only"), + uns_merge = merge, + label = NULL, + keys = NULL, + index_unique = NULL, + fill_value = NULL, + backend = c("memory", "hdf5") +) { + # Input validation + if (is.null(adatas) || length(adatas) == 0) { + cli::cli_abort("adatas must be a non-empty list") + } + + # Handle named lists + if (is.null(keys) && !is.null(names(adatas))) { + keys <- names(adatas) + adatas <- unname(adatas) + } + + # Convert to list if single object + if (inherits(adatas, "AbstractAnnData")) { + adatas <- list(adatas) + } + + # Validate all objects are AnnData + if (!all(sapply(adatas, inherits, "AbstractAnnData"))) { + cli::cli_abort("All objects in adatas must inherit from AbstractAnnData") + } + + if (length(adatas) == 1) { + cli::cli_warn("Only one object provided, returning copy") + return(as_InMemoryAnnData(adatas[[1]])) + } + + # Resolve arguments + axis <- resolve_axis(axis) + join <- match.arg(join) + merge <- match.arg(merge) + uns_merge <- match.arg(uns_merge, c("unique", "same", "first", "only")) + backend <- match.arg(backend) + + if (is.null(keys)) { + keys <- as.character(seq_along(adatas)) + } else if (length(keys) != length(adatas)) { + cli::cli_abort("Length of keys must match length of adatas") + } + + # Check for empty objects + empty_mask <- sapply(adatas, function(ad) any(dim(ad) == 0)) + if (any(empty_mask)) { + cli::cli_warn("Removing {sum(empty_mask)} empty object{?s}") + adatas <- adatas[!empty_mask] + keys <- keys[!empty_mask] + + if (length(adatas) == 0) { + cli::cli_abort("No non-empty objects remaining") + } + } + + # Set default fill value + if (is.null(fill_value)) { + all_X <- lapply(adatas, function(ad) ad$X) + fill_value <- default_fill_value(all_X) + } + + # Perform concatenation + result <- concat_impl( + adatas = adatas, + axis = axis, + join = join, + merge = merge, + uns_merge = uns_merge, + label = label, + keys = keys, + index_unique = index_unique, + fill_value = fill_value, + backend = backend + ) + + return(result) +} + +#' Internal implementation of concat +#' @inheritParams concat +#' @keywords internal +#' @noRd +concat_impl <- function( + adatas, + axis, + join, + merge, + uns_merge, + label, + keys, + index_unique, + fill_value, + backend +) { + n_obs <- sapply(adatas, function(ad) ad$n_obs()) + n_vars <- sapply(adatas, function(ad) ad$n_vars()) + + # Determine concatenation and alignment axes + if (axis == 0L) { + # Concatenating observations (rows) + concat_axis <- 0L + align_axis <- 1L + concat_indices_list <- lapply(adatas, function(ad) ad$obs_names) + align_indices_list <- lapply(adatas, function(ad) ad$var_names) + } else { + # Concatenating variables (columns) + concat_axis <- 1L + align_axis <- 0L + concat_indices_list <- lapply(adatas, function(ad) ad$var_names) + align_indices_list <- lapply(adatas, function(ad) ad$obs_names) + } + + # Create batch labels + if (axis == 0L) { + batch_sizes <- n_obs + } else { + batch_sizes <- n_vars + } + + batch_labels <- rep(keys, batch_sizes) + batch_labels <- factor(batch_labels, levels = keys) + + # Merge and create indices for alignment axis + merged_align_indices <- merge_indices(align_indices_list, join = join) + reindexers <- lapply(align_indices_list, function(idx) { + gen_reindexer(merged_align_indices, idx) + }) + + # Create concatenated indices + concat_indices <- unlist(concat_indices_list) + if (!is.null(index_unique)) { + # Make indices unique by appending batch keys + concat_indices <- paste( + concat_indices, + rep(keys, batch_sizes), + sep = index_unique + ) + } else { + # Check if indices are already unique, if not make them unique + if (any(duplicated(concat_indices))) { + concat_indices <- paste(concat_indices, rep(keys, batch_sizes), sep = "_") + } + } + + # Concatenate main matrix (X) + X <- concat_X(adatas, reindexers, concat_axis, fill_value) + + # Concatenate observation/variable annotations + if (axis == 0L) { + # Concatenating observations - combine obs, merge var + obs <- concat_annotations( + lapply(adatas, function(ad) ad$obs), + concat_indices, + join = "outer" # obs always outer join along concat axis + ) + + if (!is.null(label)) { + obs[[label]] <- batch_labels + } + + var <- merge_annotations( + lapply(adatas, function(ad) ad$var), + merged_align_indices, + strategy = merge + ) + } else { + # Concatenating variables - merge obs, combine var + obs <- merge_annotations( + lapply(adatas, function(ad) ad$obs), + merged_align_indices, + strategy = merge + ) + + var <- concat_annotations( + lapply(adatas, function(ad) ad$var), + concat_indices, + join = "outer" # var always outer join along concat axis + ) + + if (!is.null(label)) { + var[[label]] <- batch_labels + } + } + + # Handle layers + layers <- concat_layers( + adatas, + reindexers, + concat_axis, + join, + fill_value, + merge + ) + + # Handle obsm/varm (observation/variable matrices) + if (axis == 0L) { + # Concatenating obs - combine obsm, merge varm + obsm <- concat_matrices_dict( + lapply(adatas, function(ad) ad$obsm), + axis = 0L, + join = "outer", + fill_value = fill_value + ) + + varm <- merge_matrices_dict( + lapply(adatas, function(ad) ad$varm), + reindexers, + axis = 0L, + strategy = merge + ) + } else { + # Concatenating var - merge obsm, combine varm + obsm <- merge_matrices_dict( + lapply(adatas, function(ad) ad$obsm), + reindexers, + axis = 0L, + strategy = merge + ) + + varm <- concat_matrices_dict( + lapply(adatas, function(ad) ad$varm), + axis = 0L, + join = "outer", + fill_value = fill_value + ) + } + + # Handle obsp/varp (pairwise matrices) - these need special block diagonal handling + obsp <- concat_pairwise_dict( + lapply(adatas, function(ad) ad$obsp), + batch_sizes, + axis = if (axis == 0L) 0L else NULL, + strategy = merge + ) + + varp <- concat_pairwise_dict( + lapply(adatas, function(ad) ad$varp), + batch_sizes, + axis = if (axis == 1L) 1L else NULL, + strategy = merge + ) + + # Handle uns (unstructured metadata) + uns <- merge_nested( + lapply(adatas, function(ad) ad$uns), + strategy = uns_merge + ) + + # Create result object + if (backend == "memory") { + result <- InMemoryAnnData$new( + X = X, + obs = obs, + var = var, + layers = layers, + obsm = obsm, + varm = varm, + obsp = obsp, + varp = varp, + uns = uns + ) + } else { + cli::cli_abort("HDF5 backend not yet implemented for concat") + } + + return(result) +} + +#' Resolve axis argument to integer +#' @param axis Axis specification +#' @return Integer axis (0 or 1) +#' @keywords internal +#' @noRd +resolve_axis <- function(axis) { + if (is.character(axis)) { + axis <- match.arg(axis, c("obs", "var")) + if (axis == "obs") { + return(0L) + } + if (axis == "var") return(1L) + } else if (is.numeric(axis)) { + axis <- as.integer(axis) + if (axis %in% c(0L, 1L)) return(axis) + } + + cli::cli_abort("axis must be 'obs', 'var', 0, or 1") +} + +#' @rdname concat +#' @method rbind AbstractAnnData +#' @export +rbind.AbstractAnnData <- function( + ..., + join = "outer", + merge = "unique", + label = "batch", + fill_value = NULL +) { + adatas <- list(...) + + # Create keys from argument names if available + arg_names <- ...names() + if (!is.null(arg_names) && !all(arg_names == "")) { + keys <- arg_names + keys[keys == ""] <- paste0("X", which(keys == "")) + } else { + keys <- paste0("X", seq_along(adatas)) + } + + concat( + adatas, + axis = "obs", + join = join, + merge = merge, + label = label, + keys = keys, + fill_value = fill_value + ) +} + +#' @rdname concat +#' @method cbind AbstractAnnData +#' @export +cbind.AbstractAnnData <- function( + ..., + join = "outer", + merge = "unique", + label = "batch", + fill_value = NULL +) { + adatas <- list(...) + + # Create keys from argument names if available + arg_names <- ...names() + if (!is.null(arg_names) && !all(arg_names == "")) { + keys <- arg_names + keys[keys == ""] <- paste0("X", which(keys == "")) + } else { + keys <- paste0("X", seq_along(adatas)) + } + + concat( + adatas, + axis = "var", + join = join, + merge = merge, + label = label, + keys = keys, + fill_value = fill_value + ) +} diff --git a/R/concat_helpers.R b/R/concat_helpers.R new file mode 100644 index 00000000..1e7c783c --- /dev/null +++ b/R/concat_helpers.R @@ -0,0 +1,554 @@ +#' Helper functions for concatenating AnnData objects +#' +#' This file contains internal functions used for concatenating AnnData objects, +#' including reindexing, matrix concatenation, and annotation merging. +#' +#' Related files: +#' - concat.R: Main concatenation interface (concat, rbind, cbind) +#' - concat_strategies.R: Merge strategies for handling non-aligned elements +#' +#' @keywords internal +#' @name concat_helpers + +#' Reindexer class for handling dimension alignment during concatenation +#' +#' This class handles the complex logic of reindexing matrices and data frames +#' when concatenating AnnData objects with mismatched dimensions. +#' +#' @field old_idx Original index (character vector) +#' @field new_idx Target index (character vector) +#' @field old_pos Positions in original index that will be kept (integer vector) +#' @field new_pos Positions in new index where data will be placed (integer vector) +#' @field no_change Whether indices are identical (logical) +#' +#' @keywords internal +#' @noRd +Reindexer <- R6::R6Class( + "Reindexer", + public = list( + old_idx = NULL, + new_idx = NULL, + old_pos = NULL, + new_pos = NULL, + no_change = NULL, + + #' @description + #' Create a new Reindexer + #' @param old_idx Original index + #' @param new_idx Target index + initialize = function(old_idx, new_idx) { + self$old_idx <- as.character(old_idx) + self$new_idx <- as.character(new_idx) + self$no_change <- identical(self$old_idx, self$new_idx) + + if (!self$no_change) { + # Find positions for reindexing (1-based indexing in R) + new_pos <- match(self$old_idx, self$new_idx) + old_pos <- seq_along(new_pos) + + # Keep only valid matches (remove NA positions) + mask <- !is.na(new_pos) + self$new_pos <- new_pos[mask] + self$old_pos <- old_pos[mask] + } + }, + + #' @description + #' Apply reindexing to an element + #' @param el Element to reindex (matrix, data.frame, etc.) + #' @param axis Axis to reindex along (0 for rows, 1 for columns) + #' @param fill_value Value to use for missing positions + apply = function(el, axis = 1L, fill_value = NULL) { + if (self$no_change) { + return(el) + } + + if (is.data.frame(el)) { + return(self$apply_to_dataframe(el, axis, fill_value)) + } else if (inherits(el, "sparseMatrix") || methods::is(el, "Matrix")) { + return(self$apply_to_sparse(el, axis, fill_value)) + } else if (is.matrix(el) || is.array(el)) { + return(self$apply_to_dense(el, axis, fill_value)) + } else { + cli::cli_abort("Cannot reindex object of class {class(el)[1]}") + } + }, + + #' @description Apply reindexing to data.frame + apply_to_dataframe = function(el, axis, fill_value = NULL) { + if (is.null(fill_value)) { + fill_value <- NA + } + + if (axis == 0L) { + # Reindex rows + result <- el[rep(NA_integer_, length(self$new_idx)), , drop = FALSE] + rownames(result) <- self$new_idx + if (length(self$new_pos) > 0) { + result[self$new_pos, ] <- el[self$old_pos, , drop = FALSE] + } + } else { + # Reindex columns + result <- data.frame(matrix( + fill_value, + nrow = nrow(el), + ncol = length(self$new_idx) + )) + colnames(result) <- self$new_idx + rownames(result) <- rownames(el) + if (length(self$new_pos) > 0) { + result[, self$new_pos] <- el[, self$old_pos, drop = FALSE] + } + } + + return(result) + }, + + #' @description Apply reindexing to sparse matrix + apply_to_sparse = function(el, axis, fill_value = NULL) { + if (is.null(fill_value)) { + fill_value <- 0 + } + + if (axis == 0L) { + # Reindex rows + new_nrow <- length(self$new_idx) + if (length(self$new_pos) == 0) { + # No matching rows - return empty sparse matrix + result <- Matrix::sparseMatrix( + i = integer(0), + j = integer(0), + x = numeric(0), + dims = c(new_nrow, ncol(el)) + ) + } else { + # Create indexing matrix for efficient sparse reindexing + idx_i <- rep(self$new_pos, each = ncol(el)) + idx_j <- rep(seq_len(ncol(el)), times = length(self$new_pos)) + + # Extract values from original positions + orig_vals <- as.matrix(el[self$old_pos, , drop = FALSE]) + vals <- as.numeric(orig_vals) + + result <- Matrix::sparseMatrix( + i = idx_i, + j = idx_j, + x = vals, + dims = c(new_nrow, ncol(el)) + ) + } + rownames(result) <- self$new_idx + colnames(result) <- colnames(el) + } else { + # Reindex columns + new_ncol <- length(self$new_idx) + if (length(self$new_pos) == 0) { + # No matching columns - return empty sparse matrix + result <- Matrix::sparseMatrix( + i = integer(0), + j = integer(0), + x = numeric(0), + dims = c(nrow(el), new_ncol) + ) + } else { + # Create indexing matrix for efficient sparse reindexing + idx_i <- rep(seq_len(nrow(el)), times = length(self$new_pos)) + idx_j <- rep(self$new_pos, each = nrow(el)) + + # Extract values from original positions + orig_vals <- as.matrix(el[, self$old_pos, drop = FALSE]) + vals <- as.numeric(orig_vals) + + result <- Matrix::sparseMatrix( + i = idx_i, + j = idx_j, + x = vals, + dims = c(nrow(el), new_ncol) + ) + } + rownames(result) <- rownames(el) + colnames(result) <- self$new_idx + } + + return(result) + }, + + #' @description Apply reindexing to dense matrix/array + apply_to_dense = function(el, axis, fill_value = NULL) { + if (is.null(fill_value)) { + fill_value <- NA + } + + if (axis == 0L) { + # Reindex rows + result <- array(fill_value, dim = c(length(self$new_idx), ncol(el))) + if (length(self$new_pos) > 0) { + result[self$new_pos, ] <- el[self$old_pos, , drop = FALSE] + } + rownames(result) <- self$new_idx + colnames(result) <- colnames(el) + } else { + # Reindex columns + result <- array(fill_value, dim = c(nrow(el), length(self$new_idx))) + if (length(self$new_pos) > 0) { + result[, self$new_pos] <- el[, self$old_pos, drop = FALSE] + } + rownames(result) <- rownames(el) + colnames(result) <- self$new_idx + } + + return(result) + } + ) +) + +#' Create a reindexer for aligning dimensions +#' @param new_idx Target index +#' @param cur_idx Current index +#' @return A Reindexer object +#' @keywords internal +#' @noRd +gen_reindexer <- function(new_idx, cur_idx) { + Reindexer$new(cur_idx, new_idx) +} + +#' Merge indices using specified join strategy +#' @param indices_list List of index vectors +#' @param join Join strategy ("inner" or "outer") +#' @return Merged index vector +#' @keywords internal +#' @noRd +merge_indices <- function(indices_list, join = "outer") { + if (join == "inner") { + # Intersection of all indices + result <- indices_list[[1]] + for (i in seq_along(indices_list)[-1]) { + result <- intersect(result, indices_list[[i]]) + } + return(result) + } else if (join == "outer") { + # Union of all indices, preserving order + result <- character(0) + for (idx in indices_list) { + result <- union(result, idx) + } + return(result) + } else { + cli::cli_abort("Join must be 'inner' or 'outer', not {join}") + } +} + +#' @noRd +gen_outer_reindexers <- function(indices_list, merged_index) { + lapply(indices_list, function(idx) gen_reindexer(merged_index, idx)) +} + +#' @noRd +gen_inner_reindexers <- function(indices_list) { + common_idx <- merge_indices(indices_list, join = "inner") + lapply(indices_list, function(idx) gen_reindexer(common_idx, idx)) +} + +#' @noRd +default_fill_value <- function(elements) { + # Check if any element is sparse - use 0 for sparse matrices + if ( + any(sapply(elements, function(x) { + if (is.null(x)) { + return(FALSE) + } + inherits(x, "sparseMatrix") || + methods::is(x, "sparseMatrix") || + (methods::is(x, "Matrix") && attr(class(x), "package") == "Matrix") + })) + ) { + return(0) + } else { + return(NA) + } +} + +#' @noRd +concat_X <- function(adatas, reindexers, axis, fill_value) { + X_matrices <- lapply(adatas, function(ad) ad$X) + + # Check if all X are NULL + all_null <- all(sapply(X_matrices, is.null)) + any_null <- any(sapply(X_matrices, is.null)) + + if (all_null) { + return(NULL) + } else if (any_null) { + cli::cli_abort( + "Cannot concatenate: some (but not all) AnnData objects have X = NULL" + ) + } + + # Apply reindexing to align dimensions + reindexed <- Map( + function(X, reindexer) { + if (axis == 0L) { + # Concatenating rows - reindex columns + reindexer$apply(X, axis = 1L, fill_value = fill_value) + } else { + # Concatenating columns - reindex rows + reindexer$apply(X, axis = 0L, fill_value = fill_value) + } + }, + X_matrices, + reindexers + ) + + # Concatenate along specified axis + if (axis == 0L) { + # Stack rows + do.call(rbind, reindexed) + } else { + # Stack columns + do.call(cbind, reindexed) + } +} + +#' @noRd +concat_annotations <- function(annotations, new_index, join = "outer") { + if (length(annotations) == 0) { + return(data.frame()) + } + + # Remove empty annotations + annotations <- annotations[sapply(annotations, function(x) nrow(x) > 0)] + + if (length(annotations) == 0) { + return(data.frame(row.names = new_index)) + } + + # Concatenate data frames + result <- do.call(rbind, annotations) + + # Set new index + rownames(result) <- new_index + + return(result) +} + +#' @noRd +merge_annotations <- function(annotations, new_index, strategy) { + if (length(annotations) == 0) { + return(data.frame()) + } + + # Apply merge strategy + merged <- apply_merge_strategy(annotations, strategy) + + # Convert to data.frame with proper index + if (length(merged) == 0) { + result <- data.frame(row.names = new_index) + } else { + # Take first annotation as template for structure + template <- annotations[[1]] + result <- template[new_index, , drop = FALSE] + + # Replace columns with merged values + for (col_name in names(merged)) { + if (col_name %in% colnames(template)) { + result[[col_name]] <- merged[[col_name]] + } + } + } + + return(result) +} + +#' @noRd +concat_layers <- function(adatas, reindexers, axis, join, fill_value, merge) { + all_layers <- lapply(adatas, function(ad) ad$layers) + + # Get layer names based on join strategy + layer_names <- merge_indices(lapply(all_layers, names), join = join) + + if (length(layer_names) == 0) { + return(list()) + } + + result_layers <- list() + + for (layer_name in layer_names) { + # Get matrices for this layer from all objects + layer_matrices <- lapply(all_layers, function(layers) { + if (layer_name %in% names(layers)) { + layers[[layer_name]] + } else { + NULL + } + }) + + # Handle case where not all objects have this layer + if (any(sapply(layer_matrices, is.null))) { + if (join == "inner") { + next # Skip layers not present in all objects + } else { + # For outer join, create empty matrices for missing layers + for (i in seq_along(layer_matrices)) { + if (is.null(layer_matrices[[i]])) { + # Create empty matrix with correct dimensions + if (axis == 0L) { + dims <- c(adatas[[i]]$n_obs(), length(reindexers[[i]]$new_idx)) + } else { + dims <- c(length(reindexers[[i]]$new_idx), adatas[[i]]$n_vars()) + } + layer_matrices[[i]] <- Matrix::sparseMatrix( + i = integer(0), + j = integer(0), + x = numeric(0), + dims = dims + ) + } + } + } + } + + # Apply reindexing and concatenate + reindexed <- Map( + function(mat, reindexer) { + if (is.null(mat)) { + return(NULL) + } + + if (axis == 0L) { + reindexer$apply(mat, axis = 1L, fill_value = fill_value) + } else { + reindexer$apply(mat, axis = 0L, fill_value = fill_value) + } + }, + layer_matrices, + reindexers + ) + + # Remove NULL entries + reindexed <- reindexed[!sapply(reindexed, is.null)] + + if (length(reindexed) > 0) { + if (axis == 0L) { + result_layers[[layer_name]] <- do.call(rbind, reindexed) + } else { + result_layers[[layer_name]] <- do.call(cbind, reindexed) + } + } + } + + return(result_layers) +} + +#' @noRd +concat_matrices_dict <- function(matrices_list, axis, join, fill_value) { + # Get all matrix names + all_names <- unique(unlist(lapply(matrices_list, names))) + + if (length(all_names) == 0) { + return(list()) + } + + result <- list() + + for (mat_name in all_names) { + # Get matrices for this name + matrices <- lapply(matrices_list, function(mats) { + if (mat_name %in% names(mats)) mats[[mat_name]] else NULL + }) + + # Skip if not all objects have this matrix for inner join + if (join == "inner" && any(sapply(matrices, is.null))) { + next + } + + # Remove NULL entries and concatenate + matrices <- matrices[!sapply(matrices, is.null)] + if (length(matrices) > 0) { + if (axis == 0L) { + result[[mat_name]] <- do.call(rbind, matrices) + } else { + result[[mat_name]] <- do.call(cbind, matrices) + } + } + } + + return(result) +} + +#' @noRd +merge_matrices_dict <- function(matrices_list, reindexers, axis, strategy) { + # Apply merge strategy to get surviving matrices + merged <- apply_merge_strategy(matrices_list, strategy) + + if (length(merged) == 0) { + return(list()) + } + + # Note: For merge operations, we typically don't need reindexing + # as the merged matrices should already have compatible dimensions + # But we include reindexers parameter for API consistency + + return(merged) +} + +#' @noRd +concat_pairwise_dict <- function( + pairwise_list, + batch_sizes, + axis = NULL, + strategy +) { + if (is.null(axis)) { + # Not concatenating along this axis - merge using strategy + return(apply_merge_strategy(pairwise_list, strategy)) + } + + # Get all matrix names + all_names <- unique(unlist(lapply(pairwise_list, names))) + + if (length(all_names) == 0) { + return(list()) + } + + result <- list() + + for (mat_name in all_names) { + # Get matrices for this name + matrices <- lapply(pairwise_list, function(mats) { + if (mat_name %in% names(mats)) mats[[mat_name]] else NULL + }) + + # Remove NULL entries + non_null_matrices <- matrices[!sapply(matrices, is.null)] + non_null_sizes <- batch_sizes[!sapply(matrices, is.null)] + + if (length(non_null_matrices) > 0) { + # Create block diagonal matrix + result[[mat_name]] <- create_block_diagonal( + non_null_matrices, + non_null_sizes + ) + } + } + + return(result) +} + +#' @noRd +create_block_diagonal <- function(matrices, sizes) { + if (length(matrices) == 0) { + return(Matrix::sparseMatrix( + i = integer(0), + j = integer(0), + x = numeric(0), + dims = c(0, 0) + )) + } + + if (length(matrices) == 1) { + return(matrices[[1]]) + } + + # Use Matrix::bdiag for efficient block diagonal construction + Matrix::bdiag(matrices) +} diff --git a/R/concat_strategies.R b/R/concat_strategies.R new file mode 100644 index 00000000..f2dc2aed --- /dev/null +++ b/R/concat_strategies.R @@ -0,0 +1,184 @@ +#' Concatenation strategies for anndataR +#' +#' Functions that implement different strategies for merging elements +#' that are not aligned to the concatenation axis when concatenating AnnData objects. +#' +#' @name concat_strategies +#' @keywords internal + +#' @noRd +apply_merge_strategy <- function(mappings, strategy) { + if (is.character(strategy)) { + strategy <- get_merge_strategy(strategy) + } + + # Get union of all keys + all_keys <- unique(unlist(lapply(mappings, names))) + + result <- list() + for (key in all_keys) { + # Get values for this key from all mappings + values <- lapply(mappings, function(m) { + if (key %in% names(m)) m[[key]] else NULL + }) + + # Apply strategy to determine which value to keep + merged_value <- strategy(values) + if (!is.null(merged_value)) { + result[[key]] <- merged_value + } + } + + return(result) +} + +#' @noRd +get_merge_strategy <- function(strategy_name) { + strategies <- list( + "same" = merge_same, + "unique" = merge_unique, + "first" = merge_first, + "only" = merge_only + ) + + if (!strategy_name %in% names(strategies)) { + cli::cli_abort( + "Unknown merge strategy: {strategy_name}. Must be one of: {names(strategies)}" + ) + } + + return(strategies[[strategy_name]]) +} + +#' @noRd +merge_same <- function(values) { + # Remove NULL values + non_null_values <- values[!sapply(values, is.null)] + + if (length(non_null_values) == 0) { + return(NULL) + } + + # Check if all non-NULL values are identical + first_value <- non_null_values[[1]] + all_same <- all(sapply(non_null_values, function(x) { + if (is.data.frame(x) && is.data.frame(first_value)) { + return(identical(x, first_value)) + } else if (is.matrix(x) && is.matrix(first_value)) { + return(identical(x, first_value)) + } else { + return(identical(x, first_value)) + } + })) + + if (all_same) { + return(first_value) + } else { + return(NULL) + } +} + +#' @noRd +merge_unique <- function(values) { + # Remove NULL values + non_null_values <- values[!sapply(values, is.null)] + + if (length(non_null_values) == 0) { + return(NULL) + } else if (length(non_null_values) == 1) { + return(non_null_values[[1]]) + } else { + # Check if all values are identical + return(merge_same(values)) + } +} + +#' @noRd +merge_first <- function(values) { + for (value in values) { + if (!is.null(value)) { + return(value) + } + } + return(NULL) +} + +#' @noRd +merge_only <- function(values) { + # Remove NULL values + non_null_values <- values[!sapply(values, is.null)] + + if (length(non_null_values) == 1) { + return(non_null_values[[1]]) + } else { + return(NULL) + } +} + +#' @noRd +merge_nested <- function(mappings, strategy) { + if (is.character(strategy)) { + strategy <- get_merge_strategy(strategy) + } + + # Get all top-level keys + all_keys <- unique(unlist(lapply(mappings, names))) + + result <- list() + for (key in all_keys) { + # Get values for this key from all mappings + values <- lapply(mappings, function(m) { + if (key %in% names(m)) m[[key]] else NULL + }) + + # Check if all non-NULL values are themselves mappings (lists) + non_null_values <- values[!sapply(values, is.null)] + if ( + length(non_null_values) > 0 && + all(sapply(non_null_values, function(x) { + is.list(x) && !is.data.frame(x) + })) + ) { + # Recursively merge nested mappings + nested_result <- merge_nested(non_null_values, strategy) + if (length(nested_result) > 0) { + result[[key]] <- nested_result + } + } else { + # Apply strategy to leaf values + merged_value <- strategy(values) + if (!is.null(merged_value)) { + result[[key]] <- merged_value + } + } + } + + return(result) +} + +#' @noRd +merge_uns_with_batches <- function(uns_list, batch_keys, join_str = "_batch_") { + if (length(uns_list) != length(batch_keys)) { + cli::cli_abort("Length of uns_list and batch_keys must be equal") + } + + # First apply normal merge strategy + merged <- apply_merge_strategy(uns_list, "unique") + + # Then add batch-specific keys for elements that didn't merge + all_keys <- unique(unlist(lapply(uns_list, names))) + + for (key in all_keys) { + if (!key %in% names(merged)) { + # This key didn't survive merging, add batch-specific versions + for (i in seq_along(uns_list)) { + if (key %in% names(uns_list[[i]])) { + batch_key <- paste0(key, join_str, batch_keys[i]) + merged[[batch_key]] <- uns_list[[i]][[key]] + } + } + } + } + + return(merged) +} diff --git a/man/concat.Rd b/man/concat.Rd new file mode 100644 index 00000000..f0e4f2c1 --- /dev/null +++ b/man/concat.Rd @@ -0,0 +1,107 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/concat.R +\name{concat} +\alias{concat} +\alias{rbind.AbstractAnnData} +\alias{cbind.AbstractAnnData} +\title{Concatenate AnnData objects} +\usage{ +concat( + adatas, + axis = c("obs", "var", 0, 1), + join = c("outer", "inner"), + merge = c("unique", "same", "first", "only"), + uns_merge = merge, + label = NULL, + keys = NULL, + index_unique = NULL, + fill_value = NULL, + backend = c("memory", "hdf5") +) + +\method{rbind}{AbstractAnnData}( + ..., + join = "outer", + merge = "unique", + label = "batch", + fill_value = NULL +) + +\method{cbind}{AbstractAnnData}( + ..., + join = "outer", + merge = "unique", + label = "batch", + fill_value = NULL +) +} +\arguments{ +\item{adatas}{List of AnnData objects to concatenate, or named list for batch labels} + +\item{axis}{Which axis to concatenate along. Either "obs"/0 for observations (rows) +or "var"/1 for variables (columns)} + +\item{join}{How to align the other axis. "outer" takes the union of indices, +"inner" takes the intersection} + +\item{merge}{Strategy for elements not aligned to concatenation axis. One of: +\itemize{ +\item "same": Keep elements that are identical across all objects +\item "unique": Keep elements that appear in only one object (or are identical) +\item "first": Keep the first occurrence of each element +\item "only": Keep elements that appear in exactly one object +}} + +\item{uns_merge}{Strategy for merging .uns metadata (same options as merge)} + +\item{label}{Column name to add batch information to obs/var. NULL for no label} + +\item{keys}{Batch labels. Defaults to names(adatas) if named list, otherwise sequential integers} + +\item{index_unique}{Separator to make row/column names unique. NULL keeps original names} + +\item{fill_value}{Value for missing data when join="outer". Default: 0 for sparse, NA for dense} + +\item{backend}{Backend to use for result ("memory" or "hdf5")} +} +\value{ +A new AnnData object containing the concatenated data +} +\description{ +Combine multiple AnnData objects along specified axis with flexible +strategies for handling mismatched dimensions and metadata. +} +\details{ +Related files: +\itemize{ +\item concat_helpers.R: Helper functions for matrix and annotation concatenation +\item concat_strategies.R: Merge strategies for handling non-aligned elements +} + +This function provides flexible concatenation of AnnData objects similar to +Python's anndata.concat(). It handles: +\itemize{ +\item Mismatched observation/variable names through join strategies +\item Complex metadata merging through merge strategies +\item Batch tracking through labeling +\item Memory-efficient operations for large datasets +} +} +\examples{ +\dontrun{ +# Create example datasets +ad1 <- generate_dataset(n_obs = 100, n_vars = 50, format = "AnnData") +ad2 <- generate_dataset(n_obs = 80, n_vars = 50, format = "AnnData") +ad3 <- generate_dataset(n_obs = 120, n_vars = 50, format = "AnnData") + +# Basic row concatenation +combined <- concat(list(ad1, ad2, ad3), axis = "obs") + +# With batch tracking +combined <- concat(list(ctrl = ad1, treat = ad2), axis = "obs", label = "condition") + +# Inner join (intersection of variables) +combined <- concat(list(ad1, ad2), axis = "obs", join = "inner") +} + +} diff --git a/tests/testthat/test-concat.R b/tests/testthat/test-concat.R new file mode 100644 index 00000000..d396e14b --- /dev/null +++ b/tests/testthat/test-concat.R @@ -0,0 +1,69 @@ +test_that("concat works with basic functionality", { + # Create test datasets + ad1 <- generate_dataset(n_obs = 10, n_vars = 5, format = "AnnData") + ad2 <- generate_dataset(n_obs = 8, n_vars = 5, format = "AnnData") + + # Test basic row concatenation + result <- concat(list(ad1, ad2), axis = "obs") + + expect_s3_class(result, "InMemoryAnnData") + expect_s3_class(result, "AbstractAnnData") + expect_equal(nrow(result), 18) # 10 + 8 + expect_equal(ncol(result), 5) + + # Test rbind + result_rbind <- rbind(ad1, ad2) + expect_equal(dim(result_rbind), c(18, 5)) + + # Test cbind with different number of variables + ad3 <- generate_dataset(n_obs = 10, n_vars = 3, format = "AnnData") + result_cbind <- cbind(ad1, ad3) + expect_equal(dim(result_cbind), c(10, 8)) # 5 + 3 +}) + +test_that("concat handles mismatched dimensions correctly", { + ad1 <- generate_dataset(n_obs = 10, n_vars = 5, format = "AnnData") + ad2 <- generate_dataset(n_obs = 10, n_vars = 3, format = "AnnData") + + # Make variable names different to test joins properly + ad1$var_names <- paste0("gene_A_", 1:5) + ad2$var_names <- paste0("gene_B_", 1:3) + + # Inner join should take intersection of variables (none in this case) + result_inner <- concat(list(ad1, ad2), axis = "obs", join = "inner") + expect_equal(ncol(result_inner), 0) # No common variables + + # Outer join should take union + result_outer <- concat(list(ad1, ad2), axis = "obs", join = "outer") + expect_equal(ncol(result_outer), 8) # 5 + 3 unique variables + expect_equal(nrow(result_outer), 20) # 10 + 10 +}) + +test_that("concat handles empty inputs gracefully", { + expect_error(concat(list()), "adatas must be a non-empty list") + + ad1 <- generate_dataset(n_obs = 10, n_vars = 5, format = "AnnData") + + # Single object should return copy with warning + expect_warning(result <- concat(list(ad1)), "Only one object provided") + expect_equal(dim(result), dim(ad1)) +}) + +test_that("concat preserves metadata correctly", { + ad1 <- generate_dataset(n_obs = 5, n_vars = 3, format = "AnnData") + ad2 <- generate_dataset(n_obs = 4, n_vars = 3, format = "AnnData") + + # Add some test metadata + ad1$uns$test_meta <- "from_ad1" + ad2$uns$test_meta <- "from_ad2" + ad1$uns$unique_to_ad1 <- "only_in_ad1" + + result <- concat(list(ad1, ad2), axis = "obs", label = "batch") + + # Check batch labels were added + expect_true("batch" %in% colnames(result$obs)) + expect_equal(levels(result$obs$batch), c("1", "2")) + + # Check metadata merging (should use "unique" strategy by default) + expect_true("unique_to_ad1" %in% names(result$uns)) +})