Source code for crispyx.de

"""Differential expression testing utilities."""

from __future__ import annotations

import gc
import dataclasses
import logging
import os
import tempfile
import warnings
import threading
import time
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Dict, Iterable, Literal, Mapping, Tuple

from joblib import Parallel, delayed
try:
    from tqdm.auto import tqdm
    HAS_TQDM = True
except ImportError:
    HAS_TQDM = False

from numpy.typing import ArrayLike

import anndata as ad
import numpy as np
import pandas as pd
import scipy.sparse as sp
import h5py
from scipy.stats import norm, rankdata, t as t_dist

from .data import (
    AnnData,
    calculate_optimal_chunk_size,
    calculate_optimal_gene_chunk_size,
    calculate_wilcoxon_chunk_size,
    calculate_nb_glm_chunk_size,
    drop_file_cache,
    ensure_gene_symbol_column,
    get_matrix_storage_format,
    get_perturbation_slice,
    iter_matrix_chunks,
    needs_sorting_for_nbglm,
    read_backed,
    resolve_control_label,
    resolve_data_path,
    resolve_output_path,
    sort_by_perturbation,
)
from .glm import (
    NBGLMFitter,
    NBGLMBatchFitter,
    ControlStatisticsCache,
    build_design_matrix,
    estimate_covariate_effects_streaming,
    estimate_dispersion_map,
    estimate_global_dispersion_streaming,
    fit_dispersion_trend,
    precompute_control_statistics,
    precompute_control_statistics_streaming,
    precompute_global_dispersion,
    precompute_global_dispersion_from_path,
    shrink_dispersions,
    shrink_lfc_apeglm,
    shrink_lfc_apeglm_from_stats,
    _estimate_apeglm_prior_scale,
)
from ._kernels import (
    _rankdata_2d_numba,
    _tie_correction_numba,
    _compute_rank_sums_batch_numba,
    _wilcoxon_sparse_batch_numba,
    _wilcoxon_all_perts_numba,
    _presort_control_nonzeros,
    _compute_ctrl_tie_sums,
    _wilcoxon_presorted_ctrl_numba,
    _wilcoxon_batch_perts_presorted_numba,
    _ZERO_PARTITION_THRESHOLD,
)
from ._checkpoint import (
    _write_checkpoint_atomic,
    _read_checkpoint,
    _scan_h5ad_completed,
    _get_resumable_candidates,
    _get_checkpoint_interval,
    _create_progress_context,
    _DummyProgress,
)
from ._memory import _should_use_streaming
from ._size_factors import (
    _validate_size_factors,
    _median_of_ratios_size_factors,
    _deseq2_style_size_factors,
    _compute_subset_size_factors,
)
from ._statistics import (
    _tie_correction,
    _adjust_pvalue_matrix,
    _compute_se_batched,
    _compute_mom_dispersion_batched,
    _low_expr_in_both_mask,
)
from ._memory import (
    _estimate_max_workers,
    _get_available_memory_mb,
)

logger = logging.getLogger(__name__)


