"""Datashader backend for fast C2 heatmap visualization (CPU-only in v2.3.0+).
This module provides high-performance CPU-optimized heatmap rendering using Datashader,
offering 5-10x speedup over matplotlib for C2 correlation data visualization.
Key features:
- CPU-optimized rendering for HPC environments
- Fast PNG generation for large datasets
- Backward compatible with matplotlib output format
- Parallel processing support for multi-angle plots
"""
from pathlib import Path
import datashader as ds
import datashader.transfer_functions as tf
import numpy as np
import xarray as xr
from PIL import Image
from homodyne.utils.logging import get_logger
logger = get_logger(__name__)
[docs]
class DatashaderRenderer:
"""Fast heatmap rendering using Datashader (CPU-only in v2.3.0+).
This renderer uses Datashader's optimized CPU rasterization pipeline to convert
2D gridded data (C2 correlation matrices) into RGB images much faster than
matplotlib's imshow() + savefig() workflow.
Performance:
- Typical speedup: 5-10x over matplotlib
- Recommended for grids >100x100 or multiple plots
- CPU-optimized for HPC environments
Examples
--------
>>> renderer = DatashaderRenderer(width=800, height=800)
>>> img = renderer.rasterize_heatmap(c2_data, t2_coords, t1_coords)
>>> img.save('output.png') # Direct PIL save, very fast
"""
[docs]
def __init__(
self,
width: int = 800,
height: int = 800,
):
"""Initialize Datashader renderer (CPU-only in v2.3.0+).
Parameters
----------
width : int, default=800
Output image width in pixels
height : int, default=800
Output image height in pixels
"""
self.width = width
self.height = height
[docs]
def rasterize_heatmap(
self,
data: np.ndarray,
x_coords: np.ndarray,
y_coords: np.ndarray,
cmap: str = "jet",
vmin: float | None = None,
vmax: float | None = None,
) -> Image.Image:
"""Rasterize 2D gridded data to PIL Image using Datashader.
This is 5-10x faster than matplotlib.imshow() + savefig() for typical
C2 correlation data (50x50 to 200x200 grids).
The workflow:
1. Convert numpy array to xarray DataArray (Datashader's native format)
2. Create canvas at target resolution (e.g., 800x800 pixels)
3. Rasterize data to canvas (regrid/resample)
4. Apply colormap and convert to RGB PIL Image
Parameters
----------
data : np.ndarray
2D array to rasterize, shape (n_y, n_x)
For C2 data: pass c2.T to swap axes for correct display
x_coords : np.ndarray
X-axis (horizontal) coordinates, shape (n_x,)
For C2 data: pass t1 time array
y_coords : np.ndarray
Y-axis (vertical) coordinates, shape (n_y,)
For C2 data: pass t2 time array
cmap : str, default='jet'
Colormap name. Supported:
- 'jet' (default for all plots including residuals)
- 'viridis', 'plasma', 'inferno', 'magma' (matplotlib perceptually uniform)
- 'coolwarm', 'RdBu_r' (diverging, alternative for residuals)
- Any matplotlib or colorcet colormap name
vmin, vmax : float, optional
Color scale limits. If None, auto-computed from data min/max.
Returns
-------
Image
PIL Image object in RGB format, ready for saving or display
Raises
------
ValueError
If data dimensions don't match coordinate arrays
Examples
--------
>>> renderer = DatashaderRenderer(width=800, height=800)
>>> c2_data = np.random.rand(50, 50)
>>> t1 = np.linspace(0, 1, 50)
>>> t2 = np.linspace(0, 1, 50)
>>> img = renderer.rasterize_heatmap(c2_data, t2, t1, cmap='jet')
>>> img.save('c2_heatmap.png')
"""
# Validate input dimensions
if data.shape[0] != len(y_coords):
raise ValueError(
f"Data y-dimension ({data.shape[0]}) doesn't match "
f"y_coords length ({len(y_coords)})"
)
if data.shape[1] != len(x_coords):
raise ValueError(
f"Data x-dimension ({data.shape[1]}) doesn't match "
f"x_coords length ({len(x_coords)})"
)
# Convert to xarray (Datashader's native format)
xr_data = xr.DataArray(
data,
coords={"y": y_coords, "x": x_coords},
dims=["y", "x"],
name="intensity",
)
# Filter NaN values from coordinates for Canvas range computation
x_finite = x_coords[np.isfinite(x_coords)]
y_finite = y_coords[np.isfinite(y_coords)]
if x_finite.size == 0 or y_finite.size == 0:
raise ValueError("Cannot rasterize: all coordinate values are NaN")
# Create canvas at target resolution
canvas = ds.Canvas(
plot_width=self.width,
plot_height=self.height,
x_range=(float(x_finite.min()), float(x_finite.max())),
y_range=(float(y_finite.min()), float(y_finite.max())),
)
# Rasterize (CPU-optimized for HPC)
# For gridded data, canvas.raster() resamples to canvas resolution
agg = canvas.raster(xr_data)
# Get colormap
cmap_obj = self._get_colormap(cmap)
# Compute span for color normalization
if vmin is None or vmax is None:
span = (float(data.min()), float(data.max()))
else:
span = (float(vmin), float(vmax))
# Apply colormap and shade (fast!)
# Returns xarray Image with RGB channels
img = tf.shade(agg, cmap=cmap_obj, how="linear", span=span)
# Convert to PIL Image for easy saving/display
pil_img = img.to_pil()
# Convert RGBA to RGB (reduce file size by ~25%)
if pil_img.mode == "RGBA":
# Create white background
rgb_img = Image.new("RGB", pil_img.size, (255, 255, 255))
rgb_img.paste(pil_img, mask=pil_img.split()[3]) # Use alpha as mask
return rgb_img
return pil_img
def _get_colormap(self, cmap: str):
"""Get Datashader-compatible colormap.
Datashader accepts:
1. Lists of hex colors
2. Colorcet colormap objects
3. Matplotlib colormaps (via conversion)
We convert matplotlib colormap names to color lists for compatibility.
"""
import matplotlib
import matplotlib.colors as mcolors
# Get matplotlib colormap
try:
mpl_cmap = matplotlib.colormaps.get_cmap(cmap)
except ValueError:
# Fallback to jet if colormap not found
logger.warning(f"Colormap '{cmap}' not found, using 'jet'")
mpl_cmap = matplotlib.colormaps.get_cmap("jet")
# Convert to list of hex colors (Datashader format)
# Sample 256 colors from the colormap
colors = [mpl_cmap(i) for i in np.linspace(0, 1, 256)]
hex_colors = [mcolors.rgb2hex(c[:3]) for c in colors]
return hex_colors
[docs]
def plot_c2_heatmap_fast(
c2_data: np.ndarray,
t1: np.ndarray,
t2: np.ndarray,
output_path: Path,
title: str = "",
phi_angle: float | None = None,
cmap: str = "jet",
width: int = 800,
height: int = 800,
*,
vmin: float | None = None,
vmax: float | None = None,
adaptive: bool = False,
percentile_min: float = 1.0,
percentile_max: float = 99.0,
) -> None:
"""Plot C2 heatmap using Datashader for fast CPU rendering.
This function uses Datashader for rasterization (5-10x faster than matplotlib)
and matplotlib for annotations (colorbars, titles, labels).
Workflow:
1. Rasterize C2 data to RGB image with Datashader (fast CPU rendering)
2. Display RGB image in matplotlib figure (minimal overhead)
3. Add colorbar using original data values (not RGB)
4. Add title, labels, save PNG
Performance:
- Matplotlib only: ~150ms per plot
- Datashader CPU hybrid: ~30ms per plot (5x speedup)
Parameters
----------
c2_data : np.ndarray
C2 correlation data, shape (n_t1, n_t2)
t1, t2 : np.ndarray
Time arrays, shapes (n_t1,) and (n_t2,)
output_path : Path
Output PNG file path
title : str, default=""
Plot title (phi_angle will be appended if provided)
phi_angle : float, optional
Scattering angle in degrees (added to title)
cmap : str, default='jet'
Colormap name
width, height : int, default=800
Output image size in pixels (rasterization resolution)
Examples
--------
>>> c2_data = np.random.rand(50, 50)
>>> t1 = np.linspace(0, 1, 50)
>>> t2 = np.linspace(0, 1, 50)
>>> plot_c2_heatmap_fast(
... c2_data, t1, t2,
... Path('c2_phi0.png'),
... phi_angle=0.0
... )
"""
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
# Create Datashader renderer
renderer = DatashaderRenderer(width=width, height=height)
auto_vmin = vmin
auto_vmax = vmax
if adaptive and c2_data.size > 0:
if vmin is None:
auto_vmin = float(np.nanpercentile(c2_data, percentile_min))
if vmax is None:
auto_vmax = float(np.nanpercentile(c2_data, percentile_max))
vmin_use = auto_vmin if auto_vmin is not None else 1.0
vmax_use = auto_vmax if auto_vmax is not None else 1.5
# Transpose to match matplotlib convention: c2[t1_idx, t2_idx] → c2.T for correct axes
# After transpose: dim 0=t2, dim 1=t1, matching x=t1 (horizontal), y=t2 (vertical)
# Rasterize with Datashader (FAST!) using the requested color scale
img_pil = renderer.rasterize_heatmap(
c2_data.T,
t1,
t2,
cmap=cmap,
vmin=vmin_use,
vmax=vmax_use,
)
# Convert PIL to numpy array for matplotlib display
img_array = np.array(img_pil)
# CRITICAL: Flip vertically to match origin='lower'
# Datashader produces images with y=0 at top (image convention)
# matplotlib origin='lower' expects y=0 at bottom (math convention)
img_array = np.flipud(img_array)
# Use matplotlib for layout and annotations (minimal overhead)
fig, ax = plt.subplots(figsize=(8, 7), dpi=100)
# Display pre-rasterized RGB image
extent = [t1[0], t1[-1], t2[0], t2[-1]]
ax.imshow(img_array, extent=extent, origin="lower", aspect="equal")
# Add labels and title
ax.set_xlabel("t₁ (s)", fontsize=11)
ax.set_ylabel("t₂ (s)", fontsize=11)
if phi_angle is not None:
title = f"{title} at φ={phi_angle:.1f}°" if title else f"φ={phi_angle:.1f}°"
ax.set_title(title, fontsize=13, fontweight="bold")
# Add colorbar using the resolved color scale
# Create ScalarMappable with same colormap and data range
norm = Normalize(vmin=vmin_use, vmax=vmax_use)
sm = ScalarMappable(cmap=matplotlib.colormaps.get_cmap(cmap), norm=norm)
sm.set_array([]) # Required for colorbar
cbar = plt.colorbar(sm, ax=ax, label="g₂(t₁,t₂)", shrink=0.9)
cbar.ax.tick_params(labelsize=9)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
logger.debug(f"Saved Datashader plot: {output_path}")
[docs]
def plot_c2_comparison_fast(
c2_exp: np.ndarray,
c2_fit: np.ndarray,
residuals: np.ndarray,
t1: np.ndarray,
t2: np.ndarray,
output_path: Path,
phi_angle: float,
width: int = 800,
height: int = 800,
*,
vmin: float | None = None,
vmax: float | None = None,
adaptive: bool = True,
percentile_min: float = 1.0,
percentile_max: float = 99.0,
) -> None:
"""Generate 3-panel comparison plot using Datashader CPU rendering.
Replaces generate_nlsq_plots() with faster Datashader rendering.
Creates side-by-side comparison: Experimental | Fitted | Residuals
Performance:
- Matplotlib only: ~300ms per 3-panel plot
- Datashader CPU hybrid: ~60ms per 3-panel plot (5x speedup)
Parameters
----------
c2_exp : np.ndarray
Experimental correlation data, shape (n_t1, n_t2)
c2_fit : np.ndarray
Fitted theoretical data, shape (n_t1, n_t2)
residuals : np.ndarray
Residuals = exp - fit, shape (n_t1, n_t2)
t1, t2 : np.ndarray
Time arrays, shapes (n_t1,) and (n_t2,)
output_path : Path
Output PNG file path
phi_angle : float
Scattering angle in degrees
width, height : int, default=800
Individual panel size in pixels
vmin, vmax : float, optional
Explicit color scale limits. If None, computed adaptively.
adaptive : bool, default=True
Use adaptive (percentile-based) color scaling from combined
experimental AND fit data ranges. This prevents block artifacts
when fit data has a narrower range than experimental data.
percentile_min, percentile_max : float, default=1.0, 99.0
Percentiles for adaptive color scale computation.
Examples
--------
>>> plot_c2_comparison_fast(
... c2_exp, c2_fit, residuals,
... t1, t2,
... Path('comparison_phi_45.png'),
... phi_angle=45.0
... )
"""
import matplotlib.pyplot as plt
# Create renderer
renderer = DatashaderRenderer(width=width, height=height)
# Transpose to match matplotlib convention: c2[t1_idx, t2_idx] → c2.T for correct axes
# After transpose: dim 0=t2, dim 1=t1, matching x=t1 (horizontal), y=t2 (vertical)
vmin_shared = vmin
vmax_shared = vmax
if adaptive and c2_exp.size > 0 and c2_fit.size > 0:
# Compute combined range from BOTH experimental AND fit data
# This ensures both panels have proper color representation
# and avoids block artifacts from narrow fit ranges
if vmin_shared is None:
vmin_exp = float(np.nanpercentile(c2_exp, percentile_min))
vmin_fit = float(np.nanpercentile(c2_fit, percentile_min))
vmin_shared = min(vmin_exp, vmin_fit)
if vmax_shared is None:
vmax_exp = float(np.nanpercentile(c2_exp, percentile_max))
vmax_fit = float(np.nanpercentile(c2_fit, percentile_max))
vmax_shared = max(vmax_exp, vmax_fit)
# Fallback only if adaptive scaling couldn't compute values (empty data)
vmin_shared = 1.0 if vmin_shared is None else vmin_shared
vmax_shared = 1.5 if vmax_shared is None else vmax_shared
# Rasterize all three panels using Datashader for speed
img_exp = renderer.rasterize_heatmap(
c2_exp.T, t1, t2, cmap="jet", vmin=vmin_shared, vmax=vmax_shared
)
img_fit = renderer.rasterize_heatmap(
c2_fit.T, t1, t2, cmap="jet", vmin=vmin_shared, vmax=vmax_shared
)
# Residuals colormap using actual min/max (guard against all-NaN or zero-span)
res_min = float(np.nanmin(residuals)) if np.any(np.isfinite(residuals)) else 0.0
res_max = float(np.nanmax(residuals)) if np.any(np.isfinite(residuals)) else 1.0
if res_min == res_max:
res_max = res_min + 1e-10
img_res = renderer.rasterize_heatmap(
residuals.T,
t1,
t2,
cmap="jet",
vmin=res_min,
vmax=res_max,
)
# Convert PIL images to numpy arrays
img_exp_array = np.array(img_exp)
img_fit_array = np.array(img_fit)
img_res_array = np.array(img_res)
# CRITICAL: Flip vertically to match origin='lower'
# Datashader produces images with y=0 at top (image convention)
# matplotlib origin='lower' expects y=0 at bottom (math convention)
img_exp_array = np.flipud(img_exp_array)
img_fit_array = np.flipud(img_fit_array)
img_res_array = np.flipud(img_res_array)
# Create 3-panel matplotlib layout
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
extent = [t1[0], t1[-1], t2[0], t2[-1]]
# Panel 1: Experimental
axes[0].imshow(img_exp_array, extent=extent, origin="lower", aspect="equal")
axes[0].set_title(f"Experimental C₂ (φ={phi_angle:.1f}°)", fontsize=12)
axes[0].set_xlabel("t₁ (s)", fontsize=10)
axes[0].set_ylabel("t₂ (s)", fontsize=10)
# Add colorbar for experimental panel
import matplotlib
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
# Use SHARED normalization for both experimental and fit colorbars
norm_shared = Normalize(vmin=vmin_shared, vmax=vmax_shared)
sm_exp = ScalarMappable(cmap=matplotlib.colormaps.get_cmap("jet"), norm=norm_shared)
sm_exp.set_array([])
cbar0 = plt.colorbar(sm_exp, ax=axes[0], label="C₂(t₁,t₂)")
cbar0.ax.tick_params(labelsize=8)
# Panel 2: Fitted
axes[1].imshow(img_fit_array, extent=extent, origin="lower", aspect="equal")
axes[1].set_title(f"Classical Fit (φ={phi_angle:.1f}°)", fontsize=12)
axes[1].set_xlabel("t₁ (s)", fontsize=10)
axes[1].set_ylabel("t₂ (s)", fontsize=10)
# Add colorbar for fit panel (same normalization as experimental)
sm_fit = ScalarMappable(cmap=matplotlib.colormaps.get_cmap("jet"), norm=norm_shared)
sm_fit.set_array([])
cbar1 = plt.colorbar(sm_fit, ax=axes[1], label="C₂(t₁,t₂)")
cbar1.ax.tick_params(labelsize=8)
# Panel 3: Residuals
axes[2].imshow(img_res_array, extent=extent, origin="lower", aspect="equal")
axes[2].set_title(f"Residuals (φ={phi_angle:.1f}°)", fontsize=12)
axes[2].set_xlabel("t₁ (s)", fontsize=10)
axes[2].set_ylabel("t₂ (s)", fontsize=10)
# Add colorbar for residuals panel using actual min/max
norm_res = Normalize(vmin=res_min, vmax=res_max)
sm_res = ScalarMappable(cmap=matplotlib.colormaps.get_cmap("jet"), norm=norm_res)
sm_res.set_array([])
cbar2 = plt.colorbar(sm_res, ax=axes[2], label="ΔC₂")
cbar2.ax.tick_params(labelsize=8)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close(fig)
logger.debug(f"Saved Datashader 3-panel plot: {output_path}")