Source code for scdori.train_scdori

##################################
# train_scdori.py
##################################
import torch
import logging
from tqdm import tqdm
import copy
from pathlib import Path
# from scdori import config
from scdori.utils import log_nb_positive
from scdori.dataloader import create_minibatch
from scdori.data_io import save_model_weights

logger = logging.getLogger(__name__)

[docs] def get_phase_scdori(epoch, config_file): """ Determine which training phase to use at a given epoch. In warmup_1, only module 1 and 3 (ATAC and TF reconstruction are trained), after which RNA construction from ATAC is added in warmup_2 Parameters ---------- epoch : int The current training epoch. config_file : object Configuration object that includes `epoch_warmup_1` to define the cutoff for switching from phase "warmup_1" to "warmup_2". Returns ------- str The phase: "warmup_1" if `epoch < config_file.epoch_warmup_1`, else "warmup_2". """ if epoch < config_file.epoch_warmup_1: return "warmup_1" else: return "warmup_2"
[docs] def get_loss_weights_scdori(phase, config_file): """ Get the loss weight dictionary for the specified phase. Parameters ---------- phase : str The phase of training, one of {"warmup_1", "warmup_2"}. config_file : object Configuration object containing attributes like `weight_atac_phase1`, `weight_tf_phase1`, `weight_rna_phase1`, etc. Returns ------- dict A dictionary with keys {"atac", "tf", "rna"} indicating the respective loss weights. """ if phase == "warmup_1": return { "atac": config_file.weight_atac_phase1, "tf": config_file.weight_tf_phase1, "rna": config_file.weight_rna_phase1 } else: # warmup_2 return { "atac": config_file.weight_atac_phase2, "tf": config_file.weight_tf_phase2, "rna": config_file.weight_rna_phase2 }
[docs] def compute_eval_loss_scdori( model, device, eval_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file ): """ Compute the validation loss for scDoRI. Parameters ---------- model : torch.nn.Module The scDoRI model to evaluate. device : torch.device The device (CPU or CUDA) used for PyTorch operations. eval_loader : torch.utils.data.DataLoader A DataLoader providing validation cell indices. rna_anndata : anndata.AnnData RNA single-cell data in AnnData format. atac_anndata : anndata.AnnData ATAC single-cell data in AnnData format. num_cells : np.ndarray Number of cells per row (if metacells) or ones for single-cell data. tf_indices : list or np.ndarray Indices of transcription factor genes in the RNA data. encoding_batch_onehot : np.ndarray One-hot encoding for batch information (cell x num_batches). config_file : object Configuration object with hyperparameters (loss weights, penalties, etc.). Returns ------- tuple (eval_loss, eval_loss_atac, eval_loss_tf, eval_loss_rna), each a float. """ model.eval() running_loss = 0.0 running_loss_atac = 0.0 running_loss_tf = 0.0 running_loss_rna = 0.0 nbatch = 0 with torch.no_grad(): for batch_data in eval_loader: cell_indices = batch_data[0].to(device) B = cell_indices.shape[0] (input_matrix, tf_exp, library_size_value, num_cells_value, input_batch) = create_minibatch( device, cell_indices, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot ) rna_input = input_matrix[:, :model.num_genes] atac_input = input_matrix[:, model.num_genes:] tf_input = tf_exp log_lib_rna = library_size_value[:, 0].reshape(-1, 1) log_lib_atac = library_size_value[:, 1].reshape(-1, 1) batch_onehot = input_batch # Evaluate in "warmup_2" style out = model( rna_input, atac_input, tf_input, tf_input, log_lib_rna, log_lib_atac, num_cells_value, batch_onehot, phase="warmup_2" ) preds_atac = out["preds_atac"] mu_nb_tf = out["mu_nb_tf"] mu_nb_rna = out["mu_nb_rna"] # ATAC => Poisson library_factor_peak = torch.exp(log_lib_atac.view(B,1)) preds_poisson = preds_atac * library_factor_peak criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction='sum') loss_atac = criterion_poisson(preds_poisson, atac_input) # TF => NB alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B,1) nb_tf_ll = log_nb_positive(tf_input, mu_nb_tf, alpha_tf).sum(dim=1).mean() loss_tf = -nb_tf_ll # RNA => NB alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B,1) nb_rna_ll = log_nb_positive(rna_input, mu_nb_rna, alpha_rna).sum(dim=1).mean() loss_rna = -nb_rna_ll # Regularization l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1) l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2) l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1) l2_norm_peak = torch.norm(model.topic_peak_decoder.data, p=2) l1_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=1) l2_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=2) loss_norm = ( config_file.l1_penalty_topic_tf * l1_norm_tf + config_file.l2_penalty_topic_tf * l2_norm_tf + config_file.l1_penalty_topic_peak * l1_norm_peak + config_file.l2_penalty_topic_peak * l2_norm_peak + config_file.l1_penalty_gene_peak * l1_norm_gene_peak + config_file.l2_penalty_gene_peak * l2_norm_gene_peak ) total_loss = ( config_file.weight_atac_phase2 * loss_atac + config_file.weight_tf_phase2 * loss_tf + config_file.weight_rna_phase2 * loss_rna + loss_norm ) running_loss += total_loss.item() running_loss_atac += loss_atac.item() running_loss_tf += loss_tf.item() running_loss_rna += loss_rna.item() nbatch += 1 eval_loss = running_loss / max(1, nbatch) eval_loss_atac = running_loss_atac / max(1, nbatch) eval_loss_tf = running_loss_tf / max(1, nbatch) eval_loss_rna = running_loss_rna / max(1, nbatch) return eval_loss, eval_loss_atac, eval_loss_tf, eval_loss_rna
[docs] def train_scdori_phases( model, device, train_loader, eval_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file ): """ Train the scDoRI model in two warmup phases: 1) Warmup Phase 1 (ATAC + TF focus). 2) Warmup Phase 2 (adding RNA). Includes early stopping based on validation performance. Parameters ---------- model : torch.nn.Module The scDoRI model to be trained. device : torch.device The device (CPU or CUDA) for running PyTorch operations. train_loader : torch.utils.data.DataLoader DataLoader for the training set, providing cell indices. eval_loader : torch.utils.data.DataLoader DataLoader for the validation set, providing cell indices. rna_anndata : anndata.AnnData RNA single-cell data in AnnData format. atac_anndata : anndata.AnnData ATAC single-cell data in AnnData format. num_cells : np.ndarray Number of cells per row (metacells) or ones for single-cell data. tf_indices : list or np.ndarray Indices of transcription factor genes in the RNA data. encoding_batch_onehot : np.ndarray One-hot encoding matrix for batch information (cells x num_batches). config_file : object Configuration with hyperparameters including: - learning_rate_scdori - max_scdori_epochs - epoch_warmup_1 - weight_atac_phase1, weight_tf_phase1, weight_rna_phase1 - weight_atac_phase2, weight_tf_phase2, weight_rna_phase2 - l1_penalty_topic_tf, etc. - eval_frequency - phase1_patience (early stopping patience for validation loss) Returns ------- torch.nn.Module The trained scDoRI model after both warmup phases. """ optimizer = torch.optim.Adam(model.parameters(), lr=config_file.learning_rate_scdori) best_eval_loss = float('inf') val_patience = 0 max_val_patience = config_file.phase1_patience logger.info("Starting scDoRI phase 1 training (module 1,2,3) with validation + early stopping.") for epoch in range(config_file.max_scdori_epochs): phase = get_phase_scdori(epoch, config_file) weights = get_loss_weights_scdori(phase, config_file) model.train() running_loss = 0.0 running_loss_atac = 0.0 running_loss_tf = 0.0 running_loss_rna = 0.0 nbatch = 0 for batch_data in tqdm(train_loader, desc=f"Epoch {epoch} [{phase}]"): cell_indices = batch_data[0].to(device) B = cell_indices.shape[0] (input_matrix, tf_exp, library_size_value, num_cells_value, input_batch) = create_minibatch( device, cell_indices, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot ) rna_input = input_matrix[:, :model.num_genes] atac_input = input_matrix[:, model.num_genes:] tf_input = tf_exp log_lib_rna = library_size_value[:, 0].reshape(-1, 1) log_lib_atac = library_size_value[:, 1].reshape(-1, 1) batch_onehot = input_batch # This phase does not use a separate topic_tf_input topic_tf_input = tf_input out = model( rna_input, atac_input, tf_input, topic_tf_input, log_lib_rna, log_lib_atac, num_cells_value, batch_onehot, phase=phase ) preds_atac = out["preds_atac"] mu_nb_tf = out["mu_nb_tf"] mu_nb_rna = out["mu_nb_rna"] # 1) ATAC => Poisson library_factor_peak = torch.exp(log_lib_atac.view(B,1)) preds_poisson = preds_atac * library_factor_peak criterion_poisson = torch.nn.PoissonNLLLoss(log_input=False, reduction='sum') loss_atac = criterion_poisson(preds_poisson, atac_input) # 2) TF => NB alpha_tf = torch.nn.functional.softplus(model.tf_alpha_nb).repeat(B,1) nb_tf_ll = log_nb_positive(tf_input, mu_nb_tf, alpha_tf).sum(dim=1).mean() loss_tf = -nb_tf_ll # 3) RNA => NB alpha_rna = torch.nn.functional.softplus(model.rna_alpha_nb).repeat(B,1) nb_rna_ll = log_nb_positive(rna_input, mu_nb_rna, alpha_rna).sum(dim=1).mean() loss_rna = -nb_rna_ll # Regularization l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1) l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2) l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1) l2_norm_peak = torch.norm(model.topic_peak_decoder.data, p=2) l1_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=1) l2_norm_gene_peak = torch.norm(model.gene_peak_factor_learnt.data, p=2) loss_norm = ( config_file.l1_penalty_topic_tf * l1_norm_tf + config_file.l2_penalty_topic_tf * l2_norm_tf + config_file.l1_penalty_topic_peak * l1_norm_peak + config_file.l2_penalty_topic_peak * l2_norm_peak + config_file.l1_penalty_gene_peak * l1_norm_gene_peak + config_file.l2_penalty_gene_peak * l2_norm_gene_peak ) total_loss = ( weights["atac"] * loss_atac + weights["tf"] * loss_tf + weights["rna"] * loss_rna + loss_norm ) optimizer.zero_grad() total_loss.backward() optimizer.step() running_loss += total_loss.item() running_loss_atac += loss_atac.item() running_loss_tf += loss_tf.item() running_loss_rna += loss_rna.item() nbatch += 1 # Clamp gene-peak factors to [0, 1] model.gene_peak_factor_learnt.data.clamp_(min=0) model.gene_peak_factor_learnt.data.clamp_(max=1) epoch_loss = running_loss / max(1, nbatch) epoch_loss_atac = running_loss_atac / max(1, nbatch) epoch_loss_tf = running_loss_tf / max(1, nbatch) epoch_loss_rna = running_loss_rna / max(1, nbatch) logger.info( f"[Train] Epoch={epoch}, Phase={phase}, Loss={epoch_loss:.4f}, " f"Atac={epoch_loss_atac:.4f}, TF={epoch_loss_tf:.4f}, RNA={epoch_loss_rna:.4f}" ) # Evaluate periodically if (epoch + 1) % config_file.eval_frequency == 0: eval_loss, eval_loss_atac, eval_loss_tf, eval_loss_rna = compute_eval_loss_scdori( model, device, eval_loader, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot, config_file ) logger.info( f"[Eval ] Epoch={epoch}, Phase={phase}, EvalLoss={eval_loss:.4f}, " f"EvalAtac={eval_loss_atac:.4f}, EvalTF={eval_loss_tf:.4f}, EvalRNA={eval_loss_rna:.4f}" ) # Early stopping based on eval_loss if eval_loss < best_eval_loss: best_eval_loss = eval_loss val_patience = 0 save_model_weights(model, Path(config_file.weights_folder_scdori), "scdori_best_eval") else: val_patience += 1 if val_patience > max_val_patience: logger.info(f"Validation loss not improving => early stop at epoch={epoch}") break logger.info("Finished scDoRI phase 1 training (module 1,2,3) with validation checks.") return model