Source code for scdori.dataloader

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import anndata as ad
import scanpy
import scipy.sparse as sp

[docs] def create_minibatch( device, index_matrix, rna_anndata, atac_anndata, num_cells, tf_indices, encoding_batch_onehot ): """ Create a minibatch of required input tensors using integer indices of cells. Parameters ---------- device : torch.device The device (CPU or CUDA) to which the data should be moved. index_matrix : torch.Tensor A 1D tensor containing integer indices of the cells in the minibatch. rna_anndata : anndata.AnnData AnnData object for RNA data. The .X matrix should contain RNA counts or expression values. atac_anndata : anndata.AnnData AnnData object for ATAC data. The .X matrix should contain accessibility counts. num_cells : np.ndarray A NumPy array (N x 1) indicating the number of cells represented by each row (if using metacells). For single-cell level data, this may be an array of ones. tf_indices : np.ndarray Indices corresponding to transcription factors (TFs) in the RNA AnnData. encoding_batch_onehot : np.ndarray A one-hot encoded matrix representing batch information for each cell (cells x num_batches). Returns ------- tuple A tuple containing: - input_matrix (torch.Tensor): Concatenated RNA and ATAC input of shape (B, g + p), where B is batch size, g is the number of genes, p is the number of peaks. Values are floats on the given device. - tf_exp (torch.Tensor): RNA expression values for TFs, shape (B, num_tfs). - library_size_value (torch.Tensor): Log-scale library sizes for RNA and ATAC, shape (B, 2). - num_cells_value (torch.Tensor): Number of cells per row in the minibatch (B, 1). - input_batch (torch.Tensor): One-hot batch-encoding, shape (B, num_batches). Notes ----- - This function converts sparse arrays to dense if necessary. - ATAC counts are converted from insertion counts to fragment counts by using (x + 1) // 2. """ index_train = index_matrix.clone().detach().cpu().numpy() atac_input = atac_anndata[index_train,:].X rna_input = rna_anndata[index_train,:].X if sp.issparse(atac_input): atac_input = atac_input.toarray() if sp.issparse(rna_input): rna_input = rna_input.toarray() # Convert ATAC insertions to fragment counts atac_input = (np.array(atac_input) + 1) // 2 rna_input = np.array(rna_input) library_size_atac = atac_input.sum(axis=1).reshape(-1, 1) + 1e-8 library_size_rna = rna_input.sum(axis=1).reshape(-1, 1) + 1e-8 library_size = np.concatenate( (np.log(library_size_rna), np.log(library_size_atac)), axis=1 ) input_data = np.concatenate((rna_input, atac_input), axis=1) input_matrix = torch.from_numpy(input_data).to(device, dtype=torch.float) input_batch = torch.from_numpy( encoding_batch_onehot[index_train, :] ).to(device, dtype=torch.float) tf_exp = torch.from_numpy(rna_input[:, tf_indices]).to(device, dtype=torch.float) library_size_value = torch.from_numpy(library_size).to(device, dtype=torch.float) num_cells_value = torch.from_numpy(num_cells[index_train, :]).to(device, dtype=torch.float) return ( input_matrix, tf_exp, library_size_value, num_cells_value, input_batch )