Source code for crispyx.pseudobulk

"""Pseudo-bulk effect size estimators operating directly on ``.h5ad`` files."""

from __future__ import annotations

from pathlib import Path
from typing import Iterable

import anndata as ad
import numpy as np
import pandas as pd

from .data import (
    AnnData,
    calculate_optimal_chunk_size,
    ensure_gene_symbol_column,
    iter_matrix_chunks,
    normalize_total_block,
    read_backed,
    resolve_control_label,
    resolve_data_path,
    resolve_output_path,
)


def _resolve_candidates(
    labels: np.ndarray,
    control_label: str,
    perturbations: Iterable[str] | None,
) -> list[str]:
    if perturbations is None:
        unique = pd.Index(labels).unique().tolist()
    else:
        unique = [str(p) for p in perturbations]
    return [label for label in unique if label != control_label]


[docs] def compute_average_log_expression( data: str | Path | AnnData | ad.AnnData, *, perturbation_column: str, control_label: str | None = None, gene_name_column: str | None = None, perturbations: Iterable[str] | None = None, chunk_size: int | None = None, output_dir: str | Path | None = None, data_name: str | None = None, ) -> AnnData: """Compute average log-normalised expression per perturbation relative to control.""" path = resolve_data_path(data) backed = read_backed(path) try: # Calculate adaptive chunk_size if not provided if chunk_size is None: chunk_size = calculate_optimal_chunk_size(backed.n_obs, backed.n_vars) gene_symbols = ensure_gene_symbol_column(backed, gene_name_column) if perturbation_column not in backed.obs.columns: raise KeyError( f"Perturbation column '{perturbation_column}' was not found in adata.obs. Available columns: {list(backed.obs.columns)}" ) labels = backed.obs[perturbation_column].astype(str).to_numpy() control_label = resolve_control_label(labels, control_label) n_genes = backed.n_vars candidates = _resolve_candidates(labels, control_label, perturbations) groups = [control_label] + candidates sums = {label: np.zeros(n_genes, dtype=np.float64) for label in groups} counts = {label: 0 for label in groups} for slc, block in iter_matrix_chunks(backed, axis=0, chunk_size=chunk_size): slice_labels = labels[slc] normalised_block, _ = normalize_total_block(block) log_block = np.log1p(normalised_block) for label in groups: mask = slice_labels == label if not np.any(mask): continue sums[label] += log_block[mask].sum(axis=0) counts[label] += int(mask.sum()) finally: backed.file.close() if counts[control_label] == 0: raise ValueError("Control group contains no cells") control_mean = sums[control_label] / counts[control_label] effect_matrix = [] pert_means = [] for label in candidates: if counts[label] == 0: raise ValueError(f"Perturbation '{label}' contains no cells") mean = sums[label] / counts[label] pert_means.append(mean) effect_matrix.append(mean - control_mean) if not effect_matrix: obs_index = pd.Index([], name="perturbation") adata = ad.AnnData( np.zeros((0, gene_symbols.shape[0])), obs=pd.DataFrame(index=obs_index), var=pd.DataFrame(index=gene_symbols), ) output_path = resolve_output_path( path, suffix="avg_log_effects", output_dir=output_dir, data_name=data_name ) adata.write(output_path) return AnnData(output_path) effect_matrix_np = np.vstack(effect_matrix) gene_symbols = pd.Index(gene_symbols).astype(str) obs_index = pd.Index(candidates, name="perturbation").astype(str) obs = pd.DataFrame({perturbation_column: obs_index.to_list()}, index=obs_index) var = pd.DataFrame(index=gene_symbols) adata = ad.AnnData(effect_matrix_np, obs=obs, var=var) adata.layers["perturbation_mean"] = np.vstack(pert_means) adata.uns["control_mean"] = control_mean output_path = resolve_output_path(path, suffix="avg_log_effects", output_dir=output_dir, data_name=data_name) adata.write(output_path) return AnnData(output_path)
[docs] def compute_pseudobulk_expression( data: str | Path | AnnData | ad.AnnData, *, perturbation_column: str, control_label: str | None = None, gene_name_column: str | None = None, perturbations: Iterable[str] | None = None, baseline_count: float = 1.0, chunk_size: int | None = None, output_dir: str | Path | None = None, data_name: str | None = None, ) -> AnnData: """Compute pseudo-bulk log-fold changes relative to control.""" if baseline_count <= 0: raise ValueError("baseline_count must be positive") path = resolve_data_path(data) backed = read_backed(path) try: # Calculate adaptive chunk_size if not provided if chunk_size is None: chunk_size = calculate_optimal_chunk_size(backed.n_obs, backed.n_vars) gene_symbols = ensure_gene_symbol_column(backed, gene_name_column) if perturbation_column not in backed.obs.columns: raise KeyError( f"Perturbation column '{perturbation_column}' was not found in adata.obs. Available columns: {list(backed.obs.columns)}" ) labels = backed.obs[perturbation_column].astype(str).to_numpy() control_label = resolve_control_label(labels, control_label) n_genes = backed.n_vars candidates = _resolve_candidates(labels, control_label, perturbations) groups = [control_label] + candidates sums = {label: np.zeros(n_genes, dtype=np.float64) for label in groups} counts = {label: 0 for label in groups} for slc, block in iter_matrix_chunks(backed, axis=0, chunk_size=chunk_size): slice_labels = labels[slc] normalised_block, _ = normalize_total_block(block) for label in groups: mask = slice_labels == label if not np.any(mask): continue sums[label] += normalised_block[mask].sum(axis=0) counts[label] += int(mask.sum()) finally: backed.file.close() if counts[control_label] == 0: raise ValueError("Control group contains no cells") control_bulk = np.log1p(baseline_count * sums[control_label] / counts[control_label]) effect_matrix = [] pert_bulks = [] for label in candidates: if counts[label] == 0: raise ValueError(f"Perturbation '{label}' contains no cells") bulk = np.log1p(baseline_count * sums[label] / counts[label]) pert_bulks.append(bulk) effect_matrix.append(bulk - control_bulk) if not effect_matrix: obs_index = pd.Index([], name="perturbation") adata = ad.AnnData( np.zeros((0, gene_symbols.shape[0])), obs=pd.DataFrame(index=obs_index), var=pd.DataFrame(index=gene_symbols), ) adata.uns["control_bulk"] = control_bulk adata.uns["baseline_count"] = float(baseline_count) output_path = resolve_output_path( path, suffix="pseudobulk_effects", output_dir=output_dir, data_name=data_name ) adata.write(output_path) return AnnData(output_path) effect_matrix_np = np.vstack(effect_matrix) gene_symbols = pd.Index(gene_symbols).astype(str) obs_index = pd.Index(candidates, name="perturbation").astype(str) obs = pd.DataFrame({perturbation_column: obs_index.to_list()}, index=obs_index) var = pd.DataFrame(index=gene_symbols) adata = ad.AnnData(effect_matrix_np, obs=obs, var=var) adata.layers["perturbation_bulk"] = np.vstack(pert_bulks) adata.uns["control_bulk"] = control_bulk adata.uns["baseline_count"] = float(baseline_count) output_path = resolve_output_path(path, suffix="pseudobulk_effects", output_dir=output_dir, data_name=data_name) adata.write(output_path) return AnnData(output_path)