Source code for scdori.downstream

###############################################
# downstream.py
###############################################
import torch
import torch.nn as nn
import numpy as np
import umap
import scanpy as sc
import anndata
import logging
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import scipy
from scdori import config
from scdori.utils import set_seed, log_nb_positive

logger = logging.getLogger(__name__)

[docs] def load_best_model(model, best_model_path, device): """ Load the best model weights from disk into the given model. Parameters ---------- model : torch.nn.Module The model instance to which the weights will be loaded. best_model_path : str or Path Path to the file containing the best model weights. device : torch.device The device (CPU or CUDA) where the model will be moved. Returns ------- torch.nn.Module The same model, now loaded with weights and set to eval mode. Raises ------ FileNotFoundError If the specified `best_model_path` does not exist. """ if not os.path.isfile(best_model_path): raise FileNotFoundError(f"Best model file {best_model_path} not found.") model.load_state_dict(torch.load(best_model_path, map_location=device)) model.to(device) model.eval() logger.info(f"Loaded best model weights from {best_model_path}") return model
[docs] def compute_neighbors_umap(rna_anndata, rep_key="X_scdori"): """ Compute neighbors and UMAP on the specified representation in an AnnData object. Parameters ---------- rna_anndata : anndata.AnnData An AnnData object containing single-cell RNA data. rep_key : str, optional The key in `rna_anndata.obsm` that holds the latent representation used for computing UMAP. Default is "X_scdori". Returns ------- None Updates `rna_anndata` in place with neighbor graph and UMAP coordinates. """ logger.info("=== Computing neighbors + UMAP on scDoRI latent ===") sc.pp.neighbors(rna_anndata, use_rep=rep_key, n_neighbors=config.umap_n_neighbors) sc.tl.umap(rna_anndata, min_dist=config.umap_min_dist, spread=1.0, random_state=config.umap_random_state) logger.info("Done. UMAP stored in rna_anndata.obsm['X_umap'].")
[docs] def compute_topic_peak_umap(model, device): """ Compute a UMAP embedding of the topic-peak decoder matrix. Each point on this embedding is a peak. Steps ----- 1. Apply softmax to `model.topic_peak_decoder` => (num_topics, num_peaks). 2. Min-max normalize across topics. 3. Transpose to get (num_peaks, num_topics). 4. Run UMAP on the resulting matrix to get a (num_peaks, 2) embedding. Parameters ---------- model : torch.nn.Module The scDoRI model containing the topic_peak_decoder. device : torch.device The device (CPU or CUDA) used for PyTorch operations. Returns ------- tuple of (np.ndarray, np.ndarray) embedding_peaks : shape (num_peaks, 2) The UMAP embedding of the peaks. peak_mat : shape (num_peaks, num_topics) The min-max normalized topic-peak matrix. """ model.eval() with torch.no_grad(): topic_peaks = model.topic_peak_decoder.detach().to(device) topic_peaks_smx = torch.nn.functional.softmax(topic_peaks, dim=1) # shape => (num_topics, num_peaks) # min-max across topics tmin, _ = torch.min(topic_peaks_smx, dim=0, keepdim=True) tmax, _ = torch.max(topic_peaks_smx, dim=0, keepdim=True) topic_peaks_norm = (topic_peaks_smx - tmin) / (tmax - tmin + 1e-8) peak_mat = topic_peaks_norm.T.cpu().numpy() # shape => (num_peaks, num_topics) reducer = umap.UMAP(n_neighbors=config.umap_n_neighbors, min_dist=config.umap_min_dist, random_state=config.umap_random_state) embedding_peaks = reducer.fit_transform(peak_mat) logger.info(f"Done. umap_embedding_peaks shape => {embedding_peaks.shape} topic_embedding_peaks shape => {peak_mat.shape}") return embedding_peaks, peak_mat
[docs] def compute_topic_gene_matrix(model, device): """ Compute a topic-gene matrix for downstream analysis (e.g., GSEA). Steps ----- 1. Apply softmax to `model.topic_peak_decoder` => (num_topics, num_peaks). 2. Min-max normalize each peak across topics. 3. Multiply by (gene_peak_factor_fixed * gene_peak_factor_learnt). 4. Then apply batch norm and softmax. 4. Get Topic Gene matrix (num_topics, num_genes) Parameters ---------- model : torch.nn.Module The scDoRI model containing topic_peak_decoder and gene_peak_factor. device : torch.device The device (CPU or CUDA) used for PyTorch operations. Returns ------- np.ndarray A matrix of shape (num_topics, num_genes) representing topic-gene scores. """ model.eval() with torch.no_grad(): topic_peaks = model.topic_peak_decoder.detach() topic_peaks_smx = torch.nn.functional.softmax(topic_peaks, dim=1) # shape => (num_topics, num_peaks) gene_peak_factor1 = model.gene_peak_factor_fixed.detach() gene_peak_factor2 = model.gene_peak_factor_learnt.detach() gene_peak_factor = gene_peak_factor1 * gene_peak_factor2 tmin, _ = torch.min(topic_peaks_smx, dim=0, keepdim=True) tmax, _ = torch.max(topic_peaks_smx, dim=0, keepdim=True) topic_peaks_norm = (topic_peaks_smx - tmin) / (tmax - tmin + 1e-8) preds_gene = torch.mm(topic_peaks_norm, gene_peak_factor.T) preds_gene = nn.BatchNorm1d(preds_gene.shape[1])(preds_gene.detach().cpu()) preds_gene = nn.Softmax(dim=1)(preds_gene.detach().cpu()) logger.info(f"Done. computing topic gene matrix shape => {preds_gene.numpy().shape}") return preds_gene.detach().cpu().numpy()
[docs] def compute_atac_grn_activator_with_significance(model, device, cutoff_val, outdir): """ Compute significant ATAC-derived TF–gene links for activators with permutation-based significance. Uses only the learned peak-gene links and in silico ChIP-seq activator matrices. Significance is computed by permuting TF-binding profiles on peaks. Parameters ---------- model : torch.nn.Module The trained model containing peak and TF decoders. device : torch.device The device (CPU or CUDA) for PyTorch operations. cutoff_val : float Significance cutoff (e.g., 0.95) for the percentile filtering. outdir : str Directory to save the computed GRN results. Returns ------- np.ndarray A (num_topics, num_tfs, num_genes) array of significant ATAC-derived activator GRNs. """ os.makedirs(outdir, exist_ok=True) num_topics = model.num_topics logger.info(f"Computing significant ATAC-derived TF–gene links for activators. Output => {outdir}") with torch.no_grad(): if model.device != device: model = model.to(device) topic_peaks = torch.nn.Softmax(dim=1)(model.topic_peak_decoder) effective_gene_peak_factor1 = model.gene_peak_factor_fixed effective_gene_peak_factor2 = model.gene_peak_factor_learnt effective_gene_peak_factor = effective_gene_peak_factor1 * effective_gene_peak_factor2 insilico_chipseq_embeddings = model.tf_binding_matrix_activator grn_atac_significant1 = [] cutoff_val1 = cutoff_val for i in range(num_topics): logger.info(f"Processing Topic {i + 1}/{num_topics}") print(f"Processing Topic {i + 1}/{num_topics}") topic_gene_peak = (topic_peaks[i][:, None].T.clone()) * effective_gene_peak_factor topic_gene_tf = torch.matmul(topic_gene_peak, insilico_chipseq_embeddings) grn_fg = topic_gene_tf / (effective_gene_peak_factor.sum(dim=1, keepdim=True) + 1e-8) grn_fg = grn_fg.T # Compute background distribution by shuffling grn_bg_topic = [] for permutation in tqdm(range(config.num_permutations), desc=f"Permutations for Topic {i + 1}"): insilico_chipseq_random = insilico_chipseq_embeddings[torch.randperm(insilico_chipseq_embeddings.size(0))] topic_gene_tf_bg = torch.matmul(topic_gene_peak, insilico_chipseq_random) grn_bg = topic_gene_tf_bg / (effective_gene_peak_factor.sum(dim=1, keepdim=True) + 1e-8) grn_bg_topic.append(grn_bg.T.detach().cpu()) grn_bg_topic = torch.stack(grn_bg_topic).cpu().numpy() cutoff1 = np.percentile(grn_bg_topic, 100 * (1 - cutoff_val1), axis=0) grn_fg1 = grn_fg.cpu().numpy() significant_grn1 = np.where(grn_fg1 > cutoff1, grn_fg1, 0) significant_grn1 = significant_grn1 / (significant_grn1.max() + 1e-15) grn_atac_significant1.append(significant_grn1) grn_atac_significant1 = np.array(grn_atac_significant1) np.save(os.path.join(outdir, f'grn_atac_activator_{cutoff_val}.npy'), grn_atac_significant1) logger.info("Completed computing activator ATAC GRNs.") return grn_atac_significant1
[docs] def compute_atac_grn_repressor_with_significance(model, device, cutoff_val, outdir): """ Compute significant ATAC-derived TF–gene links for repressors using permutation-based significance. Uses the learned peak-gene links and in silico ChIP-seq repressor matrices. Significance is computed by permuting TF-binding profiles on peaks. Parameters ---------- model : torch.nn.Module The trained model containing peak and TF decoders. device : torch.device The device (CPU or CUDA) for PyTorch operations. cutoff_val : float Significance cutoff (e.g., 0.05) for percentile filtering. outdir : str Directory to save the computed GRN results. Returns ------- np.ndarray A (num_topics, num_tfs, num_genes) array of significant ATAC-derived repressor GRNs. """ os.makedirs(outdir, exist_ok=True) num_topics = model.num_topics logger.info(f"Computing significant ATAC-derived TF–gene links for repressors. Output => {outdir}") with torch.no_grad(): if model.device != device: model = model.to(device) topic_peaks = torch.nn.Softmax(dim=1)(model.topic_peak_decoder) effective_gene_peak_factor1 = model.gene_peak_factor_fixed effective_gene_peak_factor2 = model.gene_peak_factor_learnt effective_gene_peak_factor = effective_gene_peak_factor1 * effective_gene_peak_factor2 insilico_chipseq_embeddings_repressor = model.tf_binding_matrix_repressor grn_atac_significant1 = [] cutoff_val1 = cutoff_val for i in range(num_topics): logger.info(f"Processing Topic {i + 1}/{num_topics}") print(f"Processing Topic {i + 1}/{num_topics}") topic_gene_peak_rep = (1 / (topic_peaks[i].clone() + 1e-20))[:, None].T * effective_gene_peak_factor topic_gene_tf = torch.matmul(topic_gene_peak_rep, insilico_chipseq_embeddings_repressor) grn_fg = topic_gene_tf / (effective_gene_peak_factor.sum(dim=1, keepdim=True) + 1e-8) grn_fg = grn_fg.T grn_bg_topic = [] for permutation in tqdm(range(config.num_permutations), desc=f"Permutations for Topic {i + 1}"): insilico_chipseq_random = insilico_chipseq_embeddings_repressor[torch.randperm(insilico_chipseq_embeddings_repressor.size(0))] topic_gene_tf_bg = torch.matmul(topic_gene_peak_rep, insilico_chipseq_random) grn_bg = topic_gene_tf_bg / (effective_gene_peak_factor.sum(dim=1, keepdim=True) + 1e-8) grn_bg_topic.append(grn_bg.T.detach().cpu()) grn_bg_topic = torch.stack(grn_bg_topic).cpu().numpy() cutoff1 = np.percentile(grn_bg_topic, 100 * (cutoff_val1), axis=0) grn_fg1 = grn_fg.cpu().numpy() significant_grn1 = np.where(grn_fg1 < cutoff1, grn_fg1, 0) significant_grn1 = significant_grn1 / (significant_grn1.min() + 1e-15) grn_atac_significant1.append(significant_grn1) grn_atac_significant1 = np.array(grn_atac_significant1) np.save(os.path.join(outdir, f'grn_atac_repressor_{cutoff_val}.npy'), grn_atac_significant1) logger.info("Completed computing repressor ATAC GRNs.") return grn_atac_significant1
[docs] def compute_significant_grn(model, device, cutoff_val_activator, cutoff_val_repressor, tf_normalised, outdir): """ Combine Significant ATAC-derived and scDoRI-learned GRN links into final activator and repressor GRNs. Parameters ---------- model : torch.nn.Module The scDoRI model containing learned TF-gene topic parameters. device : torch.device CPU or CUDA device for PyTorch operations. cutoff_val_activator : float Significance cutoff used for the activator GRN file. cutoff_val_repressor : float Significance cutoff used for the repressor GRN file. tf_normalised : np.ndarray or torch.Tensor A (num_topics x num_tfs, 1) or (num_topics x num_tfs) matrix of normalized TF usage. outdir : str Directory containing the ATAC-based GRN files and to save computed results. Returns ------- tuple of np.ndarray grn_act : shape (num_topics, num_tfs, num_genes) Computed activator GRN array. grn_rep : shape (num_topics, num_tfs, num_genes) Computed repressor GRN array. Raises ------ FileNotFoundError If the required ATAC-derived GRN files are missing. """ os.makedirs(outdir, exist_ok=True) activator_path = os.path.join(outdir, f'grn_atac_activator_{cutoff_val_activator}.npy') repressor_path = os.path.join(outdir, f'grn_atac_repressor_{cutoff_val_repressor}.npy') num_topics = model.num_topics num_tfs = model.num_tfs # Check if ATAC-derived GRN files exist if not os.path.exists(activator_path): raise FileNotFoundError( f"Activator GRN file not found: {activator_path}\n" "Please compute ATAC-based GRNs using the `compute_atac_grn_activator_with_significance` function first." ) if not os.path.exists(repressor_path): raise FileNotFoundError( f"Repressor GRN file not found: {repressor_path}\n" "Please compute ATAC-based GRNs using the `compute_atac_grn_repressor_with_significance` function first." ) logger.info("Loading ATAC-derived GRNs...") grn_atac_activator = np.load(activator_path) grn_atac_repressor = np.load(repressor_path) logger.info("Computing combined GRNs...") grn_rep = torch.tensor(grn_atac_repressor) * ( -1 * torch.nn.functional.relu(model.tf_gene_topic_repressor_grn).detach().cpu() ) grn_act = torch.tensor(grn_atac_activator) * ( torch.nn.functional.relu(model.tf_gene_topic_activator_grn).detach().cpu() ) grn_tot = grn_rep.clone() + grn_act.clone() grn_tot = grn_tot.numpy() grn_rep = grn_tot.copy() grn_rep[grn_rep > 0] = 0 grn_rep = torch.from_numpy(tf_normalised).reshape((num_topics, num_tfs, 1)) * grn_rep grn_act = grn_tot.copy() grn_act[grn_act < 0] = 0 grn_act = torch.from_numpy(tf_normalised).reshape((num_topics, num_tfs, 1)) * grn_act logger.info("Saving computed GRNs...") np.save(os.path.join(outdir, f'grn_activator__{cutoff_val_activator}.npy'), grn_act.numpy()) np.save(os.path.join(outdir, f'grn_repressor__{cutoff_val_repressor}.npy'), grn_rep.numpy()) logger.info("GRN computation completed successfully.") return grn_act.numpy(), grn_rep.numpy()
[docs] def save_regulons(grn_matrix, tf_names, gene_names, num_topics, output_dir, mode="activator"): """ Save regulons (TF-gene links across topics) for each TF based on a given GRN matrix. Parameters ---------- grn_matrix : np.ndarray A GRN matrix of shape (num_topics, num_tfs, num_genes). tf_names : list of str List of transcription factor names, length = num_tfs. gene_names : list of str List of gene names, length = num_genes. num_topics : int Number of topics in the GRN matrix. output_dir : str Directory where the regulon files will be saved. mode : str, optional "activator" or "repressor", used to name the output subdirectory/files. Returns ------- None Saves individual TSV files for each TF in `output_dir` of shape (num_topics, num_genes), where non-zero values represent a link. """ output_path = os.path.join(output_dir, f"regulons_tf/{mode}") os.makedirs(output_path, exist_ok=True) for i, tf_name in enumerate(tf_names): regulon = pd.DataFrame( grn_matrix[:, i, :], index=[f"Topic_{k}" for k in range(num_topics)], columns=gene_names ) regulon.to_csv(os.path.join(output_path, f"{tf_name}_{mode}.tsv"), sep="\t") print(f"Saved {mode} regulon for TF: {tf_name} in {output_path}")
[docs] def visualize_downstream_targets(rna_anndata, gene_list, score_name="target_score", layer="log"): """ Visualize the average expression of given genes on a UMAP embedding. Uses `scanpy.tl.score_genes` to compute a gene score, then plots using `scanpy.pl.umap`. Parameters ---------- rna_anndata : anndata.AnnData The AnnData object containing RNA data with `.obsm["X_umap"]`. gene_list : list of str A list of gene names to score. score_name : str, optional Name of the resulting gene score in `rna_anndata.obs`. Default is "target_score". layer : str, optional Which layer to use if needed in `score_genes`. Default is "log". Returns ------- None Plots the UMAP colored by the computed gene score. """ sc.tl.score_genes(rna_anndata, gene_list, score_name=score_name) sc.pl.umap(rna_anndata, color=[score_name], layer=layer)
[docs] def plot_topic_activation_heatmap(rna_anndata, groupby_key="celltype", aggregation="median"): """ Compute aggregated scDoRI latent topic activation across groups, then plot a clustermap. Parameters ---------- rna_anndata : anndata.AnnData An AnnData object containing scDoRI latent factors in `obsm["X_scdori"]`. groupby_key : str, optional Column in `rna_anndata.obs` by which to group cells. Default is "celltype". aggregation : str, optional Either "median" or "mean" for aggregating factor values per group. Default is "median". Returns ------- pd.DataFrame The transposed aggregated DataFrame (topics x groups). Notes ----- Uses a Seaborn clustermap to visualize the aggregated data. """ latent = rna_anndata.obsm["X_scdori"] # shape (n_cells, num_topics) df_latent = pd.DataFrame(latent, columns=[f"Topic_{i}" for i in range(latent.shape[1])]) df_latent[groupby_key] = rna_anndata.obs[groupby_key].values if aggregation == "median": df_grouped = df_latent.groupby(groupby_key).median() else: df_grouped = df_latent.groupby(groupby_key).mean() sns.set(font_scale=0.5) g = sns.clustermap(df_grouped.T, cmap="RdBu_r", center=0, figsize=(8, 8)) plt.show() return df_grouped.T
[docs] def get_top_activators_per_topic( grn_final, # => shape (num_topics, num_tfs, num_genes) tf_names, # list of TF names => length = num_tfs latent_all_torch, # shape (num_cells, num_topics) selected_topics=None, top_k=10, clamp_value=1e-8, zscore=True, figsize=(25, 10), out_fig=None ): """ Identify and plot top activator transcription factors per topic (Topic regulators, TRs). Parameters ---------- grn_final : np.ndarray or torch.Tensor An array of shape (num_topics, num_tfs, num_genes), representing an activator GRN. tf_names : list of str List of TF names, length = num_tfs. latent_all_torch : np.ndarray or torch.Tensor scDoRI latent topic activity of shape (num_cells, num_topics). Not always used, but can be referenced. selected_topics : list of int, optional Which topics to analyze. If None, all topics are used. top_k : int, optional Number of top TFs to select per topic. Default is 10. clamp_value : float, optional Small cutoff to avoid division by zero. Default is 1e-8. zscore : bool, optional If True, apply z-score normalization across topics in the final matrix. Default is True. figsize : tuple, optional Size for the Seaborn clustermap. Default is (25, 10). out_fig : str or Path, optional If provided, the figure is saved to this path; otherwise it is shown. Returns ------- tuple df_topic_grn : pd.DataFrame The final DataFrame of shape (#topics, #TF). selected_tf : list of str A sorted list of TFs used in the final clustermap. """ logger.info("=== Plotting top activator regulators per topic ===") num_topics = grn_final.shape[0] num_tfs = grn_final.shape[1] if selected_topics is None: selected_topics = range(num_topics) topic_tf_grn = [] topic_tf_grn_norm = [] grn_final_np = grn_final if isinstance(grn_final, torch.Tensor): grn_final_np = grn_final.detach().cpu().numpy() # sum across genes => shape => (num_tfs,) for i in range(num_topics): grn_topic = grn_final_np[i].sum(axis=1) topic_tf_grn.append(grn_topic) total_activity = grn_topic.sum() + clamp_value topic_tf_grn_norm.append(grn_topic / total_activity) topic_tf_grn_act = np.array(topic_tf_grn) topic_tf_grn_norm_act = np.array(topic_tf_grn_norm) df_topic_grn = pd.DataFrame( topic_tf_grn_norm_act[selected_topics], columns=tf_names, index=[f"Topic_{k}" for k in selected_topics] ) if zscore: df_topic_grn = df_topic_grn.apply( lambda x: (x - x.mean()) / (x.std() + 1e-8), axis=0 ) selected_tf = set() for i, row_name in enumerate(df_topic_grn.index): row = df_topic_grn.loc[row_name].sort_values(ascending=False) top_tfs = row.head(top_k).index.values selected_tf.update(top_tfs) selected_tf = list(selected_tf) selected_tf = sorted(selected_tf) df_plot = df_topic_grn[selected_tf] sns.set_style("darkgrid") plt.rcParams['figure.facecolor'] = "white" sns.set(font_scale=1.2) g = sns.clustermap( df_plot, row_cluster=False, col_cluster=True, linewidths=1.5, dendrogram_ratio=0.1, cmap='Spectral', vmin=-4, vmax=4, figsize=figsize, annot_kws={"size": 20} ) if out_fig: g.figure.savefig(out_fig, dpi=300) plt.show() logger.info("=== Done plotting top regulators per topic ===") return df_topic_grn, selected_tf
[docs] def get_top_repressor_per_topic( grn_final, # => shape (num_topics, num_tfs, num_genes) tf_names, # list of TF names => length = num_tfs latent_all_torch, # shape (num_cells, num_topics) selected_topics=None, top_k=5, clamp_value=1e-8, zscore=True, figsize=(25, 10), out_fig=None ): """ Identify and plot top repressor transcription factors per topic. Parameters ---------- grn_final : np.ndarray or torch.Tensor An array of shape (num_topics, num_tfs, num_genes), representing a repressor GRN. tf_names : list of str List of TF names, length = num_tfs. latent_all_torch : np.ndarray or torch.Tensor scDoRI latent topic activity of shape (num_cells, num_topics). selected_topics : list of int, optional Which topics to analyze. If None, all topics are used. top_k : int, optional Number of top TFs to select per topic. Default is 5. clamp_value : float, optional Small cutoff to avoid division by zero. Default is 1e-8. zscore : bool, optional If True, apply z-score normalization across topics in the final matrix. Default is True. figsize : tuple, optional Size for the Seaborn clustermap. Default is (25, 10). out_fig : str or Path, optional If provided, the figure is saved to this path; otherwise it is shown. Returns ------- tuple df_plot : pd.DataFrame The final DataFrame of shape (#topics, #TF). selected_tf : list of str A sorted list of TFs used in the final clustermap. """ logger.info("=== Plotting top repressor regulators per topic ===") num_topics = grn_final.shape[0] num_tfs = grn_final.shape[1] if selected_topics is None: selected_topics = range(num_topics) topic_tf_grn = [] topic_tf_grn_norm = [] grn_final_np = grn_final if isinstance(grn_final, torch.Tensor): grn_final_np = grn_final.detach().cpu().numpy() for i in range(num_topics): grn_topic = np.abs(grn_final_np[i]).sum(axis=1) topic_tf_grn.append(grn_topic) total_activity = grn_topic.sum() + clamp_value topic_tf_grn_norm.append(grn_topic / total_activity) topic_tf_grn_rep = np.array(topic_tf_grn) topic_tf_grn_norm_rep = np.array(topic_tf_grn_norm) df_topic_grn = pd.DataFrame( topic_tf_grn_norm_rep[selected_topics], columns=tf_names, index=[f"Topic_{k}" for k in selected_topics] ) if zscore: df_topic_grn = df_topic_grn.apply( lambda x: (x - x.mean()) / (x.std() + 1e-8), axis=0 ) selected_tf = set() for i, row_name in enumerate(df_topic_grn.index): row = df_topic_grn.loc[row_name].sort_values(ascending=False) top_tfs = row.head(top_k).index.values selected_tf.update(top_tfs) selected_tf = list(selected_tf) selected_tf = sorted(selected_tf) df_plot = df_topic_grn[selected_tf] sns.set_style("darkgrid") plt.rcParams['figure.facecolor'] = "white" sns.set(font_scale=1.2) g = sns.clustermap( df_plot, row_cluster=False, col_cluster=True, linewidths=1.5, dendrogram_ratio=0.1, cmap='Spectral', vmin=-4, vmax=4, figsize=figsize, annot_kws={"size": 20} ) if out_fig: g.figure.savefig(out_fig, dpi=300) plt.show() logger.info("=== Done plotting top repressor regulators per topic ===") return df_plot, selected_tf
[docs] def compute_activator_tf_activity_per_cell( grn_final, # => shape (num_topics, num_tfs, num_genes) tf_names, # list of TF names => length = num_tfs latent_all_torch, # shape (num_cells, num_topics) selected_topics=None, clamp_value=1e-8, zscore=True ): """ Compute per-cell activity of activator TFs. Parameters ---------- grn_final : np.ndarray or torch.Tensor Activator GRN of shape (num_topics, num_tfs, num_genes). tf_names : list of str List of TF names, length = num_tfs. latent_all_torch : np.ndarray or torch.Tensor scDoRI latent topic activity of shape (num_cells, num_topics). selected_topics : list of int, optional Which topics to analyze. If None, all topics are used. clamp_value : float, optional Small constant to avoid division by zero. Default is 1e-8. zscore : bool, optional If True, apply z-score normalization across cells in the final matrix. Default is True. Returns ------- np.ndarray A (num_cells, num_tfs) array of TF activity values. """ logger.info("=== Computing TF activity per cell ===") num_topics = grn_final.shape[0] num_tfs = grn_final.shape[1] if selected_topics is None: selected_topics = range(num_topics) topic_tf_grn = [] topic_tf_grn_norm = [] grn_final_np = grn_final if isinstance(grn_final, torch.Tensor): grn_final_np = grn_final.detach().cpu().numpy() for i in range(num_topics): grn_topic = grn_final_np[i].sum(axis=1) topic_tf_grn.append(grn_topic) total_activity = grn_topic.sum() + clamp_value topic_tf_grn_norm.append(grn_topic / total_activity) topic_tf_grn_act = np.array(topic_tf_grn) topic_tf_grn_norm_act = np.array(topic_tf_grn_norm) cell_tf_act = np.einsum('ij,jk->ik', latent_all_torch, topic_tf_grn_norm_act) if zscore: cell_tf_act = scipy.stats.zscore(cell_tf_act, axis=0) return cell_tf_act
[docs] def compute_repressor_tf_activity_per_cell( grn_final, # => shape (num_topics, num_tfs, num_genes) tf_names, # list of TF names => length = num_tfs latent_all_torch, # shape (num_cells, num_topics) selected_topics=None, clamp_value=1e-8, zscore=True ): """ Compute per-cell activity of repressor TFs. Parameters ---------- grn_final : np.ndarray or torch.Tensor Repressor GRN of shape (num_topics, num_tfs, num_genes). tf_names : list of str List of TF names, length = num_tfs. latent_all_torch : np.ndarray or torch.Tensor scDoRI latent topic activity of shape (num_cells, num_topics). selected_topics : list of int, optional Which topics to analyze. If None, all topics are used. clamp_value : float, optional Small constant to avoid division by zero. Default is 1e-8. zscore : bool, optional If True, apply z-score normalization across cells in the final matrix. Default is True. Returns ------- np.ndarray A (num_cells, num_tfs) array of TF activity values. """ logger.info("=== Computing TF activity per cell ===") num_topics = grn_final.shape[0] num_tfs = grn_final.shape[1] if selected_topics is None: selected_topics = range(num_topics) topic_tf_grn = [] topic_tf_grn_norm = [] grn_final_np = grn_final if isinstance(grn_final, torch.Tensor): grn_final_np = grn_final.detach().cpu().numpy() for i in range(num_topics): grn_topic = np.abs(grn_final_np[i]).sum(axis=1) topic_tf_grn.append(grn_topic) total_activity = grn_topic.sum() + clamp_value topic_tf_grn_norm.append(grn_topic / total_activity) topic_tf_grn_rep = np.array(topic_tf_grn) topic_tf_grn_norm_rep = np.array(topic_tf_grn_norm) cell_tf_rep = np.einsum('ij,jk->ik', latent_all_torch, topic_tf_grn_norm_rep) if zscore: cell_tf_rep = scipy.stats.zscore(cell_tf_rep, axis=0) return cell_tf_rep