Source code for scdori.train_grn

import torch
import logging
import copy
from tqdm import tqdm
import scipy.sparse as sp
import numpy as np
#from scdori import config
from scdori.utils import log_nb_positive
from scdori.dataloader import create_minibatch
from scdori.evaluation import get_latent_topics
from pathlib import Path
from scdori.data_io import save_model_weights

logger = logging.getLogger(__name__)

[docs] def set_encoder_frozen(model, freeze=True): """ Freeze or unfreeze the encoder parameters. Parameters ---------- model : torch.nn.Module scDoRI model containing the encoder modules. freeze : bool, optional If True, freeze the encoder parameters; if False, unfreeze them. Default is True. """ for param in model.encoder_rna.parameters(): param.requires_grad = not freeze for param in model.encoder_atac.parameters(): param.requires_grad = not freeze for param in model.mu_theta.parameters(): param.requires_grad = not freeze logger.info(f"Encoder is now {'frozen' if freeze else 'unfrozen'} in GRN phase.")
[docs] def set_peak_gene_frozen(model, freeze=True): """ Freeze or unfreeze the peak-gene link parameters. Parameters ---------- model : torch.nn.Module scDoRI model containing the peak-gene factor. freeze : bool, optional If True, freeze the peak-gene parameters; if False, unfreeze them. Default is True. """ model.gene_peak_factor_learnt.requires_grad = not freeze logger.info(f"Peak-gene links are now {'frozen' if freeze else 'unfrozen'} in GRN phase.")
[docs] def set_topic_peak_frozen(model, freeze=True): """ Freeze or unfreeze the topic-peak decoder parameters. Parameters ---------- model : torch.nn.Module scDoRI model containing the topic-peak decoder. freeze : bool, optional If True, freeze the topic-peak decoder; if False, unfreeze it. Default is True. """ model.topic_peak_decoder.requires_grad = not freeze logger.info(f"Topic-peak decoder is now {'frozen' if freeze else 'unfrozen'} in GRN phase.")
[docs] def set_topic_tf_frozen(model, freeze=True): """ Freeze or unfreeze the topic-TF decoder parameters. Parameters ---------- model : torch.nn.Module scDoRI model containing the topic-TF decoder. freeze : bool, optional If True, freeze the topic-TF decoder; if False, unfreeze it. Default is True. """ model.topic_tf_decoder.requires_grad = not freeze logger.info(f"Topic-tf decoder is now {'frozen' if freeze else 'unfrozen'} in GRN phase.")
[docs] def get_tf_expression(tf_expression_mode, model, device, train_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file): """ Compute TF expression per topic. If `tf_expression_mode` is "True", this function computes the mean TF expression for the top-k cells in each topic. Otherwise, it uses a normalized topic-TF decoder matrix from the model. Parameters ---------- tf_expression_mode : str Mode for TF expression. "True" calculates per-topic TF expression from top-k cells, "latent" uses the topic-TF decoder matrix. model : torch.nn.Module The scDoRI model containing encoder and decoder modules. device : torch.device The device (CPU or CUDA) used for PyTorch tensors. train_loader : DataLoader DataLoader for training data. rna_anndata : anndata.AnnData RNA single-cell data in AnnData format. atac_anndata : anndata.AnnData ATAC single-cell data in AnnData format. num_cells : np.ndarray number of cells constituting each input metacell, set to 1 for single cell data. tf_indices : list of int Indices of TF features in the RNA data. encoding_batch_onehot : np.ndarray One-hot encoding for batch information. config_file : python file Configuration object with model training. Returns ------- torch.Tensor A (num_topics x num_tfs) tensor of TF expression values for each topic. """ if tf_expression_mode == "True": latent_all_torch = get_latent_topics( model, device, train_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot ) top_k_indices = np.argsort(latent_all_torch, axis=0)[-config_file.cells_per_topic:] rna_tf_vals = rna_anndata.X[:, tf_indices] if sp.issparse(rna_tf_vals): rna_tf_vals = rna_tf_vals.todense() rna_tf_vals = np.array(rna_tf_vals) median_cell = np.median(rna_tf_vals.sum(axis=1)) rna_tf_vals = median_cell * (rna_tf_vals / rna_tf_vals.sum(axis=1, keepdims=True)) topic_tf = [] for t in range(model.num_topics): topic_vals = rna_tf_vals[top_k_indices[:, t], :] topic_vals = topic_vals.mean(axis=0) topic_tf.append(topic_vals) topic_tf = np.array(topic_tf) topic_tf = torch.from_numpy(topic_tf) preds_tf_denoised_min, _ = torch.min(topic_tf, dim=1, keepdim=True) preds_tf_denoised_max, _ = torch.max(topic_tf, dim=1, keepdim=True) topic_tf = ((topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9)) topic_tf[topic_tf < config_file.tf_expression_clamp] = 0 topic_tf = topic_tf.to(device) return topic_tf else: import torch.nn as nn # Ensure this import is available if using nn.Softmax topic_tf = nn.Softmax(dim=1)(model.decoder.topic_tf_decoder.detach().cpu()) preds_tf_denoised_min, _ = torch.min(topic_tf, dim=1, keepdim=True) preds_tf_denoised_max, _ = torch.max(topic_tf, dim=1, keepdim=True) tf_normalised = ((topic_tf - preds_tf_denoised_min) / (preds_tf_denoised_max - preds_tf_denoised_min + 1e-9)) tf_normalised[tf_normalised < config_file.tf_expression_clamp] = 0 topic_tf = tf_normalised.to(device) return topic_tf
[docs] def compute_eval_loss_grn(model, device, train_loader, eval_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file): """ Compute the validation (evaluation) loss for the GRN phase. This function evaluates loss components for ATAC, TF, RNA, and RNA-from-GRN on a validation dataset. Parameters ---------- model : torch.nn.Module The scDoRI model. device : torch.device The device (CPU or CUDA) used for PyTorch tensors. train_loader : DataLoader DataLoader for the training set (used to compute TF expression). eval_loader : DataLoader DataLoader for the validation set. rna_anndata : anndata.AnnData RNA single-cell data in AnnData format. atac_anndata : anndata.AnnData ATAC single-cell data in AnnData format. num_cells : np.ndarray number of cells constituting each input metacell, set to 1 for single cell data tf_indices : list of int Indices of TF features in the RNA data. encoding_batch_onehot : np.ndarray One-hot encoding for batch information. config_file : python file Configuration file for model training. Returns ------- tuple of float A tuple containing: (eval_loss, eval_loss_atac, eval_loss_tf, eval_loss_rna, eval_loss_rna_grn). """ model.eval() running_loss = 0.0 running_loss_atac = 0.0 running_loss_tf = 0.0 running_loss_rna = 0.0 running_loss_rna_grn = 0.0 nbatch = 0 topic_tf_input = get_tf_expression( config_file.tf_expression_mode, model, device, train_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file ) with torch.no_grad(): for batch_data in eval_loader: cell_indices = batch_data[0].to(device) B = cell_indices.shape[0] input_matrix, tf_exp, library_size_value, num_cells_value, input_batch = create_minibatch( device, cell_indices, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot ) rna_input = input_matrix[:, :model.num_genes] atac_input = input_matrix[:, model.num_genes:] log_lib_rna = library_size_value[:, 0].reshape(-1, 1) log_lib_atac = library_size_value[:, 1].reshape(-1, 1) out = model( rna_input, atac_input, tf_exp, topic_tf_input, log_lib_rna, log_lib_atac, num_cells_value, input_batch, phase="grn" ) preds_atac = out["preds_atac"] mu_nb_tf = out["mu_nb_tf"] mu_nb_rna = out["mu_nb_rna"] mu_nb_rna_grn = out["mu_nb_rna_grn"] criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction='sum') library_factor_peak = torch.exp(log_lib_atac.view(B, 1)) preds_poisson = preds_atac * library_factor_peak loss_atac = criterion_poisson(preds_poisson, atac_input) alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1) nb_tf_ll = log_nb_positive(tf_exp, mu_nb_tf, alpha_tf).sum(dim=1).mean() loss_tf = -nb_tf_ll alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1) nb_rna_ll = log_nb_positive(rna_input, mu_nb_rna, alpha_rna).sum(dim=1).mean() loss_rna = -nb_rna_ll nb_rna_grn_ll = log_nb_positive(rna_input, mu_nb_rna_grn, alpha_rna).sum(dim=1).mean() loss_rna_grn = -nb_rna_grn_ll l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1) l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2) l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1) l2_norm_peak = torch.norm(model.topic_peak_decoder.data, p=2) l1_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=1) l2_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=2) l1_norm_grn_activator = torch.norm(model.tf_gene_topic_activator_grn.data, p=1) l1_norm_grn_repressor = torch.norm(model.tf_gene_topic_repressor_grn.data, p=1) loss_norm = ( config_file.l1_penalty_topic_tf * l1_norm_tf + config_file.l2_penalty_topic_tf * l2_norm_tf + config_file.l1_penalty_topic_peak * l1_norm_peak + config_file.l2_penalty_topic_peak * l2_norm_peak + config_file.l1_penalty_gene_peak * l1_norm_gene_peak + config_file.l2_penalty_gene_peak * l2_norm_gene_peak + config_file.l1_penalty_grn_activator * l1_norm_grn_activator + config_file.l1_penalty_grn_repressor * l1_norm_grn_repressor ) total_loss = ( config_file.weight_atac_grn * loss_atac + config_file.weight_tf_grn * loss_tf + config_file.weight_rna_grn * loss_rna + config_file.weight_rna_from_grn * loss_rna_grn + loss_norm ) running_loss += total_loss.item() running_loss_atac += loss_atac.item() running_loss_tf += loss_tf.item() running_loss_rna += loss_rna.item() running_loss_rna_grn += loss_rna_grn.item() nbatch += 1 eval_loss = running_loss / max(1, nbatch) eval_loss_atac = running_loss_atac / max(1, nbatch) eval_loss_tf = running_loss_tf / max(1, nbatch) eval_loss_rna = running_loss_rna / max(1, nbatch) eval_loss_rna_grn = running_loss_rna_grn / max(1, nbatch) return eval_loss, eval_loss_atac, eval_loss_tf, eval_loss_rna, eval_loss_rna_grn
[docs] def train_model_grn(model, device, train_loader, eval_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file): """ Train the model in Phase 2 (GRN phase). In this phase, the model focuses on learning activator and repressor TF-gene links per topic (module 4 of scDoRI). Other modules of the model can be optionally frozen or unfrozen based on the configuration. Parameters ---------- model : torch.nn.Module The scDoRI model to train. device : torch.device The device (CPU or CUDA) used for PyTorch tensors. train_loader : DataLoader DataLoader for the training set. eval_loader : DataLoader DataLoader for the validation set, used to check early stopping criteria. rna_anndata : anndata.AnnData RNA single-cell data in AnnData format. atac_anndata : anndata.AnnData ATAC single-cell data in AnnData format. num_cells : np.ndarray number of cells constituting each input metacell, set to 1 for single cell data tf_indices : list of int Indices of TF features in the RNA data. encoding_batch_onehot : np.ndarray One-hot encoding for batch information. config_file : python file Configuration file for model training. Returns ------- torch.nn.Module The trained model after the GRN phase completes or early stopping occurs. """ if not config_file.update_encoder_in_grn: set_encoder_frozen(model, freeze=True) else: set_encoder_frozen(model, freeze=False) if not config_file.update_peak_gene_in_grn: set_peak_gene_frozen(model, freeze=True) else: set_peak_gene_frozen(model, freeze=False) if not config_file.update_topic_peak_in_grn: set_topic_peak_frozen(model, freeze=True) else: set_topic_peak_frozen(model, freeze=False) if not config_file.update_topic_tf_in_grn: set_topic_tf_frozen(model, freeze=True) else: set_topic_tf_frozen(model, freeze=False) optimizer_grn = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=config_file.learning_rate_grn ) best_eval_loss = float('inf') val_patience = 0 max_val_patience = config_file.grn_val_patience topic_tf_input = None if config_file.tf_expression_mode == "True": topic_tf_input = get_tf_expression( config_file.tf_expression_mode, model, device, train_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file ) logger.info("Starting GRN training") for epoch in range(config_file.max_grn_epochs): model.train() running_loss = 0.0 running_loss_atac = 0.0 running_loss_tf = 0.0 running_loss_rna = 0.0 running_loss_rna_grn = 0.0 nbatch = 0 # If the encoder is being updated, recalc topic_tf_input each epoch: if config_file.update_encoder_in_grn: topic_tf_input = get_tf_expression( config_file.tf_expression_mode, model, device, train_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file ) for batch_data in tqdm(train_loader, desc=f"GRN Epoch {epoch}"): cell_indices = batch_data[0].to(device) B = cell_indices.shape[0] input_matrix, tf_exp, library_size_value, num_cells_value, input_batch = create_minibatch( device, cell_indices, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot ) rna_input = input_matrix[:, :model.num_genes] atac_input = input_matrix[:, model.num_genes:] tf_input = tf_exp log_lib_rna = library_size_value[:, 0].reshape(-1, 1) log_lib_atac = library_size_value[:, 1].reshape(-1, 1) batch_onehot = input_batch if config_file.tf_expression_mode == "latent": topic_tf_input = get_tf_expression( config_file.tf_expression_mode, model, device, train_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file ) out = model( rna_input, atac_input, tf_input, topic_tf_input, log_lib_rna, log_lib_atac, num_cells_value, batch_onehot, phase="grn" ) preds_atac = out["preds_atac"] mu_nb_tf = out["mu_nb_tf"] mu_nb_rna = out["mu_nb_rna"] preds_rna_grn = out["preds_rna_from_grn"] mu_nb_rna_grn = out["mu_nb_rna_grn"] criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction='sum') library_factor_peak = torch.exp(log_lib_atac.view(B, 1)) preds_poisson = preds_atac * library_factor_peak loss_atac = criterion_poisson(preds_poisson, atac_input) alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B, 1) nb_tf_ll = log_nb_positive(tf_input, mu_nb_tf, alpha_tf).sum(dim=1).mean() loss_tf = -nb_tf_ll alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B, 1) nb_rna_ll = log_nb_positive(rna_input, mu_nb_rna, alpha_rna).sum(dim=1).mean() loss_rna = -nb_rna_ll nb_rna_grn_ll = log_nb_positive(rna_input, mu_nb_rna_grn, alpha_rna).sum(dim=1).mean() loss_rna_grn = -nb_rna_grn_ll l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1) l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2) l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1) l2_norm_peak = torch.norm(model.topic_peak_decoder.data, p=2) l1_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=1) l2_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=2) l1_norm_grn_activator = torch.norm(model.tf_gene_topic_activator_grn.data, p=1) l1_norm_grn_repressor = torch.norm(model.tf_gene_topic_repressor_grn.data, p=1) loss_norm = ( config_file.l1_penalty_topic_tf * l1_norm_tf + config_file.l2_penalty_topic_tf * l2_norm_tf + config_file.l1_penalty_topic_peak * l1_norm_peak + config_file.l2_penalty_topic_peak * l2_norm_peak + config_file.l1_penalty_gene_peak * l1_norm_gene_peak + config_file.l2_penalty_gene_peak * l2_norm_gene_peak + config_file.l1_penalty_grn_activator * l1_norm_grn_activator + config_file.l1_penalty_grn_repressor * l1_norm_grn_repressor ) total_loss = ( config_file.weight_atac_grn * loss_atac + config_file.weight_tf_grn * loss_tf + config_file.weight_rna_grn * loss_rna + config_file.weight_rna_from_grn * loss_rna_grn + loss_norm ) optimizer_grn.zero_grad() total_loss.backward() optimizer_grn.step() running_loss += total_loss.item() running_loss_atac += loss_atac.item() running_loss_tf += loss_tf.item() running_loss_rna += loss_rna.item() running_loss_rna_grn += loss_rna_grn.item() nbatch += 1 model.gene_peak_factor_learnt.data.clamp_(min=0) model.gene_peak_factor_learnt.data.clamp_(max=1) epoch_loss = running_loss / max(1, nbatch) epoch_loss_atac = running_loss_atac / max(1, nbatch) epoch_loss_tf = running_loss_tf / max(1, nbatch) epoch_loss_rna = running_loss_rna / max(1, nbatch) epoch_loss_rna_grn = running_loss_rna_grn / max(1, nbatch) logger.info( f"[GRN-Train] Epoch={epoch}, Loss={epoch_loss:.4f}," f"Atac={epoch_loss_atac:.4f}, TF={epoch_loss_tf:.4f}, " f"RNA={epoch_loss_rna:.4f}, RNA-GRN={epoch_loss_rna_grn:.4f}" ) # Evaluate every config.eval_frequency epochs if (epoch + 1) % config_file.eval_frequency == 0: eval_loss, eval_loss_atac, eval_loss_tf, eval_loss_rna, eval_loss_rna_grn = compute_eval_loss_grn( model, device, train_loader, eval_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file ) logger.info( f"[GRN-Eval] Epoch={epoch}, EvalLoss={eval_loss:.4f}," f"EvalAtac={eval_loss_atac:.4f}, EvalTF={eval_loss_tf:.4f}, " f"EvalRNA={eval_loss_rna:.4f}, EvalRNA-GRN={eval_loss_rna_grn:.4f}" ) # Early stopping on eval_loss_rna_grn if eval_loss_rna_grn < best_eval_loss: best_eval_loss = eval_loss_rna_grn val_patience = 0 save_model_weights(model, Path(config_file.weights_folder_grn), "scdori_best_eval") else: val_patience += 1 if val_patience > max_val_patience: logger.info(f"[GRN] Validation not improving => early stop at epoch={epoch}.") break logger.info("Finished Phase 3 (GRN) with validation checks.") return model