Source code for homodyne.optimization.batch_statistics
"""Batch-level statistics tracking for streaming optimization.
This module provides a circular buffer for tracking batch-level optimization
statistics, success rates, and error distributions during streaming optimization.
"""
from collections import defaultdict, deque
from typing import Any
# Type alias for batch record
BatchRecord = dict[str, Any]
[docs]
class BatchStatistics:
"""Circular buffer for tracking batch-level statistics.
Maintains statistics for the most recent N batches (default 100) to
provide running averages and trends without unbounded memory growth.
Attributes
----------
buffer : deque
Circular buffer storing batch records (max_size most recent)
total_batches : int
Total number of batches processed (all time)
total_successes : int
Total number of successful batches (all time)
total_failures : int
Total number of failed batches (all time)
error_counts : dict
Count of each error type encountered (all time)
Examples
--------
>>> stats = BatchStatistics(max_size=100)
>>> stats.record_batch(
... batch_idx=0,
... success=True,
... loss=0.123,
... iterations=50,
... recovery_actions=[]
... )
>>> stats.get_success_rate()
1.0
"""
[docs]
def __init__(self, max_size: int = 100):
"""Initialize batch statistics tracker.
Parameters
----------
max_size : int, optional
Maximum number of batches to keep in circular buffer, by default 100
"""
self.buffer: deque[BatchRecord] = deque(maxlen=max_size)
self.total_batches = 0
self.total_successes = 0
self.total_failures = 0
self.error_counts: defaultdict[str, int] = defaultdict(int)
[docs]
def record_batch(
self,
batch_idx: int,
success: bool,
loss: float,
iterations: int,
recovery_actions: list[str],
error_type: str | None = None,
) -> None:
"""Record statistics for a single batch.
Parameters
----------
batch_idx : int
Batch index (0-indexed)
success : bool
Whether batch optimization succeeded
loss : float
Final loss value for this batch
iterations : int
Number of iterations performed
recovery_actions : list of str
List of recovery actions applied (if any)
error_type : str, optional
Type of error encountered (if failed), by default None
"""
batch_record = {
"batch_idx": batch_idx,
"success": success,
"loss": loss,
"iterations": iterations,
"recovery_actions": recovery_actions,
"error_type": error_type,
}
self.buffer.append(batch_record)
self.total_batches += 1
if success:
self.total_successes += 1
else:
self.total_failures += 1
if error_type:
self.error_counts[error_type] += 1
[docs]
def get_success_rate(self) -> float:
"""Calculate success rate from recent batches in buffer.
Returns
-------
float
Success rate (0.0 to 1.0) from recent batches. Returns 1.0 when
no batches have been recorded yet (optimistic prior) so that quality
gates do not falsely reject the first batch. Callers that need to
distinguish "no data yet" should check BatchStatistics.total_batches.
"""
if not self.buffer:
return 1.0
successes = sum(1 for batch in self.buffer if batch["success"])
return successes / len(self.buffer)
[docs]
def get_average_loss(self) -> float:
"""Calculate average loss from recent successful batches.
Returns
-------
float
Average loss from successful batches in buffer
"""
successful_batches = [b for b in self.buffer if b["success"]]
if not successful_batches:
return float("inf")
total_loss: float = sum(float(b["loss"]) for b in successful_batches)
return total_loss / len(successful_batches)
[docs]
def get_average_iterations(self) -> float:
"""Calculate average iterations from recent batches.
Returns
-------
float
Average number of iterations per batch
"""
if not self.buffer:
return 0.0
total_iterations: int = sum(int(b["iterations"]) for b in self.buffer)
return float(total_iterations) / len(self.buffer)
[docs]
def get_statistics(self) -> dict[str, Any]:
"""Return comprehensive statistics dictionary.
Returns
-------
dict
Dictionary containing:
- total_batches: Total batches processed (all time)
- total_successes: Total successful batches (all time)
- total_failures: Total failed batches (all time)
- success_rate: Success rate from recent batches
- average_loss: Average loss from recent successful batches
- average_iterations: Average iterations per batch
- error_distribution: Dictionary of error type counts
- recent_batches: List of recent batch records
"""
return {
"total_batches": self.total_batches,
"total_successes": self.total_successes,
"total_failures": self.total_failures,
"success_rate": self.get_success_rate(),
"average_loss": self.get_average_loss(),
"average_iterations": self.get_average_iterations(),
"error_distribution": dict(self.error_counts),
"recent_batches": list(self.buffer),
}
[docs]
def __repr__(self) -> str:
"""Return string representation of statistics."""
return (
f"BatchStatistics(total={self.total_batches}, "
f"successes={self.total_successes}, "
f"failures={self.total_failures}, "
f"success_rate={self.get_success_rate():.2%})"
)