"""Scanpy-style namespace classes (pp, pb, tl, pl) for crispyx."""
from __future__ import annotations
from pathlib import Path
from typing import Iterable, Literal
import anndata as ad
import numpy as np
from .data import (
AnnData,
compute_overlap,
convert_to_csc,
convert_to_csr,
ensure_gene_symbol_column,
normalize_total_log1p,
read_backed,
resolve_control_label,
resolve_data_path,
resolve_output_path,
)
from .de import (
RankGenesGroupsResult,
_adjust_pvalue_matrix,
nb_glm_test,
shrink_lfc,
t_test,
wilcoxon_test,
)
from .plotting import (
materialize_rank_genes_groups,
plot_ma,
plot_overlap_heatmap,
plot_pca,
plot_pca_loadings,
plot_pca_variance_ratio,
plot_qc_perturbation_counts,
plot_qc_summary,
plot_top_genes_bar,
plot_umap,
plot_volcano,
rank_genes_groups as plot_rank_genes_groups,
rank_genes_groups_df,
)
from .pseudobulk import (
compute_average_log_expression,
compute_pseudobulk_expression,
)
from .qc import (
filter_cells_by_gene_count,
filter_genes_by_cell_count,
filter_perturbations_by_cell_count,
quality_control_summary,
)
# ---------------------------------------------------------------------------
# Helpers used only by _ToolsNamespace
# ---------------------------------------------------------------------------
def _infer_control_label(
path: Path,
perturbation_column: str,
control_label: str | None,
) -> str:
if control_label is not None:
return str(control_label)
backed = read_backed(path)
try:
if perturbation_column not in backed.obs.columns:
raise KeyError(
"Perturbation column '%s' was not found in adata.obs. Available columns: %s"
% (perturbation_column, list(backed.obs.columns))
)
labels = backed.obs[perturbation_column].astype(str).to_numpy()
finally:
backed.file.close()
return resolve_control_label(labels, None)
def _t_test_results_to_rank_genes(
path: Path,
results,
*,
gene_name_column: str | None,
perturbation_column: str,
control_label: str,
corr_method: str,
output_dir: str | Path | None,
data_name: str | None,
) -> RankGenesGroupsResult:
groups = list(results.keys())
if groups:
first = results[groups[0]]
genes = first.genes
effect_matrix = np.vstack([results[group].effect_size for group in groups])
statistic_matrix = np.vstack([results[group].statistic for group in groups])
pvalue_matrix = np.vstack([results[group].pvalue for group in groups])
result_view = first.result
else:
backed = read_backed(path)
try:
if gene_name_column is None:
genes = backed.var_names.astype(str)
else:
genes = ensure_gene_symbol_column(backed, gene_name_column)
finally:
backed.file.close()
effect_matrix = np.zeros((0, genes.size), dtype=float)
statistic_matrix = np.zeros_like(effect_matrix)
pvalue_matrix = np.ones_like(effect_matrix)
result_path = resolve_output_path(
path,
suffix="t_test_de",
output_dir=output_dir,
data_name=data_name,
)
result_view = AnnData(result_path)
if corr_method not in {"benjamini-hochberg", "bonferroni"}:
raise ValueError(
"corr_method must be 'benjamini-hochberg' or 'bonferroni' for t-tests"
)
pvalue_adj = (
_adjust_pvalue_matrix(pvalue_matrix, corr_method)
if pvalue_matrix.size
else np.zeros_like(pvalue_matrix)
)
order = (
np.argsort(-np.abs(statistic_matrix), axis=1, kind="mergesort")
if statistic_matrix.size
else np.zeros(statistic_matrix.shape, dtype=int)
)
zeros = np.zeros_like(statistic_matrix)
result = RankGenesGroupsResult(
genes=genes,
groups=groups,
statistics=statistic_matrix,
pvalues=pvalue_matrix,
pvalues_adj=pvalue_adj,
logfoldchanges=effect_matrix,
effect_size=effect_matrix,
u_statistics=zeros,
pts=zeros,
pts_rest=zeros,
order=order,
groupby=perturbation_column,
method="t_test",
control_label=control_label,
tie_correct=False,
pvalue_correction=corr_method,
result=result_view,
)
if result.result is not None:
memory = result.result.to_memory()
memory.uns["rank_genes_groups"] = result.to_rank_genes_groups_dict()
memory.uns["genes"] = genes.to_numpy()
memory.uns["method"] = "t_test"
memory.uns["control_label"] = control_label
memory.uns["tie_correct"] = False
memory.uns["pvalue_correction"] = corr_method
memory.write(result.result.path)
result.result.close()
result.result = AnnData(result.result.path)
return result
# ---------------------------------------------------------------------------
# Namespace classes
# ---------------------------------------------------------------------------
[docs]
class _PreprocessingNamespace:
"""Scanpy-style preprocessing entry points (``cx.pp``)."""
[docs]
def filter_cells(
self,
data: str | Path | ad.AnnData,
*,
min_genes: int = 100,
gene_name_column: str | None = None,
chunk_size: int = 2048,
):
path = resolve_data_path(data)
return filter_cells_by_gene_count(
path,
min_genes=min_genes,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
)
[docs]
def filter_genes(
self,
data: str | Path | ad.AnnData,
*,
min_cells: int = 100,
cell_mask: np.ndarray | None = None,
gene_name_column: str | None = None,
chunk_size: int = 2048,
):
path = resolve_data_path(data)
return filter_genes_by_cell_count(
path,
min_cells=min_cells,
cell_mask=cell_mask,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
)
[docs]
def filter_perturbations(
self,
data: str | Path | ad.AnnData,
*,
perturbation_column: str,
control_label: str | None = None,
min_cells: int = 50,
base_mask: np.ndarray | None = None,
):
path = resolve_data_path(data)
return filter_perturbations_by_cell_count(
path,
perturbation_column=perturbation_column,
control_label=control_label,
min_cells=min_cells,
base_mask=base_mask,
)
[docs]
def qc_summary(
self,
data: str | Path | ad.AnnData,
*,
min_genes: int = 100,
min_cells_per_perturbation: int = 50,
min_cells_per_gene: int = 100,
perturbation_column: str,
control_label: str | None = None,
gene_name_column: str | None = None,
chunk_size: int = 2048,
output_dir: str | Path | None = None,
data_name: str | None = None,
cache_mode: Literal['memory', 'memmap', 'none'] = 'memmap',
):
path = resolve_data_path(data)
result = quality_control_summary(
path,
min_genes=min_genes,
min_cells_per_perturbation=min_cells_per_perturbation,
min_cells_per_gene=min_cells_per_gene,
perturbation_column=perturbation_column,
control_label=control_label,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
output_dir=output_dir,
data_name=data_name,
cache_mode=cache_mode,
)
return result.filtered
[docs]
def convert_to_csc(
self,
data: str | Path | ad.AnnData,
*,
output_path: str | Path | None = None,
chunk_size: int = 4096,
output_dir: str | Path | None = None,
data_name: str | None = None,
verbose: bool = True,
) -> AnnData:
"""Convert a backed h5ad file's matrix to CSC format.
Parameters
----------
data
Path to h5ad file or backed AnnData.
output_path
Explicit output path. If None, derived from output_dir/data_name.
chunk_size
Rows per streaming chunk. Default 4096.
output_dir
Output directory. Defaults to input file's directory.
data_name
Custom name suffix.
verbose
Print progress.
Returns
-------
AnnData
Backed AnnData pointing to the CSC output file.
"""
return convert_to_csc(
data,
output_path=output_path,
chunk_size=chunk_size,
output_dir=output_dir,
data_name=data_name,
verbose=verbose,
)
[docs]
def convert_to_csr(
self,
data: str | Path | ad.AnnData,
*,
output_path: str | Path | None = None,
chunk_size: int | None = None,
output_dir: str | Path | None = None,
data_name: str | None = None,
verbose: bool = True,
) -> AnnData:
"""Convert a backed h5ad file's matrix to CSR format.
Parameters
----------
data
Path to h5ad file or backed AnnData.
output_path
Explicit output path. If None, derived from output_dir/data_name.
chunk_size
Rows (or columns for CSC source) per streaming chunk. Default auto.
output_dir
Output directory. Defaults to input file's directory.
data_name
Custom name suffix.
verbose
Print progress.
Returns
-------
AnnData
Backed AnnData pointing to the CSR output file.
"""
return convert_to_csr(
data,
output_path=output_path,
chunk_size=chunk_size,
output_dir=output_dir,
data_name=data_name,
verbose=verbose,
)
[docs]
def normalize_total_log1p(
self,
data: str | Path | ad.AnnData,
output_path: str | Path | None = None,
*,
normalize: bool = True,
log1p: bool = True,
target_sum: float = 1e4,
chunk_size: int = 4096,
output_dir: str | Path | None = None,
data_name: str | None = None,
verbose: bool = True,
) -> AnnData:
"""Stream normalize and/or log-transform an h5ad file.
Parameters
----------
data
Path to h5ad file or backed AnnData.
output_path
Path for output. If None, uses output_dir/data_name pattern.
normalize
Apply total-count normalization. Default True.
log1p
Apply log1p transformation. Default True.
target_sum
Target counts per cell. Default 1e4.
chunk_size
Cells per chunk. Default 4096.
output_dir
Output directory. Defaults to input file's directory.
data_name
Custom output name suffix.
verbose
Print progress.
Returns
-------
AnnData
Read-only AnnData wrapper pointing to output file.
"""
return normalize_total_log1p(
data,
output_path,
normalize=normalize,
log1p=log1p,
target_sum=target_sum,
chunk_size=chunk_size,
output_dir=output_dir,
data_name=data_name,
verbose=verbose,
)
[docs]
def pca(
self,
data: str | Path | ad.AnnData,
n_comps: int = 50,
method: str = "auto",
use_highly_variable: bool = True,
chunk_size: int | None = None,
random_state: int = 0,
copy: bool = False,
show_progress: bool = True,
) -> ad.AnnData | None:
"""Compute streaming PCA on backed AnnData.
Parameters
----------
data
Path to h5ad file or backed AnnData.
n_comps
Number of principal components. Default 50.
method
'auto', 'sparse_cov', or 'incremental'. Default 'auto'.
use_highly_variable
Use only HVGs if available. Default True.
chunk_size
Cells per chunk. Auto-calculated if None.
random_state
Random seed.
copy
If True, return copy with results instead of in-place.
show_progress
Show progress bars.
Returns
-------
AnnData or None
Modified AnnData if copy=True, else None.
"""
from .dimred import pca as _pca
if isinstance(data, (str, Path)):
adata = ad.read_h5ad(data, backed='r')
else:
adata = data
return _pca(
adata,
n_comps=n_comps,
method=method,
use_highly_variable=use_highly_variable,
chunk_size=chunk_size,
random_state=random_state,
copy=copy,
show_progress=show_progress,
)
[docs]
def neighbors(
self,
data: str | Path | ad.AnnData,
n_neighbors: int = 15,
n_pcs: int | None = None,
use_rep: str = "X_pca",
metric: str = "euclidean",
method: str = "umap",
random_state: int = 0,
copy: bool = False,
show_progress: bool = True,
) -> ad.AnnData | None:
"""Compute k-nearest neighbors graph from embeddings.
Parameters
----------
data
Path to h5ad file or backed AnnData with PCA results.
n_neighbors
Number of neighbors. Default 15.
n_pcs
Number of PCs to use. Default None uses all.
use_rep
Key in .obsm for embeddings. Default 'X_pca'.
metric
Distance metric. Default 'euclidean'.
method
'umap' (fast, pynndescent) or 'sklearn' (exact).
random_state
Random seed.
copy
If True, return copy with results.
show_progress
Show progress.
Returns
-------
AnnData or None
Modified AnnData if copy=True, else None.
"""
from .dimred import neighbors as _neighbors
if isinstance(data, (str, Path)):
adata = ad.read_h5ad(data, backed='r')
else:
adata = data
return _neighbors(
adata,
n_neighbors=n_neighbors,
n_pcs=n_pcs,
use_rep=use_rep,
metric=metric,
method=method,
random_state=random_state,
copy=copy,
show_progress=show_progress,
)
[docs]
class _PseudobulkNamespace:
"""Pseudo-bulk estimators (``cx.pb``)."""
[docs]
def average_log_expression(
self,
data: str | Path | ad.AnnData,
*,
perturbation_column: str,
control_label: str | None = None,
gene_name_column: str | None = None,
perturbations: Iterable[str] | None = None,
chunk_size: int = 2048,
output_dir: str | Path | None = None,
data_name: str | None = None,
):
path = resolve_data_path(data)
return compute_average_log_expression(
path,
perturbation_column=perturbation_column,
control_label=control_label,
gene_name_column=gene_name_column,
perturbations=perturbations,
chunk_size=chunk_size,
output_dir=output_dir,
data_name=data_name,
)
[docs]
def pseudobulk(
self,
data: str | Path | ad.AnnData,
*,
perturbation_column: str,
control_label: str | None = None,
gene_name_column: str | None = None,
perturbations: Iterable[str] | None = None,
baseline_count: float = 1.0,
chunk_size: int = 2048,
output_dir: str | Path | None = None,
data_name: str | None = None,
):
path = resolve_data_path(data)
return compute_pseudobulk_expression(
path,
perturbation_column=perturbation_column,
control_label=control_label,
gene_name_column=gene_name_column,
perturbations=perturbations,
baseline_count=baseline_count,
chunk_size=chunk_size,
output_dir=output_dir,
data_name=data_name,
)
[docs]
class _PlottingNamespace:
"""Scanpy-style plotting entry points (``cx.pl``)."""
[docs]
def rank_genes_groups(self, data, **kwargs):
return plot_rank_genes_groups(data, **kwargs)
[docs]
def rank_genes_groups_df(self, data, group, **kwargs):
return rank_genes_groups_df(data, group, **kwargs)
[docs]
def volcano(self, **kwargs):
return plot_volcano(**kwargs)
[docs]
def ma(self, **kwargs):
return plot_ma(**kwargs)
[docs]
def top_genes_bar(self, **kwargs):
return plot_top_genes_bar(**kwargs)
[docs]
def qc_perturbation_counts(self, **kwargs):
return plot_qc_perturbation_counts(**kwargs)
[docs]
def qc_summary(self, qc_result, **kwargs):
return plot_qc_summary(qc_result, **kwargs)
[docs]
def materialize_rank_genes_groups(self, data, **kwargs):
return materialize_rank_genes_groups(data, **kwargs)
[docs]
def pca(self, data, **kwargs):
"""Plot PCA scatter."""
return plot_pca(data, **kwargs)
[docs]
def pca_variance_ratio(self, data, **kwargs):
"""Plot PCA variance ratio."""
return plot_pca_variance_ratio(data, **kwargs)
[docs]
def pca_loadings(self, data, **kwargs):
"""Plot PCA loadings."""
return plot_pca_loadings(data, **kwargs)
[docs]
def umap(self, data, **kwargs):
"""Plot UMAP embedding."""
return plot_umap(data, **kwargs)
[docs]
def overlap_heatmap(self, result, **kwargs):
"""Plot pairwise overlap heatmap. See :func:`crispyx.plot_overlap_heatmap`."""
return plot_overlap_heatmap(result, **kwargs)