Source code for ewoksxrpd.tasks.diagnostics

import os
import warnings
from contextlib import contextmanager
from numbers import Number

import matplotlib
import numpy
from numpy.typing import ArrayLike
from packaging.version import Version

try:
    import matplotlib.pyplot as plt
    from matplotlib.colors import LogNorm
    from matplotlib.colors import SymLogNorm
except ModuleNotFoundError:
    plt = SymLogNorm = LogNorm = None

import pyFAI

from .. import pyfai_api
from ..pyfai_api import AzimuthalIntegrator
from .calibrate import calculate_geometry
from .data_access import TaskWithDataAccess
from .utils import data_utils
from .utils import pyfai_utils
from .utils import xrpd_utils

__all__ = [
    "DiagnoseCalibrateSingleResults",
    "DiagnoseCalibrateMultiResults",
    "DiagnoseIntegrate1D",
]


class _DiagnoseTask(
    TaskWithDataAccess,
    optional_input_names=["show", "pause", "filename", "figsize", "dpi", "fontsize"],
    output_names=["saved"],
    register=False,
):
    def prepare(self):
        matplotlib.rc("font", size=self.get_input_value("fontsize", 22))

    def show(self):
        if self.inputs.show:
            if self.inputs.pause and numpy.isfinite(self.inputs.pause):
                plt.pause(self.inputs.pause)
            else:
                plt.show()
        if self.inputs.filename:
            path = os.path.dirname(self.inputs.filename)
            if path:
                os.makedirs(path, exist_ok=True)
            plt.gcf().savefig(
                self.inputs.filename, dpi=self.get_input_value("dpi", 150)
            )
            self.outputs.saved = True
        else:
            self.outputs.saved = False
        plt.close()


class _DiagnoseCalibrateResults(
    _DiagnoseTask,
    input_names=["detector", "calibrant"],
    optional_input_names=["detector_config", "scaling"],
    register=False,
):
    def plot_calibration(
        self,
        ax1,
        ax2,
        image: ArrayLike,
        control_pts: dict,
        geometry: dict,
        energy: Number,
        title=None,
    ):
        scaling = self.get_input_value("scaling", "symlog").lower()
        if scaling == "symlog":
            colornorm = SymLogNorm(
                1,
                base=10,
                vmin=numpy.nanmin(image),
                vmax=numpy.nanmax(image),
            )
        elif scaling == "log":
            colornorm = LogNorm(
                vmin=numpy.nanmin(image),
                vmax=numpy.nanmax(image),
            )

        if title:
            title = f"{title}: "
        else:
            title = ""

        self.prepare()
        ax1.imshow(image, origin="lower", cmap="inferno", norm=colornorm)
        ax1.set_title(f"{title}Rings")
        if control_pts:
            ax2.imshow(image, origin="lower", cmap="inferno", norm=colornorm)
            ax2.set_title(f"{title}Control points")

        # Calibrant rings on 2D pattern
        detector = data_utils.data_from_storage(self.inputs.detector)
        detector_config = data_utils.data_from_storage(
            self.get_input_value("detector_config", None)
        )
        detector_object = pyFAI.detectors.detector_factory(
            detector, config=detector_config
        )
        calibrant = data_utils.data_from_storage(self.inputs.calibrant)
        calibrant_object = pyFAI.calibrant.get_calibrant(calibrant)
        wavelength = xrpd_utils.energy_wavelength(energy)
        ai = AzimuthalIntegrator(
            detector=detector_object, **geometry, wavelength=wavelength
        )
        if pyfai_api.USE_PROPERTY_API:
            calibrant_object.wavelength = wavelength
        else:
            calibrant_object.set_wavelength(wavelength)
        tth = calibrant_object.get_2th()
        if pyfai_api.PYFAI_VERSION >= Version("2025.12.0"):
            ttha = ai.center_array(unit="2th_rad")
        else:
            ttha = ai.twoThetaArray()
        ax1.contour(
            ttha, levels=tth, cmap="autumn", linewidths=1
        )  # linestyles="dashed"

        # Detected points on 2D pattern
        for label, points in control_pts.items():
            ax2.scatter(points["p1"], points["p0"], label=label, marker=".")

    def show(self):
        # Diagnose
        # plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
        plt.tight_layout()
        super().show()


[docs] class DiagnoseCalibrateSingleResults( _DiagnoseCalibrateResults, input_names=["image", "geometry", "energy"], optional_input_names=["rings"], ): """Quality of a single-distance pyFAI calibration"""
[docs] def run(self): if self.inputs.filename: if os.path.exists(self.inputs.filename): self.outputs.saved = True return if plt is None: raise RuntimeError("'matplotlib' is not installed") self.prepare() rings = self.inputs.rings figsize = self.get_input_value("figsize", (16, 8)) if rings: with _ignore_mpl_thread_warning(): _, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=figsize) else: rings = dict() with _ignore_mpl_thread_warning(): _, ax1 = plt.subplots(nrows=1, ncols=1, figsize=figsize) ax2 = None self.plot_calibration( ax1, ax2, self.get_image(self.inputs.image), data_utils.data_from_storage(rings), data_utils.data_from_storage(self.inputs.geometry), self.inputs.energy, ) self.show()
[docs] class DiagnoseCalibrateMultiResults( _DiagnoseCalibrateResults, input_names=[ "images", "positions", "parametrization", "parameters", ], optional_input_names=["show", "pause", "rings"], ): """Quality of a multi-distance pyFAI calibration"""
[docs] def run(self): if self.inputs.filename: if os.path.exists(self.inputs.filename): self.outputs.saved = True return if plt is None: raise RuntimeError("'matplotlib' is not installed") nimages = len(self.inputs.images) rings = self.inputs.rings figsize = self.get_input_value("figsize", (16, 8)) figsize = (figsize[0], figsize[1] * nimages) if rings: rings = {int(k): v for k, v in rings.items()} with _ignore_mpl_thread_warning(): _, axes = plt.subplots(nrows=nimages, ncols=2, figsize=figsize) else: rings = {i: dict() for i in range(nimages)} with _ignore_mpl_thread_warning(): _, axes = plt.subplots( nrows=nimages, ncols=1, figsize=(10, 10 * nimages) ) axes = [(ax1, None) for ax1 in axes] for image, position, ringsi, (ax1, ax2) in zip( self.inputs.images, self.inputs.positions, sorted(rings), axes ): image = self.get_image(image) position = self.get_data(position) title = f"position={position}" parametrization = data_utils.data_from_storage(self.inputs.parametrization) geometry, energy = calculate_geometry( parametrization, self.inputs.parameters, position ) control_pts = data_utils.data_from_storage(rings[ringsi]) self.plot_calibration( ax1, ax2, image, control_pts, geometry, energy, title=title ) self.show()
[docs] class DiagnoseIntegrate1D( _DiagnoseTask, input_names=[ "x", "y", "xunits", ], optional_input_names=["calibrant", "energy", "yerror", "scaling"], ): """Quality of a pyFAI integration"""
[docs] def run(self): if self.inputs.filename: if os.path.exists(self.inputs.filename): self.outputs.saved = True return if plt is None: raise RuntimeError("'matplotlib' is not installed") with _ignore_mpl_thread_warning(): plt.figure(figsize=(16, 8)) plt.title("Diffractogram") scaling = self.get_input_value("scaling", "squareroot").lower() if self.missing_inputs.yerror: plt.plot(self.inputs.x, self.inputs.y) else: yerr = numpy.minimum(3 * self.inputs.y, self.inputs.yerror) plt.errorbar(self.inputs.x, self.inputs.y, yerr=yerr) plt.yscale(scaling) xaxis = pyfai_utils.parse_string_units(self.inputs.xunits) plt.xlabel(f"{xaxis.name} ({xaxis.units})") plt.ylabel("Intensity") if self.inputs.calibrant: assert self.inputs.energy, "'energy' task parameter is missing" self.plot_calibrant_lines() self.show()
[docs] def plot_calibrant_lines(self): calibrant = data_utils.data_from_storage(self.inputs.calibrant) calibrant = pyFAI.calibrant.get_calibrant(calibrant) wavelength = xrpd_utils.energy_wavelength(self.inputs.energy) if pyfai_api.USE_PROPERTY_API: calibrant.wavelength = wavelength else: calibrant.set_wavelength(wavelength) xvalues = calibrant.get_peaks(unit=self.inputs.xunits) mask = (xvalues >= min(self.inputs.x)) & (xvalues <= max(self.inputs.x)) xvalues = xvalues[mask] if xvalues.size: yvalues = numpy.interp(xvalues, self.inputs.x, self.inputs.y) labels = xrpd_utils.calibrant_ring_labels(calibrant) labels = numpy.array(labels[: mask.size])[mask] ymin = min(self.inputs.y) for label, x, ymax in zip(labels, xvalues, yvalues): plt.plot([x, x], [ymin, ymax])
# plt.text(x, ymax, label) @contextmanager def _ignore_mpl_thread_warning(): with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="Starting a Matplotlib GUI outside of the main thread will likely fail.", category=UserWarning, ) yield