Source code for pyPLUTO.toolfuncs.set_units

"""Attach and detach astropy units on loaded variables."""

from __future__ import annotations

from collections.abc import Iterable

from pyPLUTO.baseloadstate import BaseLoadState
from pyPLUTO.loadmixin import BaseLoadMixin

__all__ = ["SetUnitsManager"]


_TIME_VARS: tuple[str, ...] = ("ntime", "timelist")


class SetUnitsManager(BaseLoadMixin):
    """Attach and detach astropy units on variables stored in state."""

    def __init__(self, state: BaseLoadState) -> None:
        """Initialize the unit-attachment manager with the given load state."""
        self.state = state

    def _resolve_unit_vars(
        self,
        var: str | Iterable[str] | bool | None = None,
        skip_units: str | Iterable[str] | None = None,
    ) -> tuple[list[str], bool]:
        """Resolve variable selection for unit attach/detach operations.

        Parameters
        ----------
        - var: str | Iterable[str] | bool | None, default None
            The variable(s) to select for unit operations.
        - skip_units: str | Iterable[str] | None, default None
            The unit(s) to skip during operations.

        Returns
        -------
        - tuple[list[str], bool]
            A tuple containing the selected variable names and a boolean
            indicating if the selection was explicit.
        """

        def _to_list(value: str | Iterable[str] | bool | None) -> list[str]:
            """Coerce a variable selector to a flat list of name strings."""
            if value is None or value is True:
                return []
            if isinstance(value, str):
                return [value]
            if isinstance(value, Iterable):
                return [str(v) for v in value]
            raise TypeError(
                "var must be None, True, a string, or an iterable of strings.",
            )

        explicit = not (var is None or var is True)
        selected = (
            _to_list(var)
            if explicit
            else [
                name
                for name in (*self.state.d_vars, *_TIME_VARS)
                if name in self.state.units and hasattr(self.state, name)
            ]
        )

        if skip_units is None:
            excluded: set[str] = set()
        elif isinstance(skip_units, str):
            excluded = {skip_units}
        elif isinstance(skip_units, Iterable):
            excluded = {str(v) for v in skip_units}
        else:
            raise TypeError(
                "skip_units must be None, a string, or an iterable of strings.",
            )

        ordered = [name for name in selected if name not in excluded]
        return ordered, explicit

[docs] def to_astropy_units( self, var: str | Iterable[str] | bool | None = None, skip_units: str | Iterable[str] | None = None, ) -> None: """Attach astropy units to selected variables in place. Parameters ---------- - var: str | Iterable[str] | bool | None, default None The variable(s) to select for unit operations. - skip_units: str | Iterable[str] | None, default None The unit(s) to skip during operations. Returns ------- - None """ selected, explicit = self._resolve_unit_vars( var=var, skip_units=skip_units, ) for name in selected: if not hasattr(self.state, name): if explicit: raise KeyError(f"No known unit for variable '{name}'") continue if name not in self.state.units: if explicit: raise KeyError(f"No known unit for variable '{name}'") continue if name in self.state.unit_attached: continue setattr( self.state, name, getattr(self.state, name) * self.state.units[name], ) self.state.unit_attached.add(name)
[docs] def to_code_units( self, var: str | Iterable[str] | bool | None = None, skip_units: str | Iterable[str] | None = None, ) -> None: """Convert selected astropy Quantity variables back to code units. Parameters ---------- - var: str | Iterable[str] | bool | None, default None The variable(s) to select for unit operations. - skip_units: str | Iterable[str] | None, default None The unit(s) to skip during operations. Returns ------- - None """ selected, explicit = self._resolve_unit_vars( var=var, skip_units=skip_units, ) if not explicit: selected = [ name for name in selected if name in self.state.unit_attached ] for name in selected: if name not in self.state.units: if explicit: raise KeyError(f"No known unit for variable '{name}'") continue if not hasattr(self.state, name): continue arr = getattr(self.state, name) if hasattr(arr, "unit") and hasattr(arr, "value"): arr_code = (arr / self.state.units[name]).decompose().value setattr(self.state, name, arr_code) self.state.unit_attached.discard(name)