"""Profiling utilities for timing and memory measurement.
This module provides a unified `Profiler` class for measuring execution time
and memory usage of code sections. It supports:
- Timing measurement with start/stop labels
- Memory tracking via tracemalloc (Python objects) or RSS (process-level)
- Continuous background memory sampling
- Visualization utilities for timing and memory plots
Example usage::
from crispyx.profiling import Profiler
# Basic timing
profiler = Profiler(timing=True)
profiler.start("data_loading")
# ... load data ...
profiler.stop("data_loading")
print(profiler.get_report())
# Combined timing and memory with context manager
with Profiler(timing=True, memory=True) as p:
p.start("processing")
# ... process data ...
p.snapshot("after_processing")
p.stop("processing")
# Access results programmatically
stats = p.get_stats()
print(f"Total time: {stats['timing']['total_seconds']}s")
print(f"Peak memory: {stats['memory']['peak_mb']}MB")
"""
from __future__ import annotations
import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Literal
logger = logging.getLogger(__name__)
__all__ = ["Profiler", "TimingProfiler", "MemoryProfiler"]
[docs]
class Profiler:
"""Unified profiler for timing and memory measurement.
This class combines timing profiling, memory snapshots, and continuous
memory sampling into a single interface. It can be used as a context
manager or with explicit start/stop calls.
Parameters
----------
timing : bool, default=False
Enable timing measurement of code sections.
memory : bool, default=False
Enable memory tracking (snapshots at labeled points).
memory_method : {"tracemalloc", "rss"}, default="tracemalloc"
Method for memory measurement:
- "tracemalloc": Python object memory via tracemalloc (more detailed)
- "rss": Process resident set size via psutil (total process memory)
sampling : bool, default=False
Enable continuous background memory sampling. Runs a daemon thread
that records memory usage at `sample_interval` intervals.
sample_interval : float, default=0.1
Seconds between samples when `sampling=True`.
top_n : int, default=10
Number of top allocations to report (only for tracemalloc).
Examples
--------
>>> with Profiler(timing=True, memory=True) as p:
... p.start("section1")
... # ... code ...
... p.stop("section1")
>>> print(p.get_report())
"""
def __init__(
self,
timing: bool = False,
memory: bool = False,
memory_method: Literal["tracemalloc", "rss"] = "tracemalloc",
sampling: bool = False,
sample_interval: float = 0.1,
top_n: int = 10,
):
self.timing_enabled = timing
self.memory_enabled = memory
self.memory_method = memory_method
self.sampling_enabled = sampling
self.sample_interval = sample_interval
self.top_n = top_n
# Timing state
self._timings: dict[str, float] = {}
self._start_times: dict[str, float] = {}
self._total_start: float | None = None
# Memory state
self._snapshots: dict[str, dict] = {}
self._peak_memory_mb: float = 0.0
self._tracemalloc_start_time: float | None = None
# Sampling state
self._samples: list[tuple[float, float]] = [] # (timestamp, memory_mb)
self._sampling_thread: threading.Thread | None = None
self._stop_sampling_event: threading.Event | None = None
def __enter__(self):
"""Start profiling context."""
if self.memory_enabled and self.memory_method == "tracemalloc":
import tracemalloc
tracemalloc.start()
self._tracemalloc_start_time = time.perf_counter()
if self.sampling_enabled:
self.start_sampling()
if self.timing_enabled:
self._total_start = time.perf_counter()
self.start("total")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop profiling context."""
if self.timing_enabled and "total" in self._start_times:
self.stop("total")
if self.sampling_enabled:
self.stop_sampling()
if self.memory_enabled:
self.snapshot("end")
if self.memory_method == "tracemalloc":
import tracemalloc
tracemalloc.stop()
return False
# =========================================================================
# Timing methods
# =========================================================================
[docs]
def start(self, label: str) -> None:
"""Start timing a labeled section."""
if not self.timing_enabled:
return
if self._total_start is None:
self._total_start = time.perf_counter()
self._start_times[label] = time.perf_counter()
[docs]
def stop(self, label: str) -> float:
"""Stop timing a labeled section and return elapsed time."""
if not self.timing_enabled:
return 0.0
if label not in self._start_times:
logger.warning(f"Profiler.stop() called for unstarted label: {label}")
return 0.0
elapsed = time.perf_counter() - self._start_times[label]
if label in self._timings:
self._timings[label] += elapsed # Accumulate if called multiple times
else:
self._timings[label] = elapsed
del self._start_times[label]
return elapsed
[docs]
def get_total_time(self) -> float:
"""Get total elapsed time since first start() call."""
if not self.timing_enabled or self._total_start is None:
return 0.0
return time.perf_counter() - self._total_start
# =========================================================================
# Memory snapshot methods
# =========================================================================
[docs]
def snapshot(self, label: str) -> None:
"""Take a memory snapshot at the current point."""
if not self.memory_enabled:
return
timestamp = time.perf_counter() - (self._tracemalloc_start_time or self._total_start or time.perf_counter())
if self.memory_method == "tracemalloc":
import tracemalloc
snap = tracemalloc.take_snapshot()
current, peak = tracemalloc.get_traced_memory()
current_mb = current / 1024 / 1024
peak_mb = peak / 1024 / 1024
self._snapshots[label] = {
"snapshot": snap,
"timestamp_s": timestamp,
"current_mb": current_mb,
"peak_mb": peak_mb,
}
else: # rss
current_mb = self._get_rss_mb()
self._snapshots[label] = {
"snapshot": None,
"timestamp_s": timestamp,
"current_mb": current_mb,
"peak_mb": current_mb, # RSS doesn't track peak
}
peak_mb = current_mb
self._peak_memory_mb = max(self._peak_memory_mb, peak_mb)
logger.debug(
f"Memory snapshot '{label}': current={current_mb:.1f}MB, peak={peak_mb:.1f}MB"
)
def _get_rss_mb(self) -> float:
"""Get current process RSS in MB."""
try:
import psutil
return psutil.Process().memory_info().rss / 1024 / 1024
except ImportError:
# Fallback for systems without psutil
try:
import resource
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 # KB to MB on Linux
except ImportError:
logger.warning("Neither psutil nor resource available for RSS measurement")
return 0.0
[docs]
def reset_peak(self) -> None:
"""Reset peak memory tracking to current memory level.
This is useful when profiling a specific operation after loading data,
so that peak memory measures only the operation's memory usage,
not the memory used by previously loaded data.
"""
if not self.memory_enabled and not self.sampling_enabled:
return
if self.memory_method == "tracemalloc":
import tracemalloc
tracemalloc.reset_peak()
_, current_peak = tracemalloc.get_traced_memory()
self._peak_memory_mb = current_peak / 1024 / 1024
else: # rss
# For RSS, we can only reset our tracked peak (can't reset OS stats)
current_mb = self._get_rss_mb()
self._peak_memory_mb = current_mb
# Also clear previous samples if sampling is active
if self.sampling_enabled:
self._samples = []
# =========================================================================
# Continuous sampling methods
# =========================================================================
[docs]
def start_sampling(self) -> None:
"""Start background memory sampling thread."""
if not self.sampling_enabled:
return
if self._sampling_thread is not None:
return # Already running
self._stop_sampling_event = threading.Event()
self._samples = []
start_time = time.perf_counter()
def _sample_loop():
while not self._stop_sampling_event.is_set():
timestamp = time.perf_counter() - start_time
memory_mb = self._get_rss_mb()
self._samples.append((timestamp, memory_mb))
self._peak_memory_mb = max(self._peak_memory_mb, memory_mb)
self._stop_sampling_event.wait(self.sample_interval)
self._sampling_thread = threading.Thread(target=_sample_loop, daemon=True)
self._sampling_thread.start()
[docs]
def stop_sampling(self) -> None:
"""Stop background memory sampling thread."""
if self._stop_sampling_event is not None:
self._stop_sampling_event.set()
if self._sampling_thread is not None:
self._sampling_thread.join(timeout=1.0)
self._sampling_thread = None
# =========================================================================
# Results methods
# =========================================================================
[docs]
def get_stats(self) -> dict:
"""Get profiling statistics as a dict.
Returns
-------
dict
Dictionary with keys ``"timing"`` and ``"memory"``.
The ``"timing"`` entry contains ``"total_seconds"`` (float) and
``"sections"`` (a mapping of label to ``{"seconds": float,
"percent": float}``).
The ``"memory"`` entry contains ``"peak_mb"`` (float),
``"snapshots"`` (a mapping of label to snapshot data),
``"samples"`` (list of timestamp/memory pairs if sampling is
enabled), and ``"top_allocations"`` (if tracemalloc is used).
"""
stats = {}
# Timing stats
if self.timing_enabled:
total = self._timings.get("total", self.get_total_time())
timing_stats = {
"total_seconds": round(total, 3),
"sections": {},
}
for label, elapsed in self._timings.items():
pct = (elapsed / total * 100) if total > 0 else 0
timing_stats["sections"][label] = {
"seconds": round(elapsed, 3),
"percent": round(pct, 1),
}
stats["timing"] = timing_stats
# Memory stats
if self.memory_enabled or self.sampling_enabled:
memory_stats = {
"peak_mb": round(self._peak_memory_mb, 2),
"snapshots": {},
}
for label, snap_data in self._snapshots.items():
memory_stats["snapshots"][label] = {
"timestamp_s": round(snap_data["timestamp_s"], 3),
"current_mb": round(snap_data["current_mb"], 2),
}
if self.sampling_enabled and self._samples:
memory_stats["samples"] = [
(round(t, 3), round(m, 2)) for t, m in self._samples
]
# Top allocations from tracemalloc
if self.memory_method == "tracemalloc" and "end" in self._snapshots:
snap = self._snapshots["end"].get("snapshot")
if snap is not None:
top_stats = snap.statistics("lineno")[:self.top_n]
memory_stats["top_allocations"] = [
{
"file": str(stat.traceback),
"size_mb": round(stat.size / 1024 / 1024, 2),
"count": stat.count,
}
for stat in top_stats
]
stats["memory"] = memory_stats
return stats
[docs]
def get_report(self) -> str:
"""Generate a human-readable profiling report."""
if not self.timing_enabled and not self.memory_enabled:
return "Profiling was not enabled."
lines = ["=" * 60, "Profiling Report", "=" * 60]
# Timing section
if self.timing_enabled and self._timings:
total = self._timings.get("total", self.get_total_time())
lines.append(f"\nTotal time: {total:.2f}s\n")
lines.append("Section Breakdown:")
lines.append("-" * 50)
sorted_timings = sorted(
[(k, v) for k, v in self._timings.items() if k != "total"],
key=lambda x: -x[1]
)
for label, elapsed in sorted_timings:
pct = (elapsed / total * 100) if total > 0 else 0
bar_len = int(pct / 2)
bar = "█" * bar_len + "░" * (25 - bar_len)
lines.append(f" {label:30s} {elapsed:7.2f}s ({pct:5.1f}%) {bar}")
# Memory section
if self.memory_enabled and self._snapshots:
lines.append("\nMemory Snapshots:")
lines.append("-" * 50)
for label, snap_data in self._snapshots.items():
lines.append(
f" {label:25s} t={snap_data['timestamp_s']:7.2f}s "
f"current={snap_data['current_mb']:8.1f}MB"
)
lines.append(f"\nPeak memory: {self._peak_memory_mb:.1f}MB")
# Sampling summary
if self.sampling_enabled and self._samples:
lines.append(f"\nMemory sampling: {len(self._samples)} samples collected")
if self._samples:
min_mem = min(m for _, m in self._samples)
max_mem = max(m for _, m in self._samples)
lines.append(f" Range: {min_mem:.1f}MB - {max_mem:.1f}MB")
lines.append("=" * 60)
return "\n".join(lines)
# =========================================================================
# Visualization methods
# =========================================================================
[docs]
def plot_timeline(self, ax=None):
"""Plot timing breakdown as horizontal bar chart.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
try:
import matplotlib.pyplot as plt
except ImportError:
logger.warning("matplotlib not installed; cannot create plot")
return None
if not self.timing_enabled or not self._timings:
logger.warning("No timing data to plot")
return None
if ax is None:
fig, ax = plt.subplots(figsize=(10, 6))
# Sort by elapsed time
sorted_timings = sorted(
[(k, v) for k, v in self._timings.items() if k != "total"],
key=lambda x: x[1]
)
labels = [k for k, v in sorted_timings]
times = [v for k, v in sorted_timings]
colors = plt.cm.viridis([i / len(labels) for i in range(len(labels))])
ax.barh(labels, times, color=colors)
ax.set_xlabel("Time (seconds)")
ax.set_title("Timing Breakdown by Section")
# Add time labels on bars
for i, (label, t) in enumerate(sorted_timings):
ax.text(t + 0.1, i, f"{t:.2f}s", va="center", fontsize=9)
plt.tight_layout()
return ax
[docs]
def plot_memory(self, ax=None):
"""Plot memory usage over time (requires sampling mode).
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
"""
try:
import matplotlib.pyplot as plt
except ImportError:
logger.warning("matplotlib not installed; cannot create plot")
return None
if not self._samples:
logger.warning("No memory samples to plot. Enable sampling=True.")
return None
if ax is None:
fig, ax = plt.subplots(figsize=(10, 6))
timestamps = [t for t, m in self._samples]
memory = [m for t, m in self._samples]
ax.plot(timestamps, memory, "b-", linewidth=1.5)
ax.fill_between(timestamps, memory, alpha=0.3)
ax.set_xlabel("Time (seconds)")
ax.set_ylabel("Memory (MB)")
ax.set_title("Memory Usage Over Time")
# Mark peak
peak_idx = memory.index(max(memory))
ax.axhline(y=max(memory), color="r", linestyle="--", alpha=0.5)
ax.annotate(
f"Peak: {max(memory):.1f}MB",
xy=(timestamps[peak_idx], max(memory)),
xytext=(10, 10),
textcoords="offset points",
fontsize=9,
)
plt.tight_layout()
return ax
# =============================================================================
# Specialized profiler subclasses
# =============================================================================
[docs]
class TimingProfiler(Profiler):
"""Timing-only profiler.
Thin wrapper around ``Profiler`` with ``timing=True``.
"""
def __init__(self, enabled: bool = False):
super().__init__(timing=enabled)
self.enabled = enabled
self.timings = self._timings
[docs]
class MemoryProfiler(Profiler):
"""Memory-only profiler.
Thin wrapper around ``Profiler`` with ``memory=True``.
"""
def __init__(self, enabled: bool = False, top_n: int = 10):
super().__init__(memory=enabled, top_n=top_n)
self.enabled = enabled
self.snapshots = self._snapshots
# =============================================================================
# Standalone visualization utilities
# =============================================================================
def plot_benchmark_comparison(
profiler_results: list[dict],
labels: list[str],
metric: Literal["timing", "memory"] = "timing",
ax=None,
):
"""Compare profiling results from multiple runs side-by-side.
Useful for comparing before/after optimization, different parameters,
or different methods (crispyx vs PyDESeq2).
Parameters
----------
profiler_results : list[dict]
List of profiler stats dicts from `Profiler.get_stats()`.
labels : list[str]
Labels for each run (e.g., ["before", "after"]).
metric : {"timing", "memory"}
Which metric to compare.
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
Returns
-------
matplotlib.axes.Axes
The axes with the plot.
Examples
--------
>>> stats_before = profiler_before.get_stats()
>>> stats_after = profiler_after.get_stats()
>>> plot_benchmark_comparison([stats_before, stats_after], ["Before", "After"])
"""
try:
import matplotlib.pyplot as plt
import numpy as np
except ImportError:
logger.warning("matplotlib not installed; cannot create plot")
return None
if not profiler_results:
logger.warning("No profiler results to compare")
return None
if ax is None:
fig, ax = plt.subplots(figsize=(12, 6))
if metric == "timing":
# Collect all section names
all_sections = set()
for result in profiler_results:
if "timing" in result:
all_sections.update(result["timing"].get("sections", {}).keys())
all_sections.discard("total")
sections = sorted(all_sections)
if not sections:
logger.warning("No timing sections to compare")
return ax
# Build data matrix
n_runs = len(profiler_results)
n_sections = len(sections)
x = np.arange(n_sections)
width = 0.8 / n_runs
for i, (result, label) in enumerate(zip(profiler_results, labels)):
timing = result.get("timing", {}).get("sections", {})
values = [timing.get(s, {}).get("seconds", 0) for s in sections]
offset = (i - n_runs / 2 + 0.5) * width
ax.bar(x + offset, values, width, label=label)
ax.set_ylabel("Time (seconds)")
ax.set_title("Timing Comparison")
ax.set_xticks(x)
ax.set_xticklabels(sections, rotation=45, ha="right")
ax.legend()
else: # memory
# Compare peak memory
peaks = []
for result in profiler_results:
peak = result.get("memory", {}).get("peak_mb", 0)
peaks.append(peak)
x = np.arange(len(labels))
ax.bar(x, peaks, color=plt.cm.viridis([0.3, 0.7][:len(labels)]))
ax.set_ylabel("Peak Memory (MB)")
ax.set_title("Memory Comparison")
ax.set_xticks(x)
ax.set_xticklabels(labels)
# Add value labels
for i, peak in enumerate(peaks):
ax.text(i, peak + 5, f"{peak:.1f}MB", ha="center", fontsize=10)
plt.tight_layout()
return ax