[docs] @dataclass class DifferentialExpressionResult: genes: pd.Index effect_size: np.ndarray statistic: np.ndarray pvalue: np.ndarray method: str perturbation: str pvalue_adj: np.ndarray | None = None result: AnnData | None = field(default=None, repr=False) @property def result_path(self) -> Path: """Path to the on-disk result file. Returns ------- Path Absolute path to the ``.h5ad`` file backing the result. Raises ------ AttributeError If the result AnnData has not been initialised. """ if self.result is None: raise AttributeError("Result AnnData has not been initialised.") return self.result.path def __getstate__(self) -> dict: """Exclude the AnnData file handle from the pickle payload.""" return { f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.name != "result" } def __setstate__(self, state: dict) -> None: """Restore from pickle; ``result`` is set to ``None`` (no open file handle).""" for key, value in state.items(): object.__setattr__(self, key, value) object.__setattr__(self, "result", None)
[docs] @dataclass class RankGenesGroupsResult(Mapping[str, DifferentialExpressionResult]): genes: pd.Index groups: list[str] statistics: np.ndarray pvalues: np.ndarray pvalues_adj: np.ndarray logfoldchanges: np.ndarray effect_size: np.ndarray u_statistics: np.ndarray pts: np.ndarray pts_rest: np.ndarray order: np.ndarray groupby: str method: str control_label: str tie_correct: bool pvalue_correction: Literal["benjamini-hochberg", "bonferroni"] result: AnnData | None = field(default=None, repr=False) _group_cache: Dict[str, DifferentialExpressionResult] = field( init=False, repr=False, default_factory=dict ) def __post_init__(self) -> None: self.genes = pd.Index(self.genes).astype(str) self.groups = [str(group) for group in self.groups] @property def result_path(self) -> Path: if self.result is None: raise AttributeError("Result AnnData has not been initialised.") return self.result.path def __getstate__(self) -> dict: """Exclude the AnnData file handle and group cache from the pickle payload.""" return { f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.name not in ("result", "_group_cache") } def __setstate__(self, state: dict) -> None: """Restore from pickle; ``result`` is ``None`` and cache is empty.""" for key, value in state.items(): object.__setattr__(self, key, value) object.__setattr__(self, "result", None) object.__setattr__(self, "_group_cache", {}) def _ensure_cache(self) -> None: if self._group_cache: return for idx, group in enumerate(self.groups): self._group_cache[group] = DifferentialExpressionResult( genes=self.genes, effect_size=self.effect_size[idx], statistic=self.statistics[idx], pvalue=self.pvalues[idx], method=self.method, perturbation=group, pvalue_adj=self.pvalues_adj[idx], result=self.result, ) def __getitem__(self, key: str) -> DifferentialExpressionResult: """Return per-gene DE results for a single perturbation group. Parameters ---------- key : str Perturbation group name. Returns ------- DifferentialExpressionResult Results for the requested group. """ self._ensure_cache() return self._group_cache[key] def __iter__(self): # type: ignore[override] """Iterate over perturbation group names.""" return iter(self.groups) def __len__(self) -> int: """Return the number of perturbation groups.""" return len(self.groups)
[docs] def items(self): # type: ignore[override] """Return (group_name, DifferentialExpressionResult) pairs.""" self._ensure_cache() return self._group_cache.items()
[docs] def to_rank_genes_groups_dict(self) -> dict: """Convert results to Scanpy-compatible ``rank_genes_groups`` dict. Returns ------- dict Dictionary with keys ``'params'``, ``'names'``, ``'scores'``, ``'logfoldchanges'``, ``'pvals'``, ``'pvals_adj'``, ``'pts'``, ``'pts_rest'``, ``'auc'``, ``'u_stat'``, and ``'full'``. Each value (except ``'params'`` and ``'full'``) is a NumPy record array sorted by descending absolute score. """ gene_array = self.genes.to_numpy() sorted_names = gene_array[self.order] sorted_scores = np.take_along_axis(self.statistics, self.order, axis=1) sorted_lfc = np.take_along_axis(self.logfoldchanges, self.order, axis=1) sorted_pvals = np.take_along_axis(self.pvalues, self.order, axis=1) sorted_padj = np.take_along_axis(self.pvalues_adj, self.order, axis=1) sorted_pts = np.take_along_axis(self.pts, self.order, axis=1) sorted_pts_rest = np.take_along_axis(self.pts_rest, self.order, axis=1) sorted_effect = np.take_along_axis(self.effect_size, self.order, axis=1) sorted_u = np.take_along_axis(self.u_statistics, self.order, axis=1) def to_recarray(matrix: np.ndarray) -> np.recarray: arrays = [matrix[idx] for idx in range(matrix.shape[0])] return np.rec.fromarrays(arrays, names=self.groups) rank_genes_groups = { "params": { "groupby": self.groupby, "method": self.method, "reference": self.control_label, "tie_correct": self.tie_correct, "corr_method": self.pvalue_correction, }, "names": to_recarray(sorted_names.astype(object)), "scores": to_recarray(sorted_scores), "logfoldchanges": to_recarray(sorted_lfc), "pvals": to_recarray(sorted_pvals), "pvals_adj": to_recarray(sorted_padj), "pts": to_recarray(sorted_pts), "pts_rest": to_recarray(sorted_pts_rest), "auc": to_recarray(sorted_effect), "u_stat": to_recarray(sorted_u), } rank_genes_groups["full"] = self.to_full_order_dict() return rank_genes_groups
[docs] def to_full_order_dict(self) -> dict: """Return unsorted matrices keyed by statistic name. Unlike :meth:`to_rank_genes_groups_dict`, values are in the original gene order (not sorted by score). Returns ------- dict Keys: ``'scores'``, ``'pvals'``, ``'pvals_adj'``, ``'logfoldchanges'``, ``'auc'``, ``'u_stat'``, ``'pts'``, ``'pts_rest'``. Each value is a 2-D ``(n_groups, n_genes)`` NumPy array (copy). """ return { "scores": self.statistics.copy(), "pvals": self.pvalues.copy(), "pvals_adj": self.pvalues_adj.copy(), "logfoldchanges": self.logfoldchanges.copy(), "auc": self.effect_size.copy(), "u_stat": self.u_statistics.copy(), "pts": self.pts.copy(), "pts_rest": self.pts_rest.copy(), }
def _load_existing_nb_glm_result( output_path: Path, candidates: list[str], gene_symbols: list[str], perturbation_column: str, control_label: str, corr_method: str, ) -> "RankGenesGroupsResult": """Load an existing NB-GLM result from an h5ad file. Used when resume=True and all perturbations are already completed. """ adata = ad.read_h5ad(output_path) # NB-GLM stores results in layers, not uns["rank_genes_groups"] statistic_matrix = np.array(adata.layers["z_score"]) pvalue_matrix = np.array(adata.layers["pvalue"]) pvalue_adj_matrix = np.array(adata.layers["pvalue_adj"]) logfc_matrix = np.array(adata.layers["logfoldchanges"]) effect_matrix = logfc_matrix.copy() # effect_size equals logfc for NB-GLM pts_matrix = np.array(adata.layers.get("pts", np.zeros_like(effect_matrix, dtype=np.float32))) pts_rest_matrix = np.array(adata.layers.get("pts_rest", np.zeros_like(effect_matrix, dtype=np.float32))) # Reconstruct order from statistics statistic_for_order = np.where( np.isfinite(statistic_matrix), np.abs(statistic_matrix), -np.inf ) order_matrix = np.argsort(-statistic_for_order, axis=1, kind="mergesort") return RankGenesGroupsResult( genes=pd.Index(gene_symbols), groups=candidates, statistics=statistic_matrix, pvalues=pvalue_matrix, pvalues_adj=pvalue_adj_matrix, logfoldchanges=logfc_matrix, effect_size=effect_matrix, u_statistics=np.zeros_like(effect_matrix), pts=pts_matrix, pts_rest=pts_rest_matrix, order=order_matrix, groupby=perturbation_column, method="nb_glm", control_label=control_label, tie_correct=False, pvalue_correction=corr_method, result=adata, ) def _resolve_candidates( labels: np.ndarray, control_label: str, perturbations: Iterable[str] | None, ) -> list[str]: """Determine which perturbation groups to test. Parameters ---------- labels : ndarray All perturbation labels from the dataset. control_label : str Control group label to exclude. perturbations : iterable of str or None Explicit subset to test. ``None`` means test all non-control groups. Returns ------- list of str Candidate perturbation group names. Raises ------ ValueError If no candidate groups remain after filtering. """ if perturbations is None: unique = pd.Index(labels).unique().tolist() else: unique = [str(p) for p in perturbations] candidates = [label for label in unique if label != control_label] if not candidates: raise ValueError("No perturbation groups available for differential expression testing") return candidates def _release_chunk_memory() -> None: """Force Python and glibc to return freed memory to the OS. After large per-chunk arrays are deleted, glibc may hold freed pages in thread-local arenas. ``gc.collect()`` tears down Python reference cycles, then ``malloc_trim(0)`` (Linux-only) asks glibc to return free heap pages to the OS, shrinking virtual address space. This prevents RLIMIT_AS violations over many gene chunks. """ gc.collect() try: import ctypes libc = ctypes.CDLL(None) libc.malloc_trim(0) except (OSError, AttributeError): pass # Non-Linux or libc without malloc_trim def _write_wilcoxon_result_h5ad( output_path: Path, *, effect_matrix: np.ndarray, z_matrix: np.ndarray, pvalue_matrix: np.ndarray, pvalue_adj_matrix: np.ndarray, lfc_matrix: np.ndarray, u_matrix: np.ndarray, pts_matrix: np.ndarray, pts_rest_matrix: np.ndarray, candidates: list[str], gene_symbols: pd.Index, perturbation_column: str, control_label: str, tie_correct: bool, corr_method: str, ) -> None: """Write wilcoxon result arrays directly to h5ad via h5py. Writes each array (which may be a memmap) as an HDF5 dataset one at a time, avoiding the triple-allocation of memmap → np.array copy → AnnData → ``.write()``. h5py reads memmap pages on demand and writes HDF5 chunks without requiring the full array in RAM simultaneously. """ obs_names = pd.Index(candidates, name="perturbation").astype(str) var_names = gene_symbols.astype(str) # Create a minimal AnnData structure via h5py with h5py.File(output_path, "w") as hf: # X = effect_size matrix hf.create_dataset("X", data=effect_matrix) # obs obs_grp = hf.create_group("obs") obs_grp.attrs["_index"] = "_index" obs_grp.attrs["encoding-type"] = "dataframe" obs_grp.attrs["encoding-version"] = "0.2.0" obs_grp.create_dataset("_index", data=np.array(obs_names, dtype="S")) obs_grp.attrs["column-order"] = [perturbation_column] obs_grp.create_dataset( perturbation_column, data=np.array(obs_names, dtype="S") ) # var var_grp = hf.create_group("var") var_grp.attrs["_index"] = "_index" var_grp.attrs["encoding-type"] = "dataframe" var_grp.attrs["encoding-version"] = "0.2.0" var_grp.create_dataset("_index", data=np.array(var_names, dtype="S")) var_grp.attrs["column-order"] = [] # layers layers_grp = hf.create_group("layers") layers_grp.create_dataset("z_score", data=z_matrix) layers_grp.create_dataset("pvalue", data=pvalue_matrix) layers_grp.create_dataset("pvalue_adj", data=pvalue_adj_matrix) layers_grp.create_dataset("logfoldchanges", data=lfc_matrix) layers_grp.create_dataset("u_statistic", data=u_matrix) layers_grp.create_dataset("pts", data=pts_matrix) layers_grp.create_dataset("pts_rest", data=pts_rest_matrix) # uns metadata uns_grp = hf.create_group("uns") uns_grp.attrs["method"] = "wilcoxon" uns_grp.attrs["control_label"] = control_label uns_grp.attrs["perturbation_column"] = perturbation_column uns_grp.attrs["tie_correct"] = tie_correct uns_grp.attrs["pvalue_correction"] = corr_method def _build_result_from_h5ad( output_path: Path, candidates: list[str], gene_symbols: pd.Index, perturbation_column: str, control_label: str, tie_correct: bool, corr_method: str, *, memory_limit_gb: float | None = None, ) -> "RankGenesGroupsResult": """Build a RankGenesGroupsResult by reading back arrays from the h5ad file. For very large result files (> 25% of *physical* memory), returns a result with empty arrays and only the ``result`` AnnData reference, so callers can access data from disk on demand. The lazy threshold is based on *actual physical memory* (``None`` → auto-detect via psutil), not ``memory_limit_gb``, because the latter may be set artificially low to force the streaming dispatch path. """ from ._memory import _resolve_memory_limit_bytes n_groups = len(candidates) n_genes = len(gene_symbols) result_bytes = n_groups * n_genes * (7 * 8 + 2 * 4) # same as memmap estimate # Use actual physical memory for the lazy-load decision, not the # (possibly tiny) memory_limit_gb used for streaming dispatch. physical_budget = _resolve_memory_limit_bytes(None) # If an explicit limit was given and is *larger* than physical, honour it # (e.g., user trusts their system won't OOM at 256 GB). if memory_limit_gb is not None: explicit_budget = memory_limit_gb * 1e9 budget = max(physical_budget, explicit_budget) else: budget = physical_budget skip_loading = result_bytes > budget * 0.25 if skip_loading: logger.info( "Result arrays too large to fit in memory (%.1f GB > %.1f GB budget × 0.25). " "Returning lazy result — access data via result.result (AnnData h5ad).", result_bytes / 1e9, budget / 1e9, ) # Construct a minimal result with zero-size arrays empty_f64 = np.empty((0, 0), dtype=np.float64) empty_f32 = np.empty((0, 0), dtype=np.float32) empty_i64 = np.empty((0, 0), dtype=np.int64) result = RankGenesGroupsResult( genes=gene_symbols, groups=candidates, statistics=empty_f64, pvalues=empty_f64, pvalues_adj=empty_f64, logfoldchanges=empty_f64, effect_size=empty_f64, u_statistics=empty_f64, pts=empty_f32, pts_rest=empty_f32, order=empty_i64, groupby=perturbation_column, method="wilcoxon", control_label=control_label, tie_correct=tie_correct, pvalue_correction=corr_method, ) result.result = AnnData(output_path) return result with h5py.File(output_path, "r") as hf: z_arr = hf["layers/z_score"][:] pval_arr = hf["layers/pvalue"][:] pval_adj_arr = hf["layers/pvalue_adj"][:] lfc_arr = hf["layers/logfoldchanges"][:] effect_arr = hf["X"][:] u_arr = hf["layers/u_statistic"][:] pts_arr = np.array(hf["layers/pts"][:], dtype=np.float32) pts_rest_arr = np.array(hf["layers/pts_rest"][:], dtype=np.float32) order_arr = np.argsort(-np.abs(z_arr), axis=1, kind="mergesort").astype(np.int64) result = RankGenesGroupsResult( genes=gene_symbols, groups=candidates, statistics=z_arr, pvalues=pval_arr, pvalues_adj=pval_adj_arr, logfoldchanges=lfc_arr, effect_size=effect_arr, u_statistics=u_arr, pts=pts_arr, pts_rest=pts_rest_arr, order=order_arr, groupby=perturbation_column, method="wilcoxon", control_label=control_label, tie_correct=tie_correct, pvalue_correction=corr_method, ) result.result = AnnData(output_path) return result def _write_rank_genes_groups_hdf5( output_path: Path, result: "RankGenesGroupsResult", ) -> None: """Write rank_genes_groups to HDF5 for Scanpy compatibility. Writes arrays in full matrix order (groups × genes) to uns/rank_genes_groups/full. This format is compatible with Scanpy's rank_genes_groups output but avoids the recarray format which causes HDF5 header size limits for large group counts. Parameters ---------- output_path Path to the h5ad file to modify. result RankGenesGroupsResult containing the DE statistics. Notes ----- For datasets with many groups (>1000), this adds ~2-6 seconds of I/O overhead. The recarray format (with group names as dtype fields) is avoided because it hits HDF5 header size limits at ~2000+ groups. """ with h5py.File(output_path, "r+") as handle: uns_group = handle.require_group("uns") if "rank_genes_groups" in uns_group: del uns_group["rank_genes_groups"] rgg = uns_group.create_group("rank_genes_groups") # Store full-order matrices (groups × genes) full = rgg.create_group("full") full.create_dataset("scores", data=result.statistics) full.create_dataset("pvals", data=result.pvalues) full.create_dataset("pvals_adj", data=result.pvalues_adj) full.create_dataset("logfoldchanges", data=result.logfoldchanges) full.create_dataset("auc", data=result.effect_size) full.create_dataset("u_stat", data=result.u_statistics) full.create_dataset("pts", data=result.pts) full.create_dataset("pts_rest", data=result.pts_rest) # Store order and metadata rgg.create_dataset("order", data=result.order) rgg.create_dataset("names", data=np.array(result.groups, dtype="S")) # Store params for compatibility params = rgg.create_group("params") params.attrs["groupby"] = result.groupby params.attrs["method"] = result.method params.attrs["reference"] = result.control_label params.attrs["tie_correct"] = result.tie_correct params.attrs["corr_method"] = result.pvalue_correction def _load_completed_de_result( output_path: Path, *, memory_limit_gb: float | None = None, ) -> "RankGenesGroupsResult": """Load a completed DE result directly from an existing h5ad file. Reads all necessary metadata (method, control_label, perturbation_column, etc.) from the result file itself so the original input data is not required. Dispatches to :func:`_build_result_from_h5ad` for wilcoxon results and :func:`_load_existing_nb_glm_result` for nb_glm / t_test results. """ def _decode(v) -> str: if isinstance(v, (bytes, np.bytes_)): return v.decode() if hasattr(v, "item"): v = v.item() if isinstance(v, (bytes, np.bytes_)): return v.decode() return str(v) with h5py.File(output_path, "r") as hf: # obs index → candidate perturbation labels obs_grp = hf["obs"] idx_key = obs_grp.attrs.get("_index", "_index") if isinstance(idx_key, (bytes, np.bytes_)): idx_key = idx_key.decode() candidates: list[str] = [_decode(s) for s in hf[f"obs/{idx_key}"][:]] # var index → gene symbols var_grp = hf["var"] var_idx_key = var_grp.attrs.get("_index", "_index") if isinstance(var_idx_key, (bytes, np.bytes_)): var_idx_key = var_idx_key.decode() gene_symbols = pd.Index([_decode(s) for s in hf[f"var/{var_idx_key}"][:]]) # uns metadata — wilcoxon writes to group attrs; anndata (nb_glm, t_test) # writes scalar datasets inside the uns group. uns = hf["uns"] def _read_uns(key: str, default=None): if key in uns.attrs: return _decode(uns.attrs[key]) if key in uns: return _decode(uns[key][()]) return default method = _read_uns("method", "wilcoxon") control_label = _read_uns("control_label") perturbation_column = _read_uns("perturbation_column") corr_method = _read_uns("pvalue_correction", "benjamini-hochberg") # tie_correct is wilcoxon-specific (stored as a bool) tie_correct: bool = True if "tie_correct" in uns.attrs: tie_correct = bool(uns.attrs["tie_correct"]) elif "tie_correct" in uns: tie_correct = bool(uns["tie_correct"][()]) if method == "wilcoxon": return _build_result_from_h5ad( output_path, candidates=candidates, gene_symbols=gene_symbols, perturbation_column=perturbation_column, control_label=control_label, tie_correct=tie_correct, corr_method=corr_method, memory_limit_gb=memory_limit_gb, ) # nb_glm and t_test share the same layer layout return _load_existing_nb_glm_result( output_path=output_path, candidates=candidates, gene_symbols=gene_symbols.tolist(), perturbation_column=perturbation_column, control_label=control_label, corr_method=corr_method, ) # --------------------------------------------------------------------------- # Shared helpers for public DE functions # --------------------------------------------------------------------------- def _resolve_de_aliases( *, perturbation_column: str | None, groupby: str | None, control_label: str | None, reference: str | None, min_pct_both: float | None, min_pct_ctrl: float, min_pct_pert: float, fn_name: str, ) -> tuple[str, str | None, float, float]: """Resolve groupby/reference aliases and handle min_pct_both. Returns ``(perturbation_column, control_label, min_pct_ctrl, min_pct_pert)``. """ # groupby alias if groupby is not None and perturbation_column is not None: raise TypeError( f"{fn_name}() received both 'perturbation_column' and 'groupby'; " "they are aliases for the same parameter — pass only one." ) if groupby is not None: perturbation_column = groupby if perturbation_column is None: raise TypeError( f"{fn_name}() requires either 'perturbation_column' or its alias 'groupby'." ) # reference alias if reference is not None and control_label is not None: raise TypeError( f"{fn_name}() received both 'control_label' and 'reference'; " "they are aliases for the same parameter — pass only one." ) if reference is not None: control_label = reference # min_pct_both silent alias if min_pct_both is not None: min_pct_ctrl = float(min_pct_both) min_pct_pert = float(min_pct_both) return perturbation_column, control_label, min_pct_ctrl, min_pct_pert def _try_load_existing_de_result( output_path: "Path", *, force: bool, verbose: int | bool, method_name: str, memory_limit_gb: float | None, ) -> "RankGenesGroupsResult | None": """Return the loaded result if it already exists and ``force=False``, else ``None``.""" if not (output_path.exists() and not force): return None logger.info( "Found existing %s result at %s. Loading instead of rerunning.", method_name, output_path, ) if verbose: print(f"[crispyx] Loading existing result: {output_path}") print("[crispyx] Pass force=True to rerun the analysis.") return _load_completed_de_result(output_path, memory_limit_gb=memory_limit_gb) def _print_de_summary( verbose: int | bool, method_name: str, n_completed: int, n_groups: int, n_tested_list: "list[int]", n_genes: int, ) -> None: """Print verbose ≥ 1 summary for a DE run (t-test / NB-GLM).""" if int(verbose) >= 1 and n_tested_list: _mean = int(sum(n_tested_list) / len(n_tested_list)) _pct = 100.0 * _mean / n_genes if n_genes else 0 print( f"[crispyx] {method_name}: {n_completed}/{n_groups} perturbations " f"complete, mean {_mean}/{n_genes} genes tested ({_pct:.0f}%)" ) def _print_de_perturbation_verbose( verbose: int | bool, label: str, n_tested: int, n_genes: int, ) -> None: """Print verbose ≥ 2 per-perturbation gene-count line.""" if int(verbose) >= 2: _pct = 100.0 * n_tested / n_genes if n_genes else 0 print( f"[crispyx] {label}: {n_tested}/{n_genes} genes tested " f"({_pct:.0f}%), {n_genes - n_tested} filtered" )
[docs] def t_test( data: str | Path | AnnData | ad.AnnData, *, perturbation_column: str | None = None, groupby: str | None = None, control_label: str | None = None, reference: str | None = None, gene_name_column: str | None = None, perturbations: Iterable[str] | None = None, min_cells_expressed: int = 0, min_pct_ctrl: float = 0.01, min_pct_pert: float = 0.002, min_pct_both: float | None = None, min_mean_ctrl: float = 0.05, min_mean_pert: float = 0.005, cell_chunk_size: int | None = None, output_dir: str | Path | None = None, data_name: str | None = None, n_jobs: int | None = None, verbose: int | bool = False, resume: bool = False, checkpoint_interval: int | None = None, scanpy_format: bool = False, memory_limit_gb: float | None = None, force: bool = False, ) -> RankGenesGroupsResult: """Perform a t-test comparing log-expression means for each perturbation. Returns a RankGenesGroupsResult containing differential expression statistics. Results are stored in an h5ad file with layers containing the statistics. The RankGenesGroupsResult implements the Mapping interface, so it can be used like a dict: `result[perturbation_label]` returns a DifferentialExpressionResult for that perturbation. Input data **should already be normalized and log-transformed** (for example using `scanpy.pp.normalize_total` followed by `scanpy.pp.log1p`). To maintain backward compatibility with Scanpy-style workflows, count-like inputs are automatically normalized and log-transformed in streaming fashion, with a warning to encourage explicit preprocessing upstream. Parameters ---------- data Path to an h5ad file, or a crispyx/anndata AnnData object containing log-transformed expression data. perturbation_column Column in `adata.obs` indicating perturbation labels. groupby Alias for ``perturbation_column`` (Scanpy-compatible). Mutually exclusive with ``perturbation_column``. control_label Label for the control/reference group. If None, infers from common patterns. reference Alias for ``control_label`` (Scanpy-compatible). Mutually exclusive with ``control_label``. gene_name_column Column in `adata.var` with gene symbols. If None, uses `adata.var_names`. perturbations Specific perturbations to test. If None, tests all non-control groups. min_cells_expressed Minimum total cells (control + perturbation) expressing a gene for testing. min_pct_ctrl Minimum fraction of expressing cells for the *control* side. A gene is excluded only when *both* the control side *and* the perturbed side are jointly low. Default ``0.01``. min_pct_pert Minimum fraction of expressing cells for the *perturbed* side. Default ``0.002`` (lower than ctrl; induction from near-zero baseline is biologically valid). Set to ``0.0`` to disable the pct check on pert. min_pct_both If not ``None``, overrides both ``min_pct_ctrl`` and ``min_pct_pert`` with the same value. min_mean_ctrl Minimum mean expression (log1p units) for the *control* side. Default ``0.05``. Excluded genes are written as NaN in ``score`` / ``pvalue`` / ``logfoldchanges`` / ``effect_size``; ``pts`` and ``pts_rest`` remain populated. min_mean_pert Minimum mean expression for the *perturbed* side. Default ``0.005``. Together with ``min_pct_pert`` this forms a dual condition that is more robust to doublet / ambient-RNA artefacts than pct alone. cell_chunk_size Number of cells to process per chunk (memory vs. speed tradeoff). This controls streaming along the cell axis and is distinct from any future perturbation_chunk_size option that would batch perturbations. Data must already be normalized/log-transformed before chunking. output_dir Directory for output h5ad file. Defaults to input file's directory. data_name Custom name for output file. If None, uses "t_test" suffix. n_jobs Number of parallel workers for computing statistics across perturbations. If None, uses all available cores. If 1, runs sequentially. If -1, uses all available cores. verbose If True, show a progress bar for perturbation processing. Requires tqdm. resume If True, attempt to resume from a previous interrupted run using checkpoint. checkpoint_interval Number of perturbations between checkpoint saves. Auto-determined if None. scanpy_format If True, write Scanpy-compatible ``uns['rank_genes_groups']`` structure in addition to the layer-based storage. Adds ~2-6 seconds of I/O overhead for large datasets. Default False for performance. memory_limit_gb Maximum memory budget in GB. Used for automatic chunk size calculation. When ``None`` (default), detects available system memory via ``psutil``. For HPC environments, set this to your SLURM ``--mem`` value (e.g., ``memory_limit_gb=128``). force If True, rerun the analysis even when the output h5ad file already exists. If False (default), load and return the existing result instead of rerunning. Returns ------- RankGenesGroupsResult Differential expression results. Access results via dict-like interface: `result[label].effect_size`, `result[label].pvalue`, etc. The h5ad file path is available at `result.result_path`. """ perturbation_column, control_label, min_pct_ctrl, min_pct_pert = _resolve_de_aliases( perturbation_column=perturbation_column, groupby=groupby, control_label=control_label, reference=reference, min_pct_both=min_pct_both, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, fn_name="t_test", ) path = resolve_data_path(data) output_path = resolve_output_path(path, suffix="t_test", output_dir=output_dir, data_name=data_name) if (r := _try_load_existing_de_result( output_path, force=force, verbose=verbose, method_name="t_test", memory_limit_gb=memory_limit_gb, )): return r backed = read_backed(path) try: # Calculate adaptive chunk_size if not provided if cell_chunk_size is None: cell_chunk_size = calculate_optimal_chunk_size( backed.n_obs, backed.n_vars, available_memory_gb=memory_limit_gb, ) gene_symbols = ensure_gene_symbol_column(backed, gene_name_column) if perturbation_column not in backed.obs.columns: raise KeyError( f"Perturbation column '{perturbation_column}' was not found in adata.obs. Available columns: {list(backed.obs.columns)}" ) labels = backed.obs[perturbation_column].astype(str).to_numpy() control_label = resolve_control_label(list(labels), control_label) n_genes = backed.n_vars candidates = _resolve_candidates(labels, control_label, perturbations) groups = [control_label] + candidates group_index = {label: idx for idx, label in enumerate(groups)} label_codes = pd.Categorical(labels, categories=groups).codes # Only allow sparse matrices; raise error if dense for _, chunk in iter_matrix_chunks(backed, axis=0, chunk_size=100, convert_to_dense=False): if not sp.issparse(chunk): raise ValueError( "t_test only supports sparse input matrices. Please provide a scipy sparse matrix (e.g., CSR/CSC)." ) # Raise if data looks like raw counts (matches scanpy's check_nonnegative_integers) if np.issubdtype(chunk.dtype, np.integer): raise ValueError( "Detected integer count data in t_test. " "Please log-normalize your data first (e.g. cx.pp.normalize_total_log1p)." ) elif np.issubdtype(chunk.dtype, np.floating): non_zero = chunk.data[chunk.data > 0] is_count_like = non_zero.size > 0 and np.all(np.isclose(non_zero, np.round(non_zero))) if is_count_like: raise ValueError( "Detected count-like (integer-valued) floating point data in t_test. " "Please log-normalize your data first (e.g. cx.pp.normalize_total_log1p)." ) break # Only check the first chunk n_groups_total = len(groups) # Use float64 for accumulation to maintain numerical precision sums = np.zeros((n_groups_total, n_genes), dtype=np.float64) sumsq = np.zeros((n_groups_total, n_genes), dtype=np.float64) counts = np.zeros(n_groups_total, dtype=np.int64) expr_counts = np.zeros((n_groups_total, n_genes), dtype=np.int32) for slc, block in iter_matrix_chunks( backed, axis=0, chunk_size=cell_chunk_size, convert_to_dense=False ): slice_codes = label_codes[slc] csr = sp.csr_matrix(block) for code in np.unique(slice_codes): row_mask = slice_codes == code group_block = csr[row_mask, :] # Expression count: number of nonzero per gene expr_counts[code] += np.asarray(group_block.getnnz(axis=0), dtype=np.int32) # Sum and sumsq using sparse ops sums[code] += group_block.sum(axis=0).A1.astype(np.float64) sumsq[code] += group_block.power(2).sum(axis=0).A1.astype(np.float64) counts[code] += row_mask.sum() finally: backed.file.close() control_idx = group_index[control_label] control_n = counts[control_idx] if control_n == 0: raise ValueError("Control group contains no cells") control_mean = sums[control_idx] / control_n # Precompute control term for LFC calculation once (avoid recomputing per perturbation) control_mean_expm1 = np.expm1(control_mean) + 1e-9 control_var = np.zeros_like(control_mean) if control_n > 1: control_var = (sumsq[control_idx] - (sums[control_idx] ** 2) / control_n) / (control_n - 1) control_var = np.clip(control_var, a_min=0, a_max=None) # Calculate control pts (proportion of cells expressing each gene) control_pts = np.divide( expr_counts[control_idx], control_n, out=np.zeros(n_genes, dtype=float), where=control_n > 0, ) # Determine worker count for parallelization n_groups = len(candidates) max_available_workers = os.cpu_count() or 1 if n_jobs is None or n_jobs == 0: worker_count = min(n_groups, max_available_workers) else: worker_count = min(n_groups, abs(n_jobs)) worker_count = max(worker_count, 1) # Prepare on-disk buffers for results shape = (n_groups, n_genes) output_path.parent.mkdir(parents=True, exist_ok=True) checkpoint_path = output_path.with_suffix(".progress.json") # Handle resume logic if resume: candidates_to_run, completed_labels, failed_labels = _get_resumable_candidates( checkpoint_path, output_path, candidates, retry_failed=True ) else: candidates_to_run = candidates completed_labels = [] failed_labels = [] # Determine checkpoint interval eff_checkpoint_interval = _get_checkpoint_interval(n_groups, checkpoint_interval) candidate_to_idx = {label: idx for idx, label in enumerate(candidates)} obs_index = pd.Index(candidates, name="perturbation").astype(str) obs = pd.DataFrame({perturbation_column: obs_index.to_list()}, index=obs_index) adata = ad.AnnData(np.zeros((len(candidates), 0)), obs=obs, var=pd.DataFrame(index=[])) adata.uns["method"] = "t_test" adata.uns["control_label"] = control_label adata.uns["genes"] = gene_symbols.to_numpy() adata.uns["pvalue_correction"] = "benjamini-hochberg" adata.uns["de_filter"] = { "min_cells_expressed": int(min_cells_expressed), "min_pct_ctrl": float(min_pct_ctrl), "min_pct_pert": float(min_pct_pert), "min_mean_ctrl": float(min_mean_ctrl), "min_mean_pert": float(min_mean_pert), } adata.write(output_path) candidate_indices = {label: i for i, label in enumerate(candidates)} with tempfile.TemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) stat_memmap = np.memmap(tmp_path / "statistics.dat", mode="w+", dtype=np.float64, shape=shape) pval_memmap = np.memmap(tmp_path / "pvalues.dat", mode="w+", dtype=np.float64, shape=shape) lfc_memmap = np.memmap(tmp_path / "logfoldchanges.dat", mode="w+", dtype=np.float32, shape=shape) effect_memmap = np.memmap(tmp_path / "effect_size.dat", mode="w+", dtype=np.float32, shape=shape) pts_memmap = np.memmap(tmp_path / "pts.dat", mode="w+", dtype=np.float32, shape=shape) order_memmap = np.memmap(tmp_path / "order.dat", mode="w+", dtype=np.int64, shape=shape) pts_rest_memmap = np.memmap( tmp_path / "pts_rest.dat", mode="w+", dtype=np.float32, shape=shape ) pts_rest_memmap[:] = control_pts.astype(np.float32) with h5py.File(output_path, "r+") as handle: uns_group = handle.require_group("uns") if "rank_genes_groups" in uns_group: del uns_group["rank_genes_groups"] rgg = uns_group.create_group("rank_genes_groups") full = rgg.create_group("full") ds_scores = full.create_dataset("scores", shape=shape, dtype="float64") ds_pvals = full.create_dataset("pvals", shape=shape, dtype="float64") ds_pvals_adj = full.create_dataset("pvals_adj", shape=shape, dtype="float64") ds_lfc = full.create_dataset("logfoldchanges", shape=shape, dtype="float32") ds_auc = full.create_dataset("auc", shape=shape, dtype="float32") ds_u = full.create_dataset("u_stat", shape=shape, dtype="float32") ds_pts = full.create_dataset("pts", shape=shape, dtype="float32") ds_pts_rest = full.create_dataset("pts_rest", shape=shape, dtype="float32") ds_order = rgg.create_dataset("order", shape=shape, dtype="int64") ds_auc[:] = 0.0 ds_u[:] = 0.0 ds_pts_rest[:] = pts_rest_memmap batch_size = worker_count effect_buffer = np.zeros((batch_size, n_genes), dtype=np.float32) stat_buffer = np.zeros((batch_size, n_genes), dtype=np.float64) pval_buffer = np.ones((batch_size, n_genes), dtype=np.float64) lfc_buffer = np.zeros((batch_size, n_genes), dtype=np.float32) pts_buffer = np.zeros((batch_size, n_genes), dtype=np.float32) order_buffer = np.zeros((batch_size, n_genes), dtype=np.int64) mean_buffer = np.zeros(n_genes, dtype=np.float64) var_buffer = np.zeros(n_genes, dtype=np.float64) se_buffer = np.zeros(n_genes, dtype=np.float64) lfc_work_buffer = np.zeros(n_genes, dtype=np.float64) # Work buffer for in-place LFC n_tested_per_slot = np.zeros(batch_size, dtype=np.int32) def compute_perturbation(label: str, slot: int) -> None: idx = group_index[label] n_cells = counts[idx] if n_cells == 0: raise ValueError(f"Perturbation '{label}' contains no cells") np.divide(sums[idx], n_cells, out=mean_buffer) np.copyto(var_buffer, sumsq[idx]) if n_cells > 1: np.subtract(var_buffer, np.square(sums[idx]) / n_cells, out=var_buffer) np.divide(var_buffer, n_cells - 1, out=var_buffer) else: var_buffer.fill(0) np.clip(var_buffer, a_min=0, a_max=None, out=var_buffer) np.subtract(mean_buffer, control_mean, out=effect_buffer[slot]) effect_f32 = effect_buffer[slot].astype(np.float32, copy=False) # Compute variance terms for SE and Welch-Satterthwaite df var_term_pert = var_buffer / n_cells # var_pert / n_pert var_term_ctrl = control_var / control_n # var_ctrl / n_ctrl # SE = sqrt(var_pert/n_pert + var_ctrl/n_ctrl) np.add(var_term_pert, var_term_ctrl, out=se_buffer) np.sqrt(se_buffer, out=se_buffer) total_expr = expr_counts[idx] + expr_counts[control_idx] valid = (se_buffer > 0) & (total_expr >= min_cells_expressed) # Per-condition low-expression filter: drop genes that are # jointly low in BOTH the perturbation and control groups. low_both = _low_expr_in_both_mask( pert_expr_counts=expr_counts[idx], control_expr_counts=expr_counts[control_idx], pert_mean=mean_buffer, control_mean=control_mean, n_pert_cells=n_cells, n_control_cells=control_n, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, min_mean_ctrl=min_mean_ctrl, min_mean_pert=min_mean_pert, ) valid &= ~low_both n_tested_per_slot[slot] = int(valid.sum()) stat_buffer[slot].fill(np.nan) pval_buffer[slot].fill(np.nan) stat_buffer[slot][valid] = effect_f32[valid] / se_buffer[valid] # Welch-Satterthwaite degrees of freedom for Welch's t-test # df = (var1/n1 + var2/n2)^2 / ((var1/n1)^2/(n1-1) + (var2/n2)^2/(n2-1)) numerator = (var_term_pert + var_term_ctrl) ** 2 denominator = np.zeros_like(numerator) if n_cells > 1: denominator += (var_term_pert ** 2) / (n_cells - 1) if control_n > 1: denominator += (var_term_ctrl ** 2) / (control_n - 1) # Avoid division by zero; set df to a large value when denominator is 0 # Use np.divide with where to prevent NaN from 0/0 df_welch = np.divide( numerator, denominator, out=np.full_like(numerator, 1e6), where=denominator > 0 ) # Clip df to reasonable bounds (minimum 1, no upper limit needed) df_welch = np.clip(df_welch, 1.0, None) # Use t-distribution for p-value calculation (matches scanpy's Welch's t-test) pval_buffer[slot][valid] = 2 * t_dist.sf(np.abs(stat_buffer[slot][valid]), df_welch[valid]) np.divide( expr_counts[idx], n_cells, out=pts_buffer[slot], where=n_cells > 0, casting="unsafe", ) order_buffer[slot] = np.argsort(-np.abs(stat_buffer[slot])) # Scanpy-compatible log2 fold change: log2((expm1(mean_group) + eps) / (expm1(mean_rest) + eps)) # Use in-place operations to minimize temporary allocations np.expm1(mean_buffer, out=lfc_work_buffer) np.add(lfc_work_buffer, 1e-9, out=lfc_work_buffer) np.divide(lfc_work_buffer, control_mean_expm1, out=lfc_work_buffer) np.log2(lfc_work_buffer, out=lfc_work_buffer) lfc_buffer[slot] = lfc_work_buffer.astype(np.float32) # Mask effect / lfc to NaN for genes excluded by per-condition # low-expression filter so downstream tools see them as untested. if low_both.any(): invalid = low_both effect_buffer[slot][invalid] = np.nan lfc_buffer[slot][invalid] = np.nan # Track completed labels newly_completed = list(completed_labels) newly_failed = list(failed_labels) completed_set = set(completed_labels) n_processed = 0 # Helper to save checkpoint def _save_t_test_checkpoint() -> None: checkpoint_data = { "total": n_groups, "completed": newly_completed, "failed": newly_failed, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "method": "t_test", "control_label": control_label, } _write_checkpoint_atomic(checkpoint_path, checkpoint_data) if n_groups > 0: n_tested_list: list[int] = [] with _create_progress_context(len(candidates_to_run), "t-test DE", verbose) as pbar: for batch_start in range(0, n_groups, batch_size): batch_labels = candidates[batch_start : batch_start + batch_size] # Filter to only labels that need processing batch_to_run = [l for l in batch_labels if l not in completed_set] for local_idx, label in enumerate(batch_labels): if label in completed_set: continue try: compute_perturbation(label, local_idx) except Exception as e: logger.error(f"Failed perturbation {label}: {e}") newly_failed.append(label) continue for local_idx, label in enumerate(batch_labels): if label in completed_set or label in newly_failed: continue global_idx = candidate_to_idx[label] effect_memmap[global_idx] = effect_buffer[local_idx] stat_memmap[global_idx] = stat_buffer[local_idx] pval_memmap[global_idx] = pval_buffer[local_idx] lfc_memmap[global_idx] = lfc_buffer[local_idx] pts_memmap[global_idx] = pts_buffer[local_idx] order_memmap[global_idx] = order_buffer[local_idx] ds_scores[global_idx] = stat_buffer[local_idx] ds_pvals[global_idx] = pval_buffer[local_idx] ds_lfc[global_idx] = lfc_buffer[local_idx] ds_pts[global_idx] = pts_buffer[local_idx] ds_order[global_idx] = order_buffer[local_idx] newly_completed.append(label) n_tested_list.append(int(n_tested_per_slot[local_idx])) _print_de_perturbation_verbose(verbose, label, int(n_tested_per_slot[local_idx]), n_genes) n_processed += 1 pbar.update(1) logger.debug(f"Completed perturbation: {label}") # Save checkpoint after each batch if len(batch_to_run) > 0 and n_processed % eff_checkpoint_interval == 0: _save_t_test_checkpoint() # Final checkpoint _save_t_test_checkpoint() logger.info(f"Completed {len(newly_completed)}/{n_groups} perturbations") _print_de_summary(verbose, "t-test DE", len(newly_completed), n_groups, n_tested_list, n_genes) pvalue_adj_memmap = np.memmap( tmp_path / "pvalues_adj.dat", mode="w+", dtype=np.float64, shape=shape ) _adjust_pvalue_matrix(pval_memmap, method="benjamini-hochberg", out=pvalue_adj_memmap) ds_pvals_adj[:] = pvalue_adj_memmap # Convert memmap arrays to regular arrays before tempdir cleanup stat_matrix = np.asarray(stat_memmap) pval_matrix = np.asarray(pval_memmap) pval_adj_matrix = np.asarray(pvalue_adj_memmap) lfc_matrix = np.asarray(lfc_memmap) effect_matrix = np.asarray(effect_memmap) pts_matrix = np.asarray(pts_memmap) pts_rest_matrix = np.asarray(pts_rest_memmap) order_matrix = np.asarray(order_memmap) result = RankGenesGroupsResult( genes=gene_symbols, groups=candidates, statistics=stat_matrix, pvalues=pval_matrix, pvalues_adj=pval_adj_matrix, logfoldchanges=lfc_matrix, effect_size=effect_matrix, u_statistics=np.zeros(shape, dtype=np.float32), pts=pts_matrix, pts_rest=pts_rest_matrix, order=order_matrix, groupby=perturbation_column, method="t_test", control_label=control_label, tie_correct=False, pvalue_correction="benjamini-hochberg", result=None, ) # Create AnnData with layer-based storage (avoid recarray-based rank_genes_groups # which fails with HDF5 header size limits for large group counts) var = pd.DataFrame(index=gene_symbols) adata = ad.AnnData(effect_matrix, obs=obs, var=var) adata.layers["z_score"] = stat_matrix # t-statistic (converges to z for large n) adata.layers["pvalue"] = pval_matrix adata.layers["pvalue_adj"] = pval_adj_matrix adata.layers["logfoldchanges"] = lfc_matrix adata.layers["pts"] = pts_matrix adata.layers["pts_rest"] = pts_rest_matrix adata.uns["method"] = "t_test" adata.uns["control_label"] = control_label adata.uns["perturbation_column"] = perturbation_column adata.uns["pvalue_correction"] = "benjamini-hochberg" adata.write(output_path) # Optionally write Scanpy-compatible rank_genes_groups structure if scanpy_format: _write_rank_genes_groups_hdf5(output_path, result) # Clean up checkpoint on successful completion if checkpoint_path.exists(): try: checkpoint_path.unlink() except Exception: pass result.result = AnnData(output_path) return result
[docs] def nb_glm_test( data: str | Path | AnnData | ad.AnnData, *, # ---- Data parameters ---- perturbation_column: str | None = None, groupby: str | None = None, control_label: str | None = None, reference: str | None = None, covariates: Iterable[str] | None = None, gene_name_column: str | None = None, perturbations: Iterable[str] | None = None, # ---- Size factor parameters ---- size_factors: ArrayLike | None = None, size_factor_method: Literal["sparse", "deseq2"] = "sparse", size_factor_scope: Literal["global", "per_comparison"] = "global", scale_size_factors: bool = True, # ---- Dispersion parameters ---- dispersion: float | None = None, dispersion_method: Literal["moments", "cox-reid"] = "cox-reid", dispersion_scope: Literal["global", "per_comparison"] = "global", share_dispersion: bool = False, use_map_dispersion: bool = True, shrink_dispersion: bool = True, # ---- Optimization parameters ---- optimization_method: Literal["irls", "lbfgsb"] = "lbfgsb", max_iter: int = 25, tol: float = 1e-6, min_mu: float = 0.5, poisson_init_iter: int = 5, chunk_size: int | None = None, irls_batch_size: int | None = 128, # ---- Filtering parameters ---- min_cells_expressed: int = 0, min_pct_ctrl: float = 0.01, min_pct_pert: float = 0.002, min_pct_both: float | None = None, min_mean_ctrl: float = 0.05, min_mean_pert: float = 0.005, min_total_count: float = 1.0, cook_filter: bool = False, # ---- Output parameters ---- lfc_shrinkage_type: Literal["apeglm", "none"] = "none", lfc_base: Literal["log2", "ln"] = "log2", corr_method: Literal["benjamini-hochberg", "bonferroni"] = "benjamini-hochberg", se_method: Literal["sandwich", "fisher"] = "sandwich", output_dir: str | Path | None = None, data_name: str | None = None, scanpy_format: bool = False, verbose: int | bool = False, profiling: bool = False, # ---- Resume/Memory parameters ---- resume: bool = False, checkpoint_interval: int | None = None, memory_limit_gb: float | None = None, max_dense_fraction: float = 0.3, n_jobs: int | None = None, use_control_cache: bool = True, freeze_control: bool | None = None, force: bool = False, ) -> RankGenesGroupsResult: """Perform negative binomial GLM differential expression test. Returns a RankGenesGroupsResult containing differential expression statistics. Uses a negative binomial GLM framework that can incorporate covariates. Results are stored in an h5ad file with layers containing the statistics. The RankGenesGroupsResult implements the Mapping interface, so it can be used like a dict: `result[perturbation_label]` returns a DifferentialExpressionResult for that perturbation. Parameters ---------- data Path to an h5ad file, or a crispyx/anndata AnnData object containing raw count data. perturbation_column Column in `adata.obs` indicating perturbation labels. groupby Alias for ``perturbation_column`` (Scanpy-compatible). Mutually exclusive with ``perturbation_column``. control_label Label for the control/reference group. If None, infers from common patterns. reference Alias for ``control_label`` (Scanpy-compatible). Mutually exclusive with ``control_label``. covariates Additional columns in `adata.obs` to include as covariates in the GLM. gene_name_column Column in `adata.var` with gene symbols. If None, uses `adata.var_names`. perturbations Specific perturbations to test. If None, tests all non-control groups. size_factors Optional array of per-cell size factors. If None, computes size factors using the method specified by ``size_factor_method``. size_factor_method Method for computing size factors when ``size_factors`` is None. - ``"sparse"``: Sparse-aware median-of-ratios (default). Computes geometric means using only non-zero values, suitable for sparse single-cell data. - ``"deseq2"``: Classic DESeq2/PyDESeq2 style. Uses only genes expressed in ALL cells (typically ~50-100 genes). Provides better numerical alignment with PyDESeq2 results but may be less robust for very sparse data. size_factor_scope Scope for size factor computation. - ``"global"`` (default): Compute size factors once on the full dataset. Recommended for CRISPR screens where all cells come from the same experiment and share a common sequencing depth distribution. Faster when combined with use_control_cache=True. Note: produces different results from PyDESeq2 (rho ~ 0.7-0.8) which uses per-comparison normalization. - ``"per_comparison"``: Compute size factors separately for each control + perturbation comparison. This matches PyDESeq2's behavior exactly, leading to near-perfect LFC, statistic, and p-value concordance (rho > 0.97 on Tian-crispra, rho > 0.99 on Adamson_subset). Use this when PyDESeq2 compatibility is required or for bulk RNA-seq style analysis. scale_size_factors If True (default), scale size factors so their geometric mean equals 1. This is the standard DESeq2/crispyx behavior. If False, use raw median-of-ratios size factors without rescaling, which matches PyDESeq2's default behavior and can improve numerical alignment. dispersion Fixed dispersion parameter for negative binomial. If None, estimates per gene. dispersion_method Method for estimating dispersion when ``dispersion`` is None. - ``"moments"``: Method-of-moments (fast but less accurate). - ``"cox-reid"``: Cox-Reid adjusted profile likelihood (slower but more accurate, similar to DESeq2). This is the default. dispersion_scope Scope for dispersion estimation. - ``"global"`` (default): Precompute dispersion once using all cells (control + all perturbations). This is ~10x faster for multi-perturbation datasets since MAP dispersion estimation is done once instead of per-comparison. Recommended when perturbation effects are expected to be small relative to baseline expression (typical for CRISPR screens). - ``"per_comparison"``: Estimate dispersion separately for each control + perturbation comparison. More accurate when perturbations cause large changes in gene expression variance, but significantly slower. share_dispersion If True, estimate dispersion once using all cells, then use the same dispersion values for all Wald tests. If False (default), estimate dispersion separately for each perturbation comparison. use_map_dispersion If True (default), use MAP dispersion estimation with mean-dispersion trend. If False, use MLE dispersion without trend-based shrinkage. shrink_dispersion If True, fit a mean-dispersion trend and shrink gene-wise dispersions toward the trend using an empirical Bayes prior. optimization_method Method for coefficient optimization. - ``"lbfgsb"``: L-BFGS-B optimization (PyDESeq2 style, default). Directly optimizes the negative binomial log-likelihood. - ``"irls"``: Iteratively Reweighted Least Squares (Fisher scoring). The classic GLM fitting approach. max_iter Maximum iterations for GLM fitting. tol Convergence tolerance for GLM fitting. min_mu Minimum mean threshold for IRLS numerical stability. Predicted means are clamped to max(min_mu, predicted) during fitting to prevent numerical instability from very small means. Default: 0.5, matching PyDESeq2's default. Set to 0.0 to disable clamping. poisson_init_iter Initial Poisson iterations before switching to negative binomial. chunk_size Number of genes to process per chunk (memory vs. speed tradeoff). Smaller values stream more, reducing peak memory at the cost of additional I/O. irls_batch_size Maximum number of genes to densify per IRLS step. Keep this small to limit per-iteration memory when working with large sparse matrices. Set to ``None`` to process each chunk without additional batching. min_cells_expressed Minimum total cells (control + perturbation) expressing a gene for testing. min_total_count Minimum total count across all cells for a gene to be tested. min_pct_ctrl Minimum fraction of expressing cells for the *control* side. A gene is excluded only when *both* sides are jointly low. Default ``0.01``. min_pct_pert Minimum fraction of expressing cells for the *perturbed* side. Default ``0.002``. Combined with ``min_mean_pert`` this forms a dual condition that is more robust to doublet / ambient-RNA artefacts. min_pct_both If not ``None``, overrides both ``min_pct_ctrl`` and ``min_pct_pert`` with the same value. min_mean_ctrl Minimum mean (size-factor-normalised) expression for the *control* side. Default ``0.05``. Excluded genes appear as NaN in ``pvalue`` / ``effect`` / ``logfc`` / ``se``; ``pts`` and ``mean`` remain populated. min_mean_pert Minimum mean expression for the *perturbed* side. Default ``0.005``. cook_filter Whether to apply Cook's distance outlier filtering when available. lfc_shrinkage_type Type of log-fold change shrinkage to apply. - ``"none"``: No shrinkage (default). - ``"apeglm"``: Adaptive shrinkage using Cauchy prior (PyDESeq2-compatible). Preserves large effects while shrinking small/uncertain effects toward zero. Also updates standard errors to reflect posterior uncertainty. lfc_base Log base for fold change output. - ``"log2"`` (default): Output log2 fold change, matching PyDESeq2/edgeR. - ``"ln"``: Output natural log fold change (raw GLM coefficients). Standard error is also converted to match the selected log base. Wald statistics remain unchanged since both LFC and SE are scaled equally. corr_method Method for p-value correction: ``"benjamini-hochberg"`` or ``"bonferroni"``. se_method Method for computing standard errors. - ``"sandwich"`` (default): Sandwich estimator SE = sqrt(c' @ H @ M @ H @ c). More robust to model misspecification. - ``"fisher"``: Standard Fisher information SE = sqrt(diag(inv(X'WX + ridge*I))). Matches PyDESeq2's approach for better p-value parity. output_dir Directory for output h5ad file. Defaults to input file's directory. data_name Custom name for output file. If None, uses "nb_glm" suffix. scanpy_format If True, write Scanpy-compatible ``uns['rank_genes_groups']`` structure in addition to the layer-based storage. Adds ~2-6 seconds of I/O overhead for large datasets. Default False for performance. verbose If True, show a progress bar for perturbation fitting and log per-perturbation completion at DEBUG level. Requires tqdm to be installed for progress bar. profiling If True, enable timing and memory profiling. When enabled, stores profiling data in ``adata.uns["profiling"]`` with fields ``fit_seconds``, ``fit_peak_memory_mb``, and ``profiling_enabled``. When False (default), ``adata.uns["profiling"]`` is set to ``"NA"`` to avoid profiling overhead in production. resume If True, attempt to resume from a previous interrupted run. Reads the checkpoint file to determine which perturbations have already been completed and skips them. If the checkpoint file is missing or corrupted, falls back to scanning the output h5ad to detect completed perturbations. checkpoint_interval Number of perturbations to process between checkpoint saves. If None, auto-determined based on dataset size (1 for <100 perturbations, 10 for <1000, 50 for larger). The checkpoint file `<output>_progress.json` is written atomically to prevent corruption. memory_limit_gb Optional memory limit in GB. If provided, this is used together with available system memory to determine when to switch to streaming mode for global dispersion estimation. The effective limit is min(available_memory, memory_limit_gb). Default is None (use system memory only). max_dense_fraction Maximum fraction of available memory to use for dense matrix operations. If the estimated memory for densifying the full cell×gene matrix exceeds max_dense_fraction × min(available_memory, memory_limit_gb), the function switches to streaming mode. Default is 0.3 (30% of available memory). n_jobs Number of parallel workers for fitting GLMs across perturbations. If None, uses all available cores. If 1, runs sequentially. If -1, uses all available cores. use_control_cache If True (default), precompute control cell statistics (intercept, weights, XᵀWX contributions) once and reuse them across all perturbation comparisons. This can significantly reduce memory and computation time when there are many perturbations and the control group is large. Only applies when no covariates are specified and size_factor_scope="global". freeze_control Whether to use frozen control sufficient statistics instead of raw control matrix for parallel fitting. This dramatically reduces per-worker memory from ~5GB to ~1MB for large datasets, enabling more parallel workers. - None (default): Auto-detect based on dataset size. Frozen control is enabled when control_matrix serialization would limit workers to <4, AND the required settings (dispersion_scope='global', shrink_dispersion=True) are met. For most large datasets (>500K cells), this auto-enables. - True: Force frozen control mode. Raises ValueError if requirements not met. - False: Disable frozen control (use raw control matrix). Memory efficiency: Per-worker pickle size is reduced from (control_n × n_genes × 8) bytes to just ~1MB of sufficient statistics (W_sum, Wz_sum arrays). Example: For Replogle-GW-k562 (75K control cells × 8K genes): - Without freeze_control: ~4.7 GB per worker → 2 workers max @ 128GB - With freeze_control: ~1 MB per worker → 32 workers @ 128GB - Time reduction: ~300 hours → ~10 hours Requirements (enforced when True, auto-checked when None): - dispersion_scope="global" (per-comparison dispersion needs raw matrix) - shrink_dispersion=True (ensures global dispersion is computed) - use_control_cache=True (required for caching) Technical note: The intercept (β₀) is frozen to the control-only estimate. This is valid because control cells have perturbation indicator = 0, so μ_control = exp(β₀ + offset) is independent of the perturbation effect. force If True, rerun the analysis even when the output h5ad file already exists. If False (default), load and return the existing result instead of rerunning. Returns ------- RankGenesGroupsResult Differential expression results. Access results via dict-like interface: `result[label].effect_size`, `result[label].pvalue`, etc. The h5ad file path is available at `result.result_path`. """ perturbation_column, control_label, min_pct_ctrl, min_pct_pert = _resolve_de_aliases( perturbation_column=perturbation_column, groupby=groupby, control_label=control_label, reference=reference, min_pct_both=min_pct_both, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, fn_name="nb_glm_test", ) # Validate min_mu parameter if min_mu < 0: raise ValueError(f"min_mu must be >= 0, got {min_mu}") covariates = list(covariates or []) # Initialize profiler if enabled (timing + memory sampling) profiler = None if profiling: from .profiling import Profiler profiler = Profiler(timing=True, memory=True, memory_method="rss", sampling=True) profiler.start("total") profiler.start("fit") # Start fit timing (excludes shrinkage which is separate) # Check if dataset needs sorting for efficient I/O # Large datasets with many perturbations benefit from having cells sorted # by perturbation label, enabling contiguous reads instead of random access path = resolve_data_path(data) # Early-exit: if output already exists and force=False, reload without rerunning. # Use the original (pre-sort) path for output_path resolution so the location # is predictable for the caller regardless of internal sorting. _candidate_output_path = resolve_output_path( path, suffix="nb_glm", output_dir=output_dir, data_name=data_name ) if (r := _try_load_existing_de_result( _candidate_output_path, force=force, verbose=verbose, method_name="nb_glm", memory_limit_gb=memory_limit_gb, )): return r # Warn if the on-disk matrix is stored in CSC format – row-slicing # (used by size-factor computation and control-matrix loading) is # extremely slow on CSC. Recommend the CSR standardized file. _storage_fmt = get_matrix_storage_format(path) if _storage_fmt == "csc": warnings.warn( f"The input file '{path.name}' stores its matrix in CSC format. " "NB-GLM performs row-wise access (size factors, control matrix, " "per-perturbation slices) which is very slow on CSC. " "Use the CSR-format standardized file for much better performance.", UserWarning, stacklevel=2, ) if needs_sorting_for_nbglm(path, perturbation_column=perturbation_column): sorted_path = path.parent / f"{path.stem}_sorted.h5ad" if not sorted_path.exists(): logger.info( f"Large dataset detected with scattered cells. " f"Sorting by perturbation for efficient I/O..." ) path = sort_by_perturbation( path, perturbation_column=perturbation_column, control_label=control_label, output_path=sorted_path, ) else: logger.info(f"Using existing sorted dataset: {sorted_path}") path = sorted_path backed = read_backed(path) try: gene_symbols = ensure_gene_symbol_column(backed, gene_name_column) if perturbation_column not in backed.obs.columns: raise KeyError( f"Perturbation column '{perturbation_column}' was not found in adata.obs. Available columns: {list(backed.obs.columns)}" ) obs_df = backed.obs[[perturbation_column] + covariates].copy() labels = obs_df[perturbation_column].astype(str).to_numpy() control_label = resolve_control_label(list(labels), control_label) n_genes = backed.n_vars candidates = _resolve_candidates(labels, control_label, perturbations) control_mask = labels == control_label control_n = int(control_mask.sum()) if control_n == 0: raise ValueError("Control group contains no cells") for label in candidates: if not np.any(labels == label): raise ValueError(f"Perturbation '{label}' contains no cells") # Check if file is sorted by perturbation for efficient I/O perturbation_boundaries = None if "sorting_metadata" in backed.uns: metadata = backed.uns["sorting_metadata"] if metadata.get("sorted_by") == perturbation_column: perturbation_boundaries = metadata.get("perturbation_boundaries", {}) if perturbation_boundaries: logger.debug( f"Using sorted file with {len(perturbation_boundaries)} contiguous perturbation groups" ) finally: backed.file.close() n_cells_total = obs_df.shape[0] # Calculate adaptive chunk_size if not provided # Only reduces chunk_size when memory would be exceeded; successful datasets keep default (256) if chunk_size is None: chunk_size = calculate_nb_glm_chunk_size( n_obs=n_cells_total, n_vars=n_genes, n_groups=len(candidates), memory_limit_gb=memory_limit_gb, ) # For per_comparison size factors, we skip the expensive global computation # since size factors will be recomputed per-comparison anyway. # We use dummy size factors here (will be overwritten per-comparison). if size_factor_scope == "per_comparison" and size_factors is None: # Dummy size factors - will be replaced in worker size_factors = np.ones(n_cells_total, dtype=np.float64) logger.debug("Skipping global size factor computation for per_comparison mode") elif size_factors is None: if size_factor_method == "deseq2": size_factors = _deseq2_style_size_factors( path, chunk_size=chunk_size, scale=scale_size_factors ) else: # "sparse" (default) size_factors = _median_of_ratios_size_factors( path, chunk_size=chunk_size, scale=scale_size_factors ) else: size_factors = _validate_size_factors( size_factors, n_cells_total, scale=scale_size_factors ) offset = np.log(np.clip(size_factors, 1e-8, None)) # ========================================================================= # Independent fitting mode: per-perturbation approach # ========================================================================= beta_cov = None beta_intercept = None global_dispersion = None # ------------------------------------------------------------------------- # Worker function for parallel fitting of a single perturbation group # ------------------------------------------------------------------------- def _fit_perturbation_worker( group_idx: int, label: str, path: str | Path, labels: np.ndarray, control_mask: np.ndarray, control_matrix: np.ndarray | sp.csr_matrix, control_expr_counts: np.ndarray, control_n: int, obs_df: pd.DataFrame, covariates: list[str], size_factors: np.ndarray, offset: np.ndarray, n_genes: int, min_cells_expressed: int, min_pct_ctrl: float, min_pct_pert: float, min_mean_ctrl: float, min_mean_pert: float, min_total_count: float, max_iter: int, tol: float, min_mu: float, poisson_init_iter: int, dispersion_method: str, global_dispersion: np.ndarray | None, shrink_dispersion: bool, use_map_dispersion: bool, lfc_shrinkage_type: str, pts_rest_shared: np.ndarray, full_X: np.ndarray | sp.csr_matrix | None = None, per_comparison_sf: bool = False, se_method: str = "sandwich", perturbation_boundaries: dict | None = None, ) -> dict: """Fit NB-GLM for a single perturbation group and return results.""" group_mask = labels == label subset_mask = control_mask | group_mask subset_obs = obs_df.iloc[subset_mask] indicator = group_mask[subset_mask].astype(np.float64) # Compute per-comparison size factors if requested (matches PyDESeq2) if per_comparison_sf and full_X is not None: subset_sf = _compute_subset_size_factors(full_X, subset_mask, scale=True) subset_size_factors = subset_sf subset_offset = np.log(np.clip(subset_sf, 1e-8, None)) else: subset_size_factors = np.asarray(size_factors)[subset_mask] subset_offset = offset[subset_mask] if offset is not None else np.log(np.clip(subset_size_factors, 1e-8, None)) # Build design matrix # Build design matrix design, design_columns = build_design_matrix( subset_obs, covariate_columns=covariates, perturbation_indicator=indicator, intercept=True, ) perturbation_column_index = design_columns.index("perturbation") control_subset_mask = control_mask[subset_mask] group_subset_mask = group_mask[subset_mask] group_n = int(group_subset_mask.sum()) subset_n = int(subset_mask.sum()) n_control = int(control_subset_mask.sum()) # Load perturbation group cells # Use slice-based access if file is sorted, otherwise use mask backed = read_backed(path) try: if perturbation_boundaries is not None and label in perturbation_boundaries: # Slice-based access for sorted files (contiguous, fast) start, end = perturbation_boundaries[label] group_matrix = backed.X[start:end, :] else: # Mask-based access for unsorted files (random, slower) group_matrix = backed.X[group_mask, :] if sp.issparse(group_matrix): group_matrix = sp.csr_matrix(group_matrix, dtype=np.float64) else: group_matrix = np.asarray(group_matrix, dtype=np.float64) finally: backed.file.close() # Combine control and group matrices if sp.issparse(control_matrix) and sp.issparse(group_matrix): stacked = sp.vstack([control_matrix, group_matrix]) reorder = np.empty(subset_n, dtype=np.int32) reorder[np.where(control_subset_mask)[0]] = np.arange(n_control) reorder[np.where(group_subset_mask)[0]] = np.arange(n_control, n_control + group_n) subset_matrix = sp.csr_matrix(stacked[reorder, :]) else: if sp.issparse(control_matrix): ctrl = control_matrix.toarray() else: ctrl = control_matrix if sp.issparse(group_matrix): grp = group_matrix.toarray() else: grp = group_matrix stacked = np.vstack([ctrl, grp]) reorder = np.empty(subset_n, dtype=np.int32) reorder[np.where(control_subset_mask)[0]] = np.arange(n_control) reorder[np.where(group_subset_mask)[0]] = np.arange(n_control, n_control + group_n) subset_matrix = stacked[reorder, :] # Compute expression counts if sp.issparse(group_matrix): group_expr_counts = np.asarray(group_matrix.getnnz(axis=0)).ravel() else: group_expr_counts = np.sum(group_matrix > 0, axis=0) # Per-group means in size-factor-normalised count units. These feed # both the per-condition low-expression filter and ``result["mean"]`` # below; computing them once avoids a duplicate matrix pass. if sp.issparse(group_matrix): group_mean = ( np.asarray(group_matrix.sum(axis=0)).ravel() / max(group_n, 1) ) else: group_mean = np.asarray(group_matrix.sum(axis=0)).ravel() / max(group_n, 1) if sp.issparse(control_matrix): control_mean_raw = ( np.asarray(control_matrix.sum(axis=0)).ravel() / max(control_n, 1) ) else: control_mean_raw = np.asarray(control_matrix.sum(axis=0)).ravel() / max(control_n, 1) total_expr_counts = control_expr_counts + group_expr_counts valid_mask = total_expr_counts >= min_cells_expressed # Per-condition low-expression filter (jointly low in BOTH groups) low_both = _low_expr_in_both_mask( pert_expr_counts=group_expr_counts, control_expr_counts=control_expr_counts, pert_mean=group_mean, control_mean=control_mean_raw, n_pert_cells=group_n, n_control_cells=control_n, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, min_mean_ctrl=min_mean_ctrl, min_mean_pert=min_mean_pert, ) valid_mask = valid_mask & ~low_both valid_indices = np.where(valid_mask)[0] # Initialize result arrays result = { "group_idx": group_idx, "n_tested": int(valid_mask.sum()), "effect": np.full(n_genes, np.nan, dtype=np.float64), "statistic": np.full(n_genes, np.nan, dtype=np.float64), "pvalue": np.full(n_genes, np.nan, dtype=np.float64), "logfc": np.full(n_genes, np.nan, dtype=np.float64), "logfc_raw": np.full(n_genes, np.nan, dtype=np.float64), "intercept": np.full(n_genes, np.nan, dtype=np.float64), # MLE intercept for shrink_lfc "se": np.full(n_genes, np.nan, dtype=np.float64), "pts": np.zeros(n_genes, dtype=np.float32), "pts_rest": np.zeros(n_genes, dtype=np.float32), "dispersion": np.full(n_genes, np.nan, dtype=np.float64), "dispersion_raw": np.full(n_genes, np.nan, dtype=np.float64), "dispersion_trend": np.full(n_genes, np.nan, dtype=np.float64), "mean": np.zeros(n_genes, dtype=np.float64), "iterations": np.zeros(n_genes, dtype=np.int32), "converged": np.zeros(n_genes, dtype=bool), } # Compute pts pts = np.divide( group_expr_counts, group_n, out=np.zeros(n_genes, dtype=np.float32), where=group_n > 0, ) result["pts"] = np.where(valid_mask, pts, 0.0).astype(np.float32) result["pts_rest"] = np.where(valid_mask, pts_rest_shared, 0.0).astype(np.float32) # Compute mean expression if sp.issparse(subset_matrix): normalized = subset_matrix.multiply(1.0 / subset_size_factors[:, None]) mean_expr = np.asarray(normalized.sum(axis=0)).ravel() / subset_n else: normalized = subset_matrix / subset_size_factors[:, None] mean_expr = normalized.sum(axis=0) / subset_n result["mean"] = mean_expr if not np.any(valid_mask): return result # Fit valid genes fit_matrix = subset_matrix[:, valid_mask] batch_fitter = NBGLMBatchFitter( design, offset=subset_offset, max_iter=max_iter, tol=tol, poisson_init_iter=poisson_init_iter, dispersion_method=dispersion_method, min_mu=min_mu, min_total_count=min_total_count, ) batch_result = batch_fitter.fit_batch(fit_matrix) # Extract results result["converged"][valid_indices] = batch_result.converged result["iterations"][valid_indices] = batch_result.n_iter result["dispersion_raw"][valid_indices] = batch_result.dispersion coefs = batch_result.coef[:, perturbation_column_index] ses = batch_result.se[:, perturbation_column_index] intercepts = batch_result.coef[:, 0] # Intercept is always first column valid_results = ( batch_result.converged & np.isfinite(coefs) & np.isfinite(ses) & (ses > 0) ) for local_idx, gene_idx in enumerate(valid_indices): if not valid_results[local_idx]: continue coef = coefs[local_idx] se = ses[local_idx] statistic = coef / se pvalue = float(2.0 * norm.sf(abs(statistic))) result["statistic"][gene_idx] = statistic result["pvalue"][gene_idx] = pvalue result["logfc"][gene_idx] = coef result["se"][gene_idx] = se result["intercept"][gene_idx] = intercepts[local_idx] # Store fitted intercept # Handle dispersion if global_dispersion is not None: result["dispersion_raw"][:] = global_dispersion result["dispersion"][:] = global_dispersion result["dispersion_trend"][:] = global_dispersion elif shrink_dispersion: trend = fit_dispersion_trend(result["mean"], result["dispersion_raw"]) result["dispersion_trend"] = trend if use_map_dispersion: # Use PyDESeq2-style MAP estimation with proper fitted values # First, compute mu from the fitted model coefficients if sp.issparse(subset_matrix): Y = np.asarray(subset_matrix.todense(), dtype=np.float64) else: Y = np.asarray(subset_matrix, dtype=np.float64) # Get fitted values from the model: mu = exp(X @ beta + offset) # batch_result.coef has shape (n_valid_genes, n_design_cols) # design has shape (n_cells, n_design_cols) n_subset = subset_n mu = np.zeros((n_subset, n_genes), dtype=np.float64) # For valid genes, compute mu from fitted coefficients for local_idx, gene_idx in enumerate(valid_indices): if batch_result.converged[local_idx]: # eta = X @ beta + offset eta = design @ batch_result.coef[local_idx, :] + offset[subset_mask] mu[:, gene_idx] = np.exp(np.clip(eta, -30, 30)) # For invalid genes, use normalized counts as fallback invalid_genes = ~np.isin(np.arange(n_genes), valid_indices[batch_result.converged]) if np.any(invalid_genes): mu[:, invalid_genes] = Y[:, invalid_genes] / subset_size_factors[:, None] mu = np.maximum(mu, 1e-10) # Use PyDESeq2-style bounds: max(10, n_cells) max_disp = max(10.0, float(n_subset)) disp_map = estimate_dispersion_map( Y, mu, trend, max_disp=max_disp ) result["dispersion"] = disp_map # CRITICAL: Recompute SE using MAP dispersion (PyDESeq2 style) # SE from IRLS was computed using MoM dispersion, but Wald test # should use MAP dispersion for proper variance estimation ridge = 1e-6 for local_idx, gene_idx in enumerate(valid_indices): if not valid_results[local_idx]: continue # Compute weights with MAP dispersion: W = mu / (1 + mu * disp) mu_gene = mu[:, gene_idx] disp_gene = disp_map[gene_idx] W = mu_gene / (1.0 + mu_gene * disp_gene) # Compute (X'WX + ridge*I)^{-1} XtW = design.T * W[None, :] XtWX = XtW @ design XtWX += ridge * np.eye(design.shape[1]) try: inv_XtWX = np.linalg.inv(XtWX) se_new = np.sqrt(np.maximum(inv_XtWX[perturbation_column_index, perturbation_column_index], 1e-10)) except np.linalg.LinAlgError: se_new = result["se"][gene_idx] # Keep original SE on failure # Update SE, statistic, and p-value coef = result["logfc"][gene_idx] result["se"][gene_idx] = se_new statistic = coef / se_new result["statistic"][gene_idx] = statistic result["pvalue"][gene_idx] = float(2.0 * norm.sf(abs(statistic))) else: result["dispersion"] = shrink_dispersions(result["dispersion_raw"], trend) else: result["dispersion_trend"] = result["dispersion_raw"].copy() result["dispersion"] = result["dispersion_raw"].copy() result["logfc_raw"] = result["logfc"].copy() if lfc_shrinkage_type == "apeglm": # Full NB-GLM re-fitting with Cauchy prior (matches PyDESeq2) # Build mle_coef matrix (n_params, n_genes) from batch_result n_params = design.shape[1] mle_coef = np.full((n_params, n_genes), np.nan, dtype=np.float64) for local_idx, gene_idx in enumerate(valid_indices): if valid_results[local_idx]: mle_coef[:, gene_idx] = batch_result.coef[local_idx, :] # Densify the subset matrix if sparse if sp.issparse(subset_matrix): counts_dense = subset_matrix.toarray() else: counts_dense = np.asarray(subset_matrix, dtype=np.float64) # Call full apeGLM shrinkage with L-BFGS-B re-fitting # NOTE: PyDESeq2's lfc_shrink does NOT use min_mu during shrinkage shrunk_coef, shrunk_se_arr, shrink_converged = shrink_lfc_apeglm( counts=counts_dense, design_matrix=design, size_factors=subset_size_factors, dispersion=result["dispersion"], mle_coef=mle_coef, mle_se=result["se"], shrink_index=perturbation_column_index, prior_scale=None, # Estimate globally from MLE LFC distribution n_jobs=1, # Single-threaded within worker (parallelism at perturbation level) min_mu=0.0, # No min_mu - match PyDESeq2's lfc_shrink behavior ) # Extract shrunken LFC (perturbation coefficient) result["logfc"] = shrunk_coef[perturbation_column_index, :] result["effect"] = result["logfc"].copy() result["se"] = shrunk_se_arr # Posterior SE from inverse Hessian else: result["effect"] = result["logfc"].copy() return result # ------------------------------------------------------------------------- # ------------------------------------------------------------------------- # Optimized worker using precomputed control cache # ------------------------------------------------------------------------- def _fit_perturbation_worker_cached( group_idx: int, label: str, path: str | Path, labels: np.ndarray, control_cache: ControlStatisticsCache, size_factors: np.ndarray, n_genes: int, min_cells_expressed: int, min_pct_ctrl: float, min_pct_pert: float, min_mean_ctrl: float, min_mean_pert: float, min_total_count: float, max_iter: int, tol: float, min_mu: float, dispersion_method: str, shrink_dispersion: bool, use_map_dispersion: bool, lfc_shrinkage_type: str, se_method: str = "sandwich", perturbation_boundaries: dict | None = None, ) -> dict: """Fit NB-GLM for a perturbation group using cached control statistics. This is an optimized version that reuses precomputed control cell statistics (intercept, weights, XᵀWX contributions) to avoid redundant computation across perturbation comparisons. """ group_mask = labels == label group_n = int(group_mask.sum()) n_control = control_cache.control_n subset_n = n_control + group_n # Initialize result arrays result = { "group_idx": group_idx, "effect": np.full(n_genes, np.nan, dtype=np.float64), "statistic": np.full(n_genes, np.nan, dtype=np.float64), "pvalue": np.full(n_genes, np.nan, dtype=np.float64), "logfc": np.full(n_genes, np.nan, dtype=np.float64), "logfc_raw": np.full(n_genes, np.nan, dtype=np.float64), "intercept": np.full(n_genes, np.nan, dtype=np.float64), # MLE intercept for shrink_lfc "se": np.full(n_genes, np.nan, dtype=np.float64), "pts": np.zeros(n_genes, dtype=np.float32), "pts_rest": control_cache.pts_rest.copy(), "dispersion": np.full(n_genes, np.nan, dtype=np.float64), "dispersion_raw": np.full(n_genes, np.nan, dtype=np.float64), "dispersion_trend": np.full(n_genes, np.nan, dtype=np.float64), "mean": np.zeros(n_genes, dtype=np.float64), "iterations": np.zeros(n_genes, dtype=np.int32), "converged": np.zeros(n_genes, dtype=bool), } # Load perturbation group cells # Use slice-based access if file is sorted, otherwise use mask backed = read_backed(path) try: if perturbation_boundaries is not None and label in perturbation_boundaries: # Slice-based access for sorted files (contiguous, fast) start, end = perturbation_boundaries[label] group_matrix = backed.X[start:end, :] else: # Mask-based access for unsorted files (random, slower) group_matrix = backed.X[group_mask, :] if sp.issparse(group_matrix): group_matrix = sp.csr_matrix(group_matrix, dtype=np.float64) else: group_matrix = np.asarray(group_matrix, dtype=np.float64) finally: backed.file.close() # Compute expression counts for perturbation cells if sp.issparse(group_matrix): group_expr_counts = np.asarray(group_matrix.getnnz(axis=0)).ravel() else: group_expr_counts = np.sum(group_matrix > 0, axis=0) # Valid genes mask total_expr_counts = control_cache.control_expr_counts + group_expr_counts valid_mask = total_expr_counts >= min_cells_expressed # Compute pts pts = np.divide( group_expr_counts, group_n, out=np.zeros(n_genes, dtype=np.float32), where=group_n > 0, ) result["pts"] = np.where(valid_mask, pts, 0.0).astype(np.float32) # Compute mean expression subset_size_factors_group = np.asarray(size_factors)[group_mask] if sp.issparse(group_matrix): normalized_group = group_matrix.multiply(1.0 / subset_size_factors_group[:, None]) mean_group = np.asarray(normalized_group.sum(axis=0)).ravel() else: normalized_group = group_matrix / subset_size_factors_group[:, None] mean_group = normalized_group.sum(axis=0) # Combined mean result["mean"] = (control_cache.control_mean_expr * n_control + mean_group) / subset_n # Per-condition low-expression filter (jointly low in BOTH groups). # ``mean_group`` is the size-factor-normalised SUM; divide by group_n # for the per-cell mean used by the filter. group_mean_per_cell = ( mean_group / group_n if group_n > 0 else np.zeros_like(mean_group) ) low_both = _low_expr_in_both_mask( pert_expr_counts=group_expr_counts, control_expr_counts=control_cache.control_expr_counts, pert_mean=group_mean_per_cell, control_mean=control_cache.control_mean_expr, n_pert_cells=group_n, n_control_cells=n_control, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, min_mean_ctrl=min_mean_ctrl, min_mean_pert=min_mean_pert, ) valid_mask = valid_mask & ~low_both result["n_tested"] = int(valid_mask.sum()) if not np.any(valid_mask): return result # Get perturbation offset perturbation_offset = np.log(np.clip(subset_size_factors_group, 1e-8, None)) # Create perturbation indicator for the combined design (control first, then perturbation) perturbation_indicator = np.concatenate([ np.zeros(n_control, dtype=np.float64), np.ones(group_n, dtype=np.float64) ]) # Check if using frozen control mode (memory-efficient parallel fitting) if control_cache.use_frozen_control: # FROZEN CONTROL PATH: Use sufficient statistics instead of raw matrix # Per-worker memory: ~5GB → ~1MB (enables 32 workers instead of 2) # Create batch fitter with minimal design (only perturbation cells have data) batch_fitter = NBGLMBatchFitter( design=np.ones((group_n, 1)), # Placeholder, not used in frozen mode offset=perturbation_offset, # Only perturbation offsets needed max_iter=max_iter, tol=tol, dispersion_method=dispersion_method, min_mu=min_mu, min_total_count=min_total_count, ) batch_result = batch_fitter.fit_batch_with_frozen_control( perturbation_matrix=group_matrix, perturbation_offset=perturbation_offset, control_cache=control_cache, valid_mask=valid_mask, ) else: # STANDARD PATH: Full control_matrix available # Fit using the batch fitter with control cache batch_fitter = NBGLMBatchFitter( design=np.column_stack([np.ones(subset_n), perturbation_indicator]), offset=np.concatenate([control_cache.control_offset, perturbation_offset]), max_iter=max_iter, tol=tol, dispersion_method=dispersion_method, min_mu=min_mu, min_total_count=min_total_count, ) batch_result = batch_fitter.fit_batch_with_control_cache( perturbation_matrix=group_matrix, perturbation_offset=perturbation_offset, control_cache=control_cache, perturbation_indicator=perturbation_indicator, valid_mask=valid_mask, ) valid_indices = np.where(valid_mask)[0] # Extract results result["converged"][valid_indices] = batch_result.converged[valid_indices] result["iterations"][valid_indices] = batch_result.n_iter[valid_indices] # Perturbation coefficient is at index 1 - vectorized computation coefs = batch_result.coef[:, 1] # (n_genes,) ses = batch_result.se[:, 1] # (n_genes,) # Vectorized: compute statistics for all valid & converged genes at once # Build mask for valid, converged genes with finite coef/se valid_converged_mask = ( valid_mask & batch_result.converged & np.isfinite(coefs) & np.isfinite(ses) & (ses > 0) ) # Compute Wald statistic and p-value vectorized statistics = np.divide( coefs, ses, out=np.full(n_genes, np.nan), where=valid_converged_mask ) # Use norm.sf (survival function) to avoid underflow for large |z| # Note: ndtr(x) = CDF, so 1-ndtr(x) underflows for large x; sf(x) = 1-CDF is stable pvalues = np.where( valid_converged_mask, 2.0 * norm.sf(np.abs(statistics)), # 2-sided p-value np.nan ) result["statistic"] = statistics result["pvalue"] = pvalues result["logfc"] = np.where(valid_converged_mask, coefs, np.nan) result["se"] = np.where(valid_converged_mask, ses, np.nan) result["intercept"] = np.where(valid_converged_mask, batch_result.coef[:, 0], np.nan) # Store fitted intercept # Handle dispersion shrinkage if shrink_dispersion: # ================================================================ # MEMORY-OPTIMIZED: Use batched computation for dispersion and SE # ================================================================ # Instead of building full (n_control+n_group, n_genes) matrices, # process genes in batches to reduce peak memory. # ================================================================ # Get coefficients (needed for SE recomputation) beta0_all = batch_result.coef[:, 0] # (n_genes,) beta1_all = batch_result.coef[:, 1] # (n_genes,) # ================================================================ # OPTIMIZATION: Check for global dispersion FIRST (before MoM/trend) # When dispersion_scope='global', skip expensive per-comparison work # ================================================================ if control_cache.global_dispersion is not None: # FAST PATH: Use precomputed global dispersion # Skip MoM dispersion computation entirely # Skip trend fitting entirely logger.debug(f"Using precomputed global dispersion for {label}") result["dispersion"] = control_cache.global_dispersion.copy() if control_cache.global_dispersion_trend is not None: result["dispersion_trend"] = control_cache.global_dispersion_trend.copy() else: # Fallback: use global dispersion as trend proxy result["dispersion_trend"] = control_cache.global_dispersion.copy() # dispersion_raw not computed in global mode - set to NaN or copy global result["dispersion_raw"] = control_cache.global_dispersion.copy() # SE handling: frozen control vs standard if control_cache.use_frozen_control: # FROZEN CONTROL PATH: Use SE from fit_batch_with_frozen_control directly # No SE recomputation needed since dispersion is global (pre-fitted) # The SE already uses the correct dispersion from control_cache pass # SE already set from batch_result else: # STANDARD PATH: Recompute SE with global dispersion n_control = control_cache.control_matrix.shape[0] n_group = group_matrix.shape[0] if sp.issparse(group_matrix): Y_pert_dense = group_matrix.toarray() else: Y_pert_dense = np.asarray(group_matrix, dtype=np.float64) final_disp = result["dispersion"] recomputed_se = _compute_se_batched( Y_control=control_cache.control_matrix, Y_pert=Y_pert_dense, control_offset=control_cache.control_offset, pert_offset=perturbation_offset, beta0=beta0_all, beta1=beta1_all, dispersion=final_disp, gene_batch_size=5000, se_method=se_method, ) # Update result with recomputed SE result["se"] = np.where(valid_converged_mask, recomputed_se, np.nan) # Recompute Wald statistic and p-value with new SE coefs = batch_result.coef[:, 1] statistics = np.divide( coefs, recomputed_se, out=np.full(n_genes, np.nan), where=valid_converged_mask ) pvalues = np.where( valid_converged_mask, 2.0 * norm.sf(np.abs(statistics)), np.nan ) result["statistic"] = statistics result["pvalue"] = pvalues else: # ============================================================ # STANDARD PATH: Per-comparison dispersion (MoM → trend → MAP) # Requires control_matrix - not compatible with frozen control # ============================================================ if control_cache.use_frozen_control: raise ValueError( "Frozen control mode requires global dispersion. " "Set dispersion_scope='global' and shrink_dispersion=True when using freeze_control=True." ) # Prepare perturbation matrix as dense (needed for SE recomputation) n_control = control_cache.control_matrix.shape[0] n_group = group_matrix.shape[0] if sp.issparse(group_matrix): Y_pert_dense = group_matrix.toarray() else: Y_pert_dense = np.asarray(group_matrix, dtype=np.float64) # Compute MoM dispersion using batched processing mom_disp = _compute_mom_dispersion_batched( Y_control=control_cache.control_matrix, Y_pert=Y_pert_dense, control_offset=control_cache.control_offset, pert_offset=perturbation_offset, beta0=beta0_all, beta1=beta1_all, converged=batch_result.converged, gene_batch_size=5000, ) result["dispersion_raw"][valid_indices] = mom_disp[valid_indices] # Fit trend using corrected MoM dispersion trend = fit_dispersion_trend(result["mean"], result["dispersion_raw"]) result["dispersion_trend"] = trend if use_map_dispersion: # For MAP dispersion, we still need full matrices (can optimize later) # Build combined Y and mu for estimate_dispersion_map Y = np.empty((n_control + n_group, n_genes), dtype=np.float64) Y[:n_control, :] = control_cache.control_matrix Y[n_control:, :] = Y_pert_dense # Compute full mu matrix for MAP offset_combined = np.concatenate([control_cache.control_offset, perturbation_offset]) eta = ( beta0_all[None, :] + perturbation_indicator[:, None] * beta1_all[None, :] + offset_combined[:, None] ) np.clip(eta, -30, 30, out=eta) mu = np.exp(eta) del eta mu[:, ~batch_result.converged] = 1e-10 np.maximum(mu, 1e-10, out=mu) max_disp = max(10.0, float(subset_n)) result["dispersion"] = estimate_dispersion_map(Y, mu, trend, max_disp=max_disp) del Y, mu # Free large matrices else: result["dispersion"] = shrink_dispersions(result["dispersion_raw"], trend) # ================================================================ # Recompute SE using batched processing (memory-optimized) # ================================================================ final_disp = result["dispersion"] recomputed_se = _compute_se_batched( Y_control=control_cache.control_matrix, Y_pert=Y_pert_dense, control_offset=control_cache.control_offset, pert_offset=perturbation_offset, beta0=beta0_all, beta1=beta1_all, dispersion=final_disp, gene_batch_size=5000, se_method=se_method, ) # Update result with recomputed SE result["se"] = np.where(valid_converged_mask, recomputed_se, np.nan) # Recompute Wald statistic and p-value with new SE coefs = batch_result.coef[:, 1] statistics = np.divide( coefs, recomputed_se, out=np.full(n_genes, np.nan), where=valid_converged_mask ) pvalues = np.where( valid_converged_mask, 2.0 * norm.sf(np.abs(statistics)), # Use sf() for numerical stability np.nan ) result["statistic"] = statistics result["pvalue"] = pvalues else: result["dispersion_trend"] = result["dispersion_raw"].copy() result["dispersion"] = result["dispersion_raw"].copy() result["logfc_raw"] = result["logfc"].copy() if lfc_shrinkage_type == "apeglm": # Full NB-GLM re-fitting with Cauchy prior (matches PyDESeq2) # Build mle_coef matrix (n_params, n_genes) from batch_result n_params = 2 # Intercept + perturbation mle_coef = np.full((n_params, n_genes), np.nan, dtype=np.float64) mle_coef[0, :] = batch_result.coef[:, 0] # Intercept mle_coef[1, :] = batch_result.coef[:, 1] # LFC # Build combined count matrix (control + perturbation) if sp.issparse(group_matrix): Y_pert = group_matrix.toarray() else: Y_pert = np.asarray(group_matrix, dtype=np.float64) counts_combined = np.vstack([control_cache.control_matrix, Y_pert]) # Build combined size factors sf_combined = np.concatenate([ np.exp(control_cache.control_offset), subset_size_factors_group ]) # Build design matrix for combined data design_combined = np.zeros((n_control + group_n, 2), dtype=np.float64) design_combined[:, 0] = 1.0 # Intercept design_combined[n_control:, 1] = 1.0 # Perturbation indicator # Call full apeGLM shrinkage # NOTE: PyDESeq2's lfc_shrink does NOT use min_mu during shrinkage shrunk_coef, shrunk_se_arr, shrink_converged = shrink_lfc_apeglm( counts=counts_combined, design_matrix=design_combined, size_factors=sf_combined, dispersion=result["dispersion"], mle_coef=mle_coef, mle_se=result["se"], shrink_index=1, # LFC is at index 1 prior_scale=None, n_jobs=1, min_mu=0.0, # No min_mu - match PyDESeq2's lfc_shrink behavior ) result["logfc"] = shrunk_coef[1, :] # Shrunken LFC result["effect"] = result["logfc"].copy() result["se"] = shrunk_se_arr else: result["effect"] = result["logfc"].copy() return result # ------------------------------------------------------------------------- n_groups = len(candidates) # Determine output path first (needed for checkpoint) output_path = resolve_output_path( path, suffix="nb_glm", output_dir=output_dir, data_name=data_name, ) output_path.parent.mkdir(parents=True, exist_ok=True) checkpoint_path = output_path.with_suffix(".progress.json") # Handle resume logic if resume: candidates_to_run, completed_labels, failed_labels = _get_resumable_candidates( checkpoint_path, output_path, candidates, retry_failed=True ) # If all candidates are completed, load and return the existing result if len(candidates_to_run) == 0 and output_path.exists(): logger.info("All perturbations already completed. Loading existing result...") return _load_existing_nb_glm_result( output_path=output_path, candidates=candidates, gene_symbols=gene_symbols, perturbation_column=perturbation_column, control_label=control_label, corr_method=corr_method, ) else: candidates_to_run = candidates completed_labels = [] failed_labels = [] # Determine checkpoint interval eff_checkpoint_interval = _get_checkpoint_interval(len(candidates), checkpoint_interval) # Create index mappings candidate_to_idx = {label: idx for idx, label in enumerate(candidates)} with tempfile.TemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) def _create_memmap( name: str, dtype: np.dtype, *, fill: float | int | bool | None = np.nan ) -> np.memmap: mmap = np.memmap( tmp_path / f"{name}.dat", mode="w+", dtype=dtype, shape=(n_groups, n_genes), ) if fill is None: return mmap if isinstance(fill, float) and np.isnan(fill): mmap.fill(np.nan) else: mmap.fill(fill) return mmap effect_memmap = _create_memmap("effect", np.float64) statistic_memmap = _create_memmap("statistic", np.float64) pvalue_memmap = _create_memmap("pvalue", np.float64) logfc_memmap = _create_memmap("logfoldchange", np.float64) logfc_raw_memmap = _create_memmap("logfoldchange_raw", np.float64) intercept_memmap = _create_memmap("intercept", np.float64) # MLE intercept for shrink_lfc se_memmap = _create_memmap("standard_error", np.float64) pts_memmap = _create_memmap("pts", np.float32, fill=0.0) pts_rest_memmap = _create_memmap("pts_rest", np.float32, fill=0.0) dispersion_memmap = _create_memmap("dispersion", np.float64) dispersion_raw_memmap = _create_memmap("dispersion_raw", np.float64) dispersion_trend_memmap = _create_memmap("dispersion_trend", np.float64) mean_memmap = _create_memmap("mean", np.float64, fill=0.0) iter_memmap = _create_memmap("iterations", np.int32, fill=0) convergence_memmap = _create_memmap("converged", np.bool_, fill=False) # Load control cells matrix once (all genes) # For very large control groups, skip loading and use streaming path control_matrix_gb = control_n * n_genes * 8 / 1e9 # dense float64 if memory_limit_gb is not None: _ctrl_mem_limit = min(memory_limit_gb, _get_available_memory_mb() / 1000) else: _ctrl_mem_limit = _get_available_memory_mb() / 1000 # Streaming threshold: control dense + 3 work arrays > 30% available memory _ctrl_streaming_threshold = max_dense_fraction * _ctrl_mem_limit use_streaming_control = (control_matrix_gb * 4) > _ctrl_streaming_threshold if use_streaming_control: logger.info( f"Large control group: {control_n:,} cells × {n_genes:,} genes = " f"{control_matrix_gb:.1f} GB dense. Using streaming control statistics." ) control_matrix = None # Not loaded — streaming will read from disk # Compute control_expr_counts via streaming control_expr_counts = np.zeros(n_genes, dtype=np.int64) backed = read_backed(path) try: ctrl_indices = np.where(control_mask)[0] _chunk = 4096 for _start in range(0, control_n, _chunk): _end = min(_start + _chunk, control_n) _idx = ctrl_indices[_start:_end] _blk = backed.X[_idx, :] if sp.issparse(_blk): control_expr_counts += np.asarray(_blk.getnnz(axis=0)).ravel() else: control_expr_counts += np.asarray((_blk > 0).sum(axis=0)).ravel() finally: backed.file.close() drop_file_cache(path) # Free page cache from streaming reads else: backed = read_backed(path) try: control_matrix = backed.X[control_mask, :] if sp.issparse(control_matrix): control_matrix = sp.csr_matrix(control_matrix, dtype=np.float64) else: control_matrix = np.asarray(control_matrix, dtype=np.float64) finally: backed.file.close() drop_file_cache(path) # Free page cache from control matrix read # Pre-compute control expression counts if sp.issparse(control_matrix): control_expr_counts = np.asarray(control_matrix.getnnz(axis=0)).ravel() else: control_expr_counts = np.sum(control_matrix > 0, axis=0) # Compute pts_rest once (same for all perturbations) pts_rest_shared = np.divide( control_expr_counts, control_n, out=np.zeros(n_genes, dtype=np.float32), where=control_n > 0, ) # ===================================================================== # Parallel fitting of perturbation groups # ===================================================================== # Determine number of parallel workers with memory-awareness # For small n_groups, run sequentially to avoid joblib overhead # (profiling shows joblib.sleep takes 24s for 2 perturbations) cpu_count = os.cpu_count() or 1 if n_jobs is None or n_jobs == 0: effective_n_jobs = cpu_count elif n_jobs == -1: effective_n_jobs = cpu_count elif n_jobs < 0: effective_n_jobs = max(1, cpu_count + n_jobs + 1) else: effective_n_jobs = min(n_jobs, cpu_count) effective_n_jobs = max(1, effective_n_jobs) # Save original requested workers for auto-detection logic # (before memory-based reduction) requested_n_jobs = effective_n_jobs # Memory-aware worker limiting: Adaptive estimation based on dataset statistics # Compute group size statistics without loading the full matrix # Use numpy unique with counts since labels is a numpy array unique_labels, label_counts = np.unique(labels, return_counts=True) group_sizes = dict(zip(unique_labels, label_counts)) pert_group_sizes = [group_sizes.get(g, 0) for g in candidates] max_group_size = max(pert_group_sizes) if pert_group_sizes else 1 avg_group_size = max(1, (n_cells_total - control_n) // max(1, n_groups)) # Use p95 group size for realistic estimate (avoids over-conservative from outliers) if len(pert_group_sizes) >= 10: p95_group_size = int(np.percentile(pert_group_sizes, 95)) use_group_size = p95_group_size else: use_group_size = max_group_size # Decide early if we can use control cache (needed for memory estimation) can_use_cache_early = ( use_control_cache and len(covariates) == 0 and size_factor_scope == "global" ) # ===================================================================== # AUTO-DETECTION: Enable freeze_control for large datasets # ===================================================================== # When freeze_control=None (default), auto-enable if: # 1. Control matrix serialization would severely limit workers (<4) # 2. Required settings are met (dispersion_scope='global', shrink_dispersion=True) # # This provides optimal parallelization without user intervention. if freeze_control is None: # Check if settings are compatible with frozen control settings_compatible = ( can_use_cache_early and dispersion_scope == "global" and shrink_dispersion ) if settings_compatible: # Estimate per-worker memory in standard mode (matching actual formula below) control_matrix_mb_est = control_n * n_genes * 8 / 1e6 control_matrix_gb_est = control_matrix_mb_est / 1000 labels_mb_est = n_cells_total * 50 / 1e6 # ~50 bytes per string size_factors_mb_est = n_cells_total * 8 / 1e6 work_arrays_mb_est = (control_n + use_group_size) * n_genes * 8 * 4 / 1e6 serialized_args_mb_est = (control_matrix_mb_est + labels_mb_est + size_factors_mb_est) * 2.5 per_worker_standard_mb = serialized_args_mb_est + work_arrays_mb_est + 2000 # Get available memory if memory_limit_gb is not None: available_mb = memory_limit_gb * 1000 else: try: import psutil available_mb = psutil.virtual_memory().available / 1e6 except ImportError: available_mb = 8000.0 # ================================================================ # AUTO-ENABLE CONDITION 1: Large control matrix (>10 GB) # For datasets like Feng (110K control × 36K genes = 32 GB), # freeze_control significantly reduces per-worker memory. # ================================================================ if control_matrix_gb_est > 10.0: freeze_control = True logger.info( f"Auto-enabling freeze_control: large control matrix " f"({control_n:,} cells × {n_genes:,} genes = {control_matrix_gb_est:.1f} GB > 10 GB threshold)" ) else: # ================================================================ # AUTO-ENABLE CONDITION 2: Worker count limitation (<4 workers) # Original logic: enable if parallelization would be severely limited # ================================================================ base_memory_mb_est = control_matrix_mb_est + 1000 usable_mb = available_mb * 0.8 remaining_mb = max(usable_mb - base_memory_mb_est, per_worker_standard_mb) max_workers_standard = max(1, int(remaining_mb / per_worker_standard_mb)) if max_workers_standard < 4 and requested_n_jobs >= 4: freeze_control = True logger.info( f"Auto-enabling freeze_control: standard mode would limit to {max_workers_standard} workers " f"(control: {control_n:,} cells × {n_genes:,} genes = {control_matrix_mb_est:.0f} MB, " f"per_worker: {per_worker_standard_mb:.0f} MB). " f"Frozen control enables ~{requested_n_jobs} workers." ) else: freeze_control = False else: freeze_control = False # Check if frozen control mode is valid (after auto-detection) can_use_frozen_control = ( freeze_control and can_use_cache_early and dispersion_scope == "global" and shrink_dispersion ) # When frozen control is enabled but streaming was not, retroactively # switch to streaming. The dense IRLS peak is 4× control_matrix_gb; # streaming avoids that by processing chunks. The loaded control # matrix is freed because frozen stats replace it. if can_use_frozen_control and not use_streaming_control and control_matrix is not None: use_streaming_control = True del control_matrix control_matrix = None gc.collect() logger.info( "Switching to streaming control statistics for frozen control mode " f"(avoids {control_matrix_gb * 4:.1f} GB dense IRLS peak)." ) # Memory estimation for joblib parallel execution # IMPORTANT: joblib's loky backend serializes (pickles) all function arguments # for each worker process. This means control_matrix is copied to each worker, # not shared via copy-on-write as it would be with fork(). # # What each worker receives via pickle: # 1. control_cache.control_matrix: (control_n × n_genes) × 8 bytes # OR with freeze_control=True: frozen stats only (~1MB total) # 2. labels array: n_cells_total strings (pickled as object array) # 3. size_factors: n_cells_total × 8 bytes # 4. Other small arrays (control_offset, etc.) # # What each worker allocates during execution: # 1. group_matrix (loaded from disk, small: ~200 cells × n_genes) # 2. Intermediate arrays for SE recomputation, mu, etc. # 3. When dispersion_scope='per_comparison': full Y and mu matrices # # Pickle overhead is typically 1.5-2× the raw array size due to protocol # serialization and Python object overhead. if can_use_frozen_control: # FROZEN CONTROL MODE: Optimized memory estimation # # What each worker receives: # 1. control_cache with frozen stats (~1MB): W_sum, Wz_sum, etc. # 2. perturbation_boundaries dict: ~100KB for 10K perturbations # 3. size_factors: still full array (~16MB for 2M cells) - could optimize later # 4. labels: still needed for fallback, but could use boundaries only # # What each worker allocates: # 1. group_matrix from disk: (group_size × n_genes) × 8 bytes # 2. Work arrays: mu, W, z for perturbation cells only # 3. Result arrays: ~n_genes × 8 bytes × 10 arrays # Frozen stats: 6 arrays of shape (n_genes,) frozen_stats_mb = n_genes * 8 * 6 / 1e6 # W_sum, Wz_sum, mu_sum, etc. # Cache metadata (beta_intercept, dispersion, pts_rest, etc.) cache_metadata_mb = n_genes * 8 * 5 / 1e6 # Perturbation boundaries: ~50 bytes per perturbation boundaries_mb = n_groups * 50 / 1e6 # Labels and size_factors (still passed but could be optimized) labels_mb = n_cells_total * 50 / 1e6 size_factors_mb = n_cells_total * 8 / 1e6 # Serialized args with reduced pickle overhead (simpler objects) control_matrix_mb_for_pickle = frozen_stats_mb + cache_metadata_mb + boundaries_mb serialized_args_mb = (control_matrix_mb_for_pickle + labels_mb + size_factors_mb) * 2.0 # Work arrays: only perturbation cells (much smaller!) # mu_pert, W_pert, z_pert, Y_pert_valid for fitting work_arrays_mb = use_group_size * n_genes * 8 * 5 / 1e6 # Result arrays per worker result_arrays_mb = n_genes * 8 * 12 / 1e6 # 12 result fields # Reduced Python overhead for frozen control (simpler computation) python_overhead_mb = 500 # 500 MB instead of 2 GB per_worker_mb = serialized_args_mb + work_arrays_mb + result_arrays_mb + python_overhead_mb logger.debug( f"Frozen control memory estimate: serialized={serialized_args_mb:.1f}MB, " f"work_arrays={work_arrays_mb:.1f}MB, per_worker={per_worker_mb:.1f}MB" ) else: control_matrix_mb_for_pickle = control_n * n_genes * 8 / 1e6 # float64 # Context-aware work arrays based on dispersion mode if can_use_cache_early and dispersion_scope == "global": # Global dispersion: skip MoM/trend, but still need SE recomputation arrays work_arrays_mb = (control_n + use_group_size) * n_genes * 8 * 4 / 1e6 else: # Per-comparison: need full Y and mu matrices for MAP dispersion work_arrays_mb = (control_n + use_group_size) * n_genes * 8 * 6 / 1e6 # Standard mode: full labels and size_factors arrays labels_mb = n_cells_total * 50 / 1e6 # ~50 bytes per string (pickled) size_factors_mb = n_cells_total * 8 / 1e6 # Pickle overhead is 2-3× for complex objects due to Python object structure serialized_args_mb = (control_matrix_mb_for_pickle + labels_mb + size_factors_mb) * 2.5 # Per-worker total: serialized args + work arrays + Python/process overhead per_worker_mb = serialized_args_mb + work_arrays_mb + 2000 # For base memory and logging control_matrix_mb = control_n * n_genes * 8 / 1e6 # Base memory: parent process + one copy of control matrix + misc arrays base_memory_mb = control_matrix_mb + 1000 # Parent process overhead # Calculate available memory if memory_limit_gb is not None: available_mb = memory_limit_gb * 1000 else: try: import psutil available_mb = psutil.virtual_memory().available / 1e6 except ImportError: available_mb = 8000.0 # 8 GB default # Reserve 20% headroom for safety usable_mb = available_mb * 0.8 remaining_mb = max(usable_mb - base_memory_mb, per_worker_mb) max_workers_by_memory = max(1, int(remaining_mb / per_worker_mb)) # Dataset-size-aware caps full_matrix_gb = (control_n + avg_group_size * n_groups) * n_genes * 8 / 1e9 if full_matrix_gb < 1.0: # Tiny dataset: cap workers to reduce parallelization overhead max_workers_by_size = max(4, n_groups // 2) else: max_workers_by_size = n_groups # Apply all constraints # Use candidates_to_run (not n_groups) for worker limit since that's what we're actually running n_to_run = len(candidates_to_run) memory_limited_workers = min(max_workers_by_memory, max_workers_by_size, n_to_run) if memory_limited_workers < effective_n_jobs: # Apply memory limiting # Determine limiting factor for logging if memory_limited_workers == n_to_run and n_to_run < max_workers_by_memory and n_to_run < max_workers_by_size: limit_reason = "perturbation_count" elif memory_limited_workers == max_workers_by_memory: limit_reason = "memory" elif memory_limited_workers == max_workers_by_size: limit_reason = "small_dataset" else: limit_reason = "perturbation_count" logger.info( f"Memory-aware limiting: {effective_n_jobs} -> {memory_limited_workers} workers " f"(reason: {limit_reason}, base: {base_memory_mb:.0f}MB, per_worker: {per_worker_mb:.0f}MB, " f"available: {available_mb:.0f}MB, full_matrix: {full_matrix_gb:.1f}GB)" ) effective_n_jobs = memory_limited_workers effective_n_jobs = max(1, effective_n_jobs) # For small number of perturbations to run, run sequentially to avoid overhead use_parallel = n_to_run >= 4 and effective_n_jobs > 1 # Decide whether to use control cache optimization # Control cache is used when: no covariates, use_control_cache=True, global SF # Per-comparison size factors require fresh computation per comparison can_use_cache = ( use_control_cache and len(covariates) == 0 and size_factor_scope == "global" ) # ===================================================================== # Early memory check: determine if we need streaming mode # ===================================================================== # For very large datasets (e.g., Replogle-GW-k562: 2M cells × 8K genes), # loading the full matrix would exceed memory. Check this ONCE before # any full matrix loads (full_X for per-comparison SF, all_cell_matrix # for global dispersion). estimated_matrix_gb = n_cells_total * n_genes * 8 / 1e9 # float64 if memory_limit_gb is not None: effective_memory_limit_gb = min(available_mb / 1000, memory_limit_gb) else: effective_memory_limit_gb = available_mb / 1000 memory_budget_gb = max_dense_fraction * effective_memory_limit_gb use_streaming_mode = estimated_matrix_gb > memory_budget_gb if use_streaming_mode: logger.info( f"Large dataset detected: {n_cells_total:,} cells × {n_genes:,} genes = " f"{estimated_matrix_gb:.1f} GB > {memory_budget_gb:.1f} GB budget. " f"Using streaming mode for memory efficiency." ) # For per-comparison size factors, we need the full count matrix # Skip if streaming mode - worker will fall back to global SF if size_factor_scope == "per_comparison": if use_streaming_mode: logger.warning( f"Dataset too large ({estimated_matrix_gb:.1f} GB) for per-comparison " f"size factors (requires loading full matrix). " f"Falling back to global size factors for memory efficiency." ) full_X = None # Worker will use global SF else: logger.info("Using per-comparison size factors for PyDESeq2 compatibility...") backed = read_backed(path) try: full_X = backed.X[:] if sp.issparse(full_X): full_X = sp.csr_matrix(full_X, dtype=np.float64) else: full_X = np.asarray(full_X, dtype=np.float64) finally: backed.file.close() else: full_X = None # Precompute control statistics once if using cache control_cache = None if can_use_cache: # Validate freeze_control requirements (for explicit freeze_control=True) if freeze_control: if dispersion_scope != "global": raise ValueError( "freeze_control=True requires dispersion_scope='global'. " "Per-comparison dispersion needs raw control matrix." ) if not shrink_dispersion: raise ValueError( "freeze_control=True requires shrink_dispersion=True. " "Global dispersion must be computed for frozen control mode." ) # Note: Auto-detected freeze_control already logged in auto-detection block logger.info("Precomputing control cell statistics for cache optimization...") control_offset_arr = offset[control_mask] if use_streaming_control: # Streaming path: read control cells from disk in chunks # Forces freeze_control=True (raw matrix never materialised) if not freeze_control: logger.info( "Forcing freeze_control=True for streaming control statistics " "(raw control matrix too large to fit in memory)." ) freeze_control = True control_cache = precompute_control_statistics_streaming( path=path, control_mask=control_mask, control_offset=control_offset_arr, max_iter=max_iter, tol=tol, min_mu=min_mu, global_size_factors=size_factors, freeze_control=True, ) drop_file_cache(path) # Free page cache from streaming IRLS else: control_cache = precompute_control_statistics( control_matrix=control_matrix, control_offset=control_offset_arr, max_iter=max_iter, tol=tol, min_mu=min_mu, dispersion_method="moments", # Fast initial estimate global_size_factors=size_factors, # Store global SF in cache freeze_control=freeze_control, # Enable frozen control mode ) # If frozen control, the cache holds sufficient statistics — # the raw control_matrix is no longer needed. if freeze_control and control_matrix is not None: del control_matrix control_matrix = None gc.collect() # Precompute global dispersion if dispersion_scope='global' if dispersion_scope == "global" and shrink_dispersion and use_map_dispersion: logger.info("Precomputing global dispersion trend (dispersion_scope='global')...") if use_streaming_mode: # Use path-based streaming for large datasets # Reads chunks from disk, never loads full matrix control_cache = precompute_global_dispersion_from_path( path=path, control_cache=control_cache, all_cell_offset=offset, fit_type="parametric", ) drop_file_cache(path) # Free page cache from streaming reads else: # Load all cells for global dispersion estimation backed = read_backed(path) try: all_cell_matrix = backed.X[:] if sp.issparse(all_cell_matrix): all_cell_matrix = sp.csr_matrix(all_cell_matrix, dtype=np.float64) else: all_cell_matrix = np.asarray(all_cell_matrix, dtype=np.float64) finally: backed.file.close() # Compute global dispersion using all cells # Use fast_mode=True for speed (MoM + trend shrinkage instead of MAP) # Memory-adaptive: switches to streaming if matrix too large control_cache = precompute_global_dispersion( control_cache=control_cache, all_cell_matrix=all_cell_matrix, all_cell_offset=offset, n_grid=25, fit_type="parametric", fast_mode=True, # ~50× faster than full MAP max_dense_fraction=max_dense_fraction, memory_limit_gb=memory_limit_gb, ) del all_cell_matrix # Free memory gc.collect() # Force garbage collection before spawning workers drop_file_cache(path) # Free page cache from full matrix read logger.info(f"Global dispersion precomputed: prior_var={control_cache.global_disp_prior_var:.4f}") # Log progress info if resume and completed_labels: logger.info(f"Fitting {n_to_run}/{n_groups} remaining perturbations with {effective_n_jobs} workers...") else: logger.info(f"Fitting {n_groups} perturbations with {effective_n_jobs} workers...") # Track completed labels during this run newly_completed = list(completed_labels) # Start with already completed newly_failed = list(failed_labels) n_processed = 0 # Helper function to write result to memmap def _write_result_to_memmap(res: dict, label: str) -> None: idx = candidate_to_idx[label] effect_memmap[idx, :] = res["effect"] statistic_memmap[idx, :] = res["statistic"] pvalue_memmap[idx, :] = res["pvalue"] logfc_memmap[idx, :] = res["logfc"] logfc_raw_memmap[idx, :] = res["logfc_raw"] intercept_memmap[idx, :] = res["intercept"] # MLE intercept for shrink_lfc se_memmap[idx, :] = res["se"] pts_memmap[idx, :] = res["pts"] pts_rest_memmap[idx, :] = res["pts_rest"] dispersion_memmap[idx, :] = res["dispersion"] dispersion_raw_memmap[idx, :] = res["dispersion_raw"] dispersion_trend_memmap[idx, :] = res["dispersion_trend"] mean_memmap[idx, :] = res["mean"] iter_memmap[idx, :] = res["iterations"] convergence_memmap[idx, :] = res["converged"] # Helper to save checkpoint def _save_checkpoint() -> None: checkpoint_data = { "total": n_groups, "completed": newly_completed, "failed": newly_failed, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "method": "nb_glm", "control_label": control_label, } _write_checkpoint_atomic(checkpoint_path, checkpoint_data) # Run fitting with progress tracking n_tested_list: list[int] = [] with _create_progress_context(n_to_run, "NB-GLM DE", verbose) as pbar: if use_parallel: # Use joblib.Parallel with loky backend for true process-based parallelism # This avoids GIL contention that limits ThreadPoolExecutor performance if can_use_cache: results = Parallel( n_jobs=effective_n_jobs, backend="loky", prefer="processes", return_as="generator", # Stream results for progress updates )( delayed(_fit_perturbation_worker_cached)( group_idx=candidate_to_idx[label], label=label, path=path, labels=labels, control_cache=control_cache, size_factors=size_factors, n_genes=n_genes, min_cells_expressed=min_cells_expressed, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, min_mean_ctrl=min_mean_ctrl, min_mean_pert=min_mean_pert, min_total_count=min_total_count, max_iter=max_iter, tol=tol, min_mu=min_mu, dispersion_method=dispersion_method, shrink_dispersion=shrink_dispersion, use_map_dispersion=use_map_dispersion, lfc_shrinkage_type=lfc_shrinkage_type, se_method=se_method, perturbation_boundaries=perturbation_boundaries, ) for label in candidates_to_run ) else: results = Parallel( n_jobs=effective_n_jobs, backend="loky", prefer="processes", return_as="generator", )( delayed(_fit_perturbation_worker)( group_idx=candidate_to_idx[label], label=label, path=path, labels=labels, control_mask=control_mask, control_matrix=control_matrix, control_expr_counts=control_expr_counts, control_n=control_n, obs_df=obs_df, covariates=covariates, size_factors=size_factors, offset=offset, n_genes=n_genes, min_cells_expressed=min_cells_expressed, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, min_mean_ctrl=min_mean_ctrl, min_mean_pert=min_mean_pert, min_total_count=min_total_count, max_iter=max_iter, tol=tol, poisson_init_iter=poisson_init_iter, dispersion_method=dispersion_method, global_dispersion=global_dispersion, shrink_dispersion=shrink_dispersion, use_map_dispersion=use_map_dispersion, lfc_shrinkage_type=lfc_shrinkage_type, pts_rest_shared=pts_rest_shared, full_X=full_X, per_comparison_sf=(size_factor_scope == "per_comparison"), se_method=se_method, perturbation_boundaries=perturbation_boundaries, ) for label in candidates_to_run ) # Process results as they stream in for idx, res in enumerate(results): label = candidates_to_run[idx] try: _write_result_to_memmap(res, label) newly_completed.append(label) n_tested_list.append(res.get("n_tested", 0)) _print_de_perturbation_verbose(verbose, label, res.get("n_tested", 0), n_genes) logger.debug(f"Completed perturbation: {label}") except Exception as e: logger.error(f"Failed perturbation {label}: {e}") newly_failed.append(label) n_processed += 1 pbar.update(1) # Save checkpoint periodically if n_processed % eff_checkpoint_interval == 0: _save_checkpoint() else: # Run sequentially for label in candidates_to_run: group_idx = candidate_to_idx[label] try: if can_use_cache: res = _fit_perturbation_worker_cached( group_idx=group_idx, label=label, path=path, labels=labels, control_cache=control_cache, size_factors=size_factors, n_genes=n_genes, min_cells_expressed=min_cells_expressed, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, min_mean_ctrl=min_mean_ctrl, min_mean_pert=min_mean_pert, min_total_count=min_total_count, max_iter=max_iter, tol=tol, min_mu=min_mu, dispersion_method=dispersion_method, shrink_dispersion=shrink_dispersion, use_map_dispersion=use_map_dispersion, lfc_shrinkage_type=lfc_shrinkage_type, se_method=se_method, perturbation_boundaries=perturbation_boundaries, ) else: res = _fit_perturbation_worker( group_idx=group_idx, label=label, path=path, labels=labels, control_mask=control_mask, control_matrix=control_matrix, control_expr_counts=control_expr_counts, control_n=control_n, obs_df=obs_df, covariates=covariates, size_factors=size_factors, offset=offset, n_genes=n_genes, min_cells_expressed=min_cells_expressed, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, min_mean_ctrl=min_mean_ctrl, min_mean_pert=min_mean_pert, min_total_count=min_total_count, max_iter=max_iter, tol=tol, min_mu=min_mu, poisson_init_iter=poisson_init_iter, dispersion_method=dispersion_method, global_dispersion=global_dispersion, shrink_dispersion=shrink_dispersion, use_map_dispersion=use_map_dispersion, lfc_shrinkage_type=lfc_shrinkage_type, pts_rest_shared=pts_rest_shared, full_X=full_X, per_comparison_sf=(size_factor_scope == "per_comparison"), se_method=se_method, perturbation_boundaries=perturbation_boundaries, ) _write_result_to_memmap(res, label) newly_completed.append(label) n_tested_list.append(res.get("n_tested", 0)) _print_de_perturbation_verbose(verbose, label, res.get("n_tested", 0), n_genes) logger.debug(f"Completed perturbation: {label}") except Exception as e: logger.error(f"Failed perturbation {label}: {e}") newly_failed.append(label) n_processed += 1 pbar.update(1) # Save checkpoint periodically if n_processed % eff_checkpoint_interval == 0: _save_checkpoint() # Final checkpoint save _save_checkpoint() logger.info(f"Completed {len(newly_completed)}/{n_groups} perturbations") _print_de_summary(verbose, "NB-GLM DE", len(newly_completed), n_groups, n_tested_list, n_genes) if newly_failed: logger.warning(f"Failed {len(newly_failed)} perturbations: {newly_failed[:5]}{'...' if len(newly_failed) > 5 else ''}") pvalue_adj_memmap = np.memmap( tmp_path / "pvalue_adj.dat", mode="w+", dtype=np.float64, shape=(n_groups, n_genes) ) _adjust_pvalue_matrix(pvalue_memmap, corr_method, out=pvalue_adj_memmap) gene_symbols = pd.Index(gene_symbols).astype(str) statistic_for_order = np.where( np.isfinite(statistic_memmap), np.abs(statistic_memmap), -np.inf ) order_matrix = np.argsort(-statistic_for_order, axis=1, kind="mergesort") effect_matrix = np.array(effect_memmap) statistic_matrix = np.array(statistic_memmap) pvalue_matrix = np.array(pvalue_memmap) pvalue_adj_matrix = np.array(pvalue_adj_memmap) logfc_matrix = np.array(logfc_memmap) logfc_raw_matrix = np.array(logfc_raw_memmap) intercept_matrix = np.array(intercept_memmap) # MLE intercept for shrink_lfc se_matrix = np.array(se_memmap) dispersion_matrix = np.array(dispersion_memmap) dispersion_raw_matrix = np.array(dispersion_raw_memmap) dispersion_trend_matrix = np.array(dispersion_trend_memmap) mean_matrix = np.array(mean_memmap) iter_matrix = np.array(iter_memmap) convergence_matrix = np.array(convergence_memmap) pts_matrix = np.array(pts_memmap, dtype=np.float32) pts_rest_matrix = np.array(pts_rest_memmap, dtype=np.float32) obs_index = pd.Index(candidates, name="perturbation").astype(str) obs = pd.DataFrame({perturbation_column: obs_index.to_list()}, index=obs_index) var = pd.DataFrame(index=gene_symbols) # Store ln-scale raw values BEFORE log2 conversion (for shrink_lfc post-hoc) logfc_raw_ln_matrix = logfc_raw_matrix.copy() se_ln_matrix = se_matrix.copy() # Convert from natural log to log2 if requested (PyDESeq2/edgeR convention) if lfc_base == "log2": ln2 = np.log(2) effect_matrix = effect_matrix / ln2 logfc_matrix = logfc_matrix / ln2 logfc_raw_matrix = logfc_raw_matrix / ln2 se_matrix = se_matrix / ln2 adata = ad.AnnData(effect_matrix, obs=obs, var=var) adata.layers["z_score"] = statistic_matrix adata.layers["pvalue"] = pvalue_matrix adata.layers["pvalue_adj"] = pvalue_adj_matrix adata.layers["logfoldchanges"] = logfc_matrix adata.layers["logfoldchange_raw"] = logfc_raw_matrix adata.layers["logfoldchange_raw_ln"] = logfc_raw_ln_matrix # Always ln-scale for shrink_lfc adata.layers["intercept"] = intercept_matrix # MLE intercept (ln-scale) for shrink_lfc adata.layers["standard_error"] = se_matrix adata.layers["standard_error_ln"] = se_ln_matrix # Always ln-scale for shrink_lfc adata.layers["dispersion"] = dispersion_matrix adata.layers["dispersion_raw"] = dispersion_raw_matrix adata.layers["dispersion_trend"] = dispersion_trend_matrix adata.layers["converged"] = convergence_matrix.astype(np.float32) adata.layers["iterations"] = iter_matrix.astype(np.float32) adata.layers["pts"] = pts_matrix adata.layers["pts_rest"] = pts_rest_matrix adata.uns["lfc_base"] = lfc_base # Store for downstream tools adata.uns["method"] = "nb_glm" adata.uns["fit_method"] = "independent" adata.uns["control_label"] = control_label adata.uns["perturbation_column"] = perturbation_column adata.uns["covariates"] = covariates adata.uns["size_factors"] = size_factors adata.uns["original_dataset_path"] = str(path) # For shrink_lfc to reload data adata.uns["size_factor_method"] = size_factor_method adata.uns["size_factor_scope"] = size_factor_scope adata.uns["dispersion_scope"] = dispersion_scope adata.uns["de_filter"] = { "min_cells_expressed": int(min_cells_expressed), "min_pct_ctrl": float(min_pct_ctrl), "min_pct_pert": float(min_pct_pert), "min_mean_ctrl": float(min_mean_ctrl), "min_mean_pert": float(min_mean_pert), } # Store profiling results or "NA" for production if profiling and profiler is not None: profiler.stop("fit") profiler.snapshot("fit_end") profiler.stop("total") profiler.stop_sampling() stats = profiler.get_stats() adata.uns["profiling"] = { "profiling_enabled": True, "fit_seconds": stats.get("timing", {}).get("sections", {}).get("fit", {}).get("seconds", 0.0), "fit_peak_memory_mb": stats.get("memory", {}).get("peak_mb", 0.0), "total_seconds": stats.get("timing", {}).get("total_seconds", 0.0), } else: adata.uns["profiling"] = "NA" # output_path already resolved earlier for checkpoint adata.write(output_path) # Clean up checkpoint file on successful completion if checkpoint_path.exists(): try: checkpoint_path.unlink() except Exception: pass # Ignore cleanup errors result = RankGenesGroupsResult( genes=gene_symbols, groups=candidates, statistics=statistic_matrix, pvalues=pvalue_matrix, pvalues_adj=pvalue_adj_matrix, logfoldchanges=logfc_matrix, effect_size=effect_matrix, u_statistics=np.zeros_like(effect_matrix), pts=pts_matrix, pts_rest=pts_rest_matrix, order=order_matrix, groupby=perturbation_column, method="nb_glm", control_label=control_label, tie_correct=False, pvalue_correction=corr_method, result=AnnData(output_path), ) # Optionally write Scanpy-compatible rank_genes_groups structure if scanpy_format: _write_rank_genes_groups_hdf5(output_path, result) # Reload to pick up the new uns structure result.result = AnnData(output_path) return result
def _wilcoxon_test_streaming( path: Path, *, gene_symbols: pd.Index, perturbation_column: str, control_label: str, candidates: list[str], n_genes: int, chunk_size: int, min_cells_expressed: int, min_pct_ctrl: float = 0.01, min_pct_pert: float = 0.002, min_mean_ctrl: float = 0.05, min_mean_pert: float = 0.005, tie_correct: bool, corr_method: str, output_path: Path, checkpoint_path: Path, checkpoint_interval: int | None, scanpy_format: bool, verbose: int | bool, resume: bool, group_batch_size: int, memory_limit_gb: float | None = None, ) -> "RankGenesGroupsResult": """Memory-efficient Wilcoxon test that streams over perturbation group batches. Instead of allocating output arrays for ALL groups at once, processes groups in batches and writes results incrementally to the output h5ad via h5py. This keeps peak memory bounded by ``group_batch_size * n_genes`` rather than ``n_groups * n_genes``. """ n_groups = len(candidates) n_gene_chunks = (n_genes + chunk_size - 1) // chunk_size n_batches = (n_groups + group_batch_size - 1) // group_batch_size gene_symbols = pd.Index(gene_symbols).astype(str) # Pre-create h5ad scaffold with obs/var/uns and empty layer datasets obs_index = pd.Index(candidates, name="perturbation").astype(str) obs = pd.DataFrame({perturbation_column: obs_index.to_list()}, index=obs_index) var = pd.DataFrame(index=gene_symbols) # Write a minimal AnnData to establish the h5ad structure scaffold = ad.AnnData( np.zeros((n_groups, n_genes), dtype=np.float32), # placeholder X, will be overwritten obs=obs, var=var, ) scaffold.write(output_path) del scaffold gc.collect() # Now re-open with h5py and create layer datasets for streaming writes with h5py.File(output_path, "r+") as hf: # Overwrite X with chunked dataset for streaming writes if "X" in hf: del hf["X"] hf.create_dataset("X", shape=(n_groups, n_genes), dtype="float64", chunks=(min(group_batch_size, n_groups), n_genes)) layer_names = ["z_score", "pvalue", "pvalue_adj", "logfoldchanges", "u_statistic", "pts", "pts_rest"] layer_dtypes = { "z_score": "float64", "pvalue": "float64", "pvalue_adj": "float64", "logfoldchanges": "float64", "u_statistic": "float64", "pts": "float32", "pts_rest": "float32", } layers_group = hf.require_group("layers") for name in layer_names: if name in layers_group: del layers_group[name] layers_group.create_dataset( name, shape=(n_groups, n_genes), dtype=layer_dtypes[name], chunks=(min(group_batch_size, n_groups), n_genes), ) # Write uns metadata uns = hf.require_group("uns") for key in ["method", "control_label", "perturbation_column", "tie_correct", "pvalue_correction"]: if key in uns: del uns[key] uns.create_dataset("method", data="wilcoxon") uns.create_dataset("control_label", data=control_label) uns.create_dataset("perturbation_column", data=perturbation_column) uns.create_dataset("tie_correct", data=tie_correct) uns.create_dataset("pvalue_correction", data=corr_method) # Resume logic: track which group batches are already done last_completed_batch = -1 if resume and checkpoint_path.exists(): checkpoint = _read_checkpoint(checkpoint_path) if checkpoint is not None: last_completed_batch = checkpoint.get("last_group_batch", -1) logger.info(f"Resuming from group batch {last_completed_batch + 1}") eff_checkpoint_interval = _get_checkpoint_interval(n_batches, checkpoint_interval) logger.info( f"Wilcoxon streaming mode: {n_groups} groups in {n_batches} batches " f"(batch_size={group_batch_size}), {n_gene_chunks} gene chunks " f"(chunk_size={chunk_size})" ) def _check_not_count_like_streaming(chunk: sp.spmatrix) -> None: """Raise ValueError if the first gene chunk looks like raw counts.""" if np.issubdtype(chunk.dtype, np.integer): raise ValueError( "Detected integer count data in wilcoxon_test. " "Please log-normalize your data first (e.g. cx.pp.normalize_total_log1p)." ) if np.issubdtype(chunk.dtype, np.floating): non_zero = chunk.data[chunk.data > 0] is_count_like = non_zero.size > 0 and np.all(np.isclose(non_zero, np.round(non_zero))) if is_count_like: raise ValueError( "Detected count-like (integer-valued) floating point data in wilcoxon_test. " "Please log-normalize your data first (e.g. cx.pp.normalize_total_log1p)." ) def _save_streaming_checkpoint(batch_idx: int) -> None: checkpoint_data = { "total_group_batches": n_batches, "last_group_batch": batch_idx, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "method": "wilcoxon", "mode": "streaming", "control_label": control_label, } _write_checkpoint_atomic(checkpoint_path, checkpoint_data) with _create_progress_context(n_batches, "Wilcoxon DE (group batches)", verbose) as pbar: for batch_idx in range(n_batches): if batch_idx <= last_completed_batch: pbar.update(1) continue batch_start = batch_idx * group_batch_size batch_end = min(batch_start + group_batch_size, n_groups) batch_candidates = candidates[batch_start:batch_end] bs = len(batch_candidates) # Allocate result arrays for this batch only batch_effect = np.zeros((bs, n_genes), dtype=np.float64) batch_u = np.zeros((bs, n_genes), dtype=np.float64) batch_z = np.zeros((bs, n_genes), dtype=np.float64) batch_p = np.ones((bs, n_genes), dtype=np.float64) batch_lfc = np.zeros((bs, n_genes), dtype=np.float64) batch_pts = np.zeros((bs, n_genes), dtype=np.float32) batch_pts_rest = np.zeros((bs, n_genes), dtype=np.float32) # Stream gene chunks from backed file backed = read_backed(path) try: labels = backed.obs[perturbation_column].astype(str).to_numpy() control_mask = labels == control_label control_n = int(control_mask.sum()) # Precompute integer row indices (faster than boolean indexing # on the dense block: O(n_pert) vs O(n_cells) per group) control_idx = np.where(control_mask)[0] batch_pert_idx = {label: np.where(labels == label)[0] for label in batch_candidates} dtype_checked_streaming = False for slc, block in iter_matrix_chunks( backed, axis=1, chunk_size=chunk_size, convert_to_dense=False ): if not sp.issparse(block): raise ValueError( "wilcoxon_test only supports sparse input matrices." ) if not dtype_checked_streaming and batch_idx == 0: _check_not_count_like_streaming(block) dtype_checked_streaming = True csr_block = sp.csr_matrix(block) # Keep native dtype (float32) n_chunk_genes = csr_block.shape[1] # Control stats for this gene chunk (same for all groups) control_values = csr_block[control_mask, :] control_expr = np.asarray(control_values.getnnz(axis=0)).ravel() control_mean = ( np.asarray(control_values.mean(axis=0)).ravel() if control_values.nnz else np.zeros(n_chunk_genes, dtype=np.float64) ) control_mean_expm1 = np.expm1(control_mean) + 1e-9 control_pts_chunk = np.divide( control_expr, control_n, out=np.zeros_like(control_expr, dtype=float), where=control_n > 0, ) # Pre-compute perturbation summary stats from sparse pert_expr_counts = [] pert_means = [] pert_n_cells = [] for label in batch_candidates: gv = csr_block[batch_pert_idx[label], :] n_pc = gv.shape[0] pert_n_cells.append(n_pc) ge = np.asarray(gv.getnnz(axis=0)).ravel() pert_expr_counts.append(ge) gm = ( np.asarray(gv.mean(axis=0)).ravel() if gv.nnz else np.zeros(n_chunk_genes, dtype=np.float64) ) pert_means.append(gm) # Determine valid genes per perturbation using the shared # per-condition low-expression filter (drop genes that are # jointly low in BOTH groups by both pct and mean). valid_masks = [] low_both_masks = [] for idx in range(bs): ge = pert_expr_counts[idx] gm = pert_means[idx] total_expr = control_expr + ge valid = total_expr >= min_cells_expressed low_both = _low_expr_in_both_mask( pert_expr_counts=ge, control_expr_counts=control_expr, pert_mean=gm, control_mean=control_mean, n_pert_cells=pert_n_cells[idx], n_control_cells=control_n, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, min_mean_ctrl=min_mean_ctrl, min_mean_pert=min_mean_pert, ) low_both_masks.append(low_both) valid_masks.append(valid & ~low_both) any_valid = np.zeros(n_chunk_genes, dtype=bool) for v in valid_masks: any_valid |= v valid_gene_indices = np.where(any_valid)[0] n_valid_genes = len(valid_gene_indices) if n_valid_genes > 0: # Convert valid-gene block to dense ONCE for all cells, # then use integer indexing per group (O(n_pert) vs # O(n_cells) for boolean masks on 400K-row arrays). # Keep native dtype (float32 for typical h5ad) to halve # working-set memory, consistent with Scanpy's wilcoxon. # ctrl_sorted_flat is always float64 inside _presort_control_nonzeros. all_valid_dense = csr_block[:, valid_gene_indices].toarray() control_dense = all_valid_dense[control_idx, :] # Pre-sort control non-zeros once per chunk (~14x # speedup: avoids redundant sort across 4955 groups) ctrl_sorted_flat, ctrl_offsets, ctrl_n_nz, ctrl_n_z = \ _presort_control_nonzeros(control_dense) ctrl_tie_sums_s = _compute_ctrl_tie_sums( ctrl_sorted_flat, ctrl_offsets, ctrl_n_nz ) # Per-perturbation: pre-stack then call batched kernel # (single prange over n_perts, avoids serial launch overhead) bs_local = len(batch_candidates) total_pert_cells_s = sum( len(batch_pert_idx[label]) for label in batch_candidates ) all_pert_stacked_s = np.empty( (total_pert_cells_s, n_valid_genes), dtype=all_valid_dense.dtype ) pert_row_offsets_s = np.zeros(bs_local + 1, dtype=np.int64) valid_masks_2d_s = np.empty( (bs_local, n_valid_genes), dtype=np.bool_ ) valid_u_s = np.zeros((bs_local, n_valid_genes), dtype=np.float64) valid_z_s = np.zeros((bs_local, n_valid_genes), dtype=np.float64) valid_p_s = np.ones((bs_local, n_valid_genes), dtype=np.float64) valid_eff_s = np.zeros((bs_local, n_valid_genes), dtype=np.float64) for idx, label in enumerate(batch_candidates): rows = batch_pert_idx[label] start_s = pert_row_offsets_s[idx] end_s = start_s + len(rows) pert_row_offsets_s[idx + 1] = end_s all_pert_stacked_s[start_s:end_s] = all_valid_dense[rows, :] valid_masks_2d_s[idx] = valid_masks[idx][valid_gene_indices] _wilcoxon_batch_perts_presorted_numba( control_dense, ctrl_sorted_flat, ctrl_offsets, ctrl_n_nz, ctrl_n_z, ctrl_tie_sums_s, all_pert_stacked_s, pert_row_offsets_s, valid_masks_2d_s, tie_correct, _ZERO_PARTITION_THRESHOLD, valid_u_s, valid_z_s, valid_p_s, valid_eff_s, ) gene_pos = slc.start + valid_gene_indices # Vectorized 2-D fancy-index write (bs_local calls → 4 calls) batch_u[:, gene_pos] = valid_u_s batch_z[:, gene_pos] = valid_z_s batch_p[:, gene_pos] = valid_p_s batch_effect[:, gene_pos] = valid_eff_s # LFC and pts for this chunk for idx in range(bs): ge = pert_expr_counts[idx] gm = pert_means[idx] np_ = pert_n_cells[idx] valid = valid_masks[idx] low_both = low_both_masks[idx] pts = np.divide( ge, float(np_), out=np.zeros_like(ge, dtype=float), where=np_ > 0, ) pts = np.where(valid, pts, 0.0) pts_rest = np.where(valid, control_pts_chunk, 0.0) lfc = np.log2((np.expm1(gm) + 1e-9) / control_mean_expm1) lfc = np.where(valid, lfc, 0.0) # Mark genes excluded by the per-condition filter as # NaN so downstream tools see them as untested. if low_both.any(): lfc = np.where(low_both, np.nan, lfc) gene_pos = np.arange(slc.start, slc.stop) batch_lfc[idx, gene_pos] = lfc batch_pts[idx, gene_pos] = pts batch_pts_rest[idx, gene_pos] = pts_rest # Mark stat / pvalue / effect / u / z as NaN for genes # excluded by the low-expression filter. if low_both.any(): low_pos = slc.start + np.where(low_both)[0] batch_u[idx, low_pos] = np.nan batch_z[idx, low_pos] = np.nan batch_p[idx, low_pos] = np.nan batch_effect[idx, low_pos] = np.nan finally: backed.file.close() # P-value adjustment and ordering for this batch batch_pvalue_adj = np.ones_like(batch_p) _adjust_pvalue_matrix(batch_p, corr_method, out=batch_pvalue_adj) # Write batch results to h5ad sl = slice(batch_start, batch_end) with h5py.File(output_path, "r+") as hf: hf["X"][sl, :] = batch_effect hf["layers/z_score"][sl, :] = batch_z hf["layers/pvalue"][sl, :] = batch_p hf["layers/pvalue_adj"][sl, :] = batch_pvalue_adj hf["layers/logfoldchanges"][sl, :] = batch_lfc hf["layers/u_statistic"][sl, :] = batch_u hf["layers/pts"][sl, :] = batch_pts hf["layers/pts_rest"][sl, :] = batch_pts_rest del batch_effect, batch_u, batch_z, batch_p, batch_lfc del batch_pts, batch_pts_rest, batch_pvalue_adj gc.collect() _release_chunk_memory() # return freed batch-array pages to OS before next batch # Checkpoint if (batch_idx + 1) % eff_checkpoint_interval == 0 or batch_idx == n_batches - 1: _save_streaming_checkpoint(batch_idx) pbar.update(1) logger.info(f"Completed all {n_batches} group batches") if int(verbose) >= 1: print(f"[crispyx] Wilcoxon DE: {n_groups} perturbations complete, {n_genes} genes") # Build RankGenesGroupsResult by reading back from h5ad. # Uses _build_result_from_h5ad which skips loading for very large results # (>25% of memory budget) to avoid OOM on datasets like Huang-HEK293T. result = _build_result_from_h5ad( output_path, candidates=candidates, gene_symbols=gene_symbols, perturbation_column=perturbation_column, control_label=control_label, tie_correct=tie_correct, corr_method=corr_method, memory_limit_gb=memory_limit_gb, ) if scanpy_format and result.statistics.size > 0: _write_rank_genes_groups_hdf5(output_path, result) # Clean up checkpoint on success if checkpoint_path.exists(): try: checkpoint_path.unlink() except Exception: pass return result
[docs] def wilcoxon_test( data: str | Path | AnnData | ad.AnnData, *, perturbation_column: str | None = None, groupby: str | None = None, control_label: str | None = None, reference: str | None = None, gene_name_column: str | None = None, perturbations: Iterable[str] | None = None, min_cells_expressed: int = 0, min_pct_ctrl: float = 0.01, min_pct_pert: float = 0.002, min_pct_both: float | None = None, min_mean_ctrl: float = 0.05, min_mean_pert: float = 0.005, chunk_size: int | None = None, tie_correct: bool = True, corr_method: Literal["benjamini-hochberg", "bonferroni"] = "benjamini-hochberg", output_dir: str | Path | None = None, data_name: str | None = None, n_jobs: int | None = None, verbose: int | bool = False, resume: bool = False, checkpoint_interval: int | None = None, scanpy_format: bool = False, memory_limit_gb: float | None = None, force: bool = False, ) -> RankGenesGroupsResult: """Perform a Wilcoxon rank-sum (Mann-Whitney U) test for each gene. Input data **must already be library-size normalised and log-transformed**. The function operates directly on the provided matrix without additional preprocessing. As a safeguard, the first sparse chunk is inspected and a warning is emitted if the data appear to be raw counts (integer or count-like floats), encouraging explicit preprocessing upstream. Parameters ---------- data Path to an h5ad file, or a crispyx/anndata AnnData object containing normalised, log-transformed data. perturbation_column Column in `adata.obs` indicating perturbation labels. groupby Alias for ``perturbation_column`` (Scanpy-compatible). Mutually exclusive with ``perturbation_column``. control_label Label for the control/reference group. If None, infers from common patterns. reference Alias for ``control_label`` (Scanpy-compatible). Mutually exclusive with ``control_label``. gene_name_column Column in `adata.var` with gene symbols. If None, uses `adata.var_names`. perturbations Specific perturbations to test. If None, tests all non-control groups. min_cells_expressed Minimum total cells (control + perturbation) expressing a gene for testing. Genes below this threshold are assigned p-value=1 and effect_size=0. min_pct_ctrl Minimum fraction of expressing cells for the *control* side. A gene is excluded only when *both* sides are jointly low. Default ``0.01``. min_pct_pert Minimum fraction of expressing cells for the *perturbed* side. Default ``0.002`` (lower than ctrl; induction from near-zero baseline is biologically valid). Combined with ``min_mean_pert`` this forms a dual condition more robust than pct alone. min_pct_both If not ``None``, overrides both ``min_pct_ctrl`` and ``min_pct_pert`` with the same value. min_mean_ctrl Minimum mean log1p expression for the *control* side. Default ``0.05``. Excluded genes are written as NaN in ``score`` / ``pvalue`` / ``logfoldchanges`` / ``effect_size``; ``pts`` / ``pts_rest`` remain populated. Set to ``0.0`` together with pct thresholds to disable. min_mean_pert Minimum mean expression for the *perturbed* side. Default ``0.005``. chunk_size Number of genes to process per chunk (memory vs. speed tradeoff). Smaller values stream more, reducing peak memory at the cost of additional I/O. tie_correct Whether to apply tie correction to the U statistic. Default True for more accurate p-values when ties are present in the data. corr_method Method for p-value correction: "benjamini-hochberg" or "bonferroni". output_dir Directory for output h5ad file. Defaults to input file's directory. data_name Custom name for output file. If None, uses "wilcoxon" suffix. n_jobs Number of parallel workers for computing statistics across perturbations. If None, uses all available cores. If 1, runs sequentially. verbose If True, show a progress bar for gene chunk processing. Requires tqdm. resume If True, attempt to resume from a previous interrupted run. Reads the checkpoint file to determine which gene chunks have already been completed and skips them. checkpoint_interval Number of gene chunks between checkpoint saves. If None, auto-determined based on dataset size. The checkpoint file `<output>.progress.json` is written atomically to prevent corruption. scanpy_format If True, write Scanpy-compatible ``uns['rank_genes_groups']`` structure in addition to the layer-based storage. Adds ~2-6 seconds of I/O overhead for large datasets. Default False for performance. memory_limit_gb Maximum memory budget in GB. Controls whether the streaming path is used for very large datasets. When ``None`` (default), detects available system memory via ``psutil``. For HPC environments with fixed allocations, set this to your SLURM ``--mem`` value (e.g., ``memory_limit_gb=128``). force If True, rerun the analysis even when the output h5ad file already exists. If False (default), load and return the existing result instead of rerunning. Returns ------- RankGenesGroupsResult Differential expression results. Access results via dict-like interface: `result[label].effect_size`, `result[label].pvalue`, etc. The h5ad file path is available at `result.result_path`. """ perturbation_column, control_label, min_pct_ctrl, min_pct_pert = _resolve_de_aliases( perturbation_column=perturbation_column, groupby=groupby, control_label=control_label, reference=reference, min_pct_both=min_pct_both, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, fn_name="wilcoxon_test", ) path = resolve_data_path(data) output_path = resolve_output_path(path, suffix="wilcoxon", output_dir=output_dir, data_name=data_name) if (r := _try_load_existing_de_result( output_path, force=force, verbose=verbose, method_name="wilcoxon", memory_limit_gb=memory_limit_gb, )): return r backed = read_backed(path) try: gene_symbols = ensure_gene_symbol_column(backed, gene_name_column) if perturbation_column not in backed.obs.columns: raise KeyError( f"Perturbation column '{perturbation_column}' was not found in adata.obs. Available columns: {list(backed.obs.columns)}" ) labels = backed.obs[perturbation_column].astype(str).to_numpy() control_label = resolve_control_label(labels, control_label) n_genes = backed.n_vars candidates = _resolve_candidates(labels, control_label, perturbations) control_mask = labels == control_label control_n = int(control_mask.sum()) if control_n == 0: raise ValueError("Control group contains no cells") # Memory-efficient validation: avoid creating per-group bool masks (n_groups × n_cells # memory, e.g. 18K groups × 3.4M cells = 62 GB for Huang-HCT116). Instead use # np.unique to get the set of observed labels in a single O(n_cells) pass. _observed_labels = set(np.unique(labels).tolist()) _missing_perts = [lbl for lbl in candidates if lbl not in _observed_labels] if _missing_perts: raise ValueError( f"Perturbation(s) {_missing_perts[:3]}" f"{'...' if len(_missing_perts) > 3 else ''} contain no cells" ) # Calculate adaptive gene chunk_size if not provided. # Use the dedicated Wilcoxon calculator: no n_groups cap (output arrays # are memmapped to disk so RAM is independent of n_groups). if chunk_size is None: chunk_size = calculate_wilcoxon_chunk_size( backed.n_obs, backed.n_vars, available_memory_gb=memory_limit_gb, ) finally: backed.file.close() n_groups = len(candidates) # Determine output path and checkpoint path output_path.parent.mkdir(parents=True, exist_ok=True) checkpoint_path = output_path.with_suffix(".progress.json") # ========================================================================= # Adaptive dispatch: use group-batch streaming for large datasets # ========================================================================= use_streaming, _, _, group_batch_size = _should_use_streaming( n_groups, n_genes, memory_limit_gb=memory_limit_gb, ) # Only stream when multiple batches are actually needed. # If group_batch_size >= n_groups (one batch = all groups), the standard memmap # path is strictly better: memmaps are OS-pageable and glibc returns their pages # before readback, keeping peak ~10 GB. The streaming path keeps all result arrays # in Python heap and skips malloc_trim, leading to 20-35 GB peaks for mid-size # datasets like Feng-gwsnf (4,955 groups × 32,373 genes). if use_streaming and group_batch_size < n_groups: return _wilcoxon_test_streaming( path, gene_symbols=gene_symbols, perturbation_column=perturbation_column, control_label=control_label, candidates=candidates, n_genes=n_genes, chunk_size=chunk_size, min_cells_expressed=min_cells_expressed, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, min_mean_ctrl=min_mean_ctrl, min_mean_pert=min_mean_pert, tie_correct=tie_correct, corr_method=corr_method, output_path=output_path, checkpoint_path=checkpoint_path, checkpoint_interval=checkpoint_interval, scanpy_format=scanpy_format, verbose=verbose, resume=resume, group_batch_size=group_batch_size, memory_limit_gb=memory_limit_gb, ) # ========================================================================= # Standard single-pass path (unchanged for small/medium datasets) # ========================================================================= # For wilcoxon, we track gene chunk progress (not perturbation progress) # Resume logic: read checkpoint to get last completed gene chunk last_completed_chunk = -1 if resume and checkpoint_path.exists(): checkpoint = _read_checkpoint(checkpoint_path) if checkpoint is not None: last_completed_chunk = checkpoint.get("last_gene_chunk", -1) logger.info(f"Resuming from gene chunk {last_completed_chunk + 1}") # Determine checkpoint interval (number of gene chunks between saves) n_gene_chunks = (n_genes + chunk_size - 1) // chunk_size eff_checkpoint_interval = _get_checkpoint_interval(n_gene_chunks, checkpoint_interval) with tempfile.TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) def _create_memmap(name: str, dtype: np.dtype, *, fill: float | int = 0): path = tmpdir_path / f"{name}.dat" mmap = np.memmap(path, dtype=dtype, mode="w+", shape=(n_groups, n_genes)) if fill != 0: mmap[:] = fill else: mmap.fill(0) return mmap effect_matrix = _create_memmap("effect", np.float64) u_matrix = _create_memmap("u_stat", np.float64) pvalue_matrix = _create_memmap("pvalue", np.float64, fill=1.0) z_matrix = _create_memmap("z_score", np.float64) lfc_matrix = _create_memmap("logfoldchange", np.float64) pts_matrix = _create_memmap("pts", np.float32) pts_rest_matrix = _create_memmap("pts_rest", np.float32) backed = read_backed(path) try: labels = backed.obs[perturbation_column].astype(str).to_numpy() control_mask = labels == control_label # Precompute integer row indices (faster than boolean indexing # on the dense block: O(n_pert) vs O(n_cells) per group) control_idx = np.where(control_mask)[0] pert_idx = {label: np.where(labels == label)[0] for label in candidates} # Pre-build flat perturbation indices and row offsets once # (avoids rebuilding inside the per-chunk stacking loop). all_pert_flat_idx = np.concatenate([pert_idx[label] for label in candidates]) total_pert_cells = len(all_pert_flat_idx) pert_row_offsets = np.zeros(n_groups + 1, dtype=np.int64) for idx, label in enumerate(candidates): pert_row_offsets[idx + 1] = pert_row_offsets[idx] + len(pert_idx[label]) dtype_checked = False def _check_not_count_like(chunk: sp.spmatrix) -> None: """Raise ValueError if the chunk looks like raw counts.""" if np.issubdtype(chunk.dtype, np.integer): raise ValueError( "Detected integer count data in wilcoxon_test. " "Please log-normalize your data first (e.g. cx.pp.normalize_total_log1p)." ) if np.issubdtype(chunk.dtype, np.floating): non_zero = chunk.data[chunk.data > 0] is_count_like = non_zero.size > 0 and np.all(np.isclose(non_zero, np.round(non_zero))) if is_count_like: raise ValueError( "Detected count-like (integer-valued) floating point data in wilcoxon_test. " "Please log-normalize your data first (e.g. cx.pp.normalize_total_log1p)." ) # Track progress current_chunk = 0 n_chunks_processed = 0 _track_gene_counts = int(verbose) >= 1 if _track_gene_counts: _valid_gene_counts = np.zeros(n_groups, dtype=np.int32) # Helper to save checkpoint def _save_wilcoxon_checkpoint(chunk_idx: int) -> None: checkpoint_data = { "total_gene_chunks": n_gene_chunks, "last_gene_chunk": chunk_idx, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "method": "wilcoxon", "control_label": control_label, } _write_checkpoint_atomic(checkpoint_path, checkpoint_data) with _create_progress_context(n_gene_chunks, "Wilcoxon DE (gene chunks)", verbose) as pbar: for slc, block in iter_matrix_chunks( backed, axis=1, chunk_size=chunk_size, convert_to_dense=False ): # Skip already processed chunks on resume if current_chunk <= last_completed_chunk: current_chunk += 1 pbar.update(1) continue if not dtype_checked: if not sp.issparse(block): raise ValueError( "wilcoxon_test only supports sparse input matrices. Please provide a scipy sparse matrix (e.g., CSR/CSC)." ) _check_not_count_like(block) dtype_checked = True csr_block = sp.csr_matrix(block) # Keep native dtype (float32) n_chunk_genes = csr_block.shape[1] # ===== OPTIMIZED BATCH PROCESSING ===== # 1. Extract control and perturbation data once control_values = csr_block[control_mask, :] control_expr = np.asarray(control_values.getnnz(axis=0)).ravel() control_mean = ( np.asarray(control_values.mean(axis=0)).ravel() if control_values.nnz else np.zeros(n_chunk_genes, dtype=np.float64) ) control_mean_expm1 = np.expm1(control_mean) + 1e-9 control_pts = np.divide( control_expr, control_n, out=np.zeros_like(control_expr, dtype=float), where=control_n > 0, ) chunk_gene_indices = np.arange(slc.start, slc.stop) # 2. Pre-compute perturbation expression counts pert_expr_counts = [] pert_means = [] pert_n_cells = [] for label in candidates: group_values = csr_block[pert_idx[label], :] n_pert_cells = group_values.shape[0] pert_n_cells.append(n_pert_cells) group_expr = np.asarray(group_values.getnnz(axis=0)).ravel() pert_expr_counts.append(group_expr) group_mean = ( np.asarray(group_values.mean(axis=0)).ravel() if group_values.nnz else np.zeros(n_chunk_genes, dtype=np.float64) ) pert_means.append(group_mean) # 3. Determine valid genes per perturbation using the # shared per-condition low-expression filter (drop genes # that are jointly low in BOTH groups by both pct and mean). valid_masks = [] low_both_masks = [] for idx, label in enumerate(candidates): group_expr = pert_expr_counts[idx] group_mean = pert_means[idx] total_expr = control_expr + group_expr valid = total_expr >= min_cells_expressed low_both = _low_expr_in_both_mask( pert_expr_counts=group_expr, control_expr_counts=control_expr, pert_mean=group_mean, control_mean=control_mean, n_pert_cells=pert_n_cells[idx], n_control_cells=control_n, min_pct_ctrl=min_pct_ctrl, min_pct_pert=min_pct_pert, min_mean_ctrl=min_mean_ctrl, min_mean_pert=min_mean_pert, ) low_both_masks.append(low_both) valid_masks.append(valid & ~low_both) # Accumulate per-perturbation valid gene counts for verbose output if _track_gene_counts: for _vi in range(len(valid_masks)): _valid_gene_counts[_vi] += int(valid_masks[_vi].sum()) # 4. Find union of all valid genes (to minimize dense conversion) any_valid = np.zeros(n_chunk_genes, dtype=bool) for valid in valid_masks: any_valid |= valid valid_gene_indices = np.where(any_valid)[0] n_valid_genes = len(valid_gene_indices) # 5. Initialize output arrays for this chunk chunk_u = np.zeros((n_groups, n_chunk_genes), dtype=np.float64) chunk_z = np.zeros((n_groups, n_chunk_genes), dtype=np.float64) chunk_p = np.full((n_groups, n_chunk_genes), np.nan, dtype=np.float64) chunk_effect = np.zeros((n_groups, n_chunk_genes), dtype=np.float64) chunk_lfc = np.zeros((n_groups, n_chunk_genes), dtype=np.float64) chunk_pts = np.zeros((n_groups, n_chunk_genes), dtype=np.float32) chunk_pts_rest = np.zeros((n_groups, n_chunk_genes), dtype=np.float32) if n_valid_genes > 0: # 6. Convert valid-gene block to dense ONCE for all cells, # then use integer indexing per group (O(n_pert) vs # O(n_cells) for boolean masks on large arrays). # Keep native dtype (float32 for typical h5ad) to halve # working-set memory, consistent with Scanpy's wilcoxon. # ctrl_sorted_flat is always float64 inside _presort_control_nonzeros. all_valid_dense = csr_block[:, valid_gene_indices].toarray() control_dense = all_valid_dense[control_idx, :] # Pre-sort control non-zeros once per chunk (~14x # speedup: avoids redundant sort across all groups) ctrl_sorted_flat, ctrl_offsets, ctrl_n_nz, ctrl_n_z = \ _presort_control_nonzeros(control_dense) ctrl_tie_sums = _compute_ctrl_tie_sums( ctrl_sorted_flat, ctrl_offsets, ctrl_n_nz ) # 7. Pre-allocate output arrays for valid genes valid_u = np.zeros((n_groups, n_valid_genes), dtype=np.float64) valid_z = np.zeros((n_groups, n_valid_genes), dtype=np.float64) valid_p = np.ones((n_groups, n_valid_genes), dtype=np.float64) valid_effect = np.zeros((n_groups, n_valid_genes), dtype=np.float64) # 8. Stack all pert dense matrices and call batched kernel. # Single prange(n_perts) replaces n_perts serial kernel # launches, eliminating the ~25ms prange thread-pool startup # overhead per call (~50x speedup on kernel time). # Vectorised: single fancy-index replaces n_groups-iteration loop. all_pert_stacked = all_valid_dense[all_pert_flat_idx, :] valid_masks_2d = np.array([vm[valid_gene_indices] for vm in valid_masks]) _wilcoxon_batch_perts_presorted_numba( control_dense, ctrl_sorted_flat, ctrl_offsets, ctrl_n_nz, ctrl_n_z, ctrl_tie_sums, all_pert_stacked, pert_row_offsets, valid_masks_2d, tie_correct, _ZERO_PARTITION_THRESHOLD, valid_u, valid_z, valid_p, valid_effect, ) # 9. Map results back to full chunk gene indices # Vectorised: single 2-D fancy-index write per array # replaces n_groups Python-loop iterations. chunk_u[:, valid_gene_indices] = valid_u chunk_z[:, valid_gene_indices] = valid_z chunk_p[:, valid_gene_indices] = valid_p chunk_effect[:, valid_gene_indices] = valid_effect # 10. Compute LFC and pts — batch vectorised # (replaces n_groups Python-loop iterations) all_expr = np.array(pert_expr_counts) # (n_groups, n_chunk_genes) all_means = np.array(pert_means) # (n_groups, n_chunk_genes) all_n = np.array(pert_n_cells, dtype=np.float64) # (n_groups,) valid_arr = np.array(valid_masks) # (n_groups, n_chunk_genes) n_col = all_n[:, np.newaxis] # (n_groups, 1) raw_pts = np.where(n_col > 0, all_expr / n_col, 0.0) chunk_pts[:] = np.where(valid_arr, raw_pts, 0.0).astype(np.float32) chunk_pts_rest[:] = np.where( valid_arr, control_pts[np.newaxis, :], 0.0 ).astype(np.float32) raw_lfc = np.log2( (np.expm1(all_means) + 1e-9) / control_mean_expm1[np.newaxis, :] ) chunk_lfc[:] = np.where(valid_arr, raw_lfc, 0.0) # Mark genes excluded by the per-condition low-expression # filter as NaN so downstream tools see them as untested. low_both_arr = np.array(low_both_masks) # (n_groups, n_chunk_genes) if low_both_arr.any(): chunk_u[low_both_arr] = np.nan chunk_z[low_both_arr] = np.nan chunk_p[low_both_arr] = np.nan chunk_effect[low_both_arr] = np.nan chunk_lfc[low_both_arr] = np.nan # 13. Write results to memmap — vectorized 2-D slice # (7 calls instead of 7 × n_groups; better cache locality) u_matrix[:, slc] = chunk_u pvalue_matrix[:, slc] = chunk_p effect_matrix[:, slc] = chunk_effect z_matrix[:, slc] = chunk_z lfc_matrix[:, slc] = chunk_lfc pts_matrix[:, slc] = chunk_pts pts_rest_matrix[:, slc] = chunk_pts_rest # Release transient chunk arrays and return freed pages to OS # (prevents glibc arena fragmentation across many gene chunks) del csr_block, chunk_u, chunk_z, chunk_p, chunk_effect, chunk_lfc del chunk_pts, chunk_pts_rest, pert_expr_counts, pert_means del pert_n_cells, valid_masks, low_both_masks, low_both_arr if n_valid_genes > 0: del all_valid_dense, control_dense del ctrl_sorted_flat, ctrl_offsets, ctrl_n_nz, ctrl_n_z, ctrl_tie_sums del all_pert_stacked, valid_u, valid_z, valid_p, valid_effect del valid_masks_2d _release_chunk_memory() # Update progress and checkpoint n_chunks_processed += 1 pbar.update(1) if n_chunks_processed % eff_checkpoint_interval == 0: _save_wilcoxon_checkpoint(current_chunk) current_chunk += 1 # Final checkpoint _save_wilcoxon_checkpoint(current_chunk - 1) logger.info(f"Completed {n_chunks_processed} gene chunks") if _track_gene_counts: if int(verbose) >= 2: for _gi, _label in enumerate(candidates): _print_de_perturbation_verbose(verbose, _label, int(_valid_gene_counts[_gi]), n_genes) _mean = int(_valid_gene_counts.mean()) if n_groups > 0 else 0 _pct = 100.0 * _mean / n_genes if n_genes else 0 print(f"[crispyx] Wilcoxon DE: {n_groups} perturbations complete, mean {_mean}/{n_genes} genes tested ({_pct:.0f}%)") finally: backed.file.close() gene_symbols = pd.Index(gene_symbols).astype(str) pvalue_adj_matrix = _create_memmap("pvalue_adj", np.float64) _adjust_pvalue_matrix(pvalue_matrix, corr_method, out=pvalue_adj_matrix) # Write h5ad directly from memmaps via h5py (avoids triple allocation: # memmap + np.array copy + AnnData that previously caused OOM) _write_wilcoxon_result_h5ad( output_path, effect_matrix=effect_matrix, z_matrix=z_matrix, pvalue_matrix=pvalue_matrix, pvalue_adj_matrix=pvalue_adj_matrix, lfc_matrix=lfc_matrix, u_matrix=u_matrix, pts_matrix=pts_matrix, pts_rest_matrix=pts_rest_matrix, candidates=candidates, gene_symbols=gene_symbols, perturbation_column=perturbation_column, control_label=control_label, tie_correct=tie_correct, corr_method=corr_method, ) # Memmaps will be released when TemporaryDirectory exits below # tmpdir is now cleaned up — all memmaps released. # Read back from h5ad for the result object. _release_chunk_memory() result = _build_result_from_h5ad( output_path, candidates=candidates, gene_symbols=gene_symbols, perturbation_column=perturbation_column, control_label=control_label, tie_correct=tie_correct, corr_method=corr_method, memory_limit_gb=memory_limit_gb, ) # Optionally write Scanpy-compatible rank_genes_groups structure if scanpy_format and result.statistics.size > 0: _write_rank_genes_groups_hdf5(output_path, result) # Clean up checkpoint on successful completion if checkpoint_path.exists(): try: checkpoint_path.unlink() except Exception: pass return result
[docs] def shrink_lfc( data: str | Path | AnnData | ad.AnnData, *, output_dir: str | Path | None = None, data_name: str | None = None, method: Literal["stats", "full"] = "stats", prior_scale_mode: Literal["global", "per_comparison"] = "global", min_mu: float = 0.0, n_jobs: int = -1, batch_size: int = 128, profiling: bool = False, memory_limit_gb: float | None = None, ) -> RankGenesGroupsResult: """Apply apeGLM log-fold change shrinkage to existing NB-GLM results. This function applies apeGLM shrinkage using a Cauchy prior on the LFC coefficient. Two methods are available: - **stats** (default, recommended): Uses pre-computed MLE LFC and SE from the h5ad file with vectorized Newton-Raphson optimization. ~35× faster than full and maintains consistency with stored MLE coefficients. - **full**: Re-loads original count data and runs per-gene L-BFGS-B optimization. May produce different results from stored MLE for lowly expressed genes due to min_mu clamping differences in the likelihood. .. note:: The "stats" method is recommended because CRISPYx NB-GLM fitting uses min_mu=0.5 clamping for numerical stability, which affects the stored MLE coefficients. The "full" method re-evaluates the likelihood without this constraint, potentially finding different optima for lowly expressed genes. The "stats" method preserves shrinkage direction (always toward zero) by working directly with the stored statistics. This enables separating the base NB-GLM fitting from shrinkage for: - Benchmarking: measure base fitting and shrinkage times separately - Flexibility: apply shrinkage to existing results - Speed: use stats method for production Parameters ---------- data Path to an h5ad file, or a crispyx/anndata AnnData object containing NB-GLM results from `nb_glm_test`. Must have required layers and metadata in `uns`. output_dir Directory for output h5ad file. Defaults to input file's directory. data_name Custom name for output file. If None, appends "_shrunk" to input name. method Shrinkage method to use: - ``"stats"`` (default): Fast vectorized shrinkage using pre-computed MLE statistics. Uses Newton-Raphson optimization across all genes simultaneously. ~35× faster than "full" and maintains consistency with stored MLE coefficients. - ``"full"``: Full model re-fitting with L-BFGS-B per gene. Re-loads original count data. Note: May produce different results for lowly expressed genes due to min_mu clamping differences. Use "stats" for consistent shrinkage behavior. prior_scale_mode How to estimate the Cauchy prior scale parameter: - ``"global"`` (default): Estimate prior scale once from all perturbations' MLE LFCs. Faster and often more stable. - ``"per_comparison"``: Estimate prior scale separately for each perturbation vs control comparison. Matches PyDESeq2's behavior exactly. Use for benchmarking to demonstrate parity. min_mu Minimum mean threshold for shrinkage likelihood evaluation. Default: 0.0 (no clamping), matching PyDESeq2's lfc_shrink which omits min_mu entirely. PyDESeq2 uses min_mu=0.5 for NB-GLM fitting but does NOT pass min_mu to the shrinkage optimizer. This is intentional: the MLE coefficients represent the best fit, and shrinkage should evaluate the same likelihood surface. n_jobs Number of parallel jobs for per-gene optimization (only used when method="full"). Default -1 uses all available cores. batch_size Number of genes per joblib batch (only used when method="full"). Default: 128, matching PyDESeq2. profiling If True, enable timing and memory profiling. When enabled, stores profiling data in `adata.uns["profiling"]` with fields: - `shrinkage_seconds`: Time for lfcShrink operation - `shrinkage_peak_memory_mb`: Peak memory during shrinkage - `profiling_enabled`: True When False (default), `adata.uns["profiling"]` is set to "NA". memory_limit_gb Optional memory budget in gigabytes. When ``method="full"``, this limits the number of parallel ``n_jobs`` so that joblib workers stay within the budget. When ``None`` (default), detects available system memory via ``psutil``. Returns ------- RankGenesGroupsResult Updated differential expression results with shrunken LFCs. The result h5ad has: - `logfoldchange`: shrunken LFC values - `logfoldchange_raw`: original MLE LFC values (preserved) - `standard_error`: posterior SE reflecting shrinkage uncertainty - `X`: updated to shrunken LFC (effect_size) Examples -------- >>> # Fast default (recommended for production) >>> shrunk = crispyx.de.shrink_lfc("nb_glm_result.h5ad") >>> # Benchmark-accurate (matches PyDESeq2 exactly) >>> shrunk = crispyx.de.shrink_lfc( ... "nb_glm_result.h5ad", ... method="full", ... prior_scale_mode="per_comparison", ... ) >>> # First run NB-GLM without shrinkage >>> result = crispyx.de.nb_glm_test( ... "data.h5ad", ... perturbation_column="perturbation", ... lfc_shrinkage_type="none", # No shrinkage during fitting ... ) >>> # Then apply shrinkage as a separate step >>> shrunk_result = crispyx.de.shrink_lfc(result.result_path) """ path = resolve_data_path(data) # Validate min_mu parameter if min_mu < 0: raise ValueError(f"min_mu must be >= 0, got {min_mu}") # Initialize profiler if enabled (timing + memory sampling) profiler = None if profiling: from .profiling import Profiler profiler = Profiler(timing=True, memory=True, memory_method="rss", sampling=True) profiler.start("total") # Load the NB-GLM result adata = ad.read_h5ad(path) # Establish memory baseline AFTER loading h5ad for fair benchmarking # This way, profiling measures only shrinkage memory, not h5ad loading if profiling and profiler is not None: profiler.snapshot("after_load") profiler.reset_peak() # Reset peak memory after h5ad load profiler.start("shrinkage") # Validate that this is an NB-GLM result with required layers if "logfoldchange_raw" not in adata.layers: raise ValueError( f"Input file '{path}' does not have 'logfoldchange_raw' layer. " "This function requires NB-GLM results from nb_glm_test. " "Ensure the NB-GLM was run with a version that stores raw LFCs." ) if "standard_error" not in adata.layers: raise ValueError( f"Input file '{path}' does not have 'standard_error' layer. " "This function requires NB-GLM results with standard errors." ) if "dispersion" not in adata.layers: raise ValueError( f"Input file '{path}' does not have 'dispersion' layer. " "This function requires NB-GLM results with dispersion estimates." ) # Get required metadata control_label = adata.uns.get("control_label", "control") perturbation_column = adata.uns.get("perturbation_column", "perturbation") # Get raw LFC, SE, and dispersion in ln-scale (required for apeGLM optimization) # Use ln-scale layers if available (v0.5.0+), otherwise fall back with conversion lfc_base = adata.uns.get("lfc_base", "log2") if "logfoldchange_raw_ln" in adata.layers: raw_lfc = adata.layers["logfoldchange_raw_ln"] # Already ln-scale se = adata.layers["standard_error_ln"] else: raise ValueError( f"Input file '{path}' lacks ln-scale layers " "('logfoldchange_raw_ln', 'standard_error_ln'). " "Re-run nb_glm_test with the current version of crispyx." ) dispersion = adata.layers["dispersion"] # Get fitted intercept from NB-GLM (ln-scale, critical for accurate shrinkage) if "intercept" not in adata.layers: raise ValueError( f"Input file '{path}' lacks 'intercept' layer. " "method='full' requires NB-GLM results from nb_glm_test v0.5.1+. " "Use method='stats' or re-run nb_glm_test." ) fitted_intercept = adata.layers["intercept"] # Already ln-scale logger.info("Using fitted intercept from NB-GLM for shrinkage") n_groups, n_genes = raw_lfc.shape candidates = list(adata.obs_names.astype(str)) gene_symbols = pd.Index(adata.var_names).astype(str) # Estimate global prior scale from ALL perturbations' MLE LFCs all_mle_lfc = raw_lfc.ravel() all_mle_se = se.ravel() valid_mask = np.isfinite(all_mle_lfc) & np.isfinite(all_mle_se) & (all_mle_se > 0) global_prior_scale = _estimate_apeglm_prior_scale( all_mle_lfc[valid_mask], all_mle_se[valid_mask] ) logger.info(f"Global prior scale for apeGLM: {global_prior_scale:.4f}") # Initialize output arrays shrunk_lfc = np.zeros_like(raw_lfc) shrunk_se = np.zeros_like(se) total_converged = 0 total_genes_processed = 0 if method == "stats": # Fast stats-based shrinkage using vectorized Newton-Raphson # No need to load original dataset - uses pre-computed MLE stats logger.info(f"Using fast stats-based shrinkage (method='stats')") # Try to derive base_mean from fitted intercept for gene-specific priors # intercept is ln(mu_control), so base_mean ≈ mean(exp(intercept)) across perturbations # This is a proxy for expression level used in gene-specific prior scaling if np.any(np.isfinite(fitted_intercept)): # Use mean intercept across perturbations for each gene # Suppress "Mean of empty slice" RuntimeWarning when all values are NaN for a gene import warnings with warnings.catch_warnings(): warnings.filterwarnings('ignore', message='Mean of empty slice') mean_intercept = np.nanmean(fitted_intercept, axis=0) # Shape: (n_genes,) base_mean_proxy = np.exp(np.clip(mean_intercept, -20, 20)) # Clamp to avoid overflow logger.debug("Using intercept-derived base_mean for gene-specific priors") else: base_mean_proxy = None logger.debug("base_mean not available, using uniform prior scale") # Track genes that need full re-fitting all_needs_refit = np.zeros((n_groups, n_genes), dtype=bool) for group_idx, pert_label in enumerate(candidates): logger.debug(f"Shrinking LFC for perturbation {group_idx + 1}/{n_groups}: {pert_label}") mle_lfc_group = raw_lfc[group_idx] se_group = se[group_idx] # Determine prior scale based on mode if prior_scale_mode == "per_comparison": pert_prior_scale = _estimate_apeglm_prior_scale(mle_lfc_group, se_group) else: pert_prior_scale = global_prior_scale # Vectorized shrinkage using Newton-Raphson # NOTE: use_gene_specific_prior=False for PyDESeq2 parity (gene-specific # prior scaling is non-standard and causes accuracy regression) shrunk_lfc_group, shrunk_se_group, converged, needs_refit = shrink_lfc_apeglm_from_stats( mle_lfc=mle_lfc_group, mle_se=se_group, prior_scale=pert_prior_scale, base_mean=base_mean_proxy, use_gene_specific_prior=False, hybrid_fallback=True, ) shrunk_lfc[group_idx] = shrunk_lfc_group shrunk_se[group_idx] = shrunk_se_group all_needs_refit[group_idx] = needs_refit n_converged = converged.sum() total_converged += n_converged total_genes_processed += n_genes logger.debug(f" {n_converged}/{n_genes} genes converged for {pert_label}") # Log warning if many genes didn't converge convergence_rate = total_converged / total_genes_processed if total_genes_processed > 0 else 1.0 if convergence_rate < 0.95: logger.warning( f"Low convergence rate: {total_converged}/{total_genes_processed} " f"({convergence_rate:.1%}) genes converged within max_iter. " "Consider using method='full' for better accuracy." ) # Log info about genes that might need full re-fitting total_needs_refit = all_needs_refit.sum() if total_needs_refit > 0: refit_rate = total_needs_refit / (n_groups * n_genes) logger.info( f"Hybrid fallback: {total_needs_refit} gene-perturbation pairs " f"({refit_rate:.1%}) flagged for potential accuracy improvement with method='full'" ) else: # method == "full" # Full model re-fitting with L-BFGS-B per gene logger.info(f"Using full model re-fitting (method='full')") # Validate original dataset path for full method original_dataset_path = adata.uns.get("original_dataset_path") if original_dataset_path is None: raise ValueError( f"Input file '{path}' does not have 'original_dataset_path' in uns. " "method='full' requires NB-GLM results from nb_glm_test v0.4.0+. " "Use method='stats' or re-run nb_glm_test." ) # Strip /workspace/ prefix from Docker paths if present if original_dataset_path.startswith("/workspace/"): original_dataset_path = original_dataset_path[len("/workspace/"):] original_path = Path(original_dataset_path) if not original_path.exists(): # Also try the input file's parent as base directory input_parent = path.parent relative_name = Path(original_dataset_path).name alternative_paths = [ input_parent.parent / ".cache" / relative_name, # Try ../..cache/ input_parent / ".cache" / relative_name, # Try ../.cache/ input_parent.parent.parent / ".cache" / relative_name, # Try ../../../.cache/ ] for alt_path in alternative_paths: if alt_path.exists(): original_path = alt_path break else: raise FileNotFoundError( f"Original dataset not found: {original_dataset_path}. " "method='full' requires access to the original count data. " "Use method='stats' to shrink without original data." ) # Load original dataset for streaming backed = read_backed(original_path) try: labels = backed.obs[perturbation_column].astype(str).to_numpy() # Get or compute size factors size_factors_global = adata.uns.get("size_factors") if size_factors_global is not None: size_factors_all = np.asarray(size_factors_global, dtype=np.float64) else: size_factors_all = _median_of_ratios_size_factors(original_path) # Control cell indices control_mask = labels == control_label control_idx = np.where(control_mask)[0] if len(control_idx) == 0: raise ValueError(f"No control cells found with label '{control_label}'") # Load control cells once (they're reused for all perturbations) control_counts = backed.X[control_idx, :].toarray() if sp.issparse(backed.X[control_idx, :]) else np.asarray(backed.X[control_idx, :]) control_size_factors = size_factors_all[control_idx] # Process each perturbation for group_idx, pert_label in enumerate(candidates): logger.debug(f"Shrinking LFC for perturbation {group_idx + 1}/{n_groups}: {pert_label}") # Get perturbation cell indices pert_mask = labels == pert_label pert_idx = np.where(pert_mask)[0] if len(pert_idx) == 0: shrunk_lfc[group_idx] = raw_lfc[group_idx] shrunk_se[group_idx] = se[group_idx] continue # Load perturbation cells pert_counts = backed.X[pert_idx, :].toarray() if sp.issparse(backed.X[pert_idx, :]) else np.asarray(backed.X[pert_idx, :]) pert_size_factors = size_factors_all[pert_idx] # Combine control and perturbation combined_counts = np.vstack([control_counts, pert_counts]) combined_size_factors = np.concatenate([control_size_factors, pert_size_factors]) # Build design matrix n_control = len(control_idx) n_pert = len(pert_idx) n_combined = n_control + n_pert design_matrix = np.zeros((n_combined, 2), dtype=np.float64) design_matrix[:, 0] = 1.0 design_matrix[n_control:, 1] = 1.0 # Get MLE coefficients from NB-GLM mle_intercept = fitted_intercept[group_idx] mle_lfc_group = raw_lfc[group_idx] mle_coef = np.vstack([mle_intercept, mle_lfc_group]) disp_group = dispersion[group_idx] se_group = se[group_idx] # Determine prior scale based on mode if prior_scale_mode == "per_comparison": pert_prior_scale = _estimate_apeglm_prior_scale(mle_lfc_group, se_group) else: pert_prior_scale = global_prior_scale # Full apeGLM shrinkage with L-BFGS-B per gene # NOTE: By default (min_mu=0.0), no min_mu clamping is applied, # matching PyDESeq2's lfc_shrink which omits min_mu entirely. # PyDESeq2 uses min_mu=0.5 for NB-GLM fitting but does NOT pass # min_mu to the shrinkage optimizer. shrunk_coef, shrunk_se_group, converged = shrink_lfc_apeglm( counts=combined_counts, design_matrix=design_matrix, size_factors=combined_size_factors, dispersion=disp_group, mle_coef=mle_coef, mle_se=se_group, shrink_index=1, prior_scale=pert_prior_scale, n_jobs=n_jobs, batch_size=batch_size, min_mu=min_mu, ) shrunk_lfc[group_idx] = shrunk_coef[1, :] shrunk_se[group_idx] = shrunk_se_group n_converged = converged.sum() total_converged += n_converged total_genes_processed += n_genes logger.debug(f" {n_converged}/{n_genes} genes converged for {pert_label}") finally: backed.file.close() # Convert shrunk results from ln-scale to log2 if original output was log2 if lfc_base == "log2": ln2 = np.log(2) shrunk_lfc = shrunk_lfc / ln2 shrunk_se = shrunk_se / ln2 # Update layers adata.layers["logfoldchanges"] = shrunk_lfc adata.layers["standard_error"] = shrunk_se # Posterior SE adata.X = shrunk_lfc # Update effect_size matrix # Update metadata adata.uns["lfc_shrinkage_type"] = "apeglm" adata.uns["apeglm_prior_scale"] = global_prior_scale adata.uns["shrinkage_method"] = method adata.uns["prior_scale_mode"] = prior_scale_mode # Store profiling results or "NA" for production if profiling and profiler is not None: profiler.stop("shrinkage") profiler.snapshot("shrinkage_end") profiler.stop("total") profiler.stop_sampling() stats = profiler.get_stats() adata.uns["profiling"] = { "profiling_enabled": True, "shrinkage_seconds": stats.get("timing", {}).get("sections", {}).get("shrinkage", {}).get("seconds", 0.0), "shrinkage_peak_memory_mb": stats.get("memory", {}).get("peak_mb", 0.0), "total_seconds": stats.get("timing", {}).get("total_seconds", 0.0), } else: adata.uns["profiling"] = "NA" # Determine output path if data_name is None: # Append _shrunk to the stem stem = path.stem # Remove crispyx_ prefix if present for cleaner naming if stem.startswith("crispyx_"): stem = stem[8:] if stem.endswith("_shrunk"): # Already shrunk, use as-is data_name = stem else: data_name = f"{stem}_shrunk" # Use the data_name directly as output filename if output_dir is None: output_dir = path.parent else: output_dir = Path(output_dir) # Ensure crispyx prefix if not data_name.startswith("crispyx_"): output_filename = f"crispyx_{data_name}.h5ad" else: output_filename = f"{data_name}.h5ad" output_path = output_dir / output_filename adata.write(output_path) # Build result object corr_method = adata.uns.get("pvalue_correction", "benjamini-hochberg") # Get matrices from layers statistic_matrix = adata.layers.get("z_score", np.zeros_like(shrunk_lfc)) pvalue_matrix = adata.layers.get("pvalue", np.ones_like(shrunk_lfc)) pvalue_adj_matrix = adata.layers.get("pvalue_adj", np.ones_like(shrunk_lfc)) pts_matrix = adata.layers.get("pts", np.zeros((n_groups, n_genes), dtype=np.float32)) pts_rest_matrix = adata.layers.get("pts_rest", np.zeros((n_groups, n_genes), dtype=np.float32)) # Create order matrix statistic_for_order = np.where( np.isfinite(statistic_matrix), np.abs(statistic_matrix), -np.inf ) order_matrix = np.argsort(-statistic_for_order, axis=1, kind="mergesort") result = RankGenesGroupsResult( genes=gene_symbols, groups=candidates, statistics=np.asarray(statistic_matrix), pvalues=np.asarray(pvalue_matrix), pvalues_adj=np.asarray(pvalue_adj_matrix), logfoldchanges=shrunk_lfc, effect_size=shrunk_lfc, u_statistics=np.zeros_like(shrunk_lfc), pts=np.asarray(pts_matrix, dtype=np.float32), pts_rest=np.asarray(pts_rest_matrix, dtype=np.float32), order=order_matrix, groupby=perturbation_column, method="nb_glm", control_label=control_label, tie_correct=False, pvalue_correction=corr_method, ) result.result = AnnData(output_path) logger.info( f"Applied apeGLM LFC shrinkage (method={method}, prior_scale_mode={prior_scale_mode}) " f"to {n_groups} perturbations, {n_genes} genes (prior_scale={global_prior_scale:.4f}). " f"Output: {output_path}" ) return result