Source code for preprocessing_pipeline.metacells

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import logging
import scipy.sparse as sp

logger = logging.getLogger(__name__)

[docs] def create_metacells( data_rna: ad.AnnData, data_atac: ad.AnnData, grouping_key: str = "leiden", resolution: float = 5.0, batch_key: str = "sample" ) -> tuple[ad.AnnData, ad.AnnData]: """ Create metacell-level RNA and ATAC AnnData objects by clustering cells and computing mean values per cluster. This function: 1. Normalizes and logs the RNA data, then runs PCA. 2. Uses Harmony integration for batch correction on the PCA embeddings. 3. Clusters the RNA data with Leiden at the specified resolution, storing the cluster labels in ``data_rna.obs[grouping_key]``. 4. Summarizes RNA expression and ATAC accessibility for each cluster by taking the mean of each feature across all cells in that cluster. Parameters ---------- data_rna : anndata.AnnData RNA single-cell data. A layer "counts" is added and re-assigned later. The shape is (n_cells, n_genes). data_atac : anndata.AnnData ATAC single-cell data with the same set or superset of cell IDs in `data_rna.obs_names`. The shape is (n_cells, n_peaks). grouping_key : str, optional The key in `data_rna.obs` where the Leiden cluster labels will be stored. Default is "leiden". resolution : float, optional The resolution parameter for Leiden clustering. Higher values yield more clusters. Default is 5.0. batch_key : str, optional The column in `data_rna.obs` indicating batch information for Harmony integration. Default is "sample". Returns ------- (rna_metacell, atac_metacell) : tuple of anndata.AnnData - rna_metacell : shape (#clusters, n_genes) - atac_metacell : shape (#clusters, n_peaks) The `.obs` index is set to the cluster labels, and the `.var` is inherited from the original `data_rna`/`data_atac`. Notes ----- - The function uses `scanpy.external.pp.harmony_integrate` for batch integration on the PCA representation stored in "X_pca_harmony". - The ATAC data is transformed by `(atac_vals + 1) // 2` to interpret insertions as fragment presence, following Martens et al. (2023). - Mean values are computed across cells in each cluster for both RNA and ATAC. - The original `data_rna.X` is restored to raw counts at the end. """ import scanpy.external as sce logger.info(f"Creating metacells with resolution={resolution} (grouping key={grouping_key}).") # Keep original counts in a layer data_rna.layers["counts"] = data_rna.X.copy() # Normalize & run PCA sc.pp.normalize_total(data_rna) sc.pp.log1p(data_rna) sc.pp.pca(data_rna) # Harmony integration sce.pp.harmony_integrate(data_rna, batch_key) sc.pp.neighbors(data_rna, use_rep="X_pca_harmony") sc.tl.leiden(data_rna, resolution=resolution, key_added=grouping_key) # Summarize by cluster clusters = data_rna.obs[grouping_key].unique() cluster_groups = data_rna.obs.groupby(grouping_key) mean_rna_list = [] mean_atac_list = [] cluster_names = [] for cluster_name in clusters: cell_idx = cluster_groups.get_group(cluster_name).index # RNA rna_vals = data_rna[cell_idx].X if sp.issparse(rna_vals): rna_vals = rna_vals.toarray() mean_rna = np.array(rna_vals.mean(axis=0)).ravel() mean_rna_list.append(mean_rna) # ATAC if len(set(cell_idx).intersection(data_atac.obs_names)) == 0: mean_atac_list.append(np.zeros(data_atac.shape[1])) else: atac_vals = data_atac[cell_idx].X if sp.issparse(atac_vals): atac_vals = atac_vals.toarray() # Convert insertions to fragments atac_bin = (atac_vals + 1) // 2 mean_atac = np.array(atac_bin.mean(axis=0)).ravel() mean_atac_list.append(mean_atac) cluster_names.append(cluster_name) # Build new AnnData for metacells mean_rna_arr = np.vstack(mean_rna_list) mean_atac_arr = np.vstack(mean_atac_list) obs_df = pd.DataFrame({grouping_key: cluster_names}).set_index(grouping_key) rna_metacell = ad.AnnData(X=mean_rna_arr, obs=obs_df, var=data_rna.var) atac_metacell = ad.AnnData(X=mean_atac_arr, obs=obs_df, var=data_atac.var) # Restore original raw counts data_rna.X = data_rna.layers["counts"].copy() logger.info(f"Metacell shapes: RNA={rna_metacell.shape}, ATAC={atac_metacell.shape}") return rna_metacell, atac_metacell