Source code for scdori.utils
#################################
# utils.py
#################################
import torch
import numpy as np
import random
import logging
logger = logging.getLogger(__name__)
[docs]
def set_seed(seed=200):
"""
Set the random seed for Python, NumPy, and PyTorch (including CUDA if available).
Parameters
----------
seed : int, optional
The desired random seed. Default is 200.
Returns
-------
None
Modifies global states of Python, NumPy, and PyTorch seeds in place.
Notes
-----
Useful for ensuring reproducible results across runs when training or testing
the model. However, full reproducibility can still be subject to GPU hardware
determinism settings.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
logger.info(f"Random seed set to {seed}.")
[docs]
def log_nb_positive(
x,
mu,
theta,
eps: float = 1e-8,
log_fn: callable = torch.log,
lgamma_fn: callable = torch.lgamma,
):
"""
Compute the log-likelihood for a Negative Binomial (NB) distribution.
This function is often used for modeling overdispersed count data in
scRNA-seq .
Parameters
----------
x : torch.Tensor
Observed count data, shape (batch_size, num_features).
mu : torch.Tensor
Mean of the negative binomial, must be > 0. Same shape as x.
theta : torch.Tensor
Inverse-dispersion (overdispersion) parameter, must be > 0. Same shape as x.
eps : float, optional
A small constant for numerical stability in logarithms. Default is 1e-8.
log_fn : callable, optional
A function to take the logarithm, typically `torch.log`. Default is `torch.log`.
lgamma_fn : callable, optional
A function for computing log-gamma, typically `torch.lgamma`. Default is `torch.lgamma`.
Returns
-------
torch.Tensor
Element-wise log-likelihood of shape (batch_size, num_features).
"""
log = log_fn
lgamma = lgamma_fn
log_theta_mu_eps = log(theta + mu + eps)
res = (
theta * (log(theta + eps) - log_theta_mu_eps)
+ x * (log(mu + eps) - log_theta_mu_eps)
+ lgamma(x + theta)
- lgamma(theta)
- lgamma(x + 1)
)
return res