Source code for preprocessing_pipeline.gene_selection
import numpy as np
import scanpy as sc
#import muon as mu
import pandas as pd
import logging
from pathlib import Path
from gtfparse import read_gtf
import anndata as ad
logger = logging.getLogger(__name__)
[docs]
def load_gtf(gtf_path: Path) -> pd.DataFrame:
"""
Load gene coordinates from a GTF file into a pandas DataFrame using gtfparse.
Parameters
----------
gtf_path : pathlib.Path
The path to the GTF file (optionally gzipped).
Returns
-------
pd.DataFrame
A DataFrame containing parsed GTF records. The columns correspond to
GTF fields such as "gene_name", "gene_type", "start", "end", etc.
"""
logger.info(f"Loading GTF from {gtf_path}")
df = read_gtf(gtf_path)
gene_coordinates = pd.DataFrame(df)
gene_coordinates.columns = df.columns
return gene_coordinates
[docs]
def filter_protein_coding_genes(
data_rna: ad.AnnData,
gtf_df: pd.DataFrame
) -> ad.AnnData:
"""
Retain only protein-coding genes in the RNA AnnData object based on GTF annotations.
Parameters
----------
data_rna : anndata.AnnData
RNA single-cell data.
gtf_df : pd.DataFrame
A GTF DataFrame (from `load_gtf`) containing columns like "gene_type", "gene_name".
Returns
-------
anndata.AnnData
The AnnData subset containing only protein-coding genes found in both
the original data and the GTF's "gene_type == 'protein_coding'".
"""
df_protein_coding = gtf_df[gtf_df["gene_type"] == "protein_coding"]
pc_genes = set(df_protein_coding["gene_name"].unique())
rna_genes = set(data_rna.var_names)
keep_genes = sorted(list(pc_genes & rna_genes))
data_rna_sub = data_rna[:, keep_genes].copy()
logger.info(f"Filtered to protein-coding genes: {data_rna_sub.shape[1]} genes left.")
return data_rna_sub
[docs]
def compute_hvgs_and_tfs(
data_rna: ad.AnnData,
tf_names: list[str],
user_genes: list[str] = None,
user_tfs: list[str] = None,
num_genes: int = 3000,
num_tfs: int = 300,
min_cells: int = 20
) -> tuple[ad.AnnData, list[str], list[str]]:
"""
Compute sets of Highly Variable Genes (HVGs) and TFs (transcription factors)
for scDoRI training.
This function:
1. Identifies user-specified genes and TFs present in `data_rna`.
2. Selects additional TFs by computing HVGs among potential TFs up to `num_tfs`.
3. Selects non-TF HVGs up to `num_genes` (minus any user-specified genes and TFs).
4. Combines these sets into a final AnnData subset and returns them.
Parameters
----------
data_rna : anndata.AnnData
The RNA single-cell data from which to select HVGs and TFs.
tf_names : list of str
A list of all TF names (from a motif database or known TF list).
user_genes : list of str, optional
A list of user-specified genes that must be included in the final set,
default is None.
user_tfs : list of str, optional
A list of user-specified TFs that must be included in the final set,
default is None.
num_genes : int, optional
Total number of HVGs (non-TFs) desired. Default is 3000.
num_tfs : int, optional
Total number of TFs desired. Default is 300.
min_cells : int, optional
Minimum number of cells in which a gene must be detected (e.g., nonzero)
to be considered for HVG selection. Default is 20 (not fully enforced in
this code snippet, but typically used with standard HVG filtering).
Returns
-------
data_rna_processed : anndata.AnnData
A subset of the original data containing the selected HVGs and TFs.
final_genes : list of str
The final list of HVGs (non-TFs).
final_tfs : list of str
The final list of TFs.
Notes
-----
- HVG selection is done by `scanpy.pp.highly_variable_genes`, using normalized/log1p data.
- User-provided genes and TFs are included by default, removing them from the
HVG candidate pool if they were already selected.
- TFs are not re-labeled or otherwise changed beyond this classification.
- The column `data_rna_processed.var["gene_type"]` is set to "HVG" or "TF" for each gene.
"""
if user_genes is None:
user_genes = []
if user_tfs is None:
user_tfs = []
logger.info("Selecting HVGs and TFs...")
# 1) Validate user-specified lists
valid_genes_user = list(set(data_rna.var_names).intersection(user_genes))
valid_tfs_user = list(
set(data_rna.var_names)
.intersection(user_tfs)
.intersection(tf_names)
)
num_tfs_hvg = max(0, num_tfs - len(valid_tfs_user))
num_genes_hvg = max(0, num_genes - len(valid_genes_user) - num_tfs)
# 2) HVGs among TFs
tf_candidates = sorted(list((set(tf_names) - set(valid_tfs_user)) & set(data_rna.var_names)))
data_rna_tf = data_rna[:, tf_candidates].copy()
sc.pp.normalize_total(data_rna_tf)
sc.pp.log1p(data_rna_tf)
sc.pp.highly_variable_genes(data_rna_tf, n_top_genes=num_tfs_hvg, subset=True)
selected_tfs = sorted(list(data_rna_tf.var_names) + valid_tfs_user)
# 3) HVGs among non-TFs
non_tf_candidates = set(data_rna.var_names) - set(selected_tfs) - set(valid_genes_user)
data_rna_non_tf = data_rna[:, sorted(list(non_tf_candidates))].copy()
sc.pp.normalize_total(data_rna_non_tf)
sc.pp.log1p(data_rna_non_tf)
sc.pp.highly_variable_genes(data_rna_non_tf, n_top_genes=num_genes_hvg, subset=True)
selected_non_tfs = sorted(set(data_rna_non_tf.var_names).union(valid_genes_user))
selected_non_tfs = [g for g in selected_non_tfs if g not in selected_tfs]
final_genes = selected_non_tfs
final_tfs = selected_tfs
combined = final_genes + final_tfs
data_rna_processed = data_rna[:, combined].copy()
# Mark gene_type in .var
gene_types = ["HVG"] * len(final_genes) + ["TF"] * len(final_tfs)
data_rna_processed.var["gene_type"] = gene_types
logger.info(f"Selected {len(final_genes)} HVGs + {len(final_tfs)} TFs.")
return data_rna_processed, final_genes, final_tfs