Source code for pyPLUTO.imagefuncs.create_axes

"""Module to create axes in the image class."""

import copy
import warnings
from itertools import islice
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes

from pyPLUTO.imagemixin import ImageMixin
from pyPLUTO.imagestate import ImageState
from pyPLUTO.utils.inspector import track_kwargs

defaults: dict[str, Any] = {
    "left": 0.125,
    "right": 0.9,
    "top": 0.9,
    "bottom": 0.1,
    "hspace": [],
    "wspace": [],
    "hratio": [1.0],
    "wratio": [1.0],
}


class CreateAxesManager(ImageMixin):
    """Class to manage the creation of axes in the image.

    This class provides methods to create axes in the image class. It allows
    for customization of the axes' position, spacing, projection, and other
    properties.
    """

    exposed_methods = ("create_axes",)

    def __init__(self, state: ImageState) -> None:
        """Initialize the CreateAxesManager class."""
        self.state: ImageState = state

[docs] @track_kwargs(extra_keys=set(defaults.keys())) def create_axes( self, ncol: int = 1, nrow: int = 1, check: bool = True, **kwargs: Any ) -> Axes | list[Axes]: """Creation of a set of axes using add_subplot from matplotlib. If additional parameters (like the figure limits or the spacing) are given, the plots are located using set_position. The spacing and the ratio between the plots can be given by hand. In case only few custom options are given, the code computes the rest (but gives a small warning); in case no custom option is given, the axes are located through the standard methods of matplotlib. If more axes are created in the figure, the list of all axes is returned, otherwise the single axis is returned. Returns ------- - The list of axes (if more axes are in the figure) or the axis (if only one axis is present) Parameters ---------- - bottom: float, default 0.1 The space from the bottom border to the last row of plots. - figsize: [float, float], default [6*sqrt(ncol),5*sqrt(nrow)] Sets the figure size. The default value is computed from the number of rows and columns. - fontsize: float, default 17.0 Sets the fontsize for all the axes. - hratio: [float], default [1.0] Ratio between the rows of the plot. The default is that every plot row has the same height. - hspace: [float], default [] The space between plot rows (in figure units). If not enough or too many spaces are considered, the program will remove the excess and fill the lacks with [0.1]. - left: float, default 0.125 The space from the left border to the leftmost column of plots. - ncol: int, default 1 The number of columns of subplots. - nrow: int, default 1 The number of rows of subplots. - proj: str, default None Custom projection for the plot (e.g. 3D). Recommended only if needed. WARNING: pyPLUTO does not support 3D plotting for now, only 3D axes. The 3D plot feature will be available in future releases. - right: float, default 0.9 The space from the right border to the rightmost column of plots. - sharex: bool | str | Matplotlib axis, default False Enables/disables the sharing of the x-axis between the subplots. - sharey: bool | str | Matplotlib axis, default False Enables/disables the sharing of the y-axis between the subplots. - suptitle: str, default None Creates a figure title over all the subplots. - tight: bool, default True Enables/disables tight layout options for the figure. In case of a highly customized plot (e.g. ratios or space between rows and columns) the option is set by default to False since that option would not be available for standard matplotlib functions. - top: float, default 0.9 The space from the top border to the first row of plots. - wratio: [float], default [1.0] Ratio between the columns of the plot. The default is that every plot column has the same width. - wspace: [float], default [] The space between plot columns (in figure units). If not enough or too many spaces are considered, the program will remove the excess and fill the lacks with [0.1]. ---- Examples -------- - Example #1: create a simple grid of 2 columns and 2 rows on a new figure >>> import pyPLUTO as pp >>> I = pp.Image() >>> ax = I.create_axes(ncol=2, nrow=2) - Example #2: create a grid of 2 columns with the first one having half the width of the second one >>> import pyPLUTO as pp >>> I = pp.Image() >>> ax = I.create_axes(ncol=2, wratio=[0.5, 1]) - Example #3: create a grid of 2 rows with a lot of blank space between them >>> import pyPLUTO as pp >>> I = pp.Image() >>> ax = I.create_axes(nrow=2, hspace=[0.5]) - Example #4: create a 2x2 grid with a fifth image on the right side >>> import pyPLUTO as pp >>> I = pp.Image() >>> ax = I.create_axes(ncol=2, nrow=2, right=0.7) >>> ax = I.create_axes(left=0.75) """ kwargs.pop("check", check) # Change fontsize if requested if "fontsize" in kwargs: plt.rcParams.update({"font.size": kwargs["fontsize"]}) custom_plot = bool(defaults.keys() & kwargs.keys()) if custom_plot: kwargs["tight"] = False filtered_kwargs = { key: kwargs.get(key, value) for key, value in defaults.items() } wplot, hplot = self._set_custom_axes(filtered_kwargs, nrow, ncol) else: wplot, hplot = None, None # Set figure size if self.fig is None: raise ValueError( "You need to create a figure before creating axes." ) if figsize := kwargs.get("figsize"): self.fig.set_size_inches(*figsize) self.figsize = figsize elif not (custom_plot or self.set_size): self.fig.set_size_inches(6 * np.sqrt(ncol), 5 * np.sqrt(nrow)) # Set the projection if requested proj = kwargs.get("proj") # Set sharex and sharey sharex: bool | str | Axes | int | None = kwargs.get("sharex") sharey: bool | str | Axes | int | None = kwargs.get("sharey") for i in range(ncol * nrow): sharex_ref = self._check_shareaxis(i, sharex) # Interpret True as: share with the first axis # if sharex is True: # sharex_ref = self.ax[0] if i > 0 else None # elif isinstance(sharex, int): # sharex_ref = self.ax[sharex] # else: # sharex_ref = sharex # None or an Axes reference # Same for sharey sharey_ref = self._check_shareaxis(i, sharey) # if sharey is True: # sharey_ref = self.ax[0] if i > 0 else None # elif isinstance(sharey, int): # sharey_ref = self.ax[sharey] # else: # sharey_ref = sharey self.add_ax( axis := self.fig.add_subplot( nrow + self.nrow0, ncol + self.ncol0, i + 1, projection=proj, sharex=sharex_ref, sharey=sharey_ref, ), len(self.ax), ) self.ax.append(axis) # Compute row and column row = int(i / ncol) col = int(i % ncol) # Set position if custom axes if wplot is not None and hplot is not None: self.ax[-1].set_position( pos=( wplot[col][0], hplot[row][0], wplot[col][1], hplot[row][1], ) ) # Updates rows and columns self.nrow0 = self.nrow0 + nrow self.ncol0 = self.ncol0 + ncol # Set figure title if requested if "suptitle" in kwargs: self.fig.suptitle(kwargs["suptitle"]) # Tight layout (depending on the subplot creation) self.tight = kwargs.get("tight", self.tight) self.fig.set_layout_engine(None if not self.tight else "tight") ret_ax = self.ax[0] if len(self.ax) == 1 else self.ax # if not isinstance(ret_ax, list | Axes): # raise TypeError("The returned axis is neither a list nor an Axes.") return ret_ax
def _set_custom_axes( self, custom: dict[str, Any], nrow: int, ncol: int ) -> tuple[list[list[float]], list[list[float]]]: """Set the axes position and spacing according to the parameters. Returns ------- - wplot: list[list[float]] List of the left and right position of the axes. - hplot: list[list[float]] List of the top and bottom position of the axes. Parameters ---------- - custom: dict[str, Any] Dictionary with the custom parameters for the axes. - nrow: int Number of rows in the axes. - ncol: int Number of columns in the axes. """ hspace, hratio = self._check_rowcol( custom["hratio"], custom["hspace"], nrow, "rows" ) wspace, wratio = self._check_rowcol( custom["wratio"], custom["wspace"], ncol, "cols" ) hsize = custom["top"] - custom["bottom"] - sum(hspace) wsize = custom["right"] - custom["left"] - sum(wspace) htot, wtot = sum(hratio), sum(wratio) ll, tt = custom["left"], custom["top"] hplot, wplot = [], [] # Computes left, right of every ax for i in islice(range(ncol), ncol - 1): rr = wsize * wratio[i] / wtot wplot.append([ll, rr]) ll += rr + wspace[i] # Computes top, bottom of every ax for i in islice(range(nrow), nrow - 1): bb = tt - hsize * hratio[i] / htot hplot.append([bb, tt - bb]) tt = bb - hspace[i] # Append the last items without extra space rr = wsize * wratio[ncol - 1] / wtot wplot.append([ll, rr]) bb = tt - hsize * hratio[nrow - 1] / htot hplot.append([bb, tt - bb]) return wplot, hplot def _check_rowcol( self, ratio: list[float], space: float | list[float | int], length: int, func: str, ) -> tuple[list[float | int], list[float | int]]: """Check the width and spacing of the plots on a single row or column. Returns ------- - space: list[float] the space between the rows or columns - ratio: list[float] the ratio of the rows or columns Parameters ---------- - ratio: list[float] the ratio of the rows or columns - space: list[float] the space between the rows or columns - length: int the number of rows or columns in the single row or column - func: str the function to check (rows or cols) ---- Examples -------- - Example #1: ratio and space are given correctly (rows) >>> _check_rowcol([1, 2, 3], [0.1, 0.2], 3, "rows") - Example #2: ratio and space are given incorrectly (rows) (warning) >>> _check_rowcol([], 0.1, 3, "rows") - Example #3: ratio and space are given correctly (cols) >>> _check_rowcol([1, 2, 3], [0.1, 0.2], 3, "cols") """ rat = {"rows": "hratio", "cols": "wratio"} spc = {"rows": "hspace", "cols": "wspace"} # Check if space is a list newspace = space if isinstance(space, list) else [space] space = space if isinstance(space, list) else newspace * (length - 1) # Fill the lists with the default values ratio = ratio + [1.0] * (length - len(ratio)) space = space + [0.1] * (length - len(space) - 1) print(ratio, length) # Check if the lists have the correct length if len(ratio) != length: warn = f"WARNING! {rat[func]} has wrong length!" warnings.warn(warn, UserWarning, stacklevel=2) if len(space) + 1 != length: warn = f"WARNING! {spc[func]} has wrong length!" warnings.warn(warn, UserWarning, stacklevel=2) # End of the function. Return the lists return space[: length - 1], ratio[:length] def add_ax(self, ax: Axes, i: int) -> None: """Add the axhes properties to the class info variables. The corresponding axis is appended to the list of axes. Returns ------- - None Parameters ---------- - ax (not optional): ax The axis to be added. - i (not optional): int The index of the axis in the list. ---- Examples -------- - Example #1: Add the axis to the class info variables >>> _add_ax(ax, i) """ ax_pars = { # "ax": ax, "legpos": None, "legpar": [self.fontsize, 1, 2, 0.8, 0.8], "nline": 0, "ntext": None, "setax": 0, "setay": 0, "shade": "auto", "tickspar": 0, "xscale": "linear", "yscale": "linear", "vlims": [], } # Append the axis to the list of axes for attr, default in ax_pars.items(): getattr(self, attr).append(copy.copy(default)) # Position the axis index in the middle of the axis ax.annotate(str(i), (0.47, 0.47), xycoords="axes fraction") def _check_shareaxis( self, i: int, share: bool | str | Axes | int | None ) -> Axes | str | None: """Check the sharing of the x or y axis. Returns ------- - share_ref: Axes | str | None The reference axis for sharing. Parameters ---------- - i: int The index of the current axis. - share: bool | str | Axes | int | None The sharing option. ---- Examples -------- - Example #1: share is True >>> _check_shareaxis(0, True) - Example #2: share is False >>> _check_shareaxis(0, False) - Example #3: share is a string >>> _check_shareaxis(0, "left") - Example #4: share is an axis >>> _check_shareaxis(0, ax) """ if share is True: share_ref = self.ax[0] if i > 0 else None elif isinstance(share, int): share_ref = self.ax[share] else: share_ref = share return share_ref # End of the function