Source code for crispyx.data

"""Helpers for working with AnnData ``.h5ad`` files in a streaming friendly way."""

from __future__ import annotations

import logging
import os as _os
import re as _re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterator, Literal, Mapping, Sequence

import h5py
import anndata as ad
import numpy as np
import pandas as pd
import scipy.sparse as sp

logger = logging.getLogger(__name__)


[docs] def drop_file_cache(path: str | Path) -> None: """Advise the kernel to drop page cache for *path* (Linux only). On cgroup-limited systems (SLURM), page-cache pages count toward the memory limit. Calling this after a streaming read prevents the cached file data from consuming the cgroup budget. The call is a no-op on non-Linux platforms or when the file cannot be opened. """ try: fd = _os.open(str(path), _os.O_RDONLY) try: _os.posix_fadvise(fd, 0, 0, _os.POSIX_FADV_DONTNEED) finally: _os.close(fd) except (OSError, AttributeError): pass
from numba import njit, prange # Numba-accelerated helpers for dense→CSR conversion (60x faster than scipy) @njit(parallel=True) def _numba_count_row_nnz(dense: np.ndarray) -> np.ndarray: """Count non-zeros per row using parallel numba.""" n_rows = dense.shape[0] row_nnz = np.zeros(n_rows, dtype=np.int64) for i in prange(n_rows): count = 0 for j in range(dense.shape[1]): if dense[i, j] != 0: count += 1 row_nnz[i] = count return row_nnz @njit(parallel=True) def _numba_extract_csr_data( dense: np.ndarray, indptr: np.ndarray, data: np.ndarray, indices: np.ndarray, ) -> None: """Extract CSR data/indices in parallel from dense array.""" n_rows = dense.shape[0] for i in prange(n_rows): pos = indptr[i] for j in range(dense.shape[1]): val = dense[i, j] if val != 0: data[pos] = val indices[pos] = j pos += 1 ENSEMBL_PREFIXES = ("ENS", "FBgn", "YAL", "YBL", "YCL", "YDL", "YEL", "YFL", "YGL", "YHL", "YIL", "YJL", "YKL", "YLL", "YML", "YNL", "YOL", "YPL", "YQL", "YRL", "YSL", "YTL", "YUL", "YVL", "YWL", "YXL")
[docs] def is_dense_storage(path: str | Path) -> bool: """Check if h5ad file stores X matrix as dense array. Parameters ---------- path Path to h5ad file. Returns ------- bool True if X is stored as dense array, False if sparse (CSR/CSC). """ with h5py.File(path, 'r') as f: if 'X' not in f: return False x_obj = f['X'] if isinstance(x_obj, h5py.Dataset): # Dense array stored directly as dataset return True elif isinstance(x_obj, h5py.Group): # Check encoding-type attribute encoding = x_obj.attrs.get('encoding-type', b'') if isinstance(encoding, bytes): encoding = encoding.decode('utf-8') return encoding == 'array' return False
[docs] def get_matrix_storage_format(path: str | Path) -> str: """Detect matrix storage format in h5ad file. Parameters ---------- path Path to h5ad file. Returns ------- str One of: 'csr', 'csc', 'dense' """ with h5py.File(path, 'r') as f: if 'X' not in f: return 'dense' x_obj = f['X'] if isinstance(x_obj, h5py.Dataset): # Dense array stored directly as dataset return 'dense' elif isinstance(x_obj, h5py.Group): # Check encoding-type attribute encoding = x_obj.attrs.get('encoding-type', b'') if isinstance(encoding, bytes): encoding = encoding.decode('utf-8') if encoding == 'csc_matrix': return 'csc' elif encoding == 'csr_matrix': return 'csr' elif encoding == 'array': return 'dense' return 'dense'
_MISSING = object() class _LazyFrameAccessor: """Provide a read-friendly view over ``obs``/``var`` tables with ``.load``.""" def __init__(self, parent: "AnnData", attr: str) -> None: self._parent = parent self._attr = attr self._cache: pd.DataFrame | None = None def load(self) -> pd.DataFrame: if self._cache is None: loaded = getattr(self._parent.to_memory(), self._attr) if isinstance(loaded, pd.DataFrame): loaded = loaded.copy() else: loaded = pd.DataFrame(loaded) self._cache = loaded return self._cache def head(self, n: int = 5) -> pd.DataFrame: return self.load().head(n) def __len__(self) -> int: return len(self.load()) def __iter__(self): # pragma: no cover - passthrough for convenience return iter(self.load()) def __getitem__(self, item): # pragma: no cover - passthrough for convenience return self.load().__getitem__(item) def __getattr__(self, name: str): # pragma: no cover - passthrough for convenience return getattr(self.load(), name) def __repr__(self) -> str: # pragma: no cover - display preview frame = self.head() if frame.empty: return f"<{self._attr}: empty DataFrame>" return f"<{self._attr} preview>\n{frame}" def _preview_uns_value(value: Any, n: int = 5) -> Any: if isinstance(value, pd.DataFrame): return value.head(n) if isinstance(value, np.ndarray): return value[:n] if isinstance(value, Mapping): preview: dict[Any, Any] = {} for idx, (key, item) in enumerate(value.items()): if idx >= n: preview["…"] = "…" break preview[key] = _preview_uns_value(item, n) return preview if isinstance(value, (list, tuple)): return type(value)(value[:n]) return value class _LazyUnsEntry: """Deferred loader for a single ``uns`` key.""" def __init__(self, parent: "AnnData", key: str) -> None: self._parent = parent self._key = key self._cache: Any = _MISSING def load(self) -> Any: if self._cache is _MISSING: self._cache = self._parent.to_memory().uns[self._key] return self._cache def preview(self, n: int = 5) -> Any: return _preview_uns_value(self.load(), n) def __getattr__(self, name: str): # pragma: no cover - passthrough for convenience return getattr(self.load(), name) def __getitem__(self, item): # pragma: no cover - passthrough for convenience return self.load()[item] def __repr__(self) -> str: # pragma: no cover - display preview return repr(self.preview()) class _LazyUnsMapping(Mapping[str, _LazyUnsEntry]): """Mapping-style accessor exposing ``uns`` keys with lazy loading.""" def __init__(self, parent: "AnnData") -> None: self._parent = parent self._cache: dict[str, _LazyUnsEntry] = {} def _keys(self) -> list[str]: try: return list(self._parent.backed.uns.keys()) except AttributeError: return [] def __getitem__(self, key: str) -> _LazyUnsEntry: keys = self._keys() if key not in keys: raise KeyError(key) if key not in self._cache: self._cache[key] = _LazyUnsEntry(self._parent, key) return self._cache[key] def __iter__(self): # pragma: no cover - passthrough for convenience return iter(self._keys()) def __len__(self) -> int: return len(self._keys()) def keys(self): # pragma: no cover - convenience mirror return self._keys() def items(self): # pragma: no cover - convenience mirror return [(key, self[key]) for key in self._keys()] def __repr__(self) -> str: # pragma: no cover - display preview keys = self._keys() if not keys: return "<uns: empty>" previews = {key: self[key].preview() for key in keys} return repr(previews)
[docs] class AnnData: """Thin wrapper around a backed :class:`anndata.AnnData` handle.""" def __init__(self, path: str | Path, *, mode: str = "r") -> None: """Open a backed AnnData wrapper. Parameters ---------- path : str or Path Path to an ``.h5ad`` file. mode : str, optional HDF5 file access mode (default ``'r'``). """ self.path = Path(path) self._mode = mode self._backed: ad.AnnData | None = None self._obs_view: _LazyFrameAccessor | None = None self._var_view: _LazyFrameAccessor | None = None self._uns_view: _LazyUnsMapping | None = None @property def filename(self) -> str: """Return the underlying filename for compatibility with Scanpy.""" return str(self.path) def __fspath__(self) -> str: # pragma: no cover - filesystem protocol return str(self.path) def __str__(self) -> str: # pragma: no cover - helpful when printing paths return str(self.path) @property def backed(self) -> ad.AnnData: """Return the lazily opened backed AnnData handle.""" if self._backed is None: self._backed = ad.read_h5ad(str(self.path), backed=self._mode) return self._backed
[docs] def close(self) -> None: """Close the underlying file handle if it is open.""" if self._backed is not None: try: self._backed.file.close() finally: self._backed = None self._obs_view = None self._var_view = None self._uns_view = None
[docs] def to_memory(self) -> ad.AnnData: """Materialise the backed AnnData into memory.""" return ad.read_h5ad(str(self.path))
@property def obs(self) -> _LazyFrameAccessor: """Lazy accessor for observation (cell) metadata.""" if self._obs_view is None: self._obs_view = _LazyFrameAccessor(self, "obs") return self._obs_view @property def var(self) -> _LazyFrameAccessor: """Lazy accessor for variable (gene) metadata.""" if self._var_view is None: self._var_view = _LazyFrameAccessor(self, "var") return self._var_view @property def uns(self) -> _LazyUnsMapping: """Lazy accessor for unstructured annotations.""" if self._uns_view is None: self._uns_view = _LazyUnsMapping(self) return self._uns_view def __enter__(self) -> "AnnData": self.backed # ensure handle is opened return self def __exit__(self, exc_type, exc, tb) -> None: self.close() def __getattr__(self, name: str): # Guard: during unpickling __init__ has not run yet so _backed doesn't # exist. Without this guard, accessing self.backed calls self._backed # which triggers __getattr__("_backed") → infinite recursion. if "_backed" not in self.__dict__: raise AttributeError(name) return getattr(self.backed, name) def __getstate__(self) -> dict: """Return picklable state (path + mode only; file handle is not serialised).""" return {"path": self.path, "_mode": self._mode} def __setstate__(self, state: dict) -> None: """Restore from pickle; file handle will be reopened lazily on next access.""" self.path = state["path"] self._mode = state["_mode"] self._backed = None self._obs_view = None self._var_view = None self._uns_view = None def __repr__(self) -> str: # pragma: no cover - debugging helper return f"AnnData(path={self.path!s}, mode='{self._mode}')" def __del__(self) -> None: # pragma: no cover - defensive cleanup try: self.close() except Exception: pass
[docs] def read_backed(path: str | Path) -> ad.AnnData: """Open an ``.h5ad`` file in backed mode for low-memory access.""" return ad.read_h5ad(str(path), backed="r")
# ----------------------------------------------------------------------------- # H5AD Write Helpers (for close-write-reopen pattern) # -----------------------------------------------------------------------------
[docs] def write_obsm_to_h5ad(path: str | Path, key: str, data: np.ndarray) -> None: """Write a dense array to obsm/{key} in an h5ad file. Parameters ---------- path Path to the h5ad file. key Key under obsm (e.g., 'X_pca'). data Dense numpy array of shape (n_obs, n_dims). """ with h5py.File(path, "r+") as f: if "obsm" not in f: f.create_group("obsm") obsm = f["obsm"] if key in obsm: del obsm[key] ds = obsm.create_dataset(key, data=data, compression="gzip", compression_opts=4) ds.attrs["encoding-type"] = "array" ds.attrs["encoding-version"] = "0.2.0"
[docs] def write_varm_to_h5ad(path: str | Path, key: str, data: np.ndarray) -> None: """Write a dense array to varm/{key} in an h5ad file. Parameters ---------- path Path to the h5ad file. key Key under varm (e.g., 'PCs'). data Dense numpy array of shape (n_vars, n_dims). """ with h5py.File(path, "r+") as f: if "varm" not in f: f.create_group("varm") varm = f["varm"] if key in varm: del varm[key] ds = varm.create_dataset(key, data=data, compression="gzip", compression_opts=4) ds.attrs["encoding-type"] = "array" ds.attrs["encoding-version"] = "0.2.0"
[docs] def write_uns_dict_to_h5ad(path: str | Path, key: str, data: dict) -> None: """Write a dict to uns/{key} in an h5ad file. Handles scalar values, numpy arrays, and nested dicts. Uses AnnData-compatible encoding for proper round-trip compatibility. Parameters ---------- path Path to the h5ad file. key Key under uns (e.g., 'pca'). data Dictionary with string keys and scalar/array values. """ # Variable-length string type for h5py str_dtype = h5py.string_dtype(encoding='utf-8') def _write_value(grp: h5py.Group, k: str, v): if k in grp: del grp[k] if isinstance(v, dict): sub = grp.create_group(k) for sub_k, sub_v in v.items(): _write_value(sub, sub_k, sub_v) elif isinstance(v, np.ndarray): ds = grp.create_dataset(k, data=v) ds.attrs["encoding-type"] = "array" ds.attrs["encoding-version"] = "0.2.0" elif isinstance(v, (list, tuple)): arr = np.array(v) ds = grp.create_dataset(k, data=arr) ds.attrs["encoding-type"] = "array" ds.attrs["encoding-version"] = "0.2.0" elif isinstance(v, str): ds = grp.create_dataset(k, data=v, dtype=str_dtype) ds.attrs["encoding-type"] = "string" ds.attrs["encoding-version"] = "0.2.0" elif isinstance(v, bool): # Store bool as numpy bool_ to avoid confusion with int ds = grp.create_dataset(k, data=np.bool_(v)) ds.attrs["encoding-type"] = "numeric-scalar" ds.attrs["encoding-version"] = "0.2.0" elif isinstance(v, (int, float, np.integer, np.floating)): ds = grp.create_dataset(k, data=v) ds.attrs["encoding-type"] = "numeric-scalar" ds.attrs["encoding-version"] = "0.2.0" else: # Fallback: try as string ds = grp.create_dataset(k, data=str(v), dtype=str_dtype) ds.attrs["encoding-type"] = "string" ds.attrs["encoding-version"] = "0.2.0" with h5py.File(path, "r+") as f: if "uns" not in f: f.create_group("uns") uns = f["uns"] if key in uns: del uns[key] grp = uns.create_group(key) for k, v in data.items(): _write_value(grp, k, v)
[docs] def write_obsp_to_h5ad(path: str | Path, key: str, data: sp.spmatrix) -> None: """Write a sparse matrix to obsp/{key} in an h5ad file. Stores in CSR format following AnnData conventions. Parameters ---------- path Path to the h5ad file. key Key under obsp (e.g., 'distances', 'connectivities'). data Sparse matrix of shape (n_obs, n_obs). """ csr = sp.csr_matrix(data) with h5py.File(path, "r+") as f: if "obsp" not in f: f.create_group("obsp") obsp = f["obsp"] if key in obsp: del obsp[key] grp = obsp.create_group(key) grp.attrs["encoding-type"] = np.bytes_("csr_matrix") grp.attrs["encoding-version"] = np.bytes_("0.1.0") grp.attrs["shape"] = np.array(csr.shape, dtype=np.int64) grp.create_dataset("data", data=csr.data, compression="gzip", compression_opts=4) grp.create_dataset("indices", data=csr.indices) grp.create_dataset("indptr", data=csr.indptr)
[docs] def resolve_data_path( data: str | Path | "AnnData" | ad.AnnData, *, require_exists: bool = True, ) -> Path: """Resolve the on-disk path for a backed AnnData object or path-like input. This utility supports flexible input types for crispyx functions, allowing users to pass either a file path or an AnnData object. Parameters ---------- data One of: - A string or Path to an h5ad file - A crispyx.AnnData wrapper (has .path attribute) - A backed anndata.AnnData object (has .filename attribute) require_exists If True (default), verify the resolved path exists. Returns ------- Path The resolved file path. Raises ------ TypeError If data is an in-memory (non-backed) AnnData or unsupported type. FileNotFoundError If require_exists is True and the path does not exist. Examples -------- >>> from crispyx.data import resolve_data_path >>> path = resolve_data_path("data/counts.h5ad") >>> path = resolve_data_path(adata_wrapper) >>> path = resolve_data_path(backed_adata) """ if isinstance(data, (str, Path)): result = Path(data) elif isinstance(data, AnnData): result = data.path elif isinstance(data, ad.AnnData): filename = getattr(data, "filename", None) if filename: result = Path(filename) else: raise TypeError( "Operations in crispyx expect a backed AnnData object or file path. " "The provided AnnData appears to be in-memory (no .filename attribute)." ) else: raise TypeError( f"Expected a path-like value or backed AnnData; received {type(data)!r}." ) if require_exists and not result.exists(): raise FileNotFoundError(f"Data file not found: {result}") return result
[docs] def resolve_output_path( input_path: str | Path, *, suffix: str, output_dir: str | Path | None = None, data_name: str | None = None, module: str = "crispyx", ) -> Path: """Construct an informative output path for an intermediate ``.h5ad`` file.""" input_path = Path(input_path) output_dir = Path(output_dir) if output_dir is not None else input_path.parent if data_name: # Preserve any existing module prefix supplied by the caller. base = data_name if module and not base.startswith(f"{module}_"): base = f"{module}_{base}" # If the provided name does not already encode the suffix, append it to avoid # different intermediates overwriting each other when the same ``data_name`` # is reused across pipeline steps. if not base.endswith(f"_{suffix}"): base = f"{base}_{suffix}" return output_dir / f"{base}.h5ad" return output_dir / f"{module}_{suffix}.h5ad"
[docs] def ensure_gene_symbol_column( adata: ad.AnnData | ad._core.anndata.AnnDataMixin, gene_name_column: str | None, ) -> pd.Index: """Return a vector of gene symbols and verify they look like symbols, not Ensembl IDs.""" if gene_name_column is None: raw_names = adata.var_names logger.info( "No gene_name_column provided; using adata.var_names for gene identifiers." ) else: if gene_name_column not in adata.var.columns: if gene_name_column == "gene_symbols": raw_names = adata.var_names logger.info( "Column 'gene_symbols' not found in adata.var; " "using adata.var_names for gene identifiers." ) else: raise KeyError( f"Gene name column '{gene_name_column}' was not found in adata.var. Available columns: {list(adata.var.columns)}" ) else: raw_names = adata.var[gene_name_column] names = pd.Index(raw_names).astype(str) _validate_gene_symbols(names) return names
def _validate_gene_symbols(names: Sequence[str]) -> None: """Perform a basic sanity check that the provided gene identifiers look like symbols.""" if len(names) == 0: raise ValueError("No gene names were provided.") names = pd.Index(names).astype(str) prefixes = names.str.upper().str.slice(0, 3) ensembl_like = prefixes.isin([p[:3] for p in ENSEMBL_PREFIXES]).sum() if ensembl_like > len(names) / 2: raise ValueError( "The majority of provided gene identifiers appear to be Ensembl-style IDs. " "Please supply a column containing gene symbols." )
[docs] def resolve_control_label( labels: Sequence[str], control_label: str | None, *, verbose: bool = True, ) -> str: """Return an explicit control label, inferring one when necessary.""" if control_label is not None: return str(control_label) index = pd.Index(labels).astype(str) if index.empty: raise ValueError( "Cannot infer control label because no perturbation labels were provided." ) lower = index.str.lower() exact_terms = {"ctrl", "control", "nontarget", "non-target", "non_target"} substring_terms = ("ctrl", "control", "nontarget", "non-target", "non_target") def _select(predicate) -> str | None: for label, lowered in zip(index, lower): if predicate(lowered): return str(label) return None candidate = _select(lambda text: text in exact_terms) if candidate is None: candidate = _select(lambda text: any(term in text for term in substring_terms)) if candidate is None: candidate = _select(lambda text: ("non" in text) and ("target" in text)) if candidate is None: raise ValueError( "Unable to infer control label automatically. Please provide 'control_label' explicitly." ) if verbose: logger.info("Inferred control label '%s' from perturbation labels.", candidate) return candidate
[docs] def read_h5ad_ondisk( path: str | Path, *, n_obs: int = 5, n_vars: int = 5, ) -> AnnData: """Open an ``.h5ad`` file on disk, print a preview, and return a read-only view.""" adata_ro = AnnData(path) backed = adata_ro.backed try: print(backed) if n_obs > 0 and backed.n_obs > 0: print("First obs rows:") print(backed.obs.head(n_obs)) if n_vars > 0 and backed.n_vars > 0: print("First var rows:") print(backed.var.head(n_vars)) except Exception: adata_ro.close() raise return adata_ro
[docs] def iter_matrix_chunks( adata: ad.AnnData | ad._core.anndata.AnnDataMixin, *, axis: int = 0, chunk_size: int = 1024, convert_to_dense: bool = True, ) -> Iterator[tuple[slice, np.ndarray | sp.spmatrix]]: """Yield chunks of the expression matrix.""" if axis not in (0, 1): raise ValueError("axis must be 0 (rows) or 1 (columns)") n_obs, n_vars = adata.n_obs, adata.n_vars length = n_obs if axis == 0 else n_vars for start in range(0, length, chunk_size): end = min(start + chunk_size, length) if axis == 0: block = adata.X[start:end] else: block = adata.X[:, start:end] if convert_to_dense: block = _to_dense(block) yield slice(start, end), block
def _to_dense(matrix: np.ndarray) -> np.ndarray: if hasattr(matrix, "toarray"): matrix = matrix.toarray() return np.asarray(matrix)
[docs] def normalize_total_block( block: np.ndarray | sp.spmatrix, *, library_size: np.ndarray | None = None, target_sum: float = 1e4, dtype: np.dtype | type = np.float64, ) -> tuple[np.ndarray, np.ndarray]: """Return a library-size normalised dense view of ``block``. Parameters ---------- block: A slice of the expression matrix with shape ``(n_cells, n_genes)``. library_size: Optional precomputed library sizes for the cells in ``block``. When ``None`` the library size is computed from ``block`` directly. target_sum: Target total counts per cell after normalisation, matching the default used by :func:`scanpy.pp.normalize_total`. dtype: Data type for the output dense array. Defaults to float64. Returns ------- tuple A tuple ``(normalised, library_size)`` where ``normalised`` is a dense array containing the normalised counts and ``library_size`` contains the per-cell library sizes that were used. """ dense = _to_dense(block).astype(dtype, copy=True) if dense.ndim != 2: raise ValueError("block must be two-dimensional") if library_size is None: library_size = dense.sum(axis=1) else: library_size = np.asarray(library_size, dtype=dtype) if library_size.shape[0] != dense.shape[0]: raise ValueError( "library_size length does not match the number of cells in block" ) scale = np.divide( float(target_sum), library_size, out=np.zeros_like(library_size, dtype=dtype), where=library_size > 0, ) dense *= scale[:, None] return dense, library_size
def _ensure_csr(matrix: np.ndarray | sp.spmatrix, *, dtype: np.dtype | None = None) -> sp.csr_matrix: """Convert the provided matrix to CSR format.""" if sp.isspmatrix_csr(matrix): csr = matrix elif sp.issparse(matrix): csr = matrix.tocsr() else: csr = sp.csr_matrix(np.asarray(matrix)) if dtype is not None and csr.dtype != dtype: csr = csr.astype(dtype) return csr def _extract_csr_components_dense( block: np.ndarray, dtype: np.dtype, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: """Extract CSR data, indices, and row_nnz from dense block efficiently. Uses numba-accelerated parallel extraction when available (60x faster), falling back to numpy vectorized operations otherwise. Parameters ---------- block Dense 2D array of shape (n_rows, n_cols). dtype Target dtype for data array. Returns ------- tuple (data, indices, row_nnz, total_nnz) where data and indices are flattened CSR components and row_nnz is counts per row. """ if block.size == 0: empty_data = np.array([], dtype=dtype) empty_indices = np.array([], dtype=np.int32) empty_nnz = np.zeros(block.shape[0], dtype=np.int64) return empty_data, empty_indices, empty_nnz, 0 # Ensure C-contiguous for numba if not block.flags['C_CONTIGUOUS']: block = np.ascontiguousarray(block) # Use numba-accelerated parallel extraction (60x faster than scipy) row_nnz = _numba_count_row_nnz(block) indptr = np.zeros(block.shape[0] + 1, dtype=np.int64) indptr[1:] = np.cumsum(row_nnz) total_nnz = int(indptr[-1]) if total_nnz == 0: return np.array([], dtype=dtype), np.array([], dtype=np.int32), row_nnz, 0 data = np.empty(total_nnz, dtype=dtype) indices = np.empty(total_nnz, dtype=np.int32) _numba_extract_csr_data(block.astype(dtype, copy=False), indptr, data, indices) return data, indices, row_nnz, total_nnz
[docs] def write_filtered_subset( source_path: str | Path, *, cell_mask: np.ndarray, gene_mask: np.ndarray, output_path: str | Path, chunk_size: int = 4096, var_assignments: dict[str, Sequence] | None = None, row_nnz: np.ndarray | None = None, total_nnz: int | None = None, data_dtype: np.dtype | None = None, chunk_cache: Any = None, ) -> None: """Stream a filtered AnnData view to disk without materialising ``X``. Parameters ---------- source_path Path to source h5ad file. cell_mask Boolean mask for cells to include. gene_mask Boolean mask for genes to include. output_path Path for output h5ad file. chunk_size Number of cells to process per chunk. var_assignments Optional dict of column assignments for var DataFrame. row_nnz Optional pre-computed non-zero counts per row. When provided along with total_nnz and data_dtype, skips the counting pass. total_nnz Optional pre-computed total non-zero count. data_dtype Optional pre-computed data type for the sparse matrix. chunk_cache Optional _ChunkCache object from qc module. When provided, reads CSR data from cache instead of re-reading the source matrix. """ source_path = Path(source_path) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) backed = read_backed(source_path) try: obs = backed.obs.iloc[cell_mask].copy() var = backed.var.iloc[gene_mask].copy() obs.index = obs.index.astype(str) var.index = var.index.astype(str) # Drop stale categories so downstream tools (e.g. scanpy) only see # groups that have at least one cell in the filtered subset. for _col in obs.columns: if isinstance(obs[_col].dtype, pd.CategoricalDtype): obs[_col] = obs[_col].cat.remove_unused_categories() if var_assignments: for key, values in var_assignments.items(): if len(values) != var.shape[0]: raise ValueError( f"Length mismatch for column '{key}': expected {var.shape[0]}, received {len(values)}" ) var[key] = np.asarray(values) finally: backed.file.close() n_obs = int(cell_mask.sum()) n_vars = int(gene_mask.sum()) gene_indices = np.flatnonzero(gene_mask) # Use pre-computed values if all three are provided, otherwise compute need_counting_pass = row_nnz is None or total_nnz is None or data_dtype is None if need_counting_pass: row_nnz = np.zeros(n_obs, dtype=np.int64) total_nnz = 0 data_dtype_local: np.dtype | None = None backed = read_backed(source_path) try: row_offset = 0 for slc, block in iter_matrix_chunks(backed, axis=0, chunk_size=chunk_size, convert_to_dense=False): local_mask = cell_mask[slc] if not np.any(local_mask): continue block = block[local_mask] if gene_indices.size: block = block[:, gene_indices] else: block = block[:, []] csr = _ensure_csr(block) counts = np.diff(csr.indptr) size = counts.size row_nnz[row_offset : row_offset + size] = counts total_nnz += int(csr.nnz) if data_dtype_local is None and csr.nnz: data_dtype_local = csr.data.dtype row_offset += size finally: backed.file.close() if data_dtype_local is None: data_dtype_local = np.float32 data_dtype = data_dtype_local placeholder = sp.csr_matrix((n_obs, n_vars), dtype=data_dtype) adata = ad.AnnData(placeholder, obs=obs, var=var) adata.write(output_path) if n_obs == 0 or n_vars == 0: with h5py.File(output_path, "r+", libver='latest') as dest: if "X" in dest: del dest["X"] grp = dest.create_group("X") grp.attrs["encoding-type"] = np.bytes_("csr_matrix") grp.attrs["encoding-version"] = np.bytes_("0.1.0") grp.create_dataset("data", shape=(0,), dtype=data_dtype) grp.create_dataset("indices", shape=(0,), dtype=np.int32) grp.create_dataset("indptr", data=np.zeros(n_obs + 1, dtype=np.int64)) grp.attrs["shape"] = np.array([n_obs, n_vars], dtype=np.int64) return indptr = np.zeros(n_obs + 1, dtype=np.int64) np.cumsum(row_nnz, out=indptr[1:]) # Phase 2 I/O Optimization: Use larger HDF5 chunks and write buffering # Optimal chunk size: balance between I/O overhead and cache efficiency # Target ~1MB chunks for data (assuming float32 = 4 bytes -> ~256K elements) hdf5_chunk_size = min(262144, max(8192, total_nnz // 16)) # 8K to 256K elements # Write buffer size: accumulate data before writing to reduce I/O syscalls write_buffer_size = hdf5_chunk_size * 2 # Buffer 2x chunk size before flushing with h5py.File(output_path, "r+", libver='latest') as dest: if "X" in dest: del dest["X"] grp = dest.create_group("X") grp.attrs["encoding-type"] = np.bytes_("csr_matrix") grp.attrs["encoding-version"] = np.bytes_("0.1.0") # Use explicit chunk sizes for better I/O performance data_ds = grp.create_dataset( "data", shape=(total_nnz,), dtype=data_dtype, chunks=(hdf5_chunk_size,) if total_nnz >= hdf5_chunk_size else None ) indices_ds = grp.create_dataset( "indices", shape=(total_nnz,), dtype=np.int32, chunks=(hdf5_chunk_size,) if total_nnz >= hdf5_chunk_size else None ) grp.create_dataset("indptr", data=indptr) grp.attrs["shape"] = np.array([n_obs, n_vars], dtype=np.int64) # Stream data with write buffering if chunk_cache is not None: # Read from cached CSR chunks (avoids re-reading the original matrix) offset = 0 data_buffer = [] indices_buffer = [] buffer_nnz = 0 for filtered_data, filtered_indices, n_cells in chunk_cache.iter_filtered_chunks( gene_indices, data_dtype ): nnz = len(filtered_data) if nnz: data_buffer.append(filtered_data) indices_buffer.append(filtered_indices) buffer_nnz += nnz # Flush buffer when it exceeds threshold if buffer_nnz >= write_buffer_size: combined_data = np.concatenate(data_buffer) combined_indices = np.concatenate(indices_buffer) data_ds[offset : offset + buffer_nnz] = combined_data indices_ds[offset : offset + buffer_nnz] = combined_indices offset += buffer_nnz data_buffer = [] indices_buffer = [] buffer_nnz = 0 # Flush remaining buffer if buffer_nnz > 0: combined_data = np.concatenate(data_buffer) combined_indices = np.concatenate(indices_buffer) data_ds[offset : offset + buffer_nnz] = combined_data indices_ds[offset : offset + buffer_nnz] = combined_indices else: # Read from source matrix (fallback when cache not available) backed = read_backed(source_path) try: offset = 0 data_buffer = [] indices_buffer = [] buffer_nnz = 0 for slc, block in iter_matrix_chunks( backed, axis=0, chunk_size=chunk_size, convert_to_dense=False ): local_mask = cell_mask[slc] if not np.any(local_mask): continue block = block[local_mask] if gene_indices.size: block = block[:, gene_indices] else: block = block[:, []] # Use optimized extraction for dense, scipy for sparse if sp.issparse(block): csr = _ensure_csr(block, dtype=data_dtype) chunk_data = csr.data chunk_indices = csr.indices.astype(np.int32, copy=False) nnz = int(csr.nnz) else: # Dense: use numba-accelerated direct extraction chunk_data, chunk_indices, _row_nnz, nnz = _extract_csr_components_dense( block, data_dtype ) if nnz: data_buffer.append(chunk_data) indices_buffer.append(chunk_indices) buffer_nnz += nnz # Flush buffer when it exceeds threshold if buffer_nnz >= write_buffer_size: combined_data = np.concatenate(data_buffer) combined_indices = np.concatenate(indices_buffer) data_ds[offset : offset + buffer_nnz] = combined_data indices_ds[offset : offset + buffer_nnz] = combined_indices offset += buffer_nnz data_buffer = [] indices_buffer = [] buffer_nnz = 0 # Flush remaining buffer if buffer_nnz > 0: combined_data = np.concatenate(data_buffer) combined_indices = np.concatenate(indices_buffer) data_ds[offset : offset + buffer_nnz] = combined_data indices_ds[offset : offset + buffer_nnz] = combined_indices finally: backed.file.close()
[docs] def normalize_total_log1p( data: str | Path | "AnnData" | ad.AnnData, output_path: str | Path | None = None, *, normalize: bool = True, log1p: bool = True, target_sum: float = 1e4, chunk_size: int = 4096, output_dir: str | Path | None = None, data_name: str | None = None, verbose: bool = True, ) -> "AnnData": """Stream normalize and/or log-transform an h5ad file without loading it fully into memory. This function processes the source file in chunks, optionally applying: 1. Total-count normalization (scanpy.pp.normalize_total equivalent) 2. Log1p transformation (scanpy.pp.log1p equivalent) The output is written as a sparse CSR matrix. This is the streaming equivalent of calling ``scanpy.pp.normalize_total`` followed by ``scanpy.pp.log1p``. Parameters ---------- data Path to source h5ad file, or a backed AnnData object. output_path Path for output h5ad file. If None, uses output_dir/data_name pattern. normalize Whether to apply total-count normalization. Default True. log1p Whether to apply log1p transformation. Default True. target_sum Target total counts per cell after normalization. Default 1e4. Only used if normalize=True. chunk_size Number of cells to process per chunk. Default 4096. output_dir Directory for output file. Defaults to input file's directory. data_name Custom name for output file. If None, uses "normalized" suffix. verbose Print progress information. Returns ------- AnnData Read-only AnnData wrapper pointing to the output file. Examples -------- >>> # Full normalization + log1p (default) >>> adata_norm = cx.pp.normalize_total_log1p(adata_ro, output_dir=OUTPUT_DIR, data_name="normalized") >>> # Only log1p (no normalization) >>> adata_log = cx.pp.normalize_total_log1p(adata_ro, normalize=False, output_dir=OUTPUT_DIR) >>> # Only normalization (no log1p) >>> adata_norm = cx.pp.normalize_total_log1p(adata_ro, log1p=False, output_dir=OUTPUT_DIR) >>> # Use explicit output path >>> adata_norm = cx.pp.normalize_total_log1p(adata_ro, "results/normalized.h5ad") """ if not normalize and not log1p: raise ValueError("At least one of normalize or log1p must be True") # Resolve input path from various input types source_path = resolve_data_path(data, require_exists=True) # Resolve output path if output_path is not None: output_path = Path(output_path) else: # Build suffix based on options if normalize and log1p: suffix = "normalized_log1p" elif normalize: suffix = "normalized" else: suffix = "log1p" output_path = resolve_output_path( source_path, suffix=suffix, output_dir=output_dir, data_name=data_name, ) output_path.parent.mkdir(parents=True, exist_ok=True) ops = [] if normalize: ops.append("normalize") if log1p: ops.append("log1p") if verbose: print(f"Generating preprocessed dataset (streaming, {'+'.join(ops)}): {output_path}") # First pass: count non-zeros and get metadata backed = read_backed(source_path) try: n_obs = backed.n_obs n_vars = backed.n_vars obs = backed.obs.copy() var = backed.var.copy() obs.index = obs.index.astype(str) var.index = var.index.astype(str) # Count non-zeros per row (after normalization, same sparsity as input) row_nnz = np.zeros(n_obs, dtype=np.int64) total_nnz = 0 row_offset = 0 for slc, block in iter_matrix_chunks( backed, axis=0, chunk_size=chunk_size, convert_to_dense=False ): csr = _ensure_csr(block) counts = np.diff(csr.indptr) row_nnz[row_offset : row_offset + len(counts)] = counts total_nnz += int(csr.nnz) row_offset += len(counts) finally: backed.file.close() if total_nnz == 0: # Empty matrix: write placeholder placeholder = sp.csr_matrix((n_obs, n_vars), dtype=np.float32) adata = ad.AnnData(placeholder, obs=obs, var=var) adata.write(output_path) return output_path # Choose consistent index dtype for both indptr and indices. # scipy requires indptr and indices to share the same integer dtype; # mixed int32/int64 triggers "Output dtype not compatible" in scipy >= 1.15. idx_dtype = np.int32 if total_nnz <= np.iinfo(np.int32).max else np.int64 # Compute indptr indptr = np.zeros(n_obs + 1, dtype=idx_dtype) np.cumsum(row_nnz, out=indptr[1:]) # HDF5 chunk sizing hdf5_chunk_size = min(262144, max(8192, total_nnz // 16)) # Create output file with placeholder placeholder = sp.csr_matrix((n_obs, n_vars), dtype=np.float32) adata = ad.AnnData(placeholder, obs=obs, var=var) adata.write(output_path) # Second pass: normalize, log1p, and write with h5py.File(output_path, "r+", libver='latest') as dest: if "X" in dest: del dest["X"] grp = dest.create_group("X") grp.attrs["encoding-type"] = np.bytes_("csr_matrix") grp.attrs["encoding-version"] = np.bytes_("0.1.0") data_ds = grp.create_dataset( "data", shape=(total_nnz,), dtype=np.float32, chunks=(hdf5_chunk_size,) if total_nnz >= hdf5_chunk_size else None, ) indices_ds = grp.create_dataset( "indices", shape=(total_nnz,), dtype=idx_dtype, chunks=(hdf5_chunk_size,) if total_nnz >= hdf5_chunk_size else None, ) grp.create_dataset("indptr", data=indptr) grp.attrs["shape"] = np.array([n_obs, n_vars], dtype=np.int64) backed = read_backed(source_path) try: offset = 0 for slc, block in iter_matrix_chunks( backed, axis=0, chunk_size=chunk_size, convert_to_dense=False ): csr = _ensure_csr(block) # Start with original data processed_data = csr.data.astype(np.float32) # Apply normalization if requested if normalize: # Library size per cell lib_sizes = np.asarray(csr.sum(axis=1)).ravel() # Avoid division by zero scale = np.divide( target_sum, lib_sizes, out=np.zeros_like(lib_sizes, dtype=np.float64), where=lib_sizes > 0, ) # Apply normalization to data (CSR stores data in row-major order) # For each row i, data[indptr[i]:indptr[i+1]] are the values for i in range(csr.shape[0]): start_idx = csr.indptr[i] end_idx = csr.indptr[i + 1] processed_data[start_idx:end_idx] = ( processed_data[start_idx:end_idx].astype(np.float64) * scale[i] ).astype(np.float32) # Apply log1p if requested if log1p: processed_data = np.log1p(processed_data) # Write to HDF5 nnz = len(processed_data) if nnz: data_ds[offset : offset + nnz] = processed_data indices_ds[offset : offset + nnz] = csr.indices.astype(idx_dtype) offset += nnz finally: backed.file.close() if verbose: print(f" ✓ Preprocessed dataset written: {n_obs} cells × {n_vars} genes") return AnnData(output_path)
[docs] def convert_to_csc( data: str | Path | "AnnData" | ad.AnnData, *, output_path: str | Path | None = None, chunk_size: int | None = None, output_dir: str | Path | None = None, data_name: str | None = None, verbose: bool = True, ) -> "AnnData": """Convert a backed h5ad file's matrix from CSR (or dense) to CSC format. CSC format allows O(nnz_in_chunk) column-slicing instead of O(total_nnz) that CSR requires. This is required for efficient Wilcoxon rank-sum testing, which iterates over gene chunks with ``axis=1`` access patterns. The conversion is done in two streaming passes over the source file so the peak memory is bounded by ``total_nnz × sizeof(float32 + row_dtype)`` bytes (the output buffers) plus one row-chunk working buffer. The result is written to disk in a single sequential write which is as fast as possible on HDD/SSD. If the input file is already CSC, no file is written; the function returns a backed AnnData pointing to the original file. Parameters ---------- data Path to source h5ad file (CSR or dense), or a backed AnnData. output_path Explicit path for the output file. If ``None``, a path is derived from ``output_dir``/``data_name`` with ``"_csc"`` appended to the stem. chunk_size Number of rows (cells) to read at a time during both passes. Default 4096. output_dir Directory for the output file. Defaults to the source file's directory. data_name Custom name used when building the output filename. verbose Print progress messages. Returns ------- AnnData Backed (read-only) AnnData pointing to the written CSC h5ad file, or to the source file if it was already CSC. Examples -------- >>> adata_csc = cx.pp.convert_to_csc(preprocessed_path, output_dir=OUTPUT_DIR) """ source_path = resolve_data_path(data, require_exists=True) # Fast path: input is already CSC — return it directly. if get_matrix_storage_format(source_path) == "csc": if verbose: print(f"File is already CSC, skipping conversion: {source_path}") return AnnData(source_path) # Resolve output path. if output_path is None: output_path = resolve_output_path( source_path, suffix="csc", output_dir=output_dir, data_name=data_name, ) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) if verbose: print(f"Converting to CSC (two-pass streaming): {source_path}{output_path}") # ------------------------------------------------------------------ Pass 1 # Read all rows in chunks; count non-zeros per *column*; collect metadata. backed = read_backed(source_path) try: n_obs = backed.n_obs n_vars = backed.n_vars if chunk_size is None: chunk_size = calculate_optimal_chunk_size(n_obs, n_vars) obs = backed.obs.copy() var = backed.var.copy() obs.index = obs.index.astype(str) var.index = var.index.astype(str) col_nnz = np.zeros(n_vars, dtype=np.int64) total_nnz = 0 for _slc, block in iter_matrix_chunks( backed, axis=0, chunk_size=chunk_size, convert_to_dense=False ): csr = _ensure_csr(block) np.add.at(col_nnz, csr.indices, 1) total_nnz += csr.nnz finally: backed.file.close() # Empty matrix edge case. if total_nnz == 0: placeholder = sp.csc_matrix((n_obs, n_vars), dtype=np.float32) adata = ad.AnnData(placeholder, obs=obs, var=var) adata.write(output_path) return AnnData(output_path) # CSC indptr: length n_vars + 1. Use int64 when NNZ exceeds INT32_MAX. idx_dtype = np.int32 if total_nnz <= np.iinfo(np.int32).max else np.int64 indptr = np.zeros(n_vars + 1, dtype=idx_dtype) np.cumsum(col_nnz, out=indptr[1:]) # Row-index dtype: int32 suffices for up to ~2 billion cells. row_dtype = np.int32 if n_obs <= np.iinfo(np.int32).max else np.int64 # ------------------------------------------------------------------ Pass 2 # Scatter CSR non-zeros into in-memory CSC buffers, then write sequentially. # Memory cost: total_nnz * (4 + sizeof_row_dtype) bytes. out_data = np.empty(total_nnz, dtype=np.float32) out_row_indices = np.empty(total_nnz, dtype=row_dtype) # offset[c] = next write position in the CSC arrays for column c. # Must be int64 so positions can exceed INT32_MAX when total_nnz > 2^31. offset = indptr[:-1].astype(np.int64) row_global = 0 backed = read_backed(source_path) try: for _slc, block in iter_matrix_chunks( backed, axis=0, chunk_size=chunk_size, convert_to_dense=False ): csr = _ensure_csr(block) n_chunk = csr.shape[0] if csr.nnz == 0: row_global += n_chunk continue # Global row index for every non-zero in this chunk. local_row_ids = np.repeat( np.arange(n_chunk, dtype=row_dtype), np.diff(csr.indptr) ) + row_dtype(row_global) cols = csr.indices # column index of each non-zero vals = csr.data.astype(np.float32) # Sort non-zeros by column so we can process contiguous groups. col_order = np.argsort(cols, kind="stable") sorted_cols = cols[col_order] sorted_vals = vals[col_order] sorted_rows = local_row_ids[col_order] # Compute within-column sequential ranks: 0, 1, 2, … per column. unique_cols, col_counts = np.unique(sorted_cols, return_counts=True) col_ends = np.cumsum(col_counts) col_starts = col_ends - col_counts within_col = np.arange(len(sorted_cols)) - np.repeat(col_starts, col_counts) # Absolute write positions: base position for each column + rank. positions = np.repeat(offset[unique_cols], col_counts) + within_col out_data[positions] = sorted_vals out_row_indices[positions] = sorted_rows # Advance column write offsets. offset[unique_cols] += col_counts.astype(np.int64) row_global += n_chunk finally: backed.file.close() # ------------------------------------------------------------------ Write hdf5_chunk_size = min(262144, max(8192, total_nnz // 16)) # Bootstrap a minimal skeleton so anndata writes valid obs/var groups. placeholder = sp.csr_matrix((n_obs, n_vars), dtype=np.float32) adata = ad.AnnData(placeholder, obs=obs, var=var) adata.write(output_path) # Replace the X group with a proper CSC encoding. with h5py.File(output_path, "r+", libver="latest") as dest: if "X" in dest: del dest["X"] grp = dest.create_group("X") grp.attrs["encoding-type"] = np.bytes_("csc_matrix") grp.attrs["encoding-version"] = np.bytes_("0.1.0") grp.attrs["shape"] = np.array([n_obs, n_vars], dtype=np.int64) chunk_arg = (hdf5_chunk_size,) if total_nnz >= hdf5_chunk_size else None grp.create_dataset("data", data=out_data, chunks=chunk_arg) grp.create_dataset("indices", data=out_row_indices, chunks=chunk_arg) grp.create_dataset("indptr", data=indptr) if verbose: print( f" ✓ CSC file written: {n_obs} cells × {n_vars} genes," f" {total_nnz:,} non-zeros" ) return AnnData(output_path)
[docs] def convert_to_csr( data: str | Path | "AnnData" | ad.AnnData, *, output_path: str | Path | None = None, chunk_size: int | None = None, output_dir: str | Path | None = None, data_name: str | None = None, verbose: bool = True, ) -> "AnnData": """Convert a backed h5ad file's matrix from CSC (or dense) to CSR format. CSR format allows O(nnz_in_chunk) row-slicing instead of O(total_nnz) that CSC requires. This is needed for efficient NB-GLM, size factor computation, and any operation that iterates over cell (row) chunks. The conversion mirrors :func:`convert_to_csc`: two streaming passes over the source file so peak memory is bounded by ``total_nnz × (sizeof(float32) + sizeof(col_dtype))`` bytes (the output buffers) plus one chunk working buffer. If the input file is already CSR, no file is written; the function returns a backed AnnData pointing to the original file. Parameters ---------- data Path to source h5ad file (CSC or dense), or a backed AnnData. output_path Explicit path for the output file. If ``None``, a path is derived from ``output_dir``/``data_name`` with ``"_csr"`` appended to the stem. chunk_size Number of rows (cells) to read at a time during both passes. Default is calculated automatically. output_dir Directory for the output file. Defaults to the source file's directory. data_name Custom name used when building the output filename. verbose Print progress messages. Returns ------- AnnData Backed (read-only) AnnData pointing to the written CSR h5ad file, or to the source file if it was already CSR. """ source_path = resolve_data_path(data, require_exists=True) # Fast path: input is already CSR — return it directly. fmt = get_matrix_storage_format(source_path) if fmt == "csr": if verbose: print(f"File is already CSR, skipping conversion: {source_path}") return AnnData(source_path) # Resolve output path. if output_path is None: output_path = resolve_output_path( source_path, suffix="csr", output_dir=output_dir, data_name=data_name, ) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) source_is_csc = fmt == "csc" if verbose: print(f"Converting {fmt}→CSR (two-pass streaming): {source_path}{output_path}") # Choose the optimal reading axis based on source format. # CSC: column-chunks (axis=1) are fast, row-chunks are O(total_nnz). # Dense / CSR: row-chunks (axis=0) are fast. read_axis = 1 if source_is_csc else 0 # ------------------------------------------------------------------ Pass 1 # Count non-zeros per row and collect metadata. backed = read_backed(source_path) try: n_obs = backed.n_obs n_vars = backed.n_vars if chunk_size is None: chunk_size = calculate_optimal_chunk_size(n_obs, n_vars) obs = backed.obs.copy() var = backed.var.copy() obs.index = obs.index.astype(str) var.index = var.index.astype(str) row_nnz = np.zeros(n_obs, dtype=np.int64) total_nnz = 0 if source_is_csc: # Column chunks: convert each to CSC, count NNZ per row via indices. for _slc, block in iter_matrix_chunks( backed, axis=1, chunk_size=chunk_size, convert_to_dense=False ): if sp.issparse(block): csc = sp.csc_matrix(block) np.add.at(row_nnz, csc.indices, 1) total_nnz += csc.nnz else: # Dense column block: count non-zeros per row. dense = np.asarray(block) nz_mask = dense != 0 row_nnz += nz_mask.sum(axis=1) total_nnz += int(nz_mask.sum()) else: # Row chunks: convert to CSR, count NNZ per row via indptr diffs. for _slc, block in iter_matrix_chunks( backed, axis=0, chunk_size=chunk_size, convert_to_dense=False ): csr = _ensure_csr(block) row_counts = np.diff(csr.indptr) row_nnz[_slc] += row_counts total_nnz += csr.nnz finally: backed.file.close() # Empty matrix edge case. if total_nnz == 0: placeholder = sp.csr_matrix((n_obs, n_vars), dtype=np.float32) adata = ad.AnnData(placeholder, obs=obs, var=var) adata.write(output_path) return AnnData(output_path) # CSR indptr: length n_obs + 1. idx_dtype = np.int32 if total_nnz <= np.iinfo(np.int32).max else np.int64 indptr = np.zeros(n_obs + 1, dtype=idx_dtype) np.cumsum(row_nnz, out=indptr[1:]) # Column-index dtype: int32 suffices for up to ~2 billion genes. col_dtype = np.int32 if n_vars <= np.iinfo(np.int32).max else np.int64 # ------------------------------------------------------------------ Pass 2 # Re-read and scatter non-zeros into the global CSR output arrays. out_data = np.empty(total_nnz, dtype=np.float32) out_col_indices = np.empty(total_nnz, dtype=col_dtype) # offset[r] = next write position in the CSR arrays for row r. offset = indptr[:-1].astype(np.int64) backed = read_backed(source_path) try: if source_is_csc: # Column-chunk reading: fast on CSC. Mirror of convert_to_csc's # row-chunk scatter, but transposed. col_global = 0 for _slc, block in iter_matrix_chunks( backed, axis=1, chunk_size=chunk_size, convert_to_dense=False ): if sp.issparse(block): csc = sp.csc_matrix(block) else: csc = sp.csc_matrix(np.asarray(block)) n_chunk_cols = csc.shape[1] if csc.nnz == 0: col_global += n_chunk_cols continue # Global column index for every non-zero in this chunk. local_col_ids = np.repeat( np.arange(n_chunk_cols, dtype=col_dtype), np.diff(csc.indptr), ) + col_dtype(col_global) rows = csc.indices # row index of each non-zero vals = csc.data.astype(np.float32) # Sort non-zeros by row so we can process contiguous groups. row_order = np.argsort(rows, kind="stable") sorted_rows = rows[row_order] sorted_vals = vals[row_order] sorted_cols = local_col_ids[row_order] # Compute within-row sequential ranks: 0, 1, 2, … per row. unique_rows, row_counts = np.unique(sorted_rows, return_counts=True) row_ends = np.cumsum(row_counts) row_starts = row_ends - row_counts within_row = np.arange(len(sorted_rows)) - np.repeat(row_starts, row_counts) # Absolute write positions: base offset for each row + rank. positions = np.repeat(offset[unique_rows], row_counts) + within_row out_data[positions] = sorted_vals out_col_indices[positions] = sorted_cols # Advance row write offsets. offset[unique_rows] += row_counts.astype(np.int64) col_global += n_chunk_cols else: # Row-chunk reading: fast on dense. Each chunk is already CSR-ordered. for _slc, block in iter_matrix_chunks( backed, axis=0, chunk_size=chunk_size, convert_to_dense=False ): csr = _ensure_csr(block) if csr.nnz == 0: continue # Bulk copy: the chunk spans rows [_slc.start : _slc.stop]. # Global write position: indptr[_slc.start] to indptr[_slc.stop]. dst_start = int(indptr[_slc.start]) dst_end = dst_start + csr.nnz out_data[dst_start:dst_end] = csr.data.astype(np.float32) out_col_indices[dst_start:dst_end] = csr.indices.astype(col_dtype) finally: backed.file.close() # Restore indptr (was not mutated in the row-chunk path, but reset for # the column-chunk path which advanced offset instead). indptr[0] = 0 np.cumsum(row_nnz, out=indptr[1:]) # ------------------------------------------------------------------ Write hdf5_chunk_size = min(262144, max(8192, total_nnz // 16)) # Bootstrap a minimal skeleton so anndata writes valid obs/var groups. placeholder = sp.csr_matrix((n_obs, n_vars), dtype=np.float32) adata = ad.AnnData(placeholder, obs=obs, var=var) adata.write(output_path) # Replace the X group with a proper CSR encoding. with h5py.File(output_path, "r+", libver="latest") as dest: if "X" in dest: del dest["X"] grp = dest.create_group("X") grp.attrs["encoding-type"] = np.bytes_("csr_matrix") grp.attrs["encoding-version"] = np.bytes_("0.1.0") grp.attrs["shape"] = np.array([n_obs, n_vars], dtype=np.int64) chunk_arg = (hdf5_chunk_size,) if total_nnz >= hdf5_chunk_size else None grp.create_dataset("data", data=out_data, chunks=chunk_arg) grp.create_dataset("indices", data=out_col_indices, chunks=chunk_arg) grp.create_dataset("indptr", data=indptr) if verbose: print( f" ✓ CSR file written: {n_obs} cells × {n_vars} genes," f" {total_nnz:,} non-zeros" ) return AnnData(output_path)
[docs] def calculate_optimal_chunk_size( n_obs: int, n_vars: int, available_memory_gb: float | None = None, safety_factor: float = 8.0, min_chunk: int = 512, max_chunk: int = 4096, ) -> int: """Calculate optimal chunk size based on dataset dimensions and available memory. Parameters ---------- n_obs Number of observations (cells) in the dataset. n_vars Number of variables (genes) in the dataset. available_memory_gb Available memory in gigabytes. If None, auto-detects using psutil. safety_factor Safety multiplier to account for overhead (default 8.0 for backed operations). min_chunk Minimum chunk size to return (default 512). max_chunk Maximum chunk size to return (default 4096). Returns ------- int Recommended chunk size, clamped to [min_chunk, max_chunk]. Examples -------- >>> calculate_optimal_chunk_size(100000, 20000, available_memory_gb=32) 2000 """ if available_memory_gb is None: try: import psutil available_memory_gb = psutil.virtual_memory().available / 1e9 except ImportError: logger.warning( "psutil not installed, using default 16GB for chunk size calculation. " "Install with: pip install psutil" ) available_memory_gb = 16.0 # Calculate chunk size based on memory # Each chunk uses approximately: chunk_size * n_vars * 8 bytes (float64) # Multiply by safety_factor for overhead bytes_per_chunk = n_vars * 8 * safety_factor max_chunk_from_memory = int((available_memory_gb * 1e9) / bytes_per_chunk) # Clamp to reasonable range chunk_size = max(min_chunk, min(max_chunk, max_chunk_from_memory)) logger.info( f"Calculated chunk size: {chunk_size} " f"(dataset: {n_obs} cells × {n_vars} genes, " f"available memory: {available_memory_gb:.1f}GB)" ) return chunk_size
[docs] def calculate_optimal_gene_chunk_size( n_obs: int, n_vars: int, n_groups: int | None = None, available_memory_gb: float | None = None, safety_factor: float = 8.0, memory_fraction: float = 0.5, min_chunk: int = 32, max_chunk: int = 512, ) -> int: """Calculate optimal gene chunk size for column-wise operations. For operations that iterate over genes (columns), such as Wilcoxon tests, each chunk loads all cells for a subset of genes. Memory usage is dominated by n_obs × chunk_size rather than chunk_size × n_vars. Enhanced to account for the number of perturbation groups, which significantly impacts memory usage due to output array allocation. Parameters ---------- n_obs Number of observations (cells) in the dataset. n_vars Number of variables (genes) in the dataset. n_groups Number of perturbation groups. If provided, used to estimate memory for output arrays. Large group counts require smaller chunks. available_memory_gb Available memory in gigabytes. If None, auto-detects using psutil. safety_factor Safety multiplier to account for overhead (default 8.0). memory_fraction Fraction of available memory to use (default 0.5). Leave headroom for memory-mapped arrays and system overhead. min_chunk Minimum chunk size to return (default 32). max_chunk Maximum chunk size to return (default 512). Returns ------- int Recommended gene chunk size, clamped to [min_chunk, max_chunk]. Examples -------- >>> calculate_optimal_gene_chunk_size(100000, 20000, available_memory_gb=32) 512 >>> calculate_optimal_gene_chunk_size(4000000, 38000, n_groups=18000, available_memory_gb=128) 64 """ if available_memory_gb is None: try: import psutil available_memory_gb = psutil.virtual_memory().available / 1e9 except ImportError: logger.warning( "psutil not installed, using default 16GB for chunk size calculation. " "Install with: pip install psutil" ) available_memory_gb = 16.0 # Usable memory = fraction of available (default 50%) usable_memory_bytes = available_memory_gb * memory_fraction * 1e9 # Base memory: dense conversion of n_obs × chunk_size (float64) base_memory_per_gene = n_obs * 8 # Group memory: output arrays of n_groups × chunk_size (float64) × ~8 arrays # (effect, u_stat, pvalue, z_score, lfc, pts, pts_rest, order) group_memory_per_gene = (n_groups * 8 * 8) if n_groups else 0 # Total memory per gene with safety factor total_memory_per_gene = (base_memory_per_gene + group_memory_per_gene) * safety_factor # Calculate max chunk from memory max_chunk_from_memory = int(usable_memory_bytes / total_memory_per_gene) if total_memory_per_gene > 0 else max_chunk # Dynamic max_chunk based on n_groups (datasets with many groups need smaller chunks) effective_max_chunk = max_chunk if n_groups is not None: if n_groups > 10000: effective_max_chunk = min(effective_max_chunk, 128) elif n_groups > 5000: effective_max_chunk = min(effective_max_chunk, 256) elif n_groups > 2000: effective_max_chunk = min(effective_max_chunk, 384) # Cell-count-based caps for very large datasets (Wilcoxon ranking is memory-intensive) # On memory-constrained machines (< 32 GB available) these caps act as hard safety guards # to avoid OOM. On large-memory machines the memory-formula above (max_chunk_from_memory) # already accounts for available RAM, so over-riding it with small fixed values would be # unnecessarily conservative — e.g. Feng-gwsf on 128 GB can safely use 384-gene chunks. _CELL_COUNT_CAP_MEMORY_THRESHOLD_GB = 32.0 if available_memory_gb < _CELL_COUNT_CAP_MEMORY_THRESHOLD_GB: if n_obs > 1_000_000: effective_max_chunk = min(effective_max_chunk, 32) # Very conservative for >1M cells elif n_obs > 500_000: effective_max_chunk = min(effective_max_chunk, 64) # Conservative for >500K cells elif n_obs > 300_000: effective_max_chunk = min(effective_max_chunk, 128) # Moderate for >300K cells # else: trust max_chunk_from_memory (computed from available_memory_gb above) # Additional cell-count cap for ALL memory tiers: the per-chunk transient memory # for Wilcoxon (dense block + control presort + pert stacking) scales as # ~n_obs × chunk × 12 bytes. On high-cell datasets this can exceed 5% of available # memory per chunk, causing glibc arena fragmentation across many chunks. _PER_CHUNK_BUDGET_FRACTION = 0.05 per_chunk_bytes_per_gene = n_obs * 12 # dense(f32) + ctrl(f64) + pert(f32) per_chunk_budget = available_memory_gb * _PER_CHUNK_BUDGET_FRACTION * 1e9 if per_chunk_bytes_per_gene > 0 and per_chunk_bytes_per_gene * effective_max_chunk > per_chunk_budget: cell_cap = max(min_chunk, int(per_chunk_budget / per_chunk_bytes_per_gene)) effective_max_chunk = min(effective_max_chunk, cell_cap) # Clamp to reasonable range chunk_size = max(min_chunk, min(effective_max_chunk, max_chunk_from_memory)) logger.info( f"Calculated gene chunk size: {chunk_size} " f"(dataset: {n_obs} cells × {n_vars} genes, " f"groups: {n_groups or 'unknown'}, " f"available memory: {available_memory_gb:.1f}GB)" ) return chunk_size
[docs] def calculate_wilcoxon_chunk_size( n_obs: int, n_vars: int, *, available_memory_gb: float | None = None, min_chunk: int = 32, max_chunk: int = 4096, ) -> int: """Calculate optimal gene chunk size for Wilcoxon rank-sum tests. Unlike :func:`calculate_optimal_gene_chunk_size`, this function has **no n_groups cap**. Wilcoxon writes all output arrays (effect, pvalue, z-score, etc.) to on-disk memory-mapped files immediately, so peak RAM per chunk is independent of the number of perturbation groups. The only effective cap is a per-chunk transient-memory budget: transient ≈ chunk_size × (n_obs × 4 + n_ctrl × 8 + n_pert × 4) ≈ chunk_size × n_obs × 12 bytes The budget is set to 15 % of ``available_memory_gb`` so that a single chunk never exceeds ~1/7th of the node RAM. Parameters ---------- n_obs Number of cells in the dataset. n_vars Number of genes in the dataset (used only for logging). available_memory_gb Available memory in GB. When *None*, detected via :mod:`psutil`. On HPC nodes, pass the SLURM ``--mem`` value so the cap reflects the actual job allocation rather than system-wide free memory. min_chunk Floor for the returned chunk size (default 32). max_chunk Ceiling for the returned chunk size (default 4096). The cell-budget cap is applied *before* this ceiling, so ``max_chunk`` is only active for small/sparse datasets where the cell cap would be very large. Returns ------- int Recommended gene chunk size, clamped to ``[min_chunk, max_chunk]``. Examples -------- >>> # Feng-gwsnf: 396K cells, 128 GB → 4067 >>> calculate_wilcoxon_chunk_size(396458, 32373, available_memory_gb=128) 4067 >>> # Feng-ts: 1.16M cells, 128 GB → 1378 >>> calculate_wilcoxon_chunk_size(1161864, 33165, available_memory_gb=128) 1378 """ if available_memory_gb is None: try: import psutil available_memory_gb = psutil.virtual_memory().available / 1e9 except ImportError: logger.warning( "psutil not installed, using default 16GB for Wilcoxon chunk size calculation. " "Install with: pip install psutil" ) available_memory_gb = 16.0 # Per-chunk transient memory budget: 15% of available RAM. # transient ≈ chunk_size × n_obs × 12 bytes # (dense float32 block + ctrl float64 presort + pert float32 stack). _PER_CHUNK_BUDGET_FRACTION = 0.15 per_chunk_budget = available_memory_gb * _PER_CHUNK_BUDGET_FRACTION * 1e9 per_chunk_bytes_per_gene = n_obs * 12 cell_cap = max(min_chunk, int(per_chunk_budget / per_chunk_bytes_per_gene)) chunk_size = max(min_chunk, min(cell_cap, max_chunk)) logger.info( f"Calculated Wilcoxon chunk size: {chunk_size} " f"(dataset: {n_obs} cells × {n_vars} genes, " f"available memory: {available_memory_gb:.1f}GB)" ) return chunk_size
[docs] def calculate_nb_glm_chunk_size( n_obs: int, n_vars: int, n_groups: int | None = None, available_memory_gb: float | None = None, memory_limit_gb: float | None = None, safety_factor: float = 8.0, memory_fraction: float = 0.5, min_chunk: int = 32, max_chunk: int = 256, ) -> int: """Calculate optimal gene chunk size for NB-GLM operations. NB-GLM iterates over genes (columns) and for each chunk: - Loads dense count data: n_obs × chunk_size - Computes dispersion: requires cell-level statistics - Fits GLM: design matrix operations This function is specifically tuned for NB-GLM memory patterns, which differ from Wilcoxon (ranking) and t-test (simple statistics). Parameters ---------- n_obs Number of observations (cells) in the dataset. n_vars Number of variables (genes) in the dataset. n_groups Number of perturbation groups. If provided, used to estimate memory for output arrays and design matrix overhead. available_memory_gb Available memory in gigabytes. If None, auto-detects using psutil. memory_limit_gb Optional hard memory limit in GB. If provided, uses the minimum of available memory and this limit. safety_factor Safety multiplier to account for overhead (default 8.0). memory_fraction Fraction of available memory to use (default 0.5). min_chunk Minimum chunk size to return (default 32). max_chunk Maximum chunk size to return (default 256). Returns ------- int Recommended gene chunk size, clamped to [min_chunk, max_chunk]. For datasets where memory is sufficient, returns max_chunk (256). Only reduces chunk size when memory would be exceeded. Examples -------- >>> calculate_nb_glm_chunk_size(100000, 20000, n_groups=100, available_memory_gb=128) 256 >>> calculate_nb_glm_chunk_size(1200000, 36000, n_groups=500, available_memory_gb=128) 143 """ if available_memory_gb is None: try: import psutil available_memory_gb = psutil.virtual_memory().available / 1e9 except ImportError: logger.warning( "psutil not installed, using default 16GB for NB-GLM chunk size calculation." ) available_memory_gb = 16.0 # Apply memory limit if provided if memory_limit_gb is not None: available_memory_gb = min(available_memory_gb, memory_limit_gb) # Usable memory = fraction of available usable_memory_bytes = available_memory_gb * memory_fraction * 1e9 # NB-GLM memory per gene (conservative estimate): # - Dense counts: n_obs × 8 bytes (float64) # - Design matrix contribution: n_obs × 2 × 8 bytes # - Working matrices (mu, var, residuals): n_obs × 8 × 3 # - Output arrays: n_groups × 8 × 8 (lfc, stat, pvalue, se, etc.) base_memory_per_gene = n_obs * 8 * 6 # ~48 bytes per cell per gene group_memory_per_gene = (n_groups * 8 * 8) if n_groups else 0 total_memory_per_gene = (base_memory_per_gene + group_memory_per_gene) * safety_factor # Calculate max chunk from memory if total_memory_per_gene > 0: max_chunk_from_memory = int(usable_memory_bytes / total_memory_per_gene) else: max_chunk_from_memory = max_chunk # Clamp to reasonable range chunk_size = max(min_chunk, min(max_chunk, max_chunk_from_memory)) logger.info( f"Calculated NB-GLM chunk size: {chunk_size} " f"(dataset: {n_obs} cells × {n_vars} genes, " f"groups: {n_groups or 'unknown'}, " f"available memory: {available_memory_gb:.1f}GB)" ) return chunk_size
[docs] def calculate_pca_chunk_size( n_obs: int, n_vars: int, n_comps: int = 50, available_memory_gb: float | None = None, method: str = "auto", memory_fraction: float = 0.5, min_chunk: int = 256, max_chunk: int = 4096, ) -> tuple[int, str]: """Calculate optimal chunk size for streaming PCA. PCA memory usage depends on the method: - sparse_cov: O(genes²) for covariance matrix, fast for small gene counts - incremental: O(chunk × genes) for data chunks, better for large gene counts Parameters ---------- n_obs Number of observations (cells) in the dataset. n_vars Number of variables (genes) in the dataset. n_comps Number of principal components to compute. Default 50. available_memory_gb Available memory in gigabytes. If None, auto-detects using psutil. method PCA method: 'auto', 'sparse_cov', or 'incremental'. 'auto' selects based on gene count and available memory. memory_fraction Fraction of available memory to use (default 0.5). min_chunk Minimum chunk size to return (default 256). max_chunk Maximum chunk size to return (default 4096). Returns ------- tuple[int, str] (chunk_size, selected_method) where selected_method is 'sparse_cov' or 'incremental'. Examples -------- >>> calculate_pca_chunk_size(100000, 8000, available_memory_gb=32) (2048, 'sparse_cov') >>> calculate_pca_chunk_size(100000, 50000, available_memory_gb=32) (1024, 'incremental') """ if available_memory_gb is None: try: import psutil available_memory_gb = psutil.virtual_memory().available / 1e9 except ImportError: logger.warning( "psutil not installed, using default 16GB for PCA chunk size calculation." ) available_memory_gb = 16.0 usable_memory_gb = available_memory_gb * memory_fraction # Estimate covariance matrix memory (genes × genes × 8 bytes) cov_memory_gb = n_vars * n_vars * 8 / 1e9 # Select method if auto if method == "auto": # Use sparse_cov if covariance fits in 30% of usable memory if cov_memory_gb < usable_memory_gb * 0.3: selected_method = "sparse_cov" else: selected_method = "incremental" else: selected_method = method # Calculate chunk size based on method if selected_method == "sparse_cov": # Need: XTX (genes² × 8), sums (genes × 8), chunk (chunk × genes × 8) reserved_gb = cov_memory_gb + n_vars * 8 / 1e9 remaining_gb = max(0.1, usable_memory_gb - reserved_gb) else: # Need: mean (genes × 8), IPCA internals (~2 × comps × genes × 8), chunk ipca_internal_gb = 2 * n_comps * n_vars * 8 / 1e9 reserved_gb = n_vars * 8 / 1e9 + ipca_internal_gb remaining_gb = max(0.1, usable_memory_gb - reserved_gb) # Chunk memory: chunk_size × n_vars × 8 bytes (float64 during computation) bytes_per_row = n_vars * 8 max_chunk_from_memory = int(remaining_gb * 0.5 * 1e9 / bytes_per_row) chunk_size = max(min_chunk, min(max_chunk, max_chunk_from_memory)) logger.info( f"PCA chunk size: {chunk_size}, method: {selected_method} " f"(dataset: {n_obs} cells × {n_vars} genes, " f"cov matrix: {cov_memory_gb:.2f} GB, " f"available: {available_memory_gb:.1f} GB)" ) return chunk_size, selected_method
[docs] def calculate_adaptive_qc_thresholds( adata: ad.AnnData, perturbation_column: str, mode: str = "conservative", sample_size: int = 10000, chunk_size: int | None = None, ) -> dict: """Calculate adaptive QC thresholds based on data distribution. Uses percentile-based approach to determine appropriate QC parameters that retain most of the data while filtering outliers. Parameters ---------- adata AnnData object (can be backed). perturbation_column Column in adata.obs containing perturbation labels. mode 'conservative' (10th percentile, retains ~90%) or 'aggressive' (5th percentile, retains ~95%). sample_size Maximum number of cells to sample for gene expression analysis. chunk_size Optional fixed chunk size to use. If None, calculated adaptively. Returns ------- dict Dictionary with keys: min_genes, min_cells_per_perturbation, min_cells_per_gene, chunk_size. Examples -------- >>> adata = ad.read_h5ad("data.h5ad", backed='r') >>> thresholds = calculate_adaptive_qc_thresholds(adata, "perturbation") >>> adata.file.close() """ percentile = 10.0 if mode == "conservative" else 5.0 # Analyze perturbation sizes if perturbation_column not in adata.obs.columns: raise KeyError( f"Perturbation column '{perturbation_column}' not found in adata.obs. " f"Available columns: {list(adata.obs.columns)}" ) pert_counts = adata.obs[perturbation_column].value_counts() p_percentile = pert_counts.quantile(percentile / 100.0) min_cells_per_pert = int(max(5, min(50, p_percentile))) # Analyze gene expression (sample if dataset is large) n_sample = min(sample_size, adata.n_obs) if n_sample < adata.n_obs: sample_idx = np.random.choice(adata.n_obs, n_sample, replace=False) sample_idx = np.sort(sample_idx) # Sort for efficient backed access else: sample_idx = None # Calculate cells per gene and genes per cell efficiently # Use chunked processing to handle backed datasets cells_per_gene = np.zeros(adata.n_vars, dtype=np.int64) genes_per_cell = np.zeros(n_sample if sample_idx is not None else adata.n_obs, dtype=np.int64) if chunk_size is None: chunk_size = calculate_optimal_chunk_size(adata.n_obs, adata.n_vars) cell_idx = 0 for chunk_start in range(0, adata.n_obs, chunk_size): chunk_end = min(chunk_start + chunk_size, adata.n_obs) # Get the chunk indices if sample_idx is not None: # Get indices that fall in this chunk mask = (sample_idx >= chunk_start) & (sample_idx < chunk_end) if not mask.any(): continue chunk_indices = sample_idx[mask] - chunk_start X_chunk = adata.X[chunk_start:chunk_end][chunk_indices] else: X_chunk = adata.X[chunk_start:chunk_end] # Process chunk (works with both sparse and dense) if sp.issparse(X_chunk): # Count non-zeros per gene (column) cells_per_gene += np.asarray(np.diff(X_chunk.tocsc().indptr)) # Count non-zeros per cell (row) chunk_genes_per_cell = np.asarray(np.diff(X_chunk.tocsr().indptr)) else: # Dense matrix X_chunk_bool = X_chunk > 0 cells_per_gene += np.asarray(X_chunk_bool.sum(axis=0)).ravel() chunk_genes_per_cell = np.asarray(X_chunk_bool.sum(axis=1)).ravel() # Store genes per cell for this chunk n_cells_in_chunk = chunk_genes_per_cell.shape[0] genes_per_cell[cell_idx:cell_idx + n_cells_in_chunk] = chunk_genes_per_cell cell_idx += n_cells_in_chunk # Trim genes_per_cell if we didn't fill it completely genes_per_cell = genes_per_cell[:cell_idx] # Calculate thresholds from the collected statistics gene_percentile = np.percentile(cells_per_gene, percentile) min_cells_per_gene = int(max(5, min(100, gene_percentile))) median_genes = int(np.median(genes_per_cell)) min_genes = max(5, min(50, median_genes // 10)) # 10% of median # Calculate optimal chunk size if not provided if chunk_size is None: chunk_size = calculate_optimal_chunk_size(adata.n_obs, adata.n_vars) thresholds = { "min_genes": min_genes, "min_cells_per_perturbation": min_cells_per_pert, "min_cells_per_gene": min_cells_per_gene, "chunk_size": chunk_size, } logger.info( f"Adaptive QC thresholds ({mode} mode):\n" f" - min_genes: {min_genes}\n" f" - min_cells_per_perturbation: {min_cells_per_pert} " f"({percentile}th percentile: {p_percentile:.1f})\n" f" - min_cells_per_gene: {min_cells_per_gene} " f"({percentile}th percentile: {gene_percentile:.1f})\n" f" - chunk_size: {chunk_size}" ) return thresholds
[docs] def standardize_dataset( dataset_path: Path | str, perturbation_column: str, control_label: str | None, gene_name_column: str | None, output_dir: Path | str, force: bool = False, ) -> Path: """Standardize dataset column names and control labels with caching. Creates a standardized copy of the dataset with: - perturbation_column renamed to 'perturbation' - control labels standardized to 'control' - gene_name_column set as var.index if specified This function uses a streaming approach to avoid loading the X matrix into memory, making it suitable for very large datasets (>1M cells). Standardized files are cached in {output_dir}/.cache/ and reused unless force=True. Parameters ---------- dataset_path Path to original dataset (.h5ad file). perturbation_column Name of perturbation column in original dataset. control_label Control label to standardize. If None, auto-detects. gene_name_column Gene name column to use as var.index. If None, uses existing var.index. output_dir Directory for cached standardized files. force If True, regenerate standardized file even if cache exists. Returns ------- Path Path to standardized dataset (either cached or newly created). Examples -------- >>> standardized_path = standardize_dataset( ... "data/original.h5ad", ... perturbation_column="gene", ... control_label=None, ... gene_name_column="gene_symbols", ... output_dir="results", ... force=False ... ) """ import datetime import shutil dataset_path = Path(dataset_path) output_dir = Path(output_dir) cache_dir = output_dir / ".cache" cache_dir.mkdir(parents=True, exist_ok=True) standardized_path = cache_dir / f"standardized_{dataset_path.stem}.h5ad" # Check if cached version exists if standardized_path.exists() and not force: logger.info(f"Using cached standardized dataset: {standardized_path}") return standardized_path logger.info(f"Standardizing dataset: {dataset_path.name}") logger.info(f" - Perturbation column: '{perturbation_column}' → 'perturbation'") # Track standardization metadata metadata = { "original_path": str(dataset_path), "standardization_timestamp": datetime.datetime.now().isoformat(), "column_mappings": {}, "label_mappings": {}, } # Read obs/var metadata only (without loading X into memory) adata = ad.read_h5ad(dataset_path, backed='r') obs_df = adata.obs.copy() var_df = adata.var.copy() uns_dict = dict(adata.uns) # shallow copy of uns adata.file.close() # Standardize perturbation column in obs if perturbation_column != "perturbation": if perturbation_column not in obs_df.columns: raise KeyError( f"Perturbation column '{perturbation_column}' not found. " f"Available: {list(obs_df.columns)}" ) obs_df.rename(columns={perturbation_column: "perturbation"}, inplace=True) metadata["column_mappings"]["perturbation"] = perturbation_column logger.info(f" - Renamed '{perturbation_column}' → 'perturbation'") # Standardize control label labels = obs_df["perturbation"].astype(str).to_numpy() detected_control = resolve_control_label(labels, control_label, verbose=False) if detected_control != "control": obs_df["perturbation"] = ( obs_df["perturbation"] .astype(str) .replace({detected_control: "control"}) ) metadata["label_mappings"][detected_control] = "control" logger.info(f" - Standardized control: '{detected_control}' → 'control'") else: logger.info(f" - Control label already standardized: 'control'") # Standardize gene names in var if gene_name_column is not None: if gene_name_column in var_df.columns: if not (var_df.index == var_df[gene_name_column]).all(): var_df.index = var_df[gene_name_column].values metadata["column_mappings"]["var.index"] = gene_name_column logger.info(f" - Set var.index from '{gene_name_column}'") else: logger.warning( f"Gene column '{gene_name_column}' not found in var. " f"Using existing var.index." ) # Store metadata if "standardization_metadata" not in uns_dict: uns_dict["standardization_metadata"] = {} uns_dict["standardization_metadata"].update(metadata) # Copy h5ad file at filesystem level (streaming - no memory load) logger.info(f" - Copying dataset (streaming, no X matrix load)...") shutil.copy2(dataset_path, standardized_path) # Modify obs/var/uns in-place using h5py logger.info(f" - Updating metadata in copied file...") with h5py.File(standardized_path, 'r+') as f: # Update obs - need to rewrite the obs group # Read current obs structure and update _update_h5ad_dataframe(f, 'obs', obs_df) # Update var - need to rewrite the var group _update_h5ad_dataframe(f, 'var', var_df) # Update uns - handle the standardization_metadata key _update_h5ad_uns(f, 'uns', uns_dict) logger.info(f"Saved standardized dataset: {standardized_path}") return standardized_path
def _update_h5ad_dataframe(h5file: h5py.File, group_name: str, df: pd.DataFrame) -> None: """Update obs or var DataFrame in an h5ad file in-place. This function handles the anndata HDF5 format where DataFrames are stored as groups with individual columns as datasets. """ if group_name not in h5file: return grp = h5file[group_name] # Update the index (stored as _index attribute or separate dataset) index_key = grp.attrs.get('_index', '_index') if isinstance(index_key, bytes): index_key = index_key.decode('utf-8') if index_key in grp: del grp[index_key] # Store index as variable-length strings index_vals = df.index.astype(str).values grp.create_dataset(index_key, data=index_vals.astype('O'), dtype=h5py.special_dtype(vlen=str)) # Update column names (stored as 'column-order' attribute) col_names = list(df.columns) # HDF5 attributes don't support object dtype; encode strings as bytes grp.attrs['column-order'] = np.array([c.encode('utf-8') for c in col_names]) # Update each column for col in df.columns: if col in grp: del grp[col] col_data = df[col] # Handle categorical columns if hasattr(col_data, 'cat'): # Store as categorical (anndata format) cat_grp = grp.create_group(col) cat_grp.attrs['encoding-type'] = 'categorical' cat_grp.attrs['encoding-version'] = '0.2.0' cat_grp.attrs['ordered'] = col_data.cat.ordered # Required for anndata # Store categories categories = col_data.cat.categories.astype(str).values cat_grp.create_dataset('categories', data=categories.astype('O'), dtype=h5py.special_dtype(vlen=str)) # Store codes cat_grp.create_dataset('codes', data=col_data.cat.codes.values) elif col_data.dtype == object or col_data.dtype.kind in ('U', 'S'): # String column - store as variable-length strings str_vals = col_data.astype(str).values grp.create_dataset(col, data=str_vals.astype('O'), dtype=h5py.special_dtype(vlen=str)) else: # Numeric column grp.create_dataset(col, data=col_data.values) def _update_h5ad_uns(h5file: h5py.File, group_name: str, uns_dict: dict) -> None: """Update uns dict in an h5ad file in-place. Only updates the standardization_metadata key to avoid breaking other uns data. """ if group_name not in h5file: h5file.create_group(group_name) grp = h5file[group_name] # Only update standardization_metadata to be safe key = 'standardization_metadata' if key in uns_dict: if key in grp: del grp[key] # Store as a group with string datasets for each subkey meta_grp = grp.create_group(key) meta_grp.attrs['encoding-type'] = 'dict' meta_grp.attrs['encoding-version'] = '0.1.0' for subkey, subval in uns_dict[key].items(): if isinstance(subval, dict): # Nested dict - store as JSON string import json meta_grp.create_dataset(subkey, data=json.dumps(subval), dtype=h5py.special_dtype(vlen=str)) elif isinstance(subval, str): meta_grp.create_dataset(subkey, data=subval, dtype=h5py.special_dtype(vlen=str)) else: meta_grp.create_dataset(subkey, data=str(subval), dtype=h5py.special_dtype(vlen=str))
[docs] def needs_sorting_for_nbglm( path: str | Path, perturbation_column: str = "perturbation", *, min_cells: int = 360_000, min_perturbations: int = 100, contiguity_threshold: float = 0.5, ) -> bool: """Check if a dataset would benefit from sorting by perturbation for NB-GLM. Large datasets with scattered cells benefit from having cells sorted by perturbation label, as this enables contiguous I/O reads instead of random access. This is especially important when the data is stored on HDD (rotational disk). The default thresholds are based on I/O overhead analysis: - At ~100 IOPS (typical HDD), 360K cells = 1 hour of random I/O overhead - min_perturbations=100 ensures sufficient parallel workload to benefit - contiguity_threshold=0.5 catches scattered datasets Parameters ---------- path Path to h5ad file. perturbation_column Column in obs containing perturbation labels. min_cells Minimum number of cells for sorting to be recommended. Default: 360,000 (~1 hour of random I/O on HDD at 100 IOPS). min_perturbations Minimum number of perturbations for sorting to be recommended. contiguity_threshold If average contiguity is below this threshold, sorting is recommended. Contiguity is the fraction of a perturbation's cells that are in a contiguous block (1.0 = perfectly contiguous, 0.0 = completely scattered). Returns ------- bool True if sorting is recommended, False otherwise. """ backed = read_backed(path) try: # First check if file is already sorted if "sorting_metadata" in backed.uns: metadata = backed.uns["sorting_metadata"] if metadata.get("sorted_by") == perturbation_column: logger.debug(f"Dataset is already sorted by '{perturbation_column}'") return False n_cells = backed.n_obs # Check cell count threshold if n_cells < min_cells: logger.debug(f"Dataset has {n_cells:,} cells < {min_cells:,}, sorting not needed") return False # Get perturbation labels if perturbation_column not in backed.obs.columns: logger.warning(f"Perturbation column '{perturbation_column}' not found") return False labels = backed.obs[perturbation_column].astype(str).to_numpy() unique_labels = np.unique(labels) n_perts = len(unique_labels) # Check perturbation count threshold if n_perts < min_perturbations: logger.debug(f"Dataset has {n_perts} perturbations < {min_perturbations}, sorting not needed") return False # Sample perturbations to estimate contiguity sample_perts = unique_labels[:min(20, n_perts)] total_contiguity = 0.0 for pert in sample_perts: indices = np.where(labels == pert)[0] if len(indices) < 2: total_contiguity += 1.0 continue span = indices.max() - indices.min() + 1 contiguity = len(indices) / span total_contiguity += contiguity avg_contiguity = total_contiguity / len(sample_perts) if avg_contiguity >= contiguity_threshold: logger.debug(f"Dataset contiguity {avg_contiguity:.1%} >= {contiguity_threshold:.1%}, sorting not needed") return False logger.info( f"Dataset would benefit from sorting: {n_cells:,} cells, {n_perts} perturbations, " f"contiguity {avg_contiguity:.1%} < {contiguity_threshold:.1%}" ) return True finally: backed.file.close()
[docs] def sort_by_perturbation( path: str | Path, perturbation_column: str = "perturbation", control_label: str | None = None, *, output_path: str | Path | None = None, chunk_size: int = 4096, force: bool = False, ) -> Path: """Sort cells by perturbation label for contiguous I/O access. Creates a new h5ad file with cells reordered so that all cells from each perturbation are contiguous. Control cells are placed first, followed by each perturbation group in alphabetical order. This enables efficient sequential reads when processing perturbations in parallel. The function works in streaming mode to handle datasets larger than memory. Sorting information is stored in uns['sorting_metadata']. Parameters ---------- path Path to input h5ad file. perturbation_column Column in obs containing perturbation labels. control_label Label for control cells. If None, auto-detected from common patterns. output_path Path for output sorted file. If None, appends '_sorted' to input name. chunk_size Number of cells to process per chunk during streaming write. force If True, recreate sorted file even if it already exists. Returns ------- Path Path to the sorted h5ad file. Examples -------- >>> sorted_path = sort_by_perturbation( ... "data/large_dataset.h5ad", ... perturbation_column="perturbation", ... control_label="control", ... ) >>> # sorted_path is now "data/large_dataset_sorted.h5ad" Notes ----- The sorted file contains additional metadata in uns['sorting_metadata']: - original_path: Path to the original unsorted file - sort_order: Array mapping new indices to original indices - perturbation_boundaries: Dict mapping perturbation labels to (start, end) indices - sorted_by: The column used for sorting - timestamp: When the sorting was performed For sparse inputs the output is always written in CSR format, since sorting benefits row-wise (per-perturbation) access patterns used by NB-GLM. CSC files used by Wilcoxon do not need perturbation sorting. """ import datetime path = Path(path) # Determine output path if output_path is None: output_path = path.parent / f"{path.stem}_sorted.h5ad" else: output_path = Path(output_path) # Check if already sorted if output_path.exists() and not force: # Verify it's properly sorted try: backed = read_backed(output_path) has_metadata = "sorting_metadata" in backed.uns backed.file.close() if has_metadata: logger.info(f"Using existing sorted file: {output_path}") return output_path except Exception: pass # File exists but invalid, recreate logger.info(f"Sorting dataset by perturbation: {path}") # Read metadata backed = read_backed(path) try: n_cells = backed.n_obs n_genes = backed.n_vars # Get perturbation labels if perturbation_column not in backed.obs.columns: raise KeyError( f"Perturbation column '{perturbation_column}' not found. " f"Available: {list(backed.obs.columns)}" ) labels = backed.obs[perturbation_column].astype(str).to_numpy() # Detect control label if not specified if control_label is None: control_label = resolve_control_label(labels, None) # Create sort order: control first, then alphabetical unique_labels = sorted(set(labels) - {control_label}) label_order = [control_label] + unique_labels # Create mapping from label to sort priority label_priority = {label: i for i, label in enumerate(label_order)} # Get sort indices (stable sort to preserve original order within groups) sort_keys = np.array([label_priority[l] for l in labels]) sort_indices = np.argsort(sort_keys, kind='stable') # Compute perturbation boundaries sorted_labels = labels[sort_indices] boundaries = {} current_label = None start_idx = 0 for i, label in enumerate(sorted_labels): if label != current_label: if current_label is not None: boundaries[current_label] = [start_idx, i] # Use list, not tuple current_label = label start_idx = i if current_label is not None: boundaries[current_label] = [start_idx, len(sorted_labels)] # Use list, not tuple # Read obs and var obs_sorted = backed.obs.iloc[sort_indices].copy() var = backed.var.copy() # Get uns (convert to regular dict for modification) uns = dict(backed.uns) finally: backed.file.close() # Add sorting metadata - convert all to h5ad-compatible types # Note: boundaries values are [start, end] lists (tuples not serializable) sorting_metadata = { "original_path": str(path), "sorted_by": perturbation_column, "control_label": control_label, "timestamp": datetime.datetime.now().isoformat(), "n_perturbations": len(label_order), "perturbation_boundaries": boundaries, # Dict[str, List[int]] } # Don't store full sort_order for large datasets (memory waste) if len(sort_indices) < 100000: sorting_metadata["sort_order"] = sort_indices.tolist() uns["sorting_metadata"] = sorting_metadata # Create output file output_path.parent.mkdir(parents=True, exist_ok=True) logger.info(f" Reordering {n_cells:,} cells into {len(label_order)} contiguous groups...") # Check if source is dense or sparse storage_format = get_matrix_storage_format(path) is_dense = storage_format == "dense" try: if is_dense: # For dense storage: write directly as dense _write_sorted_dense( source_path=path, output_path=output_path, sort_indices=sort_indices, obs_sorted=obs_sorted, var=var, uns=uns, chunk_size=chunk_size, ) else: # For sparse storage: stream as CSR _write_sorted_sparse( source_path=path, output_path=output_path, sort_indices=sort_indices, obs_sorted=obs_sorted, var=var, uns=uns, chunk_size=chunk_size, ) except Exception: # Remove partial output to avoid corrupt file on next run if output_path.exists(): logger.warning(f" Removing partial sorted file: {output_path}") output_path.unlink() raise logger.info(f"Saved sorted dataset: {output_path}") logger.info(f" Perturbation groups: {len(label_order)} (control + {len(unique_labels)} perturbations)") return output_path
def _write_sorted_dense( source_path: Path, output_path: Path, sort_indices: np.ndarray, obs_sorted: pd.DataFrame, var: pd.DataFrame, uns: dict, chunk_size: int, ) -> None: """Write sorted file for dense input matrix. Uses chunked reading to avoid loading full matrix into memory. Writes h5ad in two passes: first X matrix via h5py, then metadata via anndata. """ backed = read_backed(source_path) try: n_cells = backed.n_obs n_genes = backed.n_vars # Get dtype from first chunk sample = backed.X[:min(100, n_cells), :] dtype = sample.dtype # Create output with chunked writing for memory efficiency with h5py.File(output_path, 'w') as f: # Create dataset with chunking for efficient access X_out = f.create_dataset( 'X', shape=(n_cells, n_genes), dtype=dtype, chunks=(min(chunk_size, n_cells), n_genes), ) # Write in chunks based on output order for start in range(0, n_cells, chunk_size): end = min(start + chunk_size, n_cells) # Get the original indices for this output chunk chunk_indices = sort_indices[start:end] # Read from source (optimize by sorting indices for sequential read) read_order = np.argsort(chunk_indices) sorted_chunk_indices = chunk_indices[read_order] # Read data in optimized order chunk_data = backed.X[sorted_chunk_indices, :] # Reorder to match output order inverse_order = np.argsort(read_order) chunk_data = chunk_data[inverse_order] X_out[start:end, :] = chunk_data if (start // chunk_size) % 100 == 0: logger.debug(f" Written {end:,}/{n_cells:,} cells...") finally: backed.file.close() # Write metadata using anndata (proper h5ad structure) # First create a temp file with correct metadata structure temp_path = output_path.with_suffix('.meta.h5ad') adata_meta = ad.AnnData( X=sp.csr_matrix((len(obs_sorted), len(var)), dtype=np.float32), # Placeholder obs=obs_sorted, var=var, uns=uns, ) adata_meta.write(temp_path) # Copy metadata from temp to main file with h5py.File(temp_path, 'r') as src: with h5py.File(output_path, 'a') as dst: for key in ['obs', 'var', 'uns']: if key in src: if key in dst: del dst[key] src.copy(key, dst) # Cleanup temp file temp_path.unlink() def _write_sorted_sparse( source_path: Path, output_path: Path, sort_indices: np.ndarray, obs_sorted: pd.DataFrame, var: pd.DataFrame, uns: dict, chunk_size: int, ) -> None: """Write sorted file for sparse input matrix. Uses chunked I/O to avoid loading the full matrix into memory. Rows are read in output order (chunk_size at a time), converted to CSR components, and appended to resizable HDF5 datasets. Peak memory is proportional to chunk_size × n_genes, not n_cells × n_genes. """ backed = read_backed(source_path) try: n_cells = backed.n_obs n_genes = backed.n_vars # Determine dtype from a small sample sample = backed.X[:1] if sp.issparse(sample): data_dtype = sample.dtype else: data_dtype = np.float32 logger.info( f" Streaming sorted sparse write: {n_cells:,} cells, " f"chunk_size={chunk_size}" ) # Build CSR structure incrementally via h5py with h5py.File(output_path, "w") as f: x_grp = f.create_group("X") x_grp.attrs["encoding-type"] = "csr_matrix" x_grp.attrs["encoding-version"] = "0.1.0" x_grp.attrs["shape"] = np.array([n_cells, n_genes], dtype=np.int64) # Resizable datasets for data and indices; indptr is pre-allocated ds_data = x_grp.create_dataset( "data", shape=(0,), maxshape=(None,), dtype=data_dtype, chunks=(min(262144, max(1, n_cells)),), ) ds_indices = x_grp.create_dataset( "indices", shape=(0,), maxshape=(None,), dtype=np.int32, chunks=(min(262144, max(1, n_cells)),), ) ds_indptr = x_grp.create_dataset( "indptr", shape=(n_cells + 1,), dtype=np.int64, ) nnz_written = 0 ds_indptr[0] = 0 for start in range(0, n_cells, chunk_size): end = min(start + chunk_size, n_cells) # Original row indices for this output chunk orig_idx = sort_indices[start:end] # Read rows from source (sorted for sequential access) read_order = np.argsort(orig_idx) sorted_orig_idx = orig_idx[read_order] block = backed.X[sorted_orig_idx, :] # Undo read_order to restore output order inverse_order = np.argsort(read_order) if sp.issparse(block): block = sp.csr_matrix(block)[inverse_order, :] else: block = sp.csr_matrix(block[inverse_order, :]) csr = sp.csr_matrix(block, dtype=data_dtype) chunk_nnz = csr.nnz if chunk_nnz > 0: new_total = nnz_written + chunk_nnz ds_data.resize((new_total,)) ds_indices.resize((new_total,)) ds_data[nnz_written:new_total] = csr.data ds_indices[nnz_written:new_total] = csr.indices # Write indptr for this chunk (int64 to avoid overflow # when cumulative nnz exceeds INT32_MAX ≈ 2.1 billion) chunk_indptr = csr.indptr[1:].astype(np.int64) + nnz_written ds_indptr[start + 1 : end + 1] = chunk_indptr nnz_written += chunk_nnz if (start // chunk_size) % 50 == 0: logger.debug( f" Written {end:,}/{n_cells:,} cells " f"({nnz_written:,} nnz)..." ) finally: backed.file.close() # Write obs, var, uns metadata using anndata temp_path = output_path.with_suffix(".meta.h5ad") adata_meta = ad.AnnData( X=sp.csr_matrix((len(obs_sorted), len(var)), dtype=np.float32), obs=obs_sorted, var=var, uns=uns, ) adata_meta.write(temp_path) with h5py.File(temp_path, "r") as src: with h5py.File(output_path, "a") as dst: for key in ["obs", "var", "uns"]: if key in src: if key in dst: del dst[key] src.copy(key, dst) temp_path.unlink()
[docs] def get_perturbation_slice( adata_or_path: str | Path | ad.AnnData, perturbation_label: str, perturbation_column: str = "perturbation", ) -> tuple[slice | None, bool]: """Get slice for a perturbation's cells, checking if file is sorted. If the file is sorted by perturbation, returns a contiguous slice. Otherwise, returns None for the slice (caller should use boolean mask). Parameters ---------- adata_or_path Path to h5ad file, or an already-opened AnnData object. perturbation_label Label of the perturbation to get slice for. perturbation_column Column containing perturbation labels. Returns ------- tuple[slice | None, bool] (slice object or None, is_sorted flag). If is_sorted is True, slice is valid for contiguous access. If is_sorted is False, slice is None and caller should use mask. """ # Handle both path and AnnData input if isinstance(adata_or_path, ad.AnnData): adata = adata_or_path should_close = False else: adata = read_backed(adata_or_path) should_close = True try: # Check for sorting metadata if "sorting_metadata" in adata.uns: metadata = adata.uns["sorting_metadata"] if metadata.get("sorted_by") == perturbation_column: boundaries = metadata.get("perturbation_boundaries", {}) if perturbation_label in boundaries: start, end = boundaries[perturbation_label] return slice(start, end), True return None, False finally: if should_close: adata.file.close()
# ============================================================================= # Feature 1 — Backed metadata helpers (load/write obs and var without X) # ============================================================================= def _read_dataframe_from_h5(grp: "h5py.Group") -> pd.DataFrame: """Read an AnnData-encoded HDF5 group as a pandas DataFrame.""" _idx_raw = grp.attrs.get("_index", "_index") index_key = _idx_raw.decode("utf-8") if isinstance(_idx_raw, bytes) else str(_idx_raw) _col_raw = grp.attrs.get("column-order", []) column_order = [x.decode("utf-8") if isinstance(x, bytes) else str(x) for x in _col_raw] raw_index = grp[index_key][()] if raw_index.dtype.kind in ("S", "O"): raw_index = np.array([x.decode("utf-8") if isinstance(x, bytes) else x for x in raw_index]) index = pd.Index(raw_index) all_keys = [k for k in grp.keys() if k != index_key] ordered_keys = [k for k in column_order if k in grp] + [ k for k in all_keys if k not in set(column_order) ] columns: dict[str, Any] = {} for key in ordered_keys: item = grp[key] if isinstance(item, h5py.Group): enc = str(item.attrs.get("encoding-type", "")) if enc == "categorical": cats_raw = item["categories"][()] if cats_raw.dtype.kind in ("S", "O"): cats_raw = np.array([x.decode("utf-8") if isinstance(x, bytes) else x for x in cats_raw]) cats = pd.Index(cats_raw) codes = item["codes"][()].astype(np.intp) ordered = bool(item.attrs.get("ordered", False)) columns[key] = pd.Categorical.from_codes( codes, categories=cats, ordered=ordered ) else: val = item[()] if val.dtype.kind in ("S", "O"): val = np.array([x.decode("utf-8") if isinstance(x, bytes) else x for x in val]) columns[key] = val return pd.DataFrame(columns, index=index) def _write_dataframe_to_h5(grp: "h5py.Group", df: pd.DataFrame) -> None: """Write a pandas DataFrame to an h5py Group in AnnData 0.2.0 encoding.""" str_dtype = h5py.string_dtype(encoding="utf-8") grp.attrs["encoding-type"] = "dataframe" grp.attrs["encoding-version"] = "0.2.0" grp.attrs["_index"] = "_index" grp.attrs["column-order"] = np.array(list(df.columns), dtype=object) idx_ds = grp.create_dataset( "_index", data=df.index.astype(str).to_numpy(), dtype=str_dtype ) idx_ds.attrs["encoding-type"] = "string-array" idx_ds.attrs["encoding-version"] = "0.2.0" for col in df.columns: series = df[col] if isinstance(series.dtype, pd.CategoricalDtype): cat_grp = grp.create_group(col) cat_grp.attrs["encoding-type"] = "categorical" cat_grp.attrs["encoding-version"] = "0.2.0" cat_grp.attrs["ordered"] = bool(series.cat.ordered) cats = series.cat.categories.astype(str).to_numpy() cd = cat_grp.create_dataset("categories", data=cats, dtype=str_dtype) cd.attrs["encoding-type"] = "string-array" cd.attrs["encoding-version"] = "0.2.0" codes = series.cat.codes.to_numpy() codes_dtype = np.int8 if len(series.cat.categories) < 128 else np.int16 co = cat_grp.create_dataset("codes", data=codes.astype(codes_dtype)) co.attrs["encoding-type"] = "array" co.attrs["encoding-version"] = "0.2.0" elif series.dtype.kind in ("O", "U", "S"): vals = series.fillna("").astype(str).to_numpy() ds = grp.create_dataset(col, data=vals, dtype=str_dtype) ds.attrs["encoding-type"] = "string-array" ds.attrs["encoding-version"] = "0.2.0" else: ds = grp.create_dataset(col, data=series.to_numpy()) ds.attrs["encoding-type"] = "array" ds.attrs["encoding-version"] = "0.2.0"
[docs] def load_obs(path: str | Path) -> pd.DataFrame: """Load the obs metadata table from an h5ad file without reading X. Parameters ---------- path Path to the h5ad file. Returns ------- pd.DataFrame Full obs DataFrame in memory. """ with h5py.File(Path(path), "r") as f: if "obs" not in f: raise KeyError("h5ad file has no 'obs' group.") return _read_dataframe_from_h5(f["obs"])
[docs] def load_var(path: str | Path) -> pd.DataFrame: """Load the var metadata table from an h5ad file without reading X. Parameters ---------- path Path to the h5ad file. Returns ------- pd.DataFrame Full var DataFrame in memory. """ with h5py.File(Path(path), "r") as f: if "var" not in f: raise KeyError("h5ad file has no 'var' group.") return _read_dataframe_from_h5(f["var"])
[docs] def write_obs(path: str | Path, df: pd.DataFrame) -> None: """Overwrite the obs metadata table in an h5ad file without touching X. Parameters ---------- path Path to the h5ad file (modified in-place). df New obs DataFrame. Must have the same number of rows as the existing obs table. Index values are written as cell barcodes. Raises ------ ValueError If the DataFrame length does not match the existing n_obs. """ path = Path(path) with h5py.File(path, "r+") as f: old_n = len(f["obs"]["_index"]) if len(df) != old_n: raise ValueError( f"DataFrame has {len(df)} rows but the file has {old_n} cells." ) del f["obs"] grp = f.create_group("obs") _write_dataframe_to_h5(grp, df)
[docs] def write_var(path: str | Path, df: pd.DataFrame) -> None: """Overwrite the var metadata table in an h5ad file without touching X. Parameters ---------- path Path to the h5ad file (modified in-place). df New var DataFrame. Must have the same number of rows as the existing var table. Raises ------ ValueError If the DataFrame length does not match the existing n_vars. """ path = Path(path) with h5py.File(path, "r+") as f: old_n = len(f["var"]["_index"]) if len(df) != old_n: raise ValueError( f"DataFrame has {len(df)} rows but the file has {old_n} genes." ) del f["var"] grp = f.create_group("var") _write_dataframe_to_h5(grp, df)
# ============================================================================= # Feature 2 — Gene name standardisation # =============================================================================
[docs] def standardise_gene_names( path: str | Path, *, column: str | None = None, strip_version: bool = True, normalise_mt_prefix: bool = True, lookup_symbols: bool = False, species: str = "human", unmapped_action: Literal["keep", "error", "warn"] = "warn", inplace: bool = True, ) -> "pd.Series | None": """Standardise gene identifiers in the var metadata table. Applies a deterministic normalisation pipeline: 1. Strip Ensembl version suffixes (``ENSG00000123.4`` → ``ENSG00000123``). 2. Normalise ``mt-`` prefix to ``MT-`` (human mitochondrial convention). 3. Optionally resolve Ensembl IDs to HGNC symbols via ``mygene`` (requires ``pip install mygene``). A ``tqdm`` progress bar is shown during batched lookups. Parameters ---------- path Path to the h5ad file. column var column to normalise. ``None`` normalises the index (var_names). strip_version Strip ``".N"`` Ensembl version suffixes. normalise_mt_prefix Convert lower-case ``mt-`` prefix to ``MT-``. lookup_symbols If True, query ``mygene`` to map Ensembl IDs → gene symbols. species Species string passed to ``mygene`` (default ``"human"``). unmapped_action What to do for IDs not found by mygene: ``"keep"`` leaves them unchanged, ``"warn"`` emits a warning, ``"error"`` raises. inplace If True, write the result back to the file and return ``None``. If False, return a Series without modifying the file. Returns ------- pd.Series or None Normalised gene names when ``inplace=False``, else ``None``. """ path = Path(path) df = load_var(path) if column is None: names = pd.Series(df.index.astype(str).to_numpy(), name="_index") else: if column not in df.columns: raise KeyError( f"Column '{column}' not found in var. " f"Available: {list(df.columns)}" ) names = df[column].astype(str).copy() # Step 1: strip Ensembl version suffix if strip_version: names = names.str.replace(r"\.\d+$", "", regex=True) # Step 2: normalise mt- prefix if normalise_mt_prefix: names = names.str.replace(r"^mt-", "MT-", regex=True) # Step 3: optional online lookup via mygene if lookup_symbols: try: import mygene # type: ignore[import] except ImportError: raise ImportError( "The 'mygene' package is required for online symbol lookup. " "Install it with: pip install mygene" ) mg = mygene.MyGeneInfo() unique_ids = names.unique().tolist() symbol_map: dict[str, str] = {} batch_size = 1000 batches = list(range(0, len(unique_ids), batch_size)) try: from tqdm import tqdm as _tqdm # type: ignore[import] it = _tqdm(batches, desc="mygene lookup", unit="batch") except ImportError: it = iter(batches) for start in it: batch = unique_ids[start : start + batch_size] hits = mg.querymany( batch, scopes="ensembl.gene,symbol", fields="symbol", species=species, verbose=False, as_dataframe=False, ) for hit in hits: query_id = hit.get("query", "") symbol = hit.get("symbol", "") if symbol and not hit.get("notfound", False): symbol_map[query_id] = symbol unmapped = [i for i in unique_ids if i not in symbol_map] if unmapped: msg = f"{len(unmapped)} gene IDs could not be mapped to symbols." if unmapped_action == "error": raise ValueError(msg + f" First 10: {unmapped[:10]}") elif unmapped_action == "warn": logger.warning("%s They will be left unchanged.", msg) names = names.map(lambda x: symbol_map.get(x, x)) if not inplace: return names if column is None: df.index = pd.Index(names.to_numpy(), name=df.index.name) else: df[column] = names.to_numpy() write_var(path, df) return None
# ============================================================================= # Feature 3 — Perturbation label normalisation # ============================================================================= _DEFAULT_CONTROL_ALIASES: frozenset[str] = frozenset({ "ntc", "non-targeting", "non_targeting", "nontarget", "non-target", "non_target", "control", "ctrl", "scramble", "scrambled", "non-targeting control", "non-targeting-control", })
[docs] def normalise_perturbation_labels( path: str | Path, column: str, *, strip_prefixes: list[str] | None = None, strip_suffixes: list[str] | None = None, strip_suffix_regex: str | None = None, control_aliases: list[str] | None = None, canonical_control: str = "NTC", inplace: bool = True, ) -> "pd.Series | None": """Normalise perturbation labels stored in an obs column. Applies transformations in order: 1. Strip specified prefixes via vectorised ``pd.Series.str.replace``. 2. Strip specified suffixes via vectorised ``pd.Series.str.replace``. 3. Apply a custom regex substitution (``strip_suffix_regex``). 4. Map known control aliases to ``canonical_control``. Parameters ---------- path Path to the h5ad file. column obs column containing perturbation labels. strip_prefixes List of prefix strings to remove (e.g. ``["sg-", "sg"]``). strip_suffixes List of suffix strings to remove (e.g. ``["_KO", "_KD", "_P1P2"]``). strip_suffix_regex A Python regex applied via ``pd.Series.str.replace`` after prefix/suffix stripping. control_aliases Additional strings (case-insensitive) treated as control labels. The built-in aliases (``ntc``, ``ctrl``, ``scramble``, …) are always included. canonical_control Canonical control label substituted for all matched aliases. inplace If True, write result back and return ``None``. If False, return a Series without modifying the file. Returns ------- pd.Series or None Normalised labels when ``inplace=False``, else ``None``. """ path = Path(path) df = load_obs(path) if column not in df.columns: raise KeyError( f"Column '{column}' not found in obs. " f"Available: {list(df.columns)}" ) labels = df[column].astype(str) # Step 1: strip prefixes (vectorised) if strip_prefixes: for prefix in strip_prefixes: labels = labels.str.replace( "^" + _re.escape(prefix), "", regex=True ) # Step 2: strip suffixes (vectorised) if strip_suffixes: for suffix in strip_suffixes: labels = labels.str.replace( _re.escape(suffix) + "$", "", regex=True ) # Step 3: custom regex substitution (vectorised) if strip_suffix_regex: labels = labels.str.replace(strip_suffix_regex, "", regex=True) # Step 4: unify control labels all_aliases = set(_DEFAULT_CONTROL_ALIASES) if control_aliases: all_aliases.update(a.lower() for a in control_aliases) is_control = labels.str.lower().isin(all_aliases) labels = labels.where(~is_control, other=canonical_control) if not inplace: return labels if isinstance(df[column].dtype, pd.CategoricalDtype): df[column] = pd.Categorical(labels.to_numpy()) else: df[column] = labels.to_numpy() write_obs(path, df) return None
# ============================================================================= # Feature 4 — Auto-detection of metadata columns # ============================================================================= _PERTURBATION_COL_ALIASES: frozenset[str] = frozenset({ "perturbation", "gene", "gene_target", "condition", "guide_identity", "target_gene_name", "gene_name", "sgrna", "guide", "guide_id", "sgrna_name", "target", }) _GENE_SYMBOL_COL_ALIASES: frozenset[str] = frozenset({ "gene_symbols", "gene_name", "gene", "symbol", "hgnc_symbol", "gene_symbol", "feature_name", }) _CTRL_TERMS: frozenset[str] = frozenset({ "ctrl", "control", "nontarget", "non-target", "non_target", "ntc", "scramble", "scrambled", })
[docs] def detect_perturbation_column( adata: "str | Path | AnnData | ad.AnnData", *, control_label: str | None = None, min_unique: int = 2, verbose: bool = True, ) -> str | None: """Heuristically identify the obs column containing perturbation labels. Scoring: * +3 if column name matches known aliases (``perturbation``, ``gene_target``, …). * +2 if dtype is categorical or object. * +1 if unique-value count is in [``min_unique``, 5000]. * +2 if at least one value matches a known control synonym or ``control_label`` when provided. Parameters ---------- adata Backed AnnData, :class:`~crispyx.data.AnnData`, or path to h5ad file. control_label Known control label; boosts the score for columns containing it. min_unique Minimum number of unique values required for the +1 bonus. verbose Log the detected column name. Returns ------- str or None Column name with the highest score, or ``None`` if no column scores above zero. """ path = resolve_data_path(adata) obs = load_obs(path) scores: dict[str, int] = {} for col in obs.columns: score = 0 if col.lower() in _PERTURBATION_COL_ALIASES: score += 3 dtype = obs[col].dtype if isinstance(dtype, pd.CategoricalDtype) or dtype == object: score += 2 try: n_unique = int(obs[col].nunique()) except Exception: n_unique = 0 if min_unique <= n_unique <= 5000: score += 1 lower_vals = obs[col].astype(str).str.lower() if control_label is not None: if (lower_vals == control_label.lower()).any(): score += 2 else: if lower_vals.isin(_CTRL_TERMS).any(): score += 2 scores[col] = score if not scores: return None best_col, best_score = max(scores.items(), key=lambda x: x[1]) if best_score <= 0: return None if verbose: logger.info( "Detected perturbation column: '%s' (score=%d).", best_col, best_score ) return best_col
[docs] def detect_gene_symbol_column( adata: "str | Path | AnnData | ad.AnnData", *, verbose: bool = True, ) -> str | None: """Heuristically identify the var column containing gene symbols. Scoring: * +3 if column name matches known aliases (``gene_symbols``, ``symbol``, …). * +2 if values pass :func:`_validate_gene_symbols` without error. * +1 if values do **not** start with Ensembl prefixes. Returns ``None`` when no column qualifies, which signals that ``var_names`` should be used as a fallback. Parameters ---------- adata Backed AnnData, :class:`~crispyx.data.AnnData`, or path to h5ad file. verbose Log the detected column name. Returns ------- str or None """ path = resolve_data_path(adata) var = load_var(path) ensembl_3 = frozenset(p[:3].upper() for p in ENSEMBL_PREFIXES) scores: dict[str, int] = {} for col in var.columns: score = 0 if col.lower() in _GENE_SYMBOL_COL_ALIASES: score += 3 try: _validate_gene_symbols(var[col].astype(str)) score += 2 except ValueError: pass prefixes = var[col].astype(str).str.upper().str.slice(0, 3) if not prefixes.isin(ensembl_3).any(): score += 1 scores[col] = score if not scores: return None best_col, best_score = max(scores.items(), key=lambda x: x[1]) if best_score <= 0: return None if verbose: logger.info( "Detected gene symbol column: '%s' (score=%d).", best_col, best_score ) return best_col
[docs] def infer_columns( adata: "str | Path | AnnData | ad.AnnData", *, control_label: str | None = None, verbose: bool = True, ) -> dict[str, str | None]: """Detect perturbation and gene-symbol columns in a single call. Parameters ---------- adata Backed AnnData, :class:`~crispyx.data.AnnData`, or path to h5ad file. control_label Known control label forwarded to :func:`detect_perturbation_column`. verbose Log detected column names. Returns ------- dict ``{"perturbation_column": ..., "gene_name_column": ...}`` where each value is the detected column name or ``None``. """ return { "perturbation_column": detect_perturbation_column( adata, control_label=control_label, verbose=verbose ), "gene_name_column": detect_gene_symbol_column(adata, verbose=verbose), }
# ============================================================================= # Feature 5 — Overlap analysis utilities # =============================================================================
[docs] @dataclass class OverlapResult: """Pairwise overlap statistics between named sets. Attributes ---------- count_matrix (n_sets × n_sets) DataFrame of pairwise intersection sizes. jaccard_matrix (n_sets × n_sets) DataFrame of Jaccard similarity coefficients. set_sizes Series of sizes for each input set. """ count_matrix: pd.DataFrame jaccard_matrix: pd.DataFrame set_sizes: pd.Series
[docs] def compute_overlap( sets_dict: dict[str, "set | list"], *, metric: Literal["count", "jaccard", "both"] = "both", ) -> OverlapResult: """Compute pairwise overlap statistics between named sets. Parameters ---------- sets_dict Mapping of name → set (or list, converted to set). metric Which matrices to populate: ``"count"``, ``"jaccard"``, or ``"both"`` (default). Returns ------- OverlapResult Object with ``count_matrix``, ``jaccard_matrix``, and ``set_sizes``. Examples -------- >>> result = cx.tl.compute_overlap({ ... "dataset_A": {"BRCA1", "TP53", "EGFR"}, ... "dataset_B": {"TP53", "KRAS"}, ... }) >>> result.jaccard_matrix """ names = list(sets_dict.keys()) sets: dict[str, set] = {k: set(v) for k, v in sets_dict.items()} n = len(names) count_arr = np.zeros((n, n), dtype=np.int64) jaccard_arr = np.zeros((n, n), dtype=np.float64) for i, name_i in enumerate(names): for j, name_j in enumerate(names): inter = len(sets[name_i] & sets[name_j]) if metric in ("count", "both"): count_arr[i, j] = inter if metric in ("jaccard", "both"): union = len(sets[name_i] | sets[name_j]) jaccard_arr[i, j] = inter / union if union > 0 else 0.0 sizes = pd.Series({k: len(v) for k, v in sets.items()}, name="set_size") return OverlapResult( count_matrix=pd.DataFrame(count_arr, index=names, columns=names), jaccard_matrix=pd.DataFrame(jaccard_arr, index=names, columns=names), set_sizes=sizes, )