Source code for scdori.evaluation
#################################
# evaluation.py
#################################
import torch
import numpy as np
from tqdm import tqdm
from scdori.dataloader import create_minibatch
[docs]
def get_latent_topics(
model,
device,
data_loader,
rna_anndata,
atac_anndata,
num_cells,
tf_indices,
encoding_batch_onehot
):
"""
Extract the softmaxed topic activations (theta) for each cell in the dataset.
Parameters
----------
model : torch.nn.Module
The scDoRI model containing an encoder for generating topic distributions.
device : torch.device
The PyTorch device (e.g., 'cpu' or 'cuda') used for computations.
data_loader : torch.utils.data.DataLoader
A DataLoader that yields batches of cell indices.
rna_anndata : anndata.AnnData
RNA single-cell data in AnnData format.
atac_anndata : anndata.AnnData
ATAC single-cell data in AnnData format.
num_cells : np.ndarray
Number of cells in each row (e.g., if using metacells). set to ones for single-cell data.
tf_indices : np.ndarray
Indices of transcription factor genes in the RNA data.
encoding_batch_onehot : np.ndarray
One-hot encoding of batch information (cells x num_batches).
Returns
-------
np.ndarray
A 2D NumPy array of shape (n_cells, n_topics) representing the softmaxed
topic activations for each cell in the order given by the DataLoader.
"""
model.eval()
all_thetas = []
with torch.no_grad():
for batch_data in tqdm(data_loader, desc="Extracting latent topics"):
cell_indices = batch_data[0].to(device)
B = cell_indices.shape[0]
(input_matrix, tf_exp, library_size_value, num_cells_value,
input_batch) = create_minibatch(
device,
cell_indices,
rna_anndata,
atac_anndata,
num_cells,
tf_indices,
encoding_batch_onehot
)
# Split into rna_input, atac_input, tf_input
rna_input = input_matrix[:, :model.num_genes]
atac_input = input_matrix[:, model.num_genes:]
tf_input = tf_exp
log_lib_rna = library_size_value[:, 0].reshape(-1, 1)
log_lib_atac = library_size_value[:, 1].reshape(-1, 1)
batch_onehot = input_batch
# We only need an encoder pass => e.g. use phase="warmup_1"
out = model(
rna_input,
atac_input,
tf_input,
tf_input,
log_lib_rna,
log_lib_atac,
num_cells_value,
batch_onehot,
phase="warmup_1"
)
# out["theta"] contains the softmaxed topic distribution
theta = out["theta"].detach().cpu().numpy()
all_thetas.append(theta)
latent_topics = np.concatenate(all_thetas, axis=0)
return latent_topics