"""Execution engine for financial calculations with DAG resolution."""
from __future__ import annotations
import logging
from decimal import Decimal
from typing import Any
from .exceptions import CalculationError, CircularDependencyError, MissingInputError
from .policy import DEFAULT_POLICY, Policy
from .policy_context import get_policy, use_policy
from .registry import deps, get, is_registered
from .utils import SupportsDecimal, to_decimal
from .value import FinancialValue
logger = logging.getLogger(__name__)
[docs]
class Engine:
"""
Execution engine for financial calculations.
Builds dependency graphs, detects circular dependencies, caches results,
and executes calculations in the correct order.
"""
[docs]
def __init__(self, default_policy: Policy | None = None):
"""
Initialize the engine with an optional default policy.
Args:
default_policy: Default policy for calculations. Uses DEFAULT_POLICY if None.
"""
self.default_policy: Policy = default_policy or DEFAULT_POLICY
self.metric_policy: dict[str, Policy] = {} # optional per-metric override
# Ensure calculations are registered on engine creation
try:
from .calculations import load_all
load_all()
except Exception as e:
# Don't silently ignore exceptions during development
import warnings
warnings.warn(f"Failed to load calculations: {e}", stacklevel=2)
# Re-raise in debug mode for development
if __debug__:
raise
def _choose_policy(self, name: str, override: Policy | None) -> Policy:
"""
Choose a non-None policy with a single rule:
explicit override > metric override > ambient > engine default > DEFAULT_POLICY
"""
ambient = get_policy()
return (
override
or self.metric_policy.get(name)
or ambient
or self.default_policy
or DEFAULT_POLICY
)
[docs]
def calculate(
self,
name: str,
ctx: dict[str, SupportsDecimal] | None = None,
*,
policy: Policy | None = None,
allow_partial: bool = False,
**kwargs: SupportsDecimal,
) -> FinancialValue:
"""
Calculate a target metric given a context of input values.
The engine follows a "let calculations validate" philosophy:
- None values propagate naturally through calculations
- No need for defensive checks before calling calculate
- Each calculation determines what inputs are valid
- FinancialValue results can be passed directly to other calculations
Args:
name: Name of the calculation to compute
ctx: Dictionary of input values (optional if using kwargs)
policy: Optional policy to override default
allow_partial: If True, return None on failure instead of raising
**kwargs: Input values as keyword arguments (can include None)
Returns:
FinancialValue containing the result (may wrap None)
Raises:
MissingInputError: If required non-None inputs are missing
CircularDependencyError: If circular dependencies are detected
CalculationError: If calculation fails
"""
# Merge ctx and kwargs, with kwargs taking precedence
ctx = {} if ctx is None else dict(ctx)
ctx.update(kwargs)
effective_policy = self._choose_policy(name, policy)
try:
with use_policy(effective_policy):
result = self._run_calc(name, ctx, allow_partial=allow_partial)
if not isinstance(result, FinancialValue):
result = FinancialValue(
None if result is None else to_decimal(result), effective_policy
)
return result
except MissingInputError as exc:
if allow_partial:
logger.error(f"Calculation '{name}' failed: {exc}")
return FinancialValue(None, effective_policy)
raise
except CircularDependencyError as exc:
if allow_partial:
logger.error(f"Calculation '{name}' failed: {exc}")
return FinancialValue(None, effective_policy)
raise
except Exception as exc:
if allow_partial:
logger.error(f"Calculation '{name}' failed: {exc}")
return FinancialValue(None, effective_policy)
raise CalculationError(f"Error in calculation '{name}': {exc}") from exc
def _run_calc(self, name: str, ctx: dict, *, allow_partial: bool = False):
"""
Internal method to run a single calculation.
This method handles the actual calculation execution and can be overridden
by subclasses to customize calculation behavior.
"""
# Delegate to calculate_many for consistency
results = self.calculate_many({name}, ctx, allow_partial=allow_partial)
result = results.get(name)
# Add calculation-specific provenance if result is a FinancialValue
if isinstance(result, FinancialValue) and result is not None:
result = self._add_calculation_provenance(name, result, ctx)
return result
def _add_calculation_provenance(
self, calc_name: str, result: FinancialValue, ctx: dict
) -> FinancialValue:
"""Add calculation-specific provenance to a result.
Args:
calc_name: Name of the calculation
result: The calculated FinancialValue result
ctx: Context dictionary with input names and values
Returns:
FinancialValue with calculation provenance
"""
try:
from .provenance import Provenance, hash_node
from .provenance_config import (
log_provenance_error,
should_fail_on_error,
should_track_calculations,
)
# Check if calculation tracking is enabled
if not should_track_calculations():
return result
# Extract parent FinancialValues from context with error handling
parents = []
input_names = {}
for key, value in ctx.items():
try:
if isinstance(value, FinancialValue):
parents.append(value)
if hasattr(value, "_prov") and value._prov:
input_names[value._prov.id] = str(key)
else:
# Create a temporary FinancialValue for non-FV inputs (including None) to get provenance
try:
temp_fv = FinancialValue(value, policy=result.policy)
parents.append(temp_fv)
if hasattr(temp_fv, "_prov") and temp_fv._prov:
input_names[temp_fv._prov.id] = str(key)
except Exception as temp_error:
log_provenance_error(
temp_error,
"_add_calculation_provenance_temp_fv",
calculation=calc_name,
input_key=key,
)
# Continue without this input
except Exception as input_error:
log_provenance_error(
input_error,
"_add_calculation_provenance_input",
calculation=calc_name,
input_key=key,
)
# Continue with other inputs
# Create metadata with input names and calculation context
try:
meta = {"calculation": str(calc_name), "input_names": input_names}
except Exception as meta_error:
log_provenance_error(
meta_error,
"_add_calculation_provenance_meta",
calculation=calc_name,
)
meta = {"calculation": str(calc_name)}
# Generate provenance ID for this calculation with error handling
try:
op = f"calc:{calc_name}"
prov_id = hash_node(op, tuple(parents), result.policy, meta)
except Exception as hash_error:
log_provenance_error(
hash_error,
"_add_calculation_provenance_hash",
calculation=calc_name,
)
if should_fail_on_error():
raise
return result # Graceful degradation
# Create new provenance record with error handling
try:
parent_ids = []
for parent in parents:
try:
if hasattr(parent, "_prov") and parent._prov:
parent_ids.append(parent._prov.id)
except Exception as parent_error:
log_provenance_error(
parent_error,
"_add_calculation_provenance_parent_id",
calculation=calc_name,
)
# Continue with other parents
prov = Provenance(
id=prov_id, op=op, inputs=tuple(parent_ids), meta=meta
)
# Return new FinancialValue with calculation provenance
return FinancialValue(
result._value,
policy=result.policy,
unit=result.unit,
_is_percentage=result._is_percentage,
_prov=prov,
)
except Exception as prov_error:
log_provenance_error(
prov_error,
"_add_calculation_provenance_create",
calculation=calc_name,
)
if should_fail_on_error():
raise
return result # Graceful degradation
except ImportError:
# Provenance module not available - graceful degradation
return result
except Exception as e:
# Log unexpected errors
try:
from .provenance_config import (
log_provenance_error,
should_fail_on_error,
)
log_provenance_error(
e, "_add_calculation_provenance", calculation=calc_name
)
if should_fail_on_error():
raise
except ImportError:
pass
# Graceful degradation - return original result if provenance fails
return result
[docs]
def constant(self, value: int | float | Decimal | None) -> FinancialValue:
"""
Create a constant FinancialValue.
Args:
value: The constant value to wrap
"""
if value is None:
return self.none()
# use active policy for constants so they respect ambient/use_policy
pol = get_policy() or self.default_policy or DEFAULT_POLICY
return FinancialValue(to_decimal(value), pol)
[docs]
def zero(self) -> FinancialValue:
"""
Create a zero FinancialValue.
"""
pol = get_policy() or self.default_policy or DEFAULT_POLICY
return FinancialValue(to_decimal(0), pol)
[docs]
def none(self) -> FinancialValue:
"""
Create a None FinancialValue.
"""
pol = get_policy() or self.default_policy or DEFAULT_POLICY
return FinancialValue(None, pol)
[docs]
def calculate_many(
self,
targets: set[str],
ctx: dict[str, Any] | None = None,
*,
policy: Policy | None = None,
allow_partial: bool = False,
**kwargs: Any,
) -> dict[str, FinancialValue]:
"""
Resolve all targets in one pass with shared dependency resolution.
Parameters
----------
targets : Set of metric names you want
ctx : Inputs you already have (optional if using kwargs)
policy : Optional Policy override
allow_partial : If True, return what can be computed and
leave missing ones out instead of raising.
**kwargs : Input values as keyword arguments
Returns
-------
Dictionary mapping metric name to FinancialValue
Raises
------
MissingInputError: If any targets cannot be computed (unless allow_partial=True)
CircularDependencyError: If circular dependencies are detected
CalculationError: If any calculation fails
Examples:
# Using context dict
>>> results = engine.calculate_many(
... {"gross_profit", "gross_margin_percentage"},
... {"sales": 1000, "cost": 650}
... )
# Using keyword arguments
>>> results = engine.calculate_many(
... {"gross_profit", "gross_margin_percentage"},
... sales=1000, cost=650
... )
"""
# Merge ctx and kwargs
ctx = {} if ctx is None else dict(ctx)
ctx.update(kwargs)
# IMPORTANT: respect whichever policy is already active in context
# (e.g., set by calculate()) when no explicit policy is given.
batch_policy = policy or get_policy() or self.default_policy or DEFAULT_POLICY
cache: dict[str, Any] = {} # Can hold Decimal or lists
# track invalid provided inputs during resolve()
invalid_inputs: set[str] = set()
def resolve(name: str, stack: tuple[str, ...] = ()) -> bool:
"""
Recursively resolve a calculation and its dependencies.
Args:
name: Name of calculation to resolve
stack: Current resolution stack for cycle detection
Returns:
True if successfully resolved, False otherwise
"""
# Already resolved
if name in cache:
return True
# Check for circular dependency
if name in stack:
raise CircularDependencyError(stack + (name,))
# Base case: value provided in context
if name in ctx:
value = ctx[name]
# Pass through sequences (calc decides what to do)
if isinstance(value, (list, tuple)):
cache[name] = value
return True
try:
if isinstance(value, FinancialValue):
cache[name] = value
else:
cache[name] = FinancialValue(to_decimal(value), batch_policy)
return True
except (ValueError, TypeError, CalculationError) as exc:
# Convert input conversion errors to CalculationError
raise CalculationError(
f"Invalid input type for '{name}': {exc}"
) from exc
# Check if calculation is registered
if not is_registered(name):
return False
# Resolve all dependencies first
calculation_deps = deps(name)
all_resolved = True
for dep in calculation_deps:
if not resolve(dep, stack + (name,)):
all_resolved = False
# If any dependency failed, we can't compute this
if not all_resolved:
return False
# Execute the calculation
try:
# before calling calc_func in resolve(...)
calc_func = get(name)
dep_values = {d: cache[d] for d in calculation_deps}
# choose policy per metric
pol_for_this = self._choose_policy(name, override=policy)
with use_policy(pol_for_this):
result = calc_func(**dep_values)
# store result; keep its own policy if it returns FV, else wrap with pol_for_this
if isinstance(result, FinancialValue):
cache[name] = result
else:
cache[name] = FinancialValue(to_decimal(result), pol_for_this)
return True
except Exception as exc:
if allow_partial:
logger.warning(f"Calculation '{name}' failed: {exc}")
return False
else:
raise CalculationError(
f"Error in calculation '{name}': {exc}"
) from exc
# Kick off resolution for each requested target and collect failed ones
failed_targets = set()
with use_policy(batch_policy):
for target in targets:
try:
if not resolve(target):
failed_targets.add(target)
except CircularDependencyError:
raise # Re-raise circular dependency errors immediately
except CalculationError:
raise # Re-raise calculation errors immediately
# If any targets failed and partial results not allowed, analyze what's missing
if failed_targets and not allow_partial:
# Find the missing base inputs by analyzing what couldn't be resolved
def find_missing(name: str, visited: set[str] | None = None) -> set[str]:
if visited is None:
visited = set()
if name in visited:
return set()
visited.add(name)
# If in cache, it was resolved successfully
if name in cache:
return set()
if name in ctx and name in invalid_inputs:
# Explicitly mark as invalid (not missing)
return set()
# If in context but failed to convert, it's a bad input
if name in ctx:
return {name}
# If not registered, it's a missing input
if not is_registered(name):
return {name}
# For registered calculations, check dependencies
missing_deps = set()
for dep in deps(name):
missing_deps.update(find_missing(dep, visited))
return missing_deps
# build 'missing' as before
all_missing = set()
for target in failed_targets:
all_missing.update(find_missing(target))
details = []
if all_missing:
details.append("missing: " + ", ".join(sorted(all_missing)))
if invalid_inputs:
details.append("invalid: " + ", ".join(sorted(invalid_inputs)))
if len(targets) == 1:
only = next(iter(targets))
raise MissingInputError(
f"Cannot compute '{only}' due to "
+ ("; ".join(details) or "unspecified failure"),
sorted(all_missing),
)
# For multiple targets, include the failed target names
failed_target_names = ", ".join(sorted(failed_targets))
raise MissingInputError(
f"Cannot compute targets [{failed_target_names}]. Details → "
+ "; ".join(details),
sorted(all_missing),
)
# result collection
result = {}
for key in targets:
if key in cache and key not in failed_targets:
cached_result = cache[key]
# Add calculation provenance if this is a registered calculation
if is_registered(key) and isinstance(cached_result, FinancialValue):
result[key] = self._add_calculation_provenance(
key, cached_result, ctx
)
else:
result[key] = cached_result # should be FinancialValue already
return result
[docs]
def set_metric_policy(self, name: str, policy: Policy) -> None:
self.metric_policy[name] = policy
[docs]
def clear_metric_policy(self, name: str) -> None:
self.metric_policy.pop(name, None)
[docs]
def get_dependencies(self, target: str) -> set[str]:
"""
Get all dependencies (direct and transitive) for a calculation.
Args:
target: Name of the calculation
Returns:
Set of all dependency names
Raises:
CircularDependencyError: If circular dependencies detected
"""
if not is_registered(target):
raise CalculationError(f"Calculation '{target}' is not registered")
all_deps: set[str] = set()
visited: set[str] = set()
def collect_deps(name: str, stack: tuple[str, ...] = ()) -> None:
if name in stack:
cycle = stack + (name,)
raise CircularDependencyError(cycle)
if name in visited:
return
visited.add(name)
if is_registered(name):
for dep in deps(name):
all_deps.add(dep)
collect_deps(dep, stack + (name,))
collect_deps(target)
return all_deps
[docs]
def validate_dependencies(self, target: str) -> tuple[set[str], set[str]]:
"""
Validate dependencies for a calculation.
Args:
target: Name of the calculation to validate
Returns:
Tuple of (registered_deps, unregistered_deps)
Raises:
CircularDependencyError: If circular dependencies detected
"""
all_deps = self.get_dependencies(target)
registered = {dep for dep in all_deps if is_registered(dep)}
unregistered = all_deps - registered
return registered, unregistered
[docs]
def get_all_calculations(self) -> dict[str, dict[str, Any]]:
"""
Get information about all registered calculations.
Returns:
Dict mapping calculation names to their metadata including:
- function: The calculation function
- depends_on: Set of dependencies
- docstring: The function's docstring
"""
from .registry import _dependencies, _registry
result = {}
for name, calc_func in _registry.items():
result[name] = {
"function": calc_func,
"depends_on": _dependencies[name].copy(),
"docstring": calc_func.__doc__ or "",
}
return result