"""Quality control utilities for large ``.h5ad`` datasets."""
from __future__ import annotations
import gc
import logging
import shutil
import tempfile
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, Literal, Tuple, Union
import anndata as ad
import numpy as np
import pandas as pd
import scipy.sparse as sp
from .data import (
AnnData,
_ensure_csr,
calculate_optimal_chunk_size,
ensure_gene_symbol_column,
get_matrix_storage_format,
is_dense_storage,
iter_matrix_chunks,
read_backed,
resolve_control_label,
resolve_data_path,
resolve_output_path,
write_filtered_subset,
)
logger = logging.getLogger(__name__)
@dataclass
class _CellFilterResult:
"""Result of cell filtering with both cell and gene statistics."""
cell_mask: np.ndarray
gene_counts_per_cell: np.ndarray
gene_cell_counts_all: np.ndarray # cells per gene for ALL cells (before perturbation filter)
class _ChunkCache:
"""In-memory cache for CSR chunk data during QC.
Stores CSR chunks in memory during the gene filtering pass, then
streams them to the write phase without re-reading the original matrix.
For very large datasets, consider using use_chunk_cache=False to avoid
memory overhead.
Parameters
----------
output_path
Base path for the output file (not used for caching, kept for API compat).
"""
def __init__(self, output_path: Path | str) -> None:
self.output_path = Path(output_path)
self._chunks: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = [] # (data, indices, indptr_diff)
self._n_cols: int = 0
def write_chunk(
self,
chunk_idx: int,
data: np.ndarray,
indices: np.ndarray,
indptr_diff: np.ndarray,
n_cols: int,
) -> None:
"""Store CSR chunk data in memory."""
# Ensure list is large enough
while len(self._chunks) <= chunk_idx:
self._chunks.append(None)
self._chunks[chunk_idx] = (data, indices, indptr_diff)
self._n_cols = n_cols
def iter_filtered_chunks(
self,
gene_indices: np.ndarray,
data_dtype: np.dtype,
) -> Iterable[tuple[np.ndarray, np.ndarray, int]]:
"""Iterate through cached chunks, yielding filtered CSR data.
Uses vectorized operations for efficiency.
Parameters
----------
gene_indices
Indices of genes to keep.
data_dtype
Target dtype for data array.
Yields
------
tuple
(filtered_data, filtered_indices, n_cells_in_chunk)
"""
# Build vectorized index remapping
gene_set = set(gene_indices.tolist())
# Create a dense lookup array: old_idx -> new_idx (or -1 if not kept)
remap = np.full(self._n_cols, -1, dtype=np.int32)
remap[gene_indices] = np.arange(len(gene_indices), dtype=np.int32)
for chunk in self._chunks:
if chunk is None:
continue
data, indices, indptr_diff = chunk
n_cells = len(indptr_diff)
# Vectorized filtering: find which entries have kept genes
new_col_indices = remap[indices] # -1 for dropped genes
keep_mask = new_col_indices >= 0
# Filter data and indices
filtered_data = data[keep_mask].astype(data_dtype, copy=False)
filtered_indices = new_col_indices[keep_mask]
yield filtered_data, filtered_indices, n_cells
def cleanup(self) -> None:
"""Clear cached data from memory."""
self._chunks.clear()
@property
def chunk_count(self) -> int:
"""Number of chunks cached."""
return len(self._chunks)
class _MemmapChunkCache:
"""Memory-mapped chunk cache for QC streaming.
Uses numpy memmap to store CSR chunks on disk with memory-mapped access,
reducing RAM usage while maintaining fast access through OS page caching.
Similar to the approach used in crispyx.glm for large streaming operations.
Parameters
----------
output_path
Base path for output file. Cache files created in a temp directory.
estimated_nnz
Estimated total non-zeros for pre-allocation.
data_dtype
Data type for values array.
"""
def __init__(
self,
output_path: Path | str,
estimated_nnz: int,
data_dtype: np.dtype = np.float32,
) -> None:
self.output_path = Path(output_path)
self._estimated_nnz = estimated_nnz
self._data_dtype = data_dtype
# Create temp directory for memmap files
self._cache_dir = Path(tempfile.mkdtemp(prefix="crispyx_qc_cache_"))
# Pre-allocate memory-mapped arrays
self._data_mmap = np.memmap(
self._cache_dir / "data.mmap",
dtype=data_dtype,
mode='w+',
shape=(estimated_nnz,),
)
self._indices_mmap = np.memmap(
self._cache_dir / "indices.mmap",
dtype=np.int32,
mode='w+',
shape=(estimated_nnz,),
)
# Track chunk boundaries: list of (start, end, indptr_diff)
self._chunk_info: list[tuple[int, int, np.ndarray]] = []
self._current_offset = 0
self._n_cols: int = 0
def write_chunk(
self,
chunk_idx: int,
data: np.ndarray,
indices: np.ndarray,
indptr_diff: np.ndarray,
n_cols: int,
) -> None:
"""Write CSR chunk data to memory-mapped files."""
nnz = len(data)
if nnz == 0:
# Store empty chunk info
self._chunk_info.append((self._current_offset, self._current_offset, indptr_diff.copy()))
self._n_cols = n_cols
return
start = self._current_offset
end = start + nnz
# Check if we need to expand the memmap (shouldn't happen with good estimate)
if end > self._estimated_nnz:
logger.warning(
"Memmap cache overflow: estimated %d nnz but need %d. "
"Expanding cache (may be slower).",
self._estimated_nnz, end
)
self._expand_memmaps(end)
# Write to memmap (data stays on disk, not in RAM)
self._data_mmap[start:end] = data
self._indices_mmap[start:end] = indices
# Store chunk metadata (indptr_diff is small, keep in memory)
self._chunk_info.append((start, end, indptr_diff.copy()))
self._current_offset = end
self._n_cols = n_cols
def _expand_memmaps(self, new_size: int) -> None:
"""Expand memory-mapped arrays to accommodate more data."""
# Add 20% buffer to avoid repeated expansions
new_size = int(new_size * 1.2)
# Create new larger memmaps
new_data_path = self._cache_dir / "data_expanded.mmap"
new_indices_path = self._cache_dir / "indices_expanded.mmap"
new_data = np.memmap(new_data_path, dtype=self._data_dtype, mode='w+', shape=(new_size,))
new_indices = np.memmap(new_indices_path, dtype=np.int32, mode='w+', shape=(new_size,))
# Copy existing data
new_data[:self._current_offset] = self._data_mmap[:self._current_offset]
new_indices[:self._current_offset] = self._indices_mmap[:self._current_offset]
# Close and delete old memmaps
del self._data_mmap
del self._indices_mmap
(self._cache_dir / "data.mmap").unlink(missing_ok=True)
(self._cache_dir / "indices.mmap").unlink(missing_ok=True)
# Rename new files
new_data_path.rename(self._cache_dir / "data.mmap")
new_indices_path.rename(self._cache_dir / "indices.mmap")
# Update references
self._data_mmap = new_data
self._indices_mmap = new_indices
self._estimated_nnz = new_size
def iter_filtered_chunks(
self,
gene_indices: np.ndarray,
data_dtype: np.dtype,
) -> Iterable[tuple[np.ndarray, np.ndarray, int]]:
"""Iterate cached chunks, filtering genes on-the-fly.
Parameters
----------
gene_indices
Indices of genes to keep.
data_dtype
Target dtype for data array.
Yields
------
tuple
(filtered_data, filtered_indices, n_cells_in_chunk)
"""
# Build gene index remapping
remap = np.full(self._n_cols, -1, dtype=np.int32)
remap[gene_indices] = np.arange(len(gene_indices), dtype=np.int32)
for start, end, indptr_diff in self._chunk_info:
n_cells = len(indptr_diff)
if start == end:
# Empty chunk
yield np.array([], dtype=data_dtype), np.array([], dtype=np.int32), n_cells
continue
# Read from memmap (OS handles paging efficiently)
data = np.array(self._data_mmap[start:end]) # Copy to regular array
indices = np.array(self._indices_mmap[start:end])
# Filter genes
new_indices = remap[indices]
keep_mask = new_indices >= 0
yield data[keep_mask].astype(data_dtype, copy=False), new_indices[keep_mask], n_cells
def cleanup(self) -> None:
"""Delete memory-mapped files and cache directory."""
# Close memmap references
del self._data_mmap
del self._indices_mmap
# Remove cache directory
shutil.rmtree(self._cache_dir, ignore_errors=True)
@property
def chunk_count(self) -> int:
"""Number of chunks cached."""
return len(self._chunk_info)
[docs]
@dataclass
class QualityControlResult:
"""Result of quality control filtering."""
cell_mask: np.ndarray
gene_mask: np.ndarray
perturbation_keep: Dict[str, bool]
filtered: AnnData | None # None if output_dir was not provided
cell_gene_counts: np.ndarray
gene_cell_counts: np.ndarray
@property
def filtered_path(self) -> Path | None:
"""Compatibility accessor exposing the on-disk filename.
Returns None if no output file was written (output_dir was None).
"""
if self.filtered is None:
return None
return self.filtered.path
[docs]
def filter_cells_by_gene_count(
data: str | Path | AnnData | ad.AnnData,
*,
min_genes: int = 100,
gene_name_column: str | None = None,
chunk_size: int = 2048,
return_counts: bool = False,
return_full_result: bool = False,
) -> np.ndarray | Tuple[np.ndarray, np.ndarray] | _CellFilterResult:
"""Return a boolean mask selecting cells with at least ``min_genes`` expressed genes.
This function can optionally compute both genes-per-cell (row nnz) AND
cells-per-gene (column nnz) in a single matrix pass, avoiding a separate
gene counting pass later.
Parameters
----------
data
Path to h5ad file, or a crispyx/anndata AnnData object.
min_genes
Minimum number of expressed genes per cell.
gene_name_column
Column in var containing gene names.
chunk_size
Number of cells to process per chunk.
return_counts
If True, return (mask, counts) tuple instead of just mask.
Ignored if return_full_result is True.
return_full_result
If True, return a _CellFilterResult containing cell_mask,
gene_counts_per_cell, and gene_cell_counts_all (cells per gene
for all cells, before any perturbation filtering).
Returns
-------
mask or (mask, counts) or _CellFilterResult
Boolean mask, optionally with counts, or full result dataclass.
"""
path = resolve_data_path(data)
backed = read_backed(path)
try:
ensure_gene_symbol_column(backed, gene_name_column)
n_obs = backed.n_obs
n_vars = backed.n_vars
gene_counts_per_cell = np.zeros(n_obs, dtype=np.int64)
# Only compute cells-per-gene if full result requested
if return_full_result:
gene_cell_counts_all = np.zeros(n_vars, dtype=np.int64)
else:
gene_cell_counts_all = None
for slc, block in iter_matrix_chunks(backed, axis=0, chunk_size=chunk_size, convert_to_dense=False):
if sp.issparse(block):
gene_counts_per_cell[slc] = np.asarray(block.getnnz(axis=1)).ravel()
if gene_cell_counts_all is not None:
gene_cell_counts_all += np.asarray(block.getnnz(axis=0)).ravel()
else:
gene_counts_per_cell[slc] = np.count_nonzero(block, axis=1)
if gene_cell_counts_all is not None:
gene_cell_counts_all += np.count_nonzero(block, axis=0)
finally:
backed.file.close()
mask = gene_counts_per_cell >= min_genes
if return_full_result:
return _CellFilterResult(
cell_mask=mask,
gene_counts_per_cell=gene_counts_per_cell,
gene_cell_counts_all=gene_cell_counts_all,
)
if return_counts:
return mask, gene_counts_per_cell
return mask
[docs]
def filter_perturbations_by_cell_count(
data: str | Path | AnnData | ad.AnnData,
*,
perturbation_column: str,
control_label: str | None = None,
min_cells: int = 50,
base_mask: np.ndarray | None = None,
return_counts: bool = False,
) -> np.ndarray | Tuple[np.ndarray, pd.Series]:
"""Return a mask keeping cells whose perturbation has sufficient representation.
Parameters
----------
data
Path to h5ad file, or a crispyx/anndata AnnData object.
perturbation_column
Column in obs containing perturbation labels.
control_label
Label identifying control cells. If None, auto-detected.
min_cells
Minimum number of cells required per perturbation.
base_mask
Optional mask for cells to consider (e.g., from prior filtering).
return_counts
If True, return (mask, cell_counts_per_perturbation) tuple.
Returns
-------
mask or (mask, counts)
Boolean mask, optionally with cell counts per perturbation label.
"""
path = resolve_data_path(data)
backed = read_backed(path)
try:
if perturbation_column not in backed.obs.columns:
raise KeyError(
f"Perturbation column '{perturbation_column}' was not found in adata.obs. Available columns: {list(backed.obs.columns)}"
)
labels = backed.obs[perturbation_column].astype(str).to_numpy()
control_label = resolve_control_label(labels, control_label)
finally:
backed.file.close()
if base_mask is None:
base_mask = np.ones_like(labels, dtype=bool)
# Vectorized implementation: count cells per perturbation among base_mask cells
label_series = pd.Series(labels)
counts = label_series[base_mask].value_counts()
# Map counts back to each cell (vectorized lookup)
count_per_cell = label_series.map(counts).fillna(0).to_numpy()
# Keep cell if: (is control) OR (has enough cells AND passes base_mask)
is_control = labels == control_label
has_enough_cells = count_per_cell >= min_cells
mask = (is_control | has_enough_cells) & base_mask
if return_counts:
return mask, counts
return mask
[docs]
def filter_genes_by_cell_count(
data: str | Path | AnnData | ad.AnnData,
*,
min_cells: int = 100,
cell_mask: np.ndarray | None = None,
gene_name_column: str | None = None,
chunk_size: int = 2048,
return_counts: bool = False,
) -> np.ndarray | Tuple[np.ndarray, np.ndarray]:
"""Return a boolean mask selecting genes expressed in at least ``min_cells`` cells.
Parameters
----------
data
Path to h5ad file, or a crispyx/anndata AnnData object.
min_cells
Minimum number of cells expressing each gene.
cell_mask
Optional mask for cells to consider.
gene_name_column
Column in var containing gene names.
chunk_size
Number of cells to process per chunk.
return_counts
If True, return (mask, counts) tuple instead of just mask.
Returns
-------
mask or (mask, counts)
Boolean mask, optionally with the raw cell counts per gene.
"""
path = resolve_data_path(data)
backed = read_backed(path)
try:
ensure_gene_symbol_column(backed, gene_name_column)
counts = np.zeros(backed.n_vars, dtype=np.int64)
if cell_mask is None:
cell_mask = np.ones(backed.n_obs, dtype=bool)
for slc, block in iter_matrix_chunks(backed, axis=0, chunk_size=chunk_size, convert_to_dense=False):
local_mask = cell_mask[slc]
if not np.any(local_mask):
continue
selected = block[local_mask]
if sp.issparse(selected):
counts += np.asarray(selected.getnnz(axis=0)).ravel()
else:
counts += np.count_nonzero(selected, axis=0)
finally:
backed.file.close()
mask = counts >= min_cells
if return_counts:
return mask, counts
return mask
# Type alias for chunk cache (either in-memory or memmap)
_ChunkCacheType = Union[_ChunkCache, _MemmapChunkCache, None]
@dataclass
class _GeneFilterResult:
"""Result of fused gene filtering and nnz counting."""
gene_mask: np.ndarray
gene_cell_counts: np.ndarray
row_nnz: np.ndarray
total_nnz: int
data_dtype: np.dtype
chunk_cache: _ChunkCacheType = None # Optional cache for write phase (memory or memmap)
def _compute_gene_count_delta(
path: str | Path,
*,
removed_cell_mask: np.ndarray,
gene_name_column: str | None = None,
chunk_size: int = 2048,
) -> np.ndarray:
"""Compute gene counts for removed cells only (for delta adjustment).
When perturbation filtering removes cells that passed the gene-count filter,
we need to adjust the all-cell gene counts by subtracting counts from
removed cells. This function iterates only the removed cells.
Parameters
----------
path
Path to h5ad file.
removed_cell_mask
Boolean mask where True indicates cells to count (i.e., cells that
passed gene filter but were removed by perturbation filter).
gene_name_column
Column in var containing gene names.
chunk_size
Number of cells to process per chunk.
Returns
-------
np.ndarray
Gene counts for the removed cells only. Subtract from all-cell counts
to get counts for filtered cells.
"""
backed = read_backed(path)
try:
ensure_gene_symbol_column(backed, gene_name_column)
n_vars = backed.n_vars
delta_counts = np.zeros(n_vars, dtype=np.int64)
for slc, block in iter_matrix_chunks(backed, axis=0, chunk_size=chunk_size, convert_to_dense=False):
local_mask = removed_cell_mask[slc]
if not np.any(local_mask):
continue
selected = block[local_mask]
if sp.issparse(selected):
delta_counts += np.asarray(selected.getnnz(axis=0)).ravel()
else:
delta_counts += np.count_nonzero(selected, axis=0)
finally:
backed.file.close()
return delta_counts
def _filter_genes_with_cache(
path: str | Path,
*,
min_cells: int = 100,
cell_mask: np.ndarray,
gene_cell_counts: np.ndarray,
gene_name_column: str | None = None,
chunk_size: int = 2048,
output_path: Path | None = None,
cache_mode: Literal['memory', 'memmap', 'none'] = 'memmap',
) -> _GeneFilterResult:
"""Compute gene mask and cache CSR data in a single matrix pass.
This function does a single matrix pass that:
1. Caches CSR chunk data (data, indices, indptr_diff) to memory or disk
2. Uses pre-computed gene_cell_counts to determine gene_mask
3. Computes row_nnz and total_nnz from cached data
By caching CSR data during this pass, the write phase can read from
cache instead of re-reading the original matrix, reducing total passes
from 4 to 2.
Parameters
----------
path
Path to h5ad file.
min_cells
Minimum number of cells expressing each gene.
cell_mask
Boolean mask for cells to include (from prior cell filtering).
gene_cell_counts
Pre-computed cells per gene for filtered cells.
gene_name_column
Column in var containing gene names.
chunk_size
Number of cells to process per chunk.
output_path
Path for output file. Required if cache_mode is not 'none'.
cache_mode
Cache strategy: 'memory' (fast, high RAM), 'memmap' (low RAM, disk-based),
or 'none' (no caching, requires re-reading source during write).
Default is 'memmap' for better memory efficiency.
Returns
-------
_GeneFilterResult
Dataclass containing gene_mask, gene_cell_counts, row_nnz, total_nnz,
data_dtype, and optionally chunk_cache for the write phase.
"""
# Compute gene mask from pre-computed counts
gene_mask = gene_cell_counts >= min_cells
gene_indices = np.flatnonzero(gene_mask)
# Estimate nnz for memmap pre-allocation (from file size heuristic)
# This is a rough estimate; memmap will expand if needed
backed = read_backed(path)
try:
ensure_gene_symbol_column(backed, gene_name_column)
n_vars = backed.n_vars
n_obs = backed.n_obs
n_filtered_cells = int(cell_mask.sum())
# Estimate nnz based on filtered cells ratio
# We cache FULL CSR data (before gene filtering) so estimate from total nnz
try:
# Try to get nnz from backed sparse dataset
if hasattr(backed.X, 'group') and 'data' in backed.X.group:
# Backed sparse: access nnz from HDF5 data array
total_file_nnz = len(backed.X.group['data'])
estimated_nnz = int(total_file_nnz * (n_filtered_cells / n_obs) * 1.2)
elif sp.issparse(backed.X):
# In-memory sparse (shouldn't happen for large files)
total_file_nnz = backed.X.nnz
estimated_nnz = int(total_file_nnz * (n_filtered_cells / n_obs) * 1.2)
else:
# Dense: estimate ~10% non-zero (typical for scRNA-seq)
estimated_nnz = int(n_filtered_cells * n_vars * 0.1 * 1.2)
except Exception:
# Fallback: assume ~10% density
estimated_nnz = int(n_filtered_cells * n_vars * 0.1 * 1.2)
finally:
backed.file.close()
# Create cache based on mode
chunk_cache: _ChunkCacheType = None
if output_path is not None and cache_mode != 'none':
if cache_mode == 'memory':
chunk_cache = _ChunkCache(output_path)
elif cache_mode == 'memmap':
# Detect data dtype from file for memmap pre-allocation
backed = read_backed(path)
try:
if sp.issparse(backed.X):
data_dtype_hint = backed.X.dtype
else:
data_dtype_hint = backed.X.dtype
finally:
backed.file.close()
chunk_cache = _MemmapChunkCache(output_path, estimated_nnz, data_dtype_hint)
backed = read_backed(path)
try:
row_nnz = np.zeros(n_filtered_cells, dtype=np.int64)
total_nnz = 0
data_dtype: np.dtype | None = None
row_offset = 0
chunk_idx = 0
for slc, block in iter_matrix_chunks(backed, axis=0, chunk_size=chunk_size, convert_to_dense=False):
local_cell_mask = cell_mask[slc]
if not np.any(local_cell_mask):
continue
selected = block[local_cell_mask]
csr = _ensure_csr(selected)
# Cache the full CSR data for this chunk (before gene filtering)
if chunk_cache is not None:
chunk_cache.write_chunk(
chunk_idx,
data=csr.data.copy(),
indices=csr.indices.copy(),
indptr_diff=np.diff(csr.indptr),
n_cols=n_vars,
)
# Apply gene mask and compute nnz
if gene_indices.size:
filtered = csr[:, gene_indices]
else:
filtered = csr[:, []]
filtered_csr = _ensure_csr(filtered)
counts = np.diff(filtered_csr.indptr)
size = counts.size
row_nnz[row_offset : row_offset + size] = counts
total_nnz += int(filtered_csr.nnz)
if data_dtype is None and csr.nnz:
data_dtype = csr.data.dtype
row_offset += size
chunk_idx += 1
finally:
backed.file.close()
if data_dtype is None:
data_dtype = np.float32
return _GeneFilterResult(
gene_mask=gene_mask,
gene_cell_counts=gene_cell_counts,
row_nnz=row_nnz,
total_nnz=total_nnz,
data_dtype=data_dtype,
chunk_cache=chunk_cache,
)
def _filter_genes_dense_optimized(
path: str | Path,
*,
min_cells: int = 100,
cell_mask: np.ndarray,
gene_cell_counts: np.ndarray,
gene_mask: np.ndarray,
gene_name_column: str | None = None,
chunk_size: int = 2048,
) -> _GeneFilterResult:
"""Optimized gene filtering for dense-stored datasets.
This function avoids expensive CSR conversion by directly computing
row nnz from dense blocks using vectorized numpy operations. For
datasets stored as dense arrays (encoding-type='array'), this is
significantly faster than the cache-based approach.
Unlike _filter_genes_with_cache, this does NOT cache data because:
1. Dense→CSR conversion is the bottleneck we're avoiding
2. Re-reading dense data in write phase is faster than conversion
Parameters
----------
path
Path to h5ad file.
min_cells
Minimum number of cells expressing each gene.
cell_mask
Boolean mask for cells to include (from prior cell filtering).
gene_cell_counts
Pre-computed cells per gene for filtered cells.
gene_mask
Pre-computed gene mask (genes with >= min_cells).
gene_name_column
Column in var containing gene names.
chunk_size
Number of cells to process per chunk.
Returns
-------
_GeneFilterResult
Dataclass containing gene_mask, gene_cell_counts, row_nnz, total_nnz,
data_dtype. chunk_cache is always None for dense path.
"""
gene_indices = np.flatnonzero(gene_mask)
backed = read_backed(path)
try:
ensure_gene_symbol_column(backed, gene_name_column)
n_filtered_cells = int(cell_mask.sum())
row_nnz = np.zeros(n_filtered_cells, dtype=np.int64)
total_nnz = 0
data_dtype: np.dtype | None = None
row_offset = 0
for slc, block in iter_matrix_chunks(backed, axis=0, chunk_size=chunk_size, convert_to_dense=False):
local_cell_mask = cell_mask[slc]
if not np.any(local_cell_mask):
continue
selected = block[local_cell_mask]
# Apply gene filter
if gene_indices.size:
filtered = selected[:, gene_indices]
else:
filtered = selected[:, []]
# Vectorized nnz counting - works for both dense and sparse
if sp.issparse(filtered):
counts = np.asarray(filtered.getnnz(axis=1)).ravel()
chunk_nnz = int(filtered.nnz)
if data_dtype is None and filtered.nnz:
data_dtype = filtered.data.dtype
else:
# Dense: count non-zeros per row without CSR conversion
counts = np.count_nonzero(filtered, axis=1)
chunk_nnz = int(counts.sum())
if data_dtype is None:
data_dtype = filtered.dtype
size = counts.size
row_nnz[row_offset : row_offset + size] = counts
total_nnz += chunk_nnz
row_offset += size
finally:
backed.file.close()
if data_dtype is None:
data_dtype = np.float32
return _GeneFilterResult(
gene_mask=gene_mask,
gene_cell_counts=gene_cell_counts,
row_nnz=row_nnz,
total_nnz=total_nnz,
data_dtype=data_dtype,
chunk_cache=None, # No caching for dense path
)
def _qc_in_memory(
path: str | Path,
*,
min_genes: int,
min_cells_per_perturbation: int,
min_cells_per_gene: int,
perturbation_column: str,
control_label: str,
gene_name_column: str | None,
output_path: Path,
) -> QualityControlResult:
"""Fast in-memory QC for small datasets (Option A).
Loads entire dataset into memory, processes like Scanpy, and saves.
This is the fastest approach for datasets that fit in RAM.
Memory-optimized: Uses in-place slicing with a single copy at the end,
matching Scanpy's memory efficiency.
Parameters
----------
path
Path to h5ad file.
min_genes
Minimum genes per cell.
min_cells_per_perturbation
Minimum cells per perturbation.
min_cells_per_gene
Minimum cells expressing each gene.
perturbation_column
Column in obs containing perturbation labels.
control_label
Control label (already resolved).
gene_name_column
Column in var containing gene names.
output_path
Path for output h5ad file.
Returns
-------
QualityControlResult
QC result with masks and filtered AnnData.
"""
logger.debug("Using in-memory QC path (small dataset)")
# Load entire dataset
adata = ad.read_h5ad(path)
original_n_obs = adata.n_obs
original_n_vars = adata.n_vars
original_obs_names = adata.obs_names.to_numpy()
original_var_names = adata.var_names.to_numpy()
# Convert to CSR if needed (handles CSC)
if sp.issparse(adata.X) and not sp.isspmatrix_csr(adata.X):
adata.X = adata.X.tocsr()
# Get gene names before any filtering
gene_names = ensure_gene_symbol_column(adata, gene_name_column)
# Compute gene counts per cell before filtering
if sp.issparse(adata.X):
gene_counts_per_cell = np.asarray(adata.X.getnnz(axis=1)).ravel()
else:
gene_counts_per_cell = np.count_nonzero(adata.X, axis=1)
# ===== Build combined cell mask (avoid intermediate copies) =====
# Step 1: Gene count filter
cell_mask_genes = gene_counts_per_cell >= min_genes
# Step 2: Perturbation filter - compute on original data with gene mask
labels_full = adata.obs[perturbation_column].astype(str).to_numpy()
# Only count cells that pass gene filter
labels_passing = labels_full.copy()
labels_passing[~cell_mask_genes] = '__FILTERED__'
pert_counts = Counter(labels_passing)
del pert_counts['__FILTERED__'] # Remove placeholder
# Build perturbation mask on original cells
cell_mask_pert = np.array([
cell_mask_genes[i] and (
labels_full[i] == control_label or
pert_counts.get(labels_full[i], 0) >= min_cells_per_perturbation
)
for i in range(len(labels_full))
])
# Combined cell mask
combined_cell_mask = cell_mask_pert # Already includes gene filter
# ===== Compute gene stats on filtered cells only =====
# Create a view (not copy) for gene stats computation
X_filtered_cells = adata.X[combined_cell_mask]
if sp.issparse(X_filtered_cells):
gene_cell_counts = np.asarray(X_filtered_cells.getnnz(axis=0)).ravel()
else:
gene_cell_counts = np.count_nonzero(X_filtered_cells, axis=0)
del X_filtered_cells
# Build gene mask
gene_mask = gene_cell_counts >= min_cells_per_gene
# ===== Single copy at the end (like Scanpy) =====
adata_filtered = adata[combined_cell_mask, :][:, gene_mask].copy()
# Free original data immediately
del adata
gc.collect()
# Add gene_symbols to var
adata_filtered.var["gene_symbols"] = gene_names[gene_mask].to_numpy()
# Drop stale categories so downstream tools (e.g. scanpy) only see
# groups that have at least one cell in the filtered subset.
for _col in adata_filtered.obs.columns:
if isinstance(adata_filtered.obs[_col].dtype, pd.CategoricalDtype):
adata_filtered.obs[_col] = adata_filtered.obs[_col].cat.remove_unused_categories()
# Save
output_path.parent.mkdir(parents=True, exist_ok=True)
adata_filtered.write(output_path)
# Expand gene_cell_counts to original size
gene_cell_counts_full = np.zeros(original_n_vars, dtype=np.int64)
gene_cell_counts_full[gene_mask] = gene_cell_counts[gene_mask]
# Build perturbation_keep dict from filtered data
filtered_labels = adata_filtered.obs[perturbation_column].astype(str)
pert_counts_final = filtered_labels.value_counts()
perturbation_keep = {
label: (label == control_label) or (pert_counts_final.get(label, 0) >= min_cells_per_perturbation)
for label in filtered_labels.unique()
}
# Return backed view for consistency
del adata_filtered
gc.collect()
filtered_adata_view = AnnData(output_path)
return QualityControlResult(
cell_mask=combined_cell_mask,
gene_mask=gene_mask,
perturbation_keep=perturbation_keep,
filtered=filtered_adata_view,
cell_gene_counts=gene_counts_per_cell,
gene_cell_counts=gene_cell_counts_full,
)
def _qc_column_oriented(
path: str | Path,
*,
min_genes: int,
min_cells_per_perturbation: int,
min_cells_per_gene: int,
perturbation_column: str,
control_label: str,
gene_name_column: str | None,
chunk_size: int,
output_path: Path,
) -> QualityControlResult:
"""Column-oriented QC for large CSC datasets (Option B).
Iterates by column chunks (fast for CSC) and accumulates per-cell nnz.
This maintains O(1) memory relative to data size while being efficient
for CSC-stored files.
Parameters
----------
path
Path to h5ad file.
min_genes
Minimum genes per cell.
min_cells_per_perturbation
Minimum cells per perturbation.
min_cells_per_gene
Minimum cells expressing each gene.
perturbation_column
Column in obs containing perturbation labels.
control_label
Control label (already resolved).
gene_name_column
Column in var containing gene names.
chunk_size
Number of columns to process per chunk.
output_path
Path for output h5ad file.
Returns
-------
QualityControlResult
QC result with masks and filtered AnnData.
"""
logger.debug("Using column-oriented QC path (large CSC dataset)")
# Read metadata
backed = read_backed(path)
try:
n_obs, n_vars = backed.n_obs, backed.n_vars
gene_names = ensure_gene_symbol_column(backed, gene_name_column)
labels = backed.obs[perturbation_column].astype(str).to_numpy()
finally:
backed.file.close()
# Pass 1: Iterate by columns to compute both metrics
# - genes_per_cell: nnz count per row, accumulated across column chunks
# - cells_per_gene: nnz count per column, computed per chunk
genes_per_cell = np.zeros(n_obs, dtype=np.int64)
cells_per_gene_all = np.zeros(n_vars, dtype=np.int64)
backed = read_backed(path)
try:
for col_start in range(0, n_vars, chunk_size):
col_end = min(col_start + chunk_size, n_vars)
# Column slice is O(1) for CSC
block = backed.X[:, col_start:col_end]
if sp.issparse(block):
# Per-row nnz for this column chunk
genes_per_cell += np.asarray(block.getnnz(axis=1)).ravel()
# Per-column nnz
cells_per_gene_all[col_start:col_end] = np.asarray(block.getnnz(axis=0)).ravel()
else:
genes_per_cell += np.count_nonzero(block, axis=1)
cells_per_gene_all[col_start:col_end] = np.count_nonzero(block, axis=0)
finally:
backed.file.close()
gene_counts_per_cell = genes_per_cell
# Cell filtering
cell_mask = genes_per_cell >= min_genes
# Perturbation filtering (metadata only)
label_series = pd.Series(labels)
counts = label_series[cell_mask].value_counts()
count_per_cell = label_series.map(counts).fillna(0).to_numpy()
is_control = labels == control_label
has_enough = count_per_cell >= min_cells_per_perturbation
combined_cell_mask = (is_control | has_enough) & cell_mask
# Recompute cells_per_gene for filtered cells if needed
# (only if perturbation filtering removed cells)
removed_cells = cell_mask & ~combined_cell_mask
if removed_cells.any():
# Subtract counts from removed cells using column-oriented pass
backed = read_backed(path)
try:
for col_start in range(0, n_vars, chunk_size):
col_end = min(col_start + chunk_size, n_vars)
block = backed.X[:, col_start:col_end]
selected = block[removed_cells]
if sp.issparse(selected):
cells_per_gene_all[col_start:col_end] -= np.asarray(selected.getnnz(axis=0)).ravel()
else:
cells_per_gene_all[col_start:col_end] -= np.count_nonzero(selected, axis=0)
finally:
backed.file.close()
gene_cell_counts = cells_per_gene_all
gene_mask = gene_cell_counts >= min_cells_per_gene
# Write filtered subset using existing function
write_filtered_subset(
path,
cell_mask=combined_cell_mask,
gene_mask=gene_mask,
output_path=output_path,
chunk_size=chunk_size,
var_assignments={"gene_symbols": gene_names[gene_mask]},
)
filtered = AnnData(output_path)
filtered_labels = filtered.obs[perturbation_column].astype(str)
perturbation_keep = {
label: (label == control_label) or (filtered_labels[filtered_labels == label].shape[0] >= min_cells_per_perturbation)
for label in filtered_labels.unique()
}
return QualityControlResult(
cell_mask=combined_cell_mask,
gene_mask=gene_mask,
perturbation_keep=perturbation_keep,
filtered=filtered,
cell_gene_counts=gene_counts_per_cell,
gene_cell_counts=gene_cell_counts,
)
def _qc_row_oriented(
path: str | Path,
*,
min_genes: int,
min_cells_per_perturbation: int,
min_cells_per_gene: int,
perturbation_column: str,
control_label: str,
gene_name_column: str | None,
chunk_size: int,
output_path: Path,
cache_mode: Literal['memory', 'memmap', 'none'] = 'memmap',
delta_threshold: float = 0.3,
) -> QualityControlResult:
"""Row-oriented streaming QC for large CSR/dense datasets.
This is the original streaming implementation optimized for row-oriented
access patterns (CSR format or dense arrays).
Parameters
----------
path
Path to h5ad file.
min_genes
Minimum genes per cell.
min_cells_per_perturbation
Minimum cells per perturbation.
min_cells_per_gene
Minimum cells expressing each gene.
perturbation_column
Column in obs containing perturbation labels.
control_label
Control label (already resolved).
gene_name_column
Column in var containing gene names.
chunk_size
Number of cells to process per chunk.
output_path
Path for output h5ad file.
cache_mode
Cache strategy: 'memory' (fast, high RAM), 'memmap' (low RAM, disk-based),
or 'none' (no caching, requires re-reading source during write).
Default is 'memmap' for better memory efficiency.
delta_threshold
Threshold for delta adjustment.
Returns
-------
QualityControlResult
QC result with masks and filtered AnnData.
"""
logger.debug("Using row-oriented streaming QC path (large CSR/dense dataset)")
# Read metadata
backed = read_backed(path)
try:
gene_names = ensure_gene_symbol_column(backed, gene_name_column)
n_obs, n_vars = backed.n_obs, backed.n_vars
labels = backed.obs[perturbation_column].astype(str).to_numpy()
finally:
backed.file.close()
# Pass 1: Filter cells by gene count AND compute cells-per-gene for all cells
cell_filter_result = filter_cells_by_gene_count(
path,
min_genes=min_genes,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
return_full_result=True,
)
cell_mask = cell_filter_result.cell_mask
gene_counts_per_cell = cell_filter_result.gene_counts_per_cell
gene_cell_counts_all = cell_filter_result.gene_cell_counts_all
# No matrix pass: Filter perturbations by cell count (metadata only)
perturbation_mask = filter_perturbations_by_cell_count(
path,
perturbation_column=perturbation_column,
control_label=control_label,
min_cells=min_cells_per_perturbation,
base_mask=cell_mask,
)
combined_cell_mask = cell_mask & perturbation_mask
# Compute gene counts for filtered cells using delta adjustment
removed_cell_mask = cell_mask & ~perturbation_mask
n_removed = int(removed_cell_mask.sum())
n_filtered = int(combined_cell_mask.sum())
if n_removed == 0:
gene_cell_counts = gene_cell_counts_all
logger.debug("No cells removed by perturbation filter, using all-cell gene counts")
elif n_filtered > 0 and n_removed / n_filtered < delta_threshold:
logger.debug(
"Using delta adjustment: %d removed cells (%.1f%% of %d filtered)",
n_removed, 100 * n_removed / n_filtered, n_filtered
)
delta_counts = _compute_gene_count_delta(
path,
removed_cell_mask=removed_cell_mask,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
)
gene_cell_counts = gene_cell_counts_all - delta_counts
else:
logger.debug(
"Using full recompute: %d removed cells (%.1f%% of %d filtered)",
n_removed, 100 * n_removed / n_filtered if n_filtered > 0 else 0, n_filtered
)
_, gene_cell_counts = filter_genes_by_cell_count(
path,
min_cells=0,
cell_mask=combined_cell_mask,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
return_counts=True,
)
# Pass 2: Gene filtering with nnz counting
is_dense = is_dense_storage(path)
gene_mask = gene_cell_counts >= min_cells_per_gene
if is_dense:
logger.debug("Using dense-optimized path (source stored as dense array)")
gene_filter_result = _filter_genes_dense_optimized(
path,
min_cells=min_cells_per_gene,
cell_mask=combined_cell_mask,
gene_cell_counts=gene_cell_counts,
gene_mask=gene_mask,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
)
chunk_cache = None
else:
logger.debug("Using CSR cache path (source stored as sparse, cache_mode=%s)", cache_mode)
gene_filter_result = _filter_genes_with_cache(
path,
min_cells=min_cells_per_gene,
cell_mask=combined_cell_mask,
gene_cell_counts=gene_cell_counts,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
output_path=output_path,
cache_mode=cache_mode,
)
chunk_cache = gene_filter_result.chunk_cache
gene_mask = gene_filter_result.gene_mask
try:
write_filtered_subset(
path,
cell_mask=combined_cell_mask,
gene_mask=gene_mask,
output_path=output_path,
chunk_size=chunk_size,
var_assignments={"gene_symbols": gene_names[gene_mask]},
row_nnz=gene_filter_result.row_nnz,
total_nnz=gene_filter_result.total_nnz,
data_dtype=gene_filter_result.data_dtype,
chunk_cache=chunk_cache,
)
finally:
if chunk_cache is not None:
chunk_cache.cleanup()
filtered = AnnData(output_path)
filtered_labels = filtered.obs[perturbation_column].astype(str)
perturbation_keep = {
label: (label == control_label) or (filtered_labels[filtered_labels == label].shape[0] >= min_cells_per_perturbation)
for label in filtered_labels.unique()
}
return QualityControlResult(
cell_mask=combined_cell_mask,
gene_mask=gene_mask,
perturbation_keep=perturbation_keep,
filtered=filtered,
cell_gene_counts=gene_counts_per_cell,
gene_cell_counts=gene_cell_counts,
)
def _qc_masks_only(
path: str | Path,
*,
min_genes: int,
min_cells_per_perturbation: int,
min_cells_per_gene: int,
perturbation_column: str,
control_label: str,
gene_name_column: str | None,
chunk_size: int,
delta_threshold: float = 0.3,
) -> QualityControlResult:
"""Compute QC masks without writing output file.
This is a lightweight QC path that returns only the masks and statistics
without writing a filtered h5ad file. Useful for:
- Memory-constrained environments where users only need the masks
- Workflows that apply masks downstream in a custom manner
- Quick QC statistics without I/O overhead
Parameters
----------
path
Path to h5ad file.
min_genes
Minimum genes per cell.
min_cells_per_perturbation
Minimum cells per perturbation.
min_cells_per_gene
Minimum cells expressing each gene.
perturbation_column
Column in obs containing perturbation labels.
control_label
Control label (already resolved).
gene_name_column
Column in var containing gene names.
chunk_size
Number of cells to process per chunk.
delta_threshold
Threshold for delta adjustment.
Returns
-------
QualityControlResult
QC result with masks but filtered=None (no output file).
"""
logger.debug("Computing QC masks only (no output file)")
# Read metadata
backed = read_backed(path)
try:
labels = backed.obs[perturbation_column].astype(str).to_numpy()
finally:
backed.file.close()
# Pass 1: Filter cells by gene count AND compute cells-per-gene for all cells
cell_filter_result = filter_cells_by_gene_count(
path,
min_genes=min_genes,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
return_full_result=True,
)
cell_mask = cell_filter_result.cell_mask
gene_counts_per_cell = cell_filter_result.gene_counts_per_cell
gene_cell_counts_all = cell_filter_result.gene_cell_counts_all
# Filter perturbations by cell count (metadata only)
perturbation_mask = filter_perturbations_by_cell_count(
path,
perturbation_column=perturbation_column,
control_label=control_label,
min_cells=min_cells_per_perturbation,
base_mask=cell_mask,
)
combined_cell_mask = cell_mask & perturbation_mask
# Compute gene counts for filtered cells using delta adjustment
removed_cell_mask = cell_mask & ~perturbation_mask
n_removed = int(removed_cell_mask.sum())
n_filtered = int(combined_cell_mask.sum())
if n_removed == 0:
gene_cell_counts = gene_cell_counts_all
elif n_filtered > 0 and n_removed / n_filtered < delta_threshold:
delta_counts = _compute_gene_count_delta(
path,
removed_cell_mask=removed_cell_mask,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
)
gene_cell_counts = gene_cell_counts_all - delta_counts
else:
_, gene_cell_counts = filter_genes_by_cell_count(
path,
min_cells=0,
cell_mask=combined_cell_mask,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
return_counts=True,
)
# Compute gene mask
gene_mask = gene_cell_counts >= min_cells_per_gene
# Build perturbation_keep dict
filtered_labels = labels[combined_cell_mask]
unique_labels = np.unique(filtered_labels)
label_counts = pd.Series(filtered_labels).value_counts()
perturbation_keep = {
label: (label == control_label) or (label_counts.get(label, 0) >= min_cells_per_perturbation)
for label in unique_labels
}
return QualityControlResult(
cell_mask=combined_cell_mask,
gene_mask=gene_mask,
perturbation_keep=perturbation_keep,
filtered=None, # No output file
cell_gene_counts=gene_counts_per_cell,
gene_cell_counts=gene_cell_counts,
)
[docs]
def quality_control_summary(
data: str | Path | AnnData | ad.AnnData,
*,
min_genes: int = 100,
min_cells_per_perturbation: int = 50,
min_cells_per_gene: int = 100,
perturbation_column: str,
control_label: str | None = None,
gene_name_column: str | None = None,
chunk_size: int | None = None,
memory_limit_gb: float | None = None,
output_dir: str | Path | None = None,
data_name: str | None = None,
cache_mode: Literal['memory', 'memmap', 'none'] = 'memmap',
delta_threshold: float = 0.3,
force_streaming: bool = False,
) -> QualityControlResult:
"""Run QC with automatic strategy selection for optimal performance.
This function automatically selects the best QC strategy based on:
1. Small data (any format): In-memory processing (fastest)
2. Large CSC data: Column-oriented streaming (memory efficient for CSC)
3. Large CSR/dense data: Row-oriented streaming (current behavior)
Parameters
----------
data
Path to h5ad file, or a crispyx/anndata AnnData object.
min_genes
Minimum number of expressed genes per cell.
min_cells_per_perturbation
Minimum number of cells required per perturbation.
min_cells_per_gene
Minimum number of cells expressing each gene.
perturbation_column
Column in obs containing perturbation labels.
control_label
Label identifying control cells. If None, auto-detected.
gene_name_column
Column in var containing gene names.
chunk_size
Number of cells to process per chunk. If None, automatically
calculated based on available memory.
memory_limit_gb
Optional memory limit in GB for strategy selection and chunk size.
If None, auto-detected from system memory.
output_dir
Directory for output files. If None, returns QC masks without writing
a filtered h5ad file (QualityControlResult.filtered will be None).
data_name
Base name for output files.
cache_mode
Cache strategy for row-oriented streaming: 'memory' (fast, high RAM),
'memmap' (low RAM, disk-based), or 'none' (no caching, requires
re-reading source during write). Default is 'memmap' for better
memory efficiency.
delta_threshold
Threshold for delta adjustment in row-oriented streaming.
Default 0.3 (30%).
force_streaming
If True, always use streaming path regardless of data size.
Useful for testing or memory-constrained environments.
Returns
-------
QualityControlResult
Dataclass containing masks, filtered AnnData (or None if output_dir
is None), and QC statistics.
"""
path = resolve_data_path(data)
# Read metadata and resolve control label
backed = read_backed(path)
try:
n_obs, n_vars = backed.n_obs, backed.n_vars
if perturbation_column not in backed.obs.columns:
raise KeyError(
f"Perturbation column '{perturbation_column}' was not found in adata.obs. "
f"Available columns: {list(backed.obs.columns)}"
)
labels = backed.obs[perturbation_column].astype(str).to_numpy()
control_label_resolved = resolve_control_label(labels, control_label)
finally:
backed.file.close()
# Detect storage format and file size
storage_format = get_matrix_storage_format(path)
file_size_gb = path.stat().st_size / 1e9
# Determine available memory
if memory_limit_gb is None:
try:
import psutil
memory_limit_gb = psutil.virtual_memory().available / 1e9 * 0.5 # Use 50% of available
except ImportError:
memory_limit_gb = 8.0
# Estimate memory needed for in-memory processing.
# For dense arrays, the compressed file size can be much smaller than the
# actual uncompressed footprint, so we use n_obs * n_vars * dtype_size.
if storage_format == 'dense':
import h5py as _h5py
with _h5py.File(path, 'r') as _f:
_dtype = _f['X'].dtype if isinstance(_f.get('X'), _h5py.Dataset) else None
_itemsize = _dtype.itemsize if _dtype is not None else 4 # default float32
# 2x: one copy to load + one working copy
estimated_memory_gb = n_obs * n_vars * _itemsize * 2 / 1e9
else:
# For sparse HDF5 formats, the compressed file size significantly
# underestimates the in-memory footprint. HDF5 gzip compression
# achieves 3-5× for sparse scRNA-seq data, so a 27 GB file expands to
# 80+ GB when loaded as a scipy CSR matrix. QC operations then create
# additional working copies (boolean indexing, tocsr(), etc.), raising
# peak usage to ~4-6× the compressed file size.
# Use 4× as a conservative estimate to avoid OOM on large datasets.
estimated_memory_gb = file_size_gb * 4
# Determine chunk size for streaming paths
if chunk_size is None:
chunk_size = calculate_optimal_chunk_size(
n_obs, n_vars, available_memory_gb=memory_limit_gb
)
# Handle output_dir=None: return masks only without writing output
if output_dir is None:
logger.info("output_dir is None, returning QC masks without writing filtered h5ad")
return _qc_masks_only(
path,
min_genes=min_genes,
min_cells_per_perturbation=min_cells_per_perturbation,
min_cells_per_gene=min_cells_per_gene,
perturbation_column=perturbation_column,
control_label=control_label_resolved,
gene_name_column=gene_name_column,
chunk_size=chunk_size,
delta_threshold=delta_threshold,
)
# Resolve output path
filtered_path = resolve_output_path(
path, suffix="filtered", output_dir=output_dir, data_name=data_name
)
# Common kwargs for all strategies
common_kwargs = {
"min_genes": min_genes,
"min_cells_per_perturbation": min_cells_per_perturbation,
"min_cells_per_gene": min_cells_per_gene,
"perturbation_column": perturbation_column,
"control_label": control_label_resolved,
"gene_name_column": gene_name_column,
"output_path": filtered_path,
}
# Select strategy
# Cap the in-memory threshold at 50 GB to avoid OOM on nodes with very
# high memory_limit_gb (e.g., 500 GB auto-detected on 1 TB HPC nodes).
# On a 128 GB node this allows in-memory QC for files up to ~12.5 GB
# (estimated_memory = file_size × 4 < 50 GB → file < 12.5 GB).
_in_memory_threshold_gb = min(memory_limit_gb * 0.6, 50.0)
if not force_streaming and estimated_memory_gb < _in_memory_threshold_gb:
# Option A: In-memory for small datasets
logger.info(
f"Using in-memory QC (estimated: {estimated_memory_gb:.1f}GB, threshold: {_in_memory_threshold_gb:.1f}GB)"
)
return _qc_in_memory(path, **common_kwargs)
elif storage_format == 'csc':
# Option B: Column-oriented for large CSC
logger.info(
f"Using column-oriented streaming QC (CSC format, {file_size_gb:.2f}GB)"
)
return _qc_column_oriented(path, chunk_size=chunk_size, **common_kwargs)
else:
# Row-oriented streaming for large CSR/dense
logger.info(
f"Using row-oriented streaming QC ({storage_format} format, {file_size_gb:.2f}GB)"
)
return _qc_row_oriented(
path,
chunk_size=chunk_size,
cache_mode=cache_mode,
delta_threshold=delta_threshold,
**common_kwargs,
)