"""Factories for explicit and implicit algorithm step implementations."""
from typing import Any, Mapping, Optional, Tuple, Type
from .base_algorithm_step import BaseAlgorithmStep, BaseStepConfig, ButcherTableau
from .ode_explicitstep import ExplicitStepConfig, ODEExplicitStep
from .ode_implicitstep import ImplicitStepConfig, ODEImplicitStep
from .backwards_euler import BackwardsEulerStep
from .backwards_euler_predict_correct import BackwardsEulerPCStep
from .crank_nicolson import CrankNicolsonStep
from .explicit_euler import ExplicitEulerStep
from .generic_dirk import (
DIRKStep,
)
from .generic_dirk_tableaus import DIRK_TABLEAU_REGISTRY, DIRKTableau
from .generic_firk import (
FIRKStep,
)
from .generic_firk_tableaus import FIRK_TABLEAU_REGISTRY, FIRKTableau
from .generic_erk import (
ERKStep,
ERKTableau,
)
from .generic_erk_tableaus import ERK_TABLEAU_REGISTRY
from .generic_rosenbrock_w import (
GenericRosenbrockWStep,
)
from .generic_rosenbrockw_tableaus import ROSENBROCK_TABLEAUS, RosenbrockTableau
__all__ = [
"get_algorithm_step",
"ExplicitStepConfig",
"ImplicitStepConfig",
"ExplicitEulerStep",
"BackwardsEulerStep",
"BackwardsEulerPCStep",
"CrankNicolsonStep",
"DIRKStep",
"FIRKStep",
"ERKStep",
"GenericRosenbrockWStep",
"_ALGORITHM_REGISTRY",
"DIRKTableau",
"DIRK_TABLEAU_REGISTRY",
"FIRKTableau",
"FIRK_TABLEAU_REGISTRY",
"ERKTableau",
"ERK_TABLEAU_REGISTRY",
"RosenbrockTableau",
"ROSENBROCK_TABLEAUS",
]
_ALGORITHM_REGISTRY = {
"euler": ExplicitEulerStep,
"backwards_euler": BackwardsEulerStep,
"backwards_euler_pc": BackwardsEulerPCStep,
"crank_nicolson": CrankNicolsonStep,
"dirk": DIRKStep,
"firk": FIRKStep,
"erk": ERKStep,
"rosenbrock": GenericRosenbrockWStep,
}
_TABLEAU_REGISTRY_BY_ALGORITHM = {
key: (constructor, None)
for key, constructor in _ALGORITHM_REGISTRY.items()
}
for alias, tableau in ERK_TABLEAU_REGISTRY.items():
_TABLEAU_REGISTRY_BY_ALGORITHM[alias] = (ERKStep, tableau)
for alias, tableau in DIRK_TABLEAU_REGISTRY.items():
_TABLEAU_REGISTRY_BY_ALGORITHM[alias] = (DIRKStep, tableau)
for alias, tableau in FIRK_TABLEAU_REGISTRY.items():
_TABLEAU_REGISTRY_BY_ALGORITHM[alias] = (FIRKStep, tableau)
for alias, tableau in ROSENBROCK_TABLEAUS.items():
_TABLEAU_REGISTRY_BY_ALGORITHM[alias] = (
GenericRosenbrockWStep,
tableau,
)
def resolve_alias(alias: str) -> Tuple[Type[BaseAlgorithmStep], Optional[ButcherTableau]]:
"""Return the step constructor and tableau associated with ``alias``."""
key = alias.lower()
if key not in _TABLEAU_REGISTRY_BY_ALGORITHM:
raise KeyError(alias)
return _TABLEAU_REGISTRY_BY_ALGORITHM[key]
def resolve_supplied_tableau(
tableau: ButcherTableau,
) -> Tuple[Type[BaseAlgorithmStep], ButcherTableau]:
"""Return the step constructor matching ``tableau``."""
if isinstance(tableau, ERKTableau):
return ERKStep, tableau
if isinstance(tableau, DIRKTableau):
return DIRKStep, tableau
if isinstance(tableau, FIRKTableau):
return FIRKStep, tableau
if isinstance(tableau, RosenbrockTableau):
return GenericRosenbrockWStep, tableau
raise TypeError(
"Received tableau of type "
f"{type(tableau).__name__} which does not match known algorithms."
)
[docs]
def get_algorithm_step(
precision: type,
settings: Optional[Mapping[str, Any]] = None,
warn_on_unused: bool = False,
**kwargs: Any,
) -> BaseAlgorithmStep:
"""Thin factory which filters arguments and instantiates an algorithm.
Parameters
----------
precision
Floating-point dtype used when compiling the step implementation.
settings
Mapping of settings applied to the algorithm. Must include
``"algorithm"`` and can contain any keywords from
``ALL_ALGORITHM_STEP_PARAMETERS``.
warn_on_unused
If ``True``, issue a warning for settings that the selected algorithm
does not accept.
**kwargs
Additional keywords from ``ALL_ALGORITHM_STEP_PARAMETERS``. These
override entries provided in ``settings``.
Returns
-------
BaseAlgorithmStep
The requested step instance.
Raises
------
ValueError
Raised when settings['algorithm'] does not match a known algorithm
type or when required configuration keys are missing.
"""
algorithm_settings = {}
if settings is not None:
algorithm_settings.update(settings)
algorithm_settings.update(kwargs)
algorithm_value = algorithm_settings.pop("algorithm", None)
if algorithm_value is None:
raise ValueError("Algorithm settings must include 'algorithm'.")
if isinstance(algorithm_value, str):
try:
algorithm_type, resolved_tableau = resolve_alias(algorithm_value)
except KeyError as exc:
raise ValueError(f"Unknown algorithm '{algorithm_value}'.") from exc
elif isinstance(algorithm_value, ButcherTableau):
algorithm_type, resolved_tableau = resolve_supplied_tableau(
algorithm_value
)
else:
raise TypeError(
"Expected algorithm name or ButcherTableau instance, "
f"received {type(algorithm_value).__name__}."
)
algorithm_settings["precision"] = precision
if resolved_tableau is not None:
algorithm_settings["tableau"] = resolved_tableau
# Pass all settings to algorithm __init__ which uses build_config internally
# build_config filters to valid config fields and handles defaults
return algorithm_type(**algorithm_settings)