Source code for scdori.models

import torch
import torch.nn as nn
import torch.nn.functional as F
import logging

logger = logging.getLogger(__name__)

[docs] class scDoRI(nn.Module): """ The scDoRI model integrates single cell multi-ome RNA and ATAC data to learn latent topic representations and perform gene regulatory network (GRN) inference. This model contains: - **Encoders** for RNA and ATAC, producing a shared topic distribution. - **Decoders** for ATAC, TF, and RNA reconstruction. - **GRN logic** for combining TF binding data with gene-peak links and tf expression to reconstruct RNA profiles. Parameters ---------- device : torch.device The device (CPU or CUDA) for PyTorch operations. num_genes : int Number of genes in the RNA data. num_peaks : int Number of peaks in the ATAC data. num_tfs : int Number of transcription factors being modeled. num_topics : int Number of latent topics or factors. num_batches : int Number of distinct batches (for batch correction). dim_encoder1 : int Dimension of the first encoder layer. dim_encoder2 : int Dimension of the second encoder layer. batch_norm : bool, optional If True, use batch normalization in encoder and library factor MLPs. Default is True. Attributes ---------- encoder_rna : torch.nn.Sequential The neural network layers for the RNA encoder. encoder_atac : torch.nn.Sequential The neural network layers for the ATAC encoder. mu_theta : torch.nn.Linear Linear layer converting combined RNA+ATAC encoder outputs into raw topic logits. topic_peak_decoder : torch.nn.Parameter A (num_topics x num_peaks) parameter for ATAC reconstruction. atac_batch_factor : torch.nn.Parameter A (num_batches x num_peaks) parameter for batch effects in ATAC. atac_batch_norm : torch.nn.BatchNorm1d Batch normalization layer for ATAC predictions. topic_tf_decoder : torch.nn.Parameter A (num_topics x num_tfs) parameter for TF expression reconstruction. tf_batch_factor : torch.nn.Parameter A (num_batches x num_tfs) parameter for batch effects in TF reconstruction. tf_batch_norm : torch.nn.BatchNorm1d Batch normalization layer for TF predictions. tf_alpha_nb : torch.nn.Parameter A (1 x num_tfs) parameter for TF negative binomial overdispersion. gene_peak_factor_learnt : torch.nn.Parameter A (num_genes x num_peaks) learned matrix linking peaks to genes. gene_peak_factor_fixed : torch.nn.Parameter A (num_genes x num_peaks) fixed mask for feasible gene-peak links. rna_batch_factor : torch.nn.Parameter A (num_batches x num_genes) parameter for batch effects in RNA reconstruction. rna_batch_norm : torch.nn.BatchNorm1d Batch normalization layer for RNA predictions. rna_alpha_nb : torch.nn.Parameter A (1 x num_genes) parameter for RNA negative binomial overdispersion. tf_library_factor : torch.nn.Sequential An MLP to predict library scaling factor for TF data from the observed TF expression. rna_library_factor : torch.nn.Sequential An MLP to predict library scaling factor for RNA data from the observed gene counts. tf_binding_matrix_activator : torch.nn.Parameter A (num_peaks x num_tfs) matrix of in silico ChIP-seq (activator) TF-peak binding; precomputed and fixed. tf_binding_matrix_repressor : torch.nn.Parameter A (num_peaks x num_tfs) matrix of in silico ChIP-seq (repressor) TF-peak binding; precomputed and fixed. tf_gene_topic_activator_grn : torch.nn.Parameter A (num_topics x num_tfs x num_genes) matrix capturing per-topic activator regulation. tf_gene_topic_repressor_grn : torch.nn.Parameter A (num_topics x num_tfs x num_genes) matrix capturing per-topic repressor regulation. rna_grn_batch_factor : torch.nn.Parameter A (num_batches x num_genes) batch-effect parameter for the GRN-based RNA reconstruction (module 4). rna_grn_batch_norm : torch.nn.BatchNorm1d Batch normalization layer for GRN-based RNA predictions. """ def __init__( self, device, num_genes, num_peaks, num_tfs, num_topics, num_batches, dim_encoder1, dim_encoder2, batch_norm=True ): super(scDoRI, self).__init__() self.device = device self.num_genes = num_genes self.num_peaks = num_peaks self.num_tfs = num_tfs self.num_topics = num_topics self.num_batches = num_batches self.dim_encoder1 = dim_encoder1 self.dim_encoder2 = dim_encoder2 self.batch_norm = batch_norm # ENCODER for RNA self.encoder_rna = nn.Sequential( nn.Linear(num_genes + 2, dim_encoder1), nn.BatchNorm1d(dim_encoder1) if batch_norm else nn.Identity(), nn.ReLU(), nn.Dropout(0.05), nn.Linear(dim_encoder1, dim_encoder2), nn.BatchNorm1d(dim_encoder2) if batch_norm else nn.Identity(), nn.ReLU() ) # ENCODER for ATAC self.encoder_atac = nn.Sequential( nn.Linear(num_peaks + 2, dim_encoder1), nn.BatchNorm1d(dim_encoder1) if batch_norm else nn.Identity(), nn.ReLU(), nn.Dropout(0.05), nn.Linear(dim_encoder1, dim_encoder2), nn.BatchNorm1d(dim_encoder2) if batch_norm else nn.Identity(), nn.ReLU() ) self.mu_theta = nn.Linear(dim_encoder2 * 2, num_topics) # ATAC decoder (module 1) self.topic_peak_decoder = nn.Parameter(torch.rand(num_topics, num_peaks)) self.atac_batch_factor = nn.Parameter(torch.rand(num_batches, num_peaks)) self.atac_batch_norm = nn.BatchNorm1d(num_peaks) # RNA from ATAC (module 2) self.gene_peak_factor_learnt = nn.Parameter(torch.rand(num_genes, num_peaks)) self.gene_peak_factor_fixed = nn.Parameter(torch.ones(num_genes, num_peaks)) self.rna_batch_factor = nn.Parameter(torch.rand(num_batches, num_genes)) self.rna_batch_norm = nn.BatchNorm1d(num_genes) self.rna_alpha_nb = nn.Parameter(torch.rand(1, num_genes)) # TF decoder (module 3) self.topic_tf_decoder = nn.Parameter(torch.rand(num_topics, num_tfs)) self.tf_batch_factor = nn.Parameter(torch.rand(num_batches, num_tfs)) self.tf_batch_norm = nn.BatchNorm1d(num_tfs) self.tf_alpha_nb = nn.Parameter(torch.rand(1, num_tfs)) # MLP for library factor (TF, RNA) self.tf_library_factor = nn.Sequential( nn.Linear(num_tfs, dim_encoder2), nn.BatchNorm1d(dim_encoder2) if batch_norm else nn.Identity(), nn.ReLU(), nn.Dropout(0.05), nn.Linear(dim_encoder2, dim_encoder1), nn.BatchNorm1d(dim_encoder1) if batch_norm else nn.Identity(), nn.ReLU(), nn.Linear(dim_encoder1, 1), nn.Softplus() ) self.rna_library_factor = nn.Sequential( nn.Linear(num_genes, dim_encoder2), nn.BatchNorm1d(dim_encoder2) if batch_norm else nn.Identity(), nn.ReLU(), nn.Dropout(0.05), nn.Linear(dim_encoder2, dim_encoder1), nn.BatchNorm1d(dim_encoder1) if batch_norm else nn.Identity(), nn.ReLU(), nn.Linear(dim_encoder1, 1), nn.Softplus() ) # GRN portion (module 4) self.tf_binding_matrix_activator = nn.Parameter(torch.rand(num_peaks, num_tfs)) self.tf_binding_matrix_repressor = nn.Parameter(torch.rand(num_peaks, num_tfs)) self.tf_gene_topic_activator_grn = nn.Parameter(torch.rand(num_topics, num_tfs, num_genes)) self.tf_gene_topic_repressor_grn = nn.Parameter(torch.rand(num_topics, num_tfs, num_genes)) self.rna_grn_batch_factor = nn.Parameter(torch.rand(num_batches, num_genes)) self.rna_grn_batch_norm = nn.BatchNorm1d(num_genes)
[docs] def encode(self, rna_input, atac_input, log_lib_rna, log_lib_atac, num_cells): """ Encode RNA and ATAC input into a topic distribution (theta). Parameters ---------- rna_input : torch.Tensor A (B, num_genes) tensor of RNA counts per cell. atac_input : torch.Tensor A (B, num_peaks) tensor of ATAC counts per cell. log_lib_rna : torch.Tensor A (B, 1) tensor of log RNA library sizes. log_lib_atac : torch.Tensor A (B, 1) tensor of log ATAC library sizes. num_cells : torch.Tensor A (B, 1) tensor representing how many cells are aggregated (if metacells), or all ones for single-cell data. Returns ------- (theta, mu_theta) : tuple of torch.Tensor theta : (B, num_topics), softmaxed topic distribution. mu_theta : (B, num_topics), raw topic logits. """ B = rna_input.shape[0] # Concat RNA input, log_lib_rna, and num_cells x_rna = torch.cat([rna_input, log_lib_rna.view(B, 1), num_cells.view(B, 1)], dim=1) x_atac = torch.cat([atac_input, log_lib_atac.view(B, 1), num_cells.view(B, 1)], dim=1) qrna = self.encoder_rna(x_rna) qatac = self.encoder_atac(x_atac) combined = torch.cat([qrna, qatac], dim=1) mu_theta = self.mu_theta(combined) theta = F.softmax(mu_theta, dim=-1) return theta, mu_theta
[docs] def forward( self, rna_input, atac_input, tf_input, topic_tf_input, log_lib_rna, log_lib_atac, num_cells, batch_onehot, phase="warmup_1" ): """ Forward pass through scDoRI, producing predictions for ATAC, TF, and RNA reconstructions (Phase 1), as well as GRN-based RNA predictions in GRN phase (Phase 2). Parameters ---------- rna_input : torch.Tensor Shape (B, num_genes). RNA counts per cell in the batch. atac_input : torch.Tensor Shape (B, num_peaks). ATAC counts per cell in the batch. tf_input : torch.Tensor Shape (B, num_tfs). Observed TF expression. topic_tf_input : torch.Tensor Shape (num_topics, num_tfs). TF expression aggregated by topic, used only if phase == "grn". log_lib_rna : torch.Tensor Shape (B, 1). Log of RNA library sizes. log_lib_atac : torch.Tensor Shape (B, 1). Log of ATAC library sizes. num_cells : torch.Tensor Shape (B, 1). Number of cells aggregated (if metacells), else ones. batch_onehot : torch.Tensor Shape (B, num_batches). One-hot batch encoding for each cell. phase : str, optional Which training phase: "warmup_1", "warmup_2", or "grn". If phase=="grn", the GRN-based RNA predictions are included. Returns ------- dict A dictionary with the following keys: - "theta": (B, num_topics), the softmaxed topic distribution. - "mu_theta": (B, num_topics), raw topic logits. - "preds_atac": (B, num_peaks), predicted peak accessibility. - "preds_tf": (B, num_tfs), predicted TF expression. - "mu_nb_tf": (B, num_tfs), TF negative binomial mean = preds_tf * TF library factor. - "preds_rna": (B, num_genes), predicted RNA expression. - "mu_nb_rna": (B, num_genes), RNA negative binomial mean = preds_rna * RNA library factor. - "preds_rna_from_grn": (B, num_genes), optional GRN-based RNA predictions. - "mu_nb_rna_grn": (B, num_genes), negative binomial mean of GRN-based RNA predictions. - "library_factor_tf": (B, 1), predicted library factor for TF. - "library_factor_rna": (B, 1), predicted library factor for RNA. """ B = rna_input.shape[0] # 1) ENCODE => topic distribution theta, mu_theta = self.encode(rna_input, atac_input, log_lib_rna, log_lib_atac, num_cells) # 2) ATAC decoding batch_factor_atac = torch.mm(batch_onehot, self.atac_batch_factor) preds_atac = torch.mm(theta, self.topic_peak_decoder) + batch_factor_atac preds_atac = self.atac_batch_norm(preds_atac) preds_atac = F.softmax(preds_atac, dim=-1) # 3) TF decoding => library factor batch_factor_tf = torch.mm(batch_onehot, self.tf_batch_factor) tf_logits = torch.mm(theta, self.topic_tf_decoder) + batch_factor_tf tf_logits = self.tf_batch_norm(tf_logits) preds_tf = F.softmax(tf_logits, dim=-1) # library MLP for TF library_factor_tf = self.tf_library_factor(tf_input) mu_nb_tf = preds_tf * library_factor_tf # 4) RNA from ATAC => library factor topic_peak_denoised1 = F.softmax(self.topic_peak_decoder, dim=1) topic_peak_min, _ = torch.min(topic_peak_denoised1, dim=0, keepdim=True) topic_peak_max, _ = torch.max(topic_peak_denoised1, dim=0, keepdim=True) topic_peak_denoised = (topic_peak_denoised1 - topic_peak_min) / (topic_peak_max - topic_peak_min + 1e-8) gene_peak = (self.gene_peak_factor_learnt * self.gene_peak_factor_fixed).T batch_factor_rna = torch.mm(batch_onehot, self.rna_batch_factor) topicxgene = torch.mm(topic_peak_denoised, gene_peak) rna_logits = torch.mm(theta, topicxgene) + batch_factor_rna rna_logits = self.rna_batch_norm(rna_logits) preds_rna = F.softmax(rna_logits, dim=-1) topic_peak_denoised1 = nn.Softmax(dim=1)(self.topic_peak_decoder) # library MLP for RNA library_factor_rna = self.rna_library_factor(rna_input) mu_nb_rna = preds_rna * library_factor_rna # 5) GRN => preds_rna_from_grn if phase=="grn" if phase == "grn": grn_atac_activator = torch.empty(size=(self.num_topics, self.num_tfs, self.num_genes)).to(self.device) grn_atac_repressor = torch.empty(size=(self.num_topics, self.num_tfs, self.num_genes)).to(self.device) # Calculate ATAC-based TF–gene links (activator/repressor) for each topic for topic in range(self.num_topics): topic_gene_peak = (topic_peak_denoised1[topic][:, None] * gene_peak) G_topic = self.tf_binding_matrix_activator.T @ topic_gene_peak G_topic = G_topic / (gene_peak.sum(axis=0, keepdims=True) + 1e-7) grn_atac_activator[topic] = G_topic topic_gene_peak = (1 / (topic_peak_denoised1[topic] + 1e-20))[:, None] * gene_peak G_topic = self.tf_binding_matrix_repressor.T @ topic_gene_peak G_topic = G_topic / (gene_peak.sum(axis=0, keepdims=True) + 1e-7) grn_atac_repressor[topic] = G_topic C = torch.empty(size=(self.num_topics, self.num_genes)).to(self.device) tf_expression_input = topic_tf_input.to(self.device) for topic in range(self.num_topics): gene_atac_activator_topic = grn_atac_activator[topic] / (grn_atac_activator[topic].max() + 1e-15) gene_atac_repressor_topic = grn_atac_repressor[topic] / (grn_atac_repressor[topic].min() + 1e-15) G_act = gene_atac_activator_topic * torch.nn.functional.relu(self.tf_gene_topic_activator_grn[topic]) G_rep = gene_atac_repressor_topic * -1 * torch.nn.functional.relu(self.tf_gene_topic_repressor_grn[topic]) C[topic] = tf_expression_input[topic] @ G_act + tf_expression_input[topic] @ G_rep batch_factor_rna_grn = torch.mm(batch_onehot, self.rna_grn_batch_factor) preds_rna_from_grn = torch.mm(theta, C) preds_rna_from_grn = preds_rna_from_grn + batch_factor_rna_grn preds_rna_from_grn = self.rna_grn_batch_norm(preds_rna_from_grn) preds_rna_from_grn = nn.Softmax(dim=1)(preds_rna_from_grn) else: preds_rna_from_grn = torch.zeros_like(preds_rna) mu_nb_rna_grn = preds_rna_from_grn * library_factor_rna return { "theta": theta, "mu_theta": mu_theta, "preds_atac": preds_atac, "preds_tf": preds_tf, "mu_nb_tf": mu_nb_tf, "preds_rna": preds_rna, "mu_nb_rna": mu_nb_rna, "preds_rna_from_grn": preds_rna_from_grn, "mu_nb_rna_grn": mu_nb_rna_grn, "library_factor_tf": library_factor_tf, "library_factor_rna": library_factor_rna }
[docs] def initialize_scdori_parameters( model, gene_peak_distance_exp: torch.Tensor, gene_peak_fixed: torch.Tensor, insilico_act: torch.Tensor, insilico_rep: torch.Tensor, phase="warmup" ): """ Initialize or freeze certain scDoRI parameters, preparing for either warmup or GRN phases. Parameters ---------- model : torch.nn.Module An instance of the scDoRI model. gene_peak_distance_exp : torch.Tensor Shape (num_genes, num_peaks). Peak-gene distance matrix, usually an exponential decay. gene_peak_fixed : torch.Tensor Shape (num_genes, num_peaks). A binary mask indicating allowable gene-peak links. insilico_act : torch.Tensor Shape (num_peaks, num_tfs). In silico ChIP-seq matrix for activators. insilico_rep : torch.Tensor Shape (num_peaks, num_tfs). In silico ChIP-seq matrix for repressors. phase : str, optional "warmup" or "grn". In "warmup", sets gene-peak and TF-binding matrices, and keeps them fixed or partially trainable. In "grn", enables TF-gene parameters to be trainable. Returns ------- None Modifies `model` in place, setting appropriate `.data` values and `.requires_grad` booleans. """ with torch.no_grad(): if phase != "grn": # 1) Set the fixed gene-peak mask. model.gene_peak_factor_fixed.data.copy_(gene_peak_fixed) model.gene_peak_factor_fixed.requires_grad = False # 2) Initialize the learnable gene-peak factor with distance-based weights. model.gene_peak_factor_learnt.data.copy_(gene_peak_distance_exp) model.gene_peak_factor_learnt.requires_grad = True # 3) Initialize TF binding matrices for activator & repressor. model.tf_binding_matrix_activator.data.copy_(insilico_act) model.tf_binding_matrix_activator.requires_grad = False model.tf_binding_matrix_repressor.data.copy_(insilico_rep) model.tf_binding_matrix_repressor.requires_grad = False model.tf_gene_topic_activator_grn.requires_grad = False model.tf_gene_topic_repressor_grn.requires_grad = False elif phase == "grn": model.gene_peak_factor_fixed.requires_grad = False # Enable fine-tuning of TF-gene links per topic model.tf_gene_topic_activator_grn.requires_grad = True model.tf_gene_topic_repressor_grn.requires_grad = True print("scDoRI parameters (peak-gene distance & TF binding) initialized and relevant parameters frozen.")