Source code for scdori.main

#################################
# main.py
#################################
import logging
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

from scdori import config
from scdori.data_io import load_scdori_inputs, save_model_weights
from scdori.utils import set_seed
from scdori.models import scDoRI
from scdori.train_scdori import train_scdori_phases
from scdori.train_grn import train_model_grn
from scdori.models import initialize_scdori_parameters
from pathlib import Path

logger = logging.getLogger(__name__)

[docs] def run_scdori_pipeline(): """ Run the scDoRI pipeline in three main phases: 1) ATAC+TF warmup (phase 1 warmup), 2) Add RNA (phase 1 full), 3) GRN training (phase 2). Steps ----- 1. Configure logging, set random seed, determine computing device. 2. Load data: RNA/ATAC AnnData, gene-peak distances, in silico ChIP-seq embeddings. 3. Split cells into train and eval sets, create DataLoaders. 4. Build and initialize the scDoRI model: - The model is configured with the number of genes, peaks, TFs, and topics. - Initialize parameters (gene-peak, in silico matrices, etc.). 5. Train phases 1 & 2 (integrated ATAC + TF, then add RNA). 6. Save model weights. 7. Re-initialize GRN-related parameters and run phase 3 (GRN training). 8. Save final model weights for the GRN phase. Returns ------- None The pipeline executes end-to-end training of the scDoRI model, saving intermediate and final weights to disk as specified in `config`. Notes ----- - This function relies on configuration settings in `config.py`. - The pipeline uses `train_scdori_phases` for phases 1 & 2, and `train_model_grn` for the GRN phase. - Outputs (model weights) are saved to the paths specified by `config.weights_folder_scdori` and `config.weights_folder_grn`. """ logging.basicConfig(level=config.logging_level) logger.info("Starting scDoRI pipeline with integrated GRN.") set_seed(config.random_seed) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # 1) Load data rna_metacell, atac_metacell, gene_peak_dist, insilico_act, insilico_rep = load_scdori_inputs() gene_peak_fixed = gene_peak_dist.copy() gene_peak_fixed[gene_peak_fixed > 0] = 1 # 2) Make small train/test sets n_cells = rna_metacell.n_obs indices = np.arange(n_cells) train_idx, eval_idx = train_test_split(indices, test_size=0.2, random_state=42) train_dataset = TensorDataset(torch.from_numpy(train_idx)) train_loader = DataLoader(train_dataset, batch_size=config.batch_size_cell, shuffle=True) eval_dataset = TensorDataset(torch.from_numpy(eval_idx)) eval_loader = DataLoader(eval_dataset, batch_size=config.batch_size_cell, shuffle=False) # 3) Build integrated model num_genes = rna_metacell.n_vars num_peaks = atac_metacell.n_vars # Suppose we deduce num_tfs from somewhere num_tfs = 100 # example model = scDoRI( device=device, num_genes=num_genes, num_peaks=num_peaks, num_tfs=num_tfs, num_topics=config.num_topics, num_batches=config.num_batches, dim_encoder1=config.dim_encoder1, dim_encoder2=config.dim_encoder2 ).to(device) initialize_scdori_parameters( model, gene_peak_dist, gene_peak_fixed, insilico_act=insilico_act, insilico_rep=insilico_rep, phase="warmup" ) # 4) Train Phase 1 + 2 model = train_scdori_phases(model, device, train_loader, eval_loader) save_model_weights(model, Path(config.weights_folder_scdori), "scdori_final") # 5) Phase 3 => GRN initialize_scdori_parameters( model, gene_peak_dist, gene_peak_fixed, insilico_act=insilico_act, insilico_rep=insilico_rep, phase="grn" ) model = train_model_grn(model, device, train_loader, eval_loader) save_model_weights(model, Path(config.weights_folder_grn), "grn_final") logger.info("All phases complete. scDoRI pipeline done.")