Source code for preprocessing_pipeline.peak_selection
import numpy as np
import anndata as ad
import logging
import scipy.sparse as sp
logger = logging.getLogger(__name__)
[docs]
def select_highly_variable_peaks_by_std(
data_atac: ad.AnnData,
n_top_peaks: int,
cluster_key: str = "leiden"
) -> ad.AnnData:
"""
Select highly variable peaks based on the standard deviation of peak accessibility
across clusters.
This function:
1. Groups cells by `cluster_key` in `data_atac.obs`.
2. Computes ATAC fragment counts and calculates mean accessibility
per peak for each cluster.
3. Computes the standard deviation of each peak's accessibility across clusters.
4. Selects the top `n_top_peaks` peaks with the highest standard deviation.
Parameters
----------
data_atac : anndata.AnnData
An AnnData containing ATAC data. Expects `.obs[cluster_key]` to exist.
n_top_peaks : int
Number of peaks to retain based on the highest standard deviation across clusters.
cluster_key : str, optional
The column in `data_atac.obs` specifying cluster labels. Defaults to "leiden".
Returns
-------
anndata.AnnData
A subset of `data_atac` containing only the selected peaks.
If `n_top_peaks >= data_atac.shape[1]`, returns the original data.
Notes
-----
- If `cluster_key` is missing, a warning is logged and the original AnnData is returned.
- Transformation `(X + 1) // 2` to interpret insertions as fragment counts.
"""
if cluster_key not in data_atac.obs.columns:
logger.warning(f"{cluster_key} not found in data_atac.obs; skipping peak selection.")
return data_atac
clusters = data_atac.obs[cluster_key].unique()
cluster_groups = data_atac.obs.groupby(cluster_key)
mean_list = []
for c_label in clusters:
idx_cells = cluster_groups.get_group(c_label).index
mat = data_atac[idx_cells].X
if sp.issparse(mat):
mat = mat.toarray()
mat = (mat + 1) // 2 # get fragment presence
mean_vec = mat.mean(axis=0).A1 if hasattr(mat, "A1") else mat.mean(axis=0)
mean_list.append(mean_vec)
# shape => (n_clusters, n_peaks)
cluster_matrix = np.vstack(mean_list)
stdev_peaks = cluster_matrix.std(axis=0)
data_atac.var["std_cluster"] = stdev_peaks
if n_top_peaks < data_atac.shape[1]:
sorted_idx = np.argsort(stdev_peaks)[::-1]
keep_idx = sorted_idx[:n_top_peaks]
mask = np.zeros(data_atac.shape[1], dtype=bool)
mask[keep_idx] = True
data_atac_sub = data_atac[:, mask].copy()
logger.info(f"Selected top {n_top_peaks} variable peaks (by std across {cluster_key}).")
return data_atac_sub
else:
logger.info("n_top_peaks >= total peaks; no filtering applied.")
return data_atac
[docs]
def keep_promoters_and_select_hv_peaks(
data_atac: ad.AnnData,
total_n_peaks: int,
cluster_key: str = "leiden",
promoter_col: str = "is_promoter"
) -> ad.AnnData:
"""
Retain all promoter peaks and then select highly variable (HV) peaks among non-promoters.
Steps
-----
1. Identify peaks marked as promoters where `var[promoter_col] == True`.
2. Keep all promoter peaks unconditionally.
3. Among non-promoter peaks, select the top (total_n_peaks - #promoters) peaks by standard deviation across clusters.
4. If the number of promoter peaks alone is >= `total_n_peaks`, keep all promoters.
5. Combine the sets of promoter and HV non-promoter peaks.
Parameters
----------
data_atac : anndata.AnnData
An AnnData containing ATAC data, with a boolean promoter column in `.var`.
total_n_peaks : int
The target total number of peaks to keep. May be exceeded if promoter peaks alone
surpass this number.
cluster_key : str, optional
Column in `data_atac.obs` defining cluster labels for HV peak selection. Default is "leiden".
promoter_col : str, optional
Column in `data_atac.var` indicating which peaks are promoters. Default is "is_promoter".
Returns
-------
anndata.AnnData
A subset of `data_atac` containing all promoter peaks plus HV non-promoter peaks.
Notes
-----
- If `promoter_col` is missing, falls back to standard HV peak selection (without promoter logic).
- The standard deviation for HV peaks is computed by `select_highly_variable_peaks_by_std`.
"""
if promoter_col not in data_atac.var.columns:
logger.warning(f"Column {promoter_col} not found in data_atac.var; no special promoter logic.")
return select_highly_variable_peaks_by_std(data_atac, total_n_peaks, cluster_key)
# (A) Extract promoter vs non-promoter
promoter_mask = data_atac.var[promoter_col].values == True
promoter_peaks = data_atac.var_names[promoter_mask]
n_promoters = len(promoter_peaks)
logger.info(f"Found {n_promoters} promoter peaks. Target total is {total_n_peaks}.")
if n_promoters >= total_n_peaks:
logger.warning(
f"Promoter peaks ({n_promoters}) exceed total_n_peaks={total_n_peaks}. "
"Keeping all promoters, final set might exceed user target."
)
data_atac_sub = data_atac[:, promoter_peaks].copy()
return data_atac_sub
else:
# (B) Keep all promoters, then select HV among non-promoters
n_needed = total_n_peaks - n_promoters
logger.info(f"Selecting HV among non-promoters => picking {n_needed} peaks.")
# Subset to non-promoters
non_promoter_mask = ~promoter_mask
data_atac_nonprom = data_atac[:, non_promoter_mask].copy()
# HV selection among non-promoters
data_atac_nonprom_hv = select_highly_variable_peaks_by_std(data_atac_nonprom, n_needed, cluster_key)
# Final union => promoter + HV(non-promoters)
final_promoter_set = set(promoter_peaks)
final_nonprom_set = set(data_atac_nonprom_hv.var_names)
final_set = list(final_promoter_set.union(final_nonprom_set))
data_atac_sub = data_atac[:, final_set].copy()
logger.info(
f"Final set => {len(promoter_peaks)} promoter + "
f"{data_atac_nonprom_hv.shape[1]} HV => total {data_atac_sub.shape[1]} peaks."
)
return data_atac_sub