Source code for crispyx.plotting

"""Plotting utilities for crispyx with Scanpy-style helpers.

These functions are designed to work with on-disk AnnData objects and
avoid loading full count matrices into memory.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Iterable, Literal, Sequence

import anndata as ad
import h5py
import numpy as np
import pandas as pd
import scipy.sparse as sp

from .data import (
    AnnData,
    OverlapResult,
    iter_matrix_chunks,
    normalize_total_block,
    read_backed,
    resolve_data_path,
)
from .qc import QualityControlResult

logger = logging.getLogger(__name__)

PlotInput = str | Path | AnnData | ad.AnnData


# -----------------------------------------------------------------------------
# Internal helpers
# -----------------------------------------------------------------------------


def _decode_strings(values: Iterable) -> list[str]:
    decoded: list[str] = []
    for value in values:
        if isinstance(value, (bytes, np.bytes_)):
            decoded.append(value.decode("utf-8"))
        else:
            decoded.append(str(value))
    return decoded


def _decode_scalar(value):
    if isinstance(value, (bytes, np.bytes_)):
        return value.decode("utf-8")
    if isinstance(value, np.ndarray):
        if value.shape == ():
            return _decode_scalar(value.item())
        if value.size == 1:
            return _decode_scalar(value.reshape(-1)[0])
    return value


def _resolve_path(data: PlotInput) -> Path:
    return resolve_data_path(data)


def _read_uns_scalar(path: Path, key: str):
    try:
        with h5py.File(path, "r") as handle:
            if "uns" not in handle or key not in handle["uns"]:
                return None
            dataset = handle["uns"][key]
            if not isinstance(dataset, h5py.Dataset):
                return None
            return _decode_scalar(dataset[()])
    except Exception:
        return None


def _read_group_names(path: Path, key: str) -> list[str]:
    with h5py.File(path, "r") as handle:
        rgg_path = f"uns/{key}"
        if rgg_path in handle:
            rgg = handle[rgg_path]
            if "names" in rgg:
                names_ds = rgg["names"]
                if names_ds.dtype.names is not None:
                    return [str(name) for name in names_ds.dtype.names]
                if names_ds.shape:
                    return _decode_strings(names_ds[()])
    backed = read_backed(path)
    try:
        return backed.obs_names.astype(str).tolist()
    finally:
        backed.file.close()


def _read_gene_names(path: Path) -> np.ndarray:
    backed = read_backed(path)
    try:
        if "genes" in backed.uns:
            genes = np.asarray(backed.uns["genes"]).astype(str)
        else:
            genes = backed.var_names.astype(str).to_numpy()
    finally:
        backed.file.close()
    return genes


def _read_var_column(path: Path, column: str) -> pd.Series | None:
    backed = read_backed(path)
    try:
        if column not in backed.var.columns:
            return None
        return backed.var[column].copy()
    finally:
        backed.file.close()


def _infer_rgg_params(path: Path, key: str) -> dict:
    params: dict[str, object] = {}
    with h5py.File(path, "r") as handle:
        rgg_path = f"uns/{key}"
        if rgg_path in handle:
            rgg = handle[rgg_path]
            if "params" in rgg:
                attrs = rgg["params"].attrs
                for attr in ("groupby", "method", "reference", "tie_correct", "corr_method", "use_raw"):
                    if attr in attrs:
                        params[attr] = _decode_scalar(attrs[attr])
    params.setdefault("groupby", _read_uns_scalar(path, "perturbation_column") or "group")
    params.setdefault("reference", _read_uns_scalar(path, "control_label") or "reference")
    params.setdefault("corr_method", _read_uns_scalar(path, "pvalue_correction") or "benjamini-hochberg")
    params.setdefault("method", _read_uns_scalar(path, "method") or "unknown")
    params.setdefault("tie_correct", False)
    params.setdefault("use_raw", False)
    return params


def _to_recarray(arrays: list[np.ndarray], names: Sequence[str]) -> np.recarray:
    return np.rec.fromarrays(arrays, names=[str(name) for name in names])


def _build_rgg_from_full(
    rgg: h5py.Group,
    groups: list[str],
    group_indices: list[int],
    genes: np.ndarray,
    n_genes: int | None,
) -> dict:
    full = rgg["full"]
    if "scores" not in full:
        raise KeyError("rank_genes_groups/full is missing required 'scores' dataset")

    order_ds = rgg.get("order")
    metrics = [
        key
        for key in (
            "scores",
            "logfoldchanges",
            "pvals",
            "pvals_adj",
            "pts",
            "pts_rest",
            "auc",
            "u_stat",
        )
        if key in full
    ]

    arrays_by_metric: dict[str, list[np.ndarray]] = {metric: [] for metric in metrics}
    name_arrays: list[np.ndarray] = []

    for idx, _group in zip(group_indices, groups):
        if order_ds is not None:
            order = order_ds[idx]
        else:
            order = np.arange(len(genes), dtype=int)
        if n_genes is not None:
            order = order[:n_genes]
        name_arrays.append(genes[order].astype(str))
        for metric in metrics:
            row = full[metric][idx]
            arrays_by_metric[metric].append(np.take(row, order))

    rgg_dict = {
        "names": _to_recarray(name_arrays, groups),
    }
    for metric, arrays in arrays_by_metric.items():
        rgg_dict[metric] = _to_recarray(arrays, groups)
    return rgg_dict


def _build_rgg_from_recarray(
    rgg: h5py.Group,
    groups: list[str],
    n_genes: int | None,
) -> dict:
    names_ds = rgg["names"]
    names_arr = names_ds[()]
    available = list(names_arr.dtype.names or [])
    if not available:
        raise KeyError("rank_genes_groups names dataset is not structured")

    for group in groups:
        if group not in available:
            raise KeyError(f"Group '{group}' not found in rank_genes_groups names")

    limit = n_genes or names_arr.shape[0]

    name_arrays = []
    for group in groups:
        raw = names_arr[group][:limit]
        name_arrays.append(np.asarray(_decode_strings(raw), dtype=object))

    rgg_dict = {
        "names": _to_recarray(name_arrays, groups),
    }

    for key in (
        "scores",
        "logfoldchanges",
        "pvals",
        "pvals_adj",
        "pts",
        "pts_rest",
        "auc",
        "u_stat",
    ):
        if key not in rgg:
            continue
        metric_arr = rgg[key][()]
        rgg_dict[key] = _to_recarray([metric_arr[group][:limit] for group in groups], groups)

    return rgg_dict


def _materialize_rank_genes_groups_uns(
    path: Path,
    *,
    key: str,
    groups: list[str],
    n_genes: int | None,
) -> dict:
    genes = _read_gene_names(path)
    params = _infer_rgg_params(path, key)
    all_groups = _read_group_names(path, key)
    group_indices = [all_groups.index(group) for group in groups]

    with h5py.File(path, "r") as handle:
        rgg_path = f"uns/{key}"
        if rgg_path not in handle:
            return _materialize_rank_genes_groups_from_layers(
                path,
                groups=groups,
                genes=genes,
                n_genes=n_genes,
                params=params,
            )
        rgg = handle[rgg_path]
        if "full" in rgg:
            rgg_dict = _build_rgg_from_full(rgg, groups, group_indices, genes, n_genes)
        else:
            rgg_dict = _build_rgg_from_recarray(rgg, groups, n_genes)

    rgg_dict["params"] = params
    return rgg_dict


def _materialize_rank_genes_groups_from_layers(
    path: Path,
    *,
    groups: list[str],
    genes: np.ndarray,
    n_genes: int | None,
    params: dict,
) -> dict:
    backed = read_backed(path)
    try:
        obs_names = backed.obs_names.astype(str).tolist()
        group_indices = [obs_names.index(group) for group in groups]

        def pick_layer(options: Sequence[str]) -> str | None:
            for name in options:
                if name in backed.layers:
                    return name
            return None

        score_layer = pick_layer(["z_score", "u_statistic", "u_stat", "scores"])
        if score_layer is None:
            raise KeyError("No score layer found for rank_genes_groups materialization")

        layer_map = {
            "scores": score_layer,
            "logfoldchanges": pick_layer(["logfoldchange", "logfoldchanges"]),
            "pvals": pick_layer(["pvalue", "pvals"]),
            "pvals_adj": pick_layer(["pvalue_adj", "pvals_adj"]),
            "pts": pick_layer(["pts"]),
            "pts_rest": pick_layer(["pts_rest"]),
        }

        arrays_by_metric: dict[str, list[np.ndarray]] = {key: [] for key in layer_map if layer_map[key]}
        name_arrays: list[np.ndarray] = []

        for idx in group_indices:
            scores = np.asarray(backed.layers[score_layer][idx]).ravel()
            order = np.argsort(-np.abs(scores), kind="mergesort")
            if n_genes is not None:
                order = order[:n_genes]
            name_arrays.append(genes[order].astype(str))
            for metric, layer_name in layer_map.items():
                if layer_name is None:
                    continue
                values = np.asarray(backed.layers[layer_name][idx]).ravel()
                arrays_by_metric[metric].append(np.take(values, order))
    finally:
        backed.file.close()

    rgg_dict = {"names": _to_recarray(name_arrays, groups)}
    for metric, arrays in arrays_by_metric.items():
        rgg_dict[metric] = _to_recarray(arrays, groups)
    rgg_dict["params"] = params
    return rgg_dict


# -----------------------------------------------------------------------------
# Public helpers
# -----------------------------------------------------------------------------


[docs] def materialize_rank_genes_groups( data: PlotInput, *, key: str = "rank_genes_groups", groups: Sequence[str] | None = None, n_genes: int | None = None, gene_symbols: str | None = None, ) -> ad.AnnData: """Create a minimal in-memory AnnData with Scanpy-style rank_genes_groups. This helper is intended for plotting only. It never loads the expression matrix from disk; instead it constructs an empty sparse matrix and injects Scanpy-compatible ``uns['rank_genes_groups']``. """ path = _resolve_path(data) all_groups = _read_group_names(path, key) if groups is None: selected_groups = all_groups else: selected_groups = [str(group) for group in groups] missing = [group for group in selected_groups if group not in all_groups] if missing: raise KeyError(f"Groups not found in rank_genes_groups: {missing}") genes = _read_gene_names(path) rgg = _materialize_rank_genes_groups_uns( path, key=key, groups=selected_groups, n_genes=n_genes, ) obs = pd.DataFrame(index=pd.Index(selected_groups, name="group")) groupby = rgg.get("params", {}).get("groupby") if groupby: obs[str(groupby)] = selected_groups var = pd.DataFrame(index=pd.Index(genes, name="gene")) if gene_symbols is not None: var_col = _read_var_column(path, gene_symbols) if var_col is not None: var[gene_symbols] = var_col.to_numpy() else: logger.warning("gene_symbols column '%s' not found in var", gene_symbols) X = sp.csr_matrix((len(selected_groups), len(genes)), dtype=np.float32) adata = ad.AnnData(X, obs=obs, var=var) adata.uns[key] = rgg return adata
[docs] def rank_genes_groups( data: PlotInput, groups: Sequence[str] | None = None, *, n_genes: int = 20, gene_symbols: str | None = None, key: str = "rank_genes_groups", **kwargs, ): """Scanpy-style rank_genes_groups plot from on-disk crispyx results.""" try: import scanpy as sc except ImportError as exc: # pragma: no cover - depends on optional dependency raise ImportError("scanpy is required for rank_genes_groups plotting") from exc adata = materialize_rank_genes_groups( data, key=key, groups=groups, n_genes=n_genes, gene_symbols=gene_symbols, ) return sc.pl.rank_genes_groups( adata, groups=groups, n_genes=n_genes, gene_symbols=gene_symbols, key=key, **kwargs, )
[docs] def rank_genes_groups_df( data: PlotInput, group: str | Sequence[str], *, key: str = "rank_genes_groups", n_genes: int | None = None, ) -> pd.DataFrame: """Return a tidy DataFrame for rank_genes_groups results from disk.""" path = _resolve_path(data) genes = _read_gene_names(path) groups = [str(group)] if isinstance(group, str) else [str(g) for g in group] group_names = _read_group_names(path, key) dfs: list[pd.DataFrame] = [] with h5py.File(path, "r") as handle: rgg_path = f"uns/{key}" if rgg_path not in handle: rgg = _materialize_rank_genes_groups_uns( path, key=key, groups=groups, n_genes=n_genes, ) for group_name in groups: df = pd.DataFrame({"names": rgg["names"][group_name]}) for key_name in ( "scores", "logfoldchanges", "pvals", "pvals_adj", "pts", "pts_rest", "auc", "u_stat", ): if key_name in rgg: df[key_name] = rgg[key_name][group_name] df["group"] = group_name dfs.append(df) return pd.concat(dfs, ignore_index=True) rgg = handle[rgg_path] if "full" in rgg: full = rgg["full"] order_ds = rgg.get("order") for group_name in groups: if group_name not in group_names: raise KeyError(f"Group '{group_name}' not found in rank_genes_groups") idx = group_names.index(group_name) if order_ds is not None: order = order_ds[idx] else: order = np.arange(len(genes), dtype=int) if n_genes is not None: order = order[:n_genes] df = pd.DataFrame({"names": genes[order].astype(str)}) for key_name in ( "scores", "logfoldchanges", "pvals", "pvals_adj", "pts", "pts_rest", "auc", "u_stat", ): if key_name in full: row = np.take(full[key_name][idx], order) df[key_name] = row df["group"] = group_name dfs.append(df) else: names_ds = rgg["names"] names_arr = names_ds[()] available = list(names_arr.dtype.names or []) if not available: raise KeyError("rank_genes_groups names dataset is not structured") limit = n_genes or names_arr.shape[0] for group_name in groups: if group_name not in available: raise KeyError(f"Group '{group_name}' not found in rank_genes_groups") names = pd.Series(names_arr[group_name][:limit]).astype(str).to_numpy() df = pd.DataFrame({"names": names}) for key_name in ( "scores", "logfoldchanges", "pvals", "pvals_adj", "pts", "pts_rest", "auc", "u_stat", ): if key_name in rgg: metric_arr = rgg[key_name][()] df[key_name] = metric_arr[group_name][:limit] df["group"] = group_name dfs.append(df) result = pd.concat(dfs, ignore_index=True) return result
# ----------------------------------------------------------------------------- # Differential expression plots # ----------------------------------------------------------------------------- def _require_matplotlib(): try: import matplotlib.pyplot as plt except ImportError: logger.warning("matplotlib not installed; cannot create plot") return None return plt
[docs] def plot_volcano( *, data: PlotInput | None = None, group: str | None = None, de_df: pd.DataFrame | None = None, key: str = "rank_genes_groups", p_cut: float = 0.05, lfc_cut: float = 1.0, ax=None, show: bool | None = None, savepath: str | Path | None = None, ): """Volcano plot for a single group. Provide ``de_df`` directly or supply ``data`` and ``group`` to read from disk. """ plt = _require_matplotlib() if plt is None: return None if de_df is None: if data is None or group is None: raise ValueError("Provide either de_df or both data and group") de_df = rank_genes_groups_df(data, group=group, key=key) if group is None: if "group" in de_df.columns and not de_df["group"].empty: group = str(de_df["group"].iloc[0]) else: group = "" if "pvals_adj" in de_df.columns: pvals = de_df["pvals_adj"].to_numpy() elif "pvals" in de_df.columns: pvals = de_df["pvals"].to_numpy() else: raise KeyError("p-values not found in de_df (expected pvals or pvals_adj)") if "logfoldchanges" not in de_df.columns: raise KeyError("logfoldchanges not found in de_df") df = de_df.copy() df["neglog10_p"] = -np.log10(np.clip(pvals, 1e-300, None)) lfc = df["logfoldchanges"].to_numpy() sig = (pvals < p_cut) & (np.abs(lfc) >= lfc_cut) up = sig & (lfc > 0) down = sig & (lfc < 0) if ax is None: _, ax = plt.subplots(figsize=(6, 5)) ax.scatter(df.loc[~sig, "logfoldchanges"], df.loc[~sig, "neglog10_p"], s=6, alpha=0.4) ax.scatter(df.loc[down, "logfoldchanges"], df.loc[down, "neglog10_p"], s=8, alpha=0.9) ax.scatter(df.loc[up, "logfoldchanges"], df.loc[up, "neglog10_p"], s=8, alpha=0.9) ax.axhline(-np.log10(p_cut), linestyle="--", linewidth=1) ax.axvline(-lfc_cut, linestyle="--", linewidth=1) ax.axvline(lfc_cut, linestyle="--", linewidth=1) ax.set_title(f"Volcano: {group}") ax.set_xlabel("log2FC") ax.set_ylabel("-log10(adj p)") if savepath: ax.figure.savefig(savepath, dpi=300, bbox_inches="tight") if show: plt.show() return None return ax
[docs] def plot_top_genes_bar( *, data: PlotInput | None = None, group: str | None = None, de_df: pd.DataFrame | None = None, key: str = "rank_genes_groups", topn: int = 15, ax=None, show: bool | None = None, savepath: str | Path | None = None, ): """Horizontal bar plot of top-ranked genes for a group.""" plt = _require_matplotlib() if plt is None: return None if de_df is None: if data is None or group is None: raise ValueError("Provide either de_df or both data and group") de_df = rank_genes_groups_df(data, group=group, key=key) if group is None: if "group" in de_df.columns and not de_df["group"].empty: group = str(de_df["group"].iloc[0]) else: group = "" if "scores" not in de_df.columns: raise KeyError("scores not found in de_df") df = de_df.sort_values("scores", ascending=False).head(topn).iloc[::-1] if "logfoldchanges" in df.columns: colors = np.where(df["logfoldchanges"].to_numpy() > 0, "tab:red", "tab:blue") else: colors = "tab:gray" if ax is None: _, ax = plt.subplots(figsize=(7, 5)) ax.barh(df["names"], df["scores"], color=colors) ax.set_title(f"Top {topn} genes: {group}") ax.set_xlabel("score") if savepath: ax.figure.savefig(savepath, dpi=300, bbox_inches="tight") if show: plt.show() return None return ax
def _compute_library_sizes(adata: ad.AnnData, chunk_size: int) -> np.ndarray: n_obs = adata.n_obs library_size = np.zeros(n_obs, dtype=np.float64) for slc, block in iter_matrix_chunks(adata, axis=0, chunk_size=chunk_size, convert_to_dense=False): if sp.issparse(block): sums = np.asarray(block.sum(axis=1)).ravel() else: sums = np.asarray(block).sum(axis=1) library_size[slc] = sums return library_size def _mean_expression_by_group( adata: ad.AnnData, gene_indices: np.ndarray, group_mask: np.ndarray, ref_mask: np.ndarray, *, chunk_size: int, mean_mode: str, target_sum: float, ) -> tuple[np.ndarray, np.ndarray]: subset = adata[:, gene_indices] n_genes = len(gene_indices) group_sum = np.zeros(n_genes, dtype=np.float64) ref_sum = np.zeros(n_genes, dtype=np.float64) if mean_mode == "log1p": library_size = _compute_library_sizes(adata, chunk_size=chunk_size) else: library_size = None for slc, block in iter_matrix_chunks(subset, axis=0, chunk_size=chunk_size, convert_to_dense=False): if mean_mode == "log1p": block, _ = normalize_total_block( block, library_size=library_size[slc], target_sum=target_sum, ) block = np.log1p(block) else: if sp.issparse(block): block = block.toarray() else: block = np.asarray(block) if group_mask[slc].any(): group_sum += block[group_mask[slc]].sum(axis=0) if ref_mask[slc].any(): ref_sum += block[ref_mask[slc]].sum(axis=0) n_group = int(group_mask.sum()) n_ref = int(ref_mask.sum()) if n_group == 0 or n_ref == 0: raise ValueError("Group or reference has zero cells") return group_sum / n_group, ref_sum / n_ref
[docs] def plot_ma( *, data: PlotInput, de_result: PlotInput | None = None, group: str, reference: str | None = None, perturbation_column: str | None = None, key: str = "rank_genes_groups", de_df: pd.DataFrame | None = None, mean_mode: str = "raw", target_sum: float = 1e4, n_genes: int | None = None, p_cut: float = 0.05, lfc_cut: float = 1.0, chunk_size: int = 1024, ax=None, show: bool | None = None, savepath: str | Path | None = None, ): """MA plot using raw counts or normalized log1p means. Parameters ---------- data Path or backed AnnData containing raw counts. de_result Path or AnnData with rank_genes_groups results. Defaults to ``data``. mean_mode "raw" or "log1p" (normalized log1p means). """ plt = _require_matplotlib() if plt is None: return None if mean_mode not in {"raw", "log1p"}: raise ValueError("mean_mode must be 'raw' or 'log1p'") if de_result is None: de_result = data if de_df is None: de_df = rank_genes_groups_df(de_result, group=group, key=key, n_genes=n_genes) if "logfoldchanges" not in de_df.columns: raise KeyError("logfoldchanges not found in de_df") path = _resolve_path(data) backed = read_backed(path) try: var_names = backed.var_names.astype(str) if perturbation_column is None: params = _infer_rgg_params(_resolve_path(de_result), key) perturbation_column = str(params.get("groupby", "group")) if perturbation_column not in backed.obs.columns: raise KeyError( f"Perturbation column '{perturbation_column}' not found in adata.obs" ) labels = backed.obs[perturbation_column].astype(str).to_numpy() if reference is None: params = _infer_rgg_params(_resolve_path(de_result), key) reference = str(params.get("reference", "reference")) finally: backed.file.close() genes = de_df["names"].astype(str).to_numpy() gene_indexer = pd.Index(var_names).get_indexer(genes) valid_mask = gene_indexer >= 0 if not np.any(valid_mask): raise ValueError("None of the DE genes were found in the data var_names") genes = genes[valid_mask] gene_indexer = gene_indexer[valid_mask] lfc = de_df.iloc[valid_mask, :].loc[:, "logfoldchanges"].to_numpy() pvals_adj = ( de_df.iloc[valid_mask, :].loc[:, "pvals_adj"].to_numpy() if "pvals_adj" in de_df.columns else None ) backed = read_backed(path) try: group_mask = labels == group ref_mask = labels == reference mean_group, mean_ref = _mean_expression_by_group( backed, gene_indexer, group_mask, ref_mask, chunk_size=chunk_size, mean_mode=mean_mode, target_sum=target_sum, ) finally: backed.file.close() if mean_mode == "raw": A = np.log1p((mean_group + mean_ref) / 2.0) else: A = (mean_group + mean_ref) / 2.0 sig = None if pvals_adj is not None: sig = (pvals_adj < p_cut) & (np.abs(lfc) >= lfc_cut) if ax is None: _, ax = plt.subplots(figsize=(6, 5)) if sig is None: ax.scatter(A, lfc, s=6, alpha=0.6) else: ax.scatter(A[~sig], lfc[~sig], s=6, alpha=0.4) ax.scatter(A[sig], lfc[sig], s=8, alpha=0.9) ax.axhline(0, linestyle="--", linewidth=1) ax.axhline(lfc_cut, linestyle="--", linewidth=1) ax.axhline(-lfc_cut, linestyle="--", linewidth=1) ax.set_title(f"MA plot: {group} vs {reference}") ax.set_xlabel("log1p(mean expression)" if mean_mode == "raw" else "mean log1p expression") ax.set_ylabel("log2FC") if savepath: ax.figure.savefig(savepath, dpi=300, bbox_inches="tight") if show: plt.show() return None return ax
# ----------------------------------------------------------------------------- # QC plots # -----------------------------------------------------------------------------
[docs] def plot_qc_perturbation_counts( *, data: PlotInput, perturbation_column: str, cell_mask: np.ndarray | None = None, top_n: int | None = None, ax=None, show: bool | None = None, savepath: str | Path | None = None, ): """Plot per-perturbation cell counts (optionally after QC filtering).""" plt = _require_matplotlib() if plt is None: return None path = _resolve_path(data) backed = read_backed(path) try: if perturbation_column not in backed.obs.columns: raise KeyError( f"Perturbation column '{perturbation_column}' not found in adata.obs" ) labels = backed.obs[perturbation_column].astype(str).to_numpy() finally: backed.file.close() if cell_mask is not None: labels = labels[cell_mask] counts = pd.Series(labels).value_counts() if top_n is not None: counts = counts.head(top_n) if ax is None: _, ax = plt.subplots(figsize=(8, 4)) ax.bar(counts.index.astype(str), counts.to_numpy()) ax.set_ylabel("Cells") ax.set_xlabel("Perturbation") ax.set_title("Perturbation composition") ax.tick_params(axis="x", rotation=45) for label in ax.get_xticklabels(): label.set_ha("right") if savepath: ax.figure.savefig(savepath, dpi=300, bbox_inches="tight") if show: plt.show() return None return ax
[docs] def plot_qc_summary( qc_result: QualityControlResult, *, bins: int = 50, min_genes: int | None = None, min_cells_per_gene: int | None = None, ax=None, show: bool | None = None, savepath: str | Path | None = None, ): """Plot QC summary distributions from a QualityControlResult.""" plt = _require_matplotlib() if plt is None: return None if ax is None: fig, axes = plt.subplots(1, 2, figsize=(10, 4)) else: if isinstance(ax, (list, tuple, np.ndarray)) and len(ax) == 2: axes = ax fig = ax[0].figure else: raise ValueError("ax must be a sequence of two matplotlib axes") axes[0].hist(qc_result.cell_gene_counts, bins=bins, color="tab:blue", alpha=0.7) axes[0].set_title("Genes per cell") axes[0].set_xlabel("genes") axes[0].set_ylabel("cells") if min_genes is not None: axes[0].axvline(min_genes, linestyle="--", color="black") axes[1].hist(qc_result.gene_cell_counts, bins=bins, color="tab:green", alpha=0.7) axes[1].set_title("Cells per gene") axes[1].set_xlabel("cells") axes[1].set_ylabel("genes") if min_cells_per_gene is not None: axes[1].axvline(min_cells_per_gene, linestyle="--", color="black") if savepath: fig.savefig(savepath, dpi=300, bbox_inches="tight") if show: plt.show() return None return axes
# ----------------------------------------------------------------------------- # PCA Plotting Functions # -----------------------------------------------------------------------------
[docs] def plot_pca( data: PlotInput, *, color: str | Sequence[str] | None = None, use_raw: bool | None = None, layer: str | None = None, sort_order: bool = True, groups: str | Sequence[str] | None = None, projection: str = "2d", components: str | Sequence[str] | None = None, palette=None, na_color: str = "lightgray", na_in_legend: bool = True, size: float | None = None, frameon: bool | None = None, legend_fontsize: int | float | str | None = None, legend_fontweight: int | str | None = None, legend_loc: str = "right margin", legend_fontoutline: int | None = None, colorbar_loc: str | None = "right", ncols: int = 4, wspace: float | None = None, hspace: float = 0.25, title: str | Sequence[str] | None = None, show: bool | None = None, save: str | bool | None = None, ax=None, return_fig: bool | None = None, **kwargs, ): """Plot PCA scatter from on-disk crispyx/backed AnnData or in-memory AnnData. Wrapper around scanpy.pl.pca that works with backed and in-memory AnnData. Loads only the PCA embeddings and specified color columns. Parameters ---------- data Path to h5ad, crispyx.AnnData, or anndata.AnnData with X_pca computed. color Keys for annotations of observations in .obs or variables in .var. components e.g. '1,2' or ['1,2', '3,4']. Default first 2 components. projection '2d' or '3d'. palette Color palette for categorical annotations. size Point size. show Show the figure. save Save the figure. **kwargs Passed to scanpy.pl.pca. Returns ------- matplotlib.axes.Axes or list of Axes, or None if show=True. """ try: import scanpy as sc except ImportError as exc: raise ImportError("scanpy is required for PCA plotting") from exc # Handle both backed and in-memory AnnData if isinstance(data, ad.AnnData) and not hasattr(data, 'path'): # In-memory AnnData: use directly adata = data else: # Backed data: resolve path and load path = _resolve_path(data) adata = read_backed(path) # Check X_pca exists if "X_pca" not in adata.obsm: raise ValueError( "X_pca not found in adata.obsm. Run cx.pp.pca() first." ) # Load into memory for plotting (just embeddings + obs) adata_plot = ad.AnnData( X=sp.csr_matrix((adata.n_obs, adata.n_vars), dtype=np.float32), obs=adata.obs.copy() if hasattr(adata.obs, 'copy') else pd.DataFrame(adata.obs), ) adata_plot.obsm["X_pca"] = np.asarray(adata.obsm["X_pca"]) # Copy uns['pca'] if present if "pca" in adata.uns: adata_plot.uns["pca"] = dict(adata.uns["pca"]) return sc.pl.pca( adata_plot, color=color, use_raw=use_raw, layer=layer, sort_order=sort_order, groups=groups, projection=projection, components=components, palette=palette, na_color=na_color, na_in_legend=na_in_legend, size=size, frameon=frameon, legend_fontsize=legend_fontsize, legend_fontweight=legend_fontweight, legend_loc=legend_loc, legend_fontoutline=legend_fontoutline, colorbar_loc=colorbar_loc, ncols=ncols, wspace=wspace, hspace=hspace, title=title, show=show, save=save, ax=ax, return_fig=return_fig, **kwargs, )
[docs] def plot_pca_variance_ratio( data: PlotInput, *, n_pcs: int | None = None, log: bool = False, show: bool | None = None, save: str | bool | None = None, ): """Plot variance ratio explained by each PC. Wrapper around scanpy.pl.pca_variance_ratio that works with backed AnnData. Parameters ---------- data Path to h5ad, crispyx.AnnData, or anndata.AnnData with PCA computed. n_pcs Number of PCs to show. Default shows all computed. log Plot on log scale. show Show the figure. save Save the figure. Returns ------- matplotlib.axes.Axes or None if show=True. """ try: import scanpy as sc except ImportError as exc: raise ImportError("scanpy is required for PCA plotting") from exc # Handle both backed and in-memory AnnData if isinstance(data, ad.AnnData) and not hasattr(data, 'path'): adata = data else: path = _resolve_path(data) adata = read_backed(path) if "pca" not in adata.uns or "variance_ratio" not in adata.uns["pca"]: raise ValueError( "PCA variance info not found. Run cx.pp.pca() first." ) # Create minimal AnnData with just PCA uns adata_plot = ad.AnnData( X=sp.csr_matrix((1, 1), dtype=np.float32), ) adata_plot.uns["pca"] = dict(adata.uns["pca"]) # Only pass n_pcs if specified, otherwise let scanpy use its default kwargs = {"log": log, "show": show, "save": save} if n_pcs is not None: kwargs["n_pcs"] = n_pcs return sc.pl.pca_variance_ratio(adata_plot, **kwargs)
[docs] def plot_pca_loadings( data: PlotInput, *, components: int | str | Sequence[int] | None = None, include_lowest: bool = True, show: bool | None = None, save: str | bool | None = None, ): """Plot gene loadings for principal components. Wrapper around scanpy.pl.pca_loadings that works with backed and in-memory AnnData. Parameters ---------- data Path to h5ad, crispyx.AnnData, or anndata.AnnData with PCA computed. components Which PCs to plot loadings for. e.g. [1, 2, 3] or '1,2,3'. Default shows first few components. include_lowest Show genes with lowest loadings (most negative) as well. show Show the figure. save Save the figure. Returns ------- matplotlib.axes.Axes or None if show=True. """ try: import scanpy as sc except ImportError as exc: raise ImportError("scanpy is required for PCA plotting") from exc # Handle both backed and in-memory AnnData if isinstance(data, ad.AnnData) and not hasattr(data, 'path'): adata = data else: path = _resolve_path(data) adata = read_backed(path) if "PCs" not in adata.varm: raise ValueError( "PCA loadings (varm['PCs']) not found. Run cx.pp.pca() first." ) # Load var and PCs for plotting adata_plot = ad.AnnData( X=sp.csr_matrix((1, adata.n_vars), dtype=np.float32), var=adata.var.copy() if hasattr(adata.var, 'copy') else pd.DataFrame(adata.var), ) adata_plot.varm["PCs"] = np.asarray(adata.varm["PCs"]) if "pca" in adata.uns: adata_plot.uns["pca"] = dict(adata.uns["pca"]) return sc.pl.pca_loadings( adata_plot, components=components, include_lowest=include_lowest, show=show, save=save, )
# ----------------------------------------------------------------------------- # UMAP Plotting Functions # -----------------------------------------------------------------------------
[docs] def plot_umap( data: PlotInput, *, color: str | Sequence[str] | None = None, use_raw: bool | None = None, layer: str | None = None, sort_order: bool = True, groups: str | Sequence[str] | None = None, components: str | Sequence[int] | None = None, palette=None, na_color: str = "lightgray", na_in_legend: bool = True, size: float | None = None, frameon: bool | None = None, legend_fontsize: int | float | str | None = None, legend_fontweight: int | str | None = None, legend_loc: str = "right margin", legend_fontoutline: int | None = None, colorbar_loc: str | None = "right", ncols: int = 4, wspace: float | None = None, hspace: float = 0.25, title: str | Sequence[str] | None = None, show: bool | None = None, save: str | bool | None = None, ax=None, return_fig: bool | None = None, **kwargs, ): """Plot UMAP embedding from on-disk crispyx/backed AnnData or in-memory AnnData. Wrapper around scanpy.pl.umap that works with backed and in-memory AnnData. Loads only the UMAP embeddings and specified color columns. Parameters ---------- data Path to h5ad, crispyx.AnnData, or anndata.AnnData with X_umap computed. color Keys for annotations of observations in .obs or variables in .var. components Which dimensions to use (e.g. [0, 1] for first two). Default first 2. palette Color palette for categorical annotations. size Point size. show Show the figure. save Save the figure. **kwargs Passed to scanpy.pl.umap. Returns ------- matplotlib.axes.Axes or list of Axes, or None if show=True. Examples -------- >>> import crispyx as cx >>> adata = cx.read_backed("data.h5ad") >>> cx.pl.umap(adata, color="perturbation") See Also -------- cx.tl.umap : Compute UMAP embedding. cx.pp.neighbors : Compute neighbor graph (required for UMAP). """ try: import scanpy as sc except ImportError as exc: raise ImportError("scanpy is required for UMAP plotting") from exc # Handle both backed and in-memory AnnData if isinstance(data, ad.AnnData) and not hasattr(data, 'path'): # In-memory AnnData: use directly adata = data else: # Backed data: resolve path and load path = _resolve_path(data) adata = read_backed(path) # Check X_umap exists if "X_umap" not in adata.obsm: raise ValueError( "X_umap not found in adata.obsm. Run cx.tl.umap() first." ) # Load into memory for plotting (just embeddings + obs) adata_plot = ad.AnnData( X=sp.csr_matrix((adata.n_obs, adata.n_vars), dtype=np.float32), obs=adata.obs.copy() if hasattr(adata.obs, 'copy') else pd.DataFrame(adata.obs), ) adata_plot.obsm["X_umap"] = np.asarray(adata.obsm["X_umap"]) # Copy uns['umap'] if present if "umap" in adata.uns: adata_plot.uns["umap"] = dict(adata.uns["umap"]) return sc.pl.umap( adata_plot, color=color, use_raw=use_raw, layer=layer, sort_order=sort_order, groups=groups, components=components, palette=palette, na_color=na_color, na_in_legend=na_in_legend, size=size, frameon=frameon, legend_fontsize=legend_fontsize, legend_fontweight=legend_fontweight, legend_loc=legend_loc, legend_fontoutline=legend_fontoutline, colorbar_loc=colorbar_loc, ncols=ncols, wspace=wspace, hspace=hspace, title=title, show=show, save=save, ax=ax, return_fig=return_fig, **kwargs, )
# ============================================================================= # Overlap heatmap (Feature 5) # =============================================================================
[docs] def plot_overlap_heatmap( result, *, metric="jaccard", ax=None, cmap="Blues", annot=True, fmt=None, title=None, ): """Heatmap of pairwise overlap statistics. Parameters ---------- result :class:`~crispyx.data.OverlapResult` from :func:`~crispyx.data.compute_overlap`. metric ``"jaccard"`` (default) or ``"count"``. ax Existing :class:`matplotlib.axes.Axes`; a new figure is created if ``None``. cmap Matplotlib colormap name. annot Annotate each cell with the numeric value. fmt Number format. Defaults to ``".2f"`` for Jaccard and ``"d"`` for counts. title Figure title. Defaults to ``"Jaccard similarity"`` or ``"Overlap counts"``. Returns ------- matplotlib.axes.Axes or None """ plt = _require_matplotlib() if plt is None: return None data = result.jaccard_matrix if metric == "jaccard" else result.count_matrix default_fmt = ".2f" if metric == "jaccard" else "d" fmt = fmt or default_fmt title = title or ( "Jaccard similarity" if metric == "jaccard" else "Overlap counts" ) n = len(data) try: import seaborn as sns # type: ignore[import] if ax is None: _, ax = plt.subplots(figsize=(max(4, n * 0.8), max(3, n * 0.7))) sns.heatmap( data, ax=ax, cmap=cmap, annot=annot, fmt=fmt, linewidths=0.5, square=True, vmin=0, vmax=1.0 if metric == "jaccard" else None, cbar_kws={"shrink": 0.6}, ) ax.set_title(title) except ImportError: if ax is None: _, ax = plt.subplots(figsize=(max(4, n * 0.8), max(3, n * 0.7))) arr = data.to_numpy(dtype=float) vmax = 1.0 if metric == "jaccard" else None im = ax.imshow(arr, aspect="auto", cmap=cmap, vmin=0, vmax=vmax) ax.set_xticks(range(n)) ax.set_xticklabels(data.columns, rotation=45, ha="right") ax.set_yticks(range(n)) ax.set_yticklabels(data.index) if annot: for i in range(arr.shape[0]): for j in range(arr.shape[1]): ax.text( j, i, f"{arr[i, j]:{fmt}}", ha="center", va="center", fontsize=8, ) ax.figure.colorbar(im, ax=ax, shrink=0.6) ax.set_title(title) return ax