diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..31040a30 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,23 @@ +#FROM python:3.12-slim-bullseye +FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime + +RUN apt-get update && \ + apt-get install -y libqt5gui5 && \ + rm -rf /var/lib/apt/lists/* +#ENV QT_DEBUG_PLUGINS=1 + +# Upgrade pip +RUN python3 -m pip install --upgrade pip + +# Set the working directory in the container +WORKDIR /app + +# Copy the wheel file into the container at /app +COPY dist/ptychodus-*.whl dist/ptychi-*.whl . + +# Install the wheel +RUN python3 -m pip install --no-cache-dir --find-links=. ptychodus[globus,gui] ptychi && \ + rm ptychodus-*.whl ptychi-*.whl + +# Run ptychodus when the container launches +CMD ["python3", "-m", "ptychodus"] diff --git a/apptainer/ptychodus.def b/apptainer/ptychodus.def deleted file mode 100644 index 6f4bb435..00000000 --- a/apptainer/ptychodus.def +++ /dev/null @@ -1,26 +0,0 @@ -Bootstrap: docker -From: registry.fedoraproject.org/fedora-minimal:40-{{ target_arch }} - -%arguments -target_arch=x86_64 -cuda_version=12.0 -pkg_version=master - -%post -curl -L -o conda-installer.sh https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-{{ target_arch }}.sh -bash conda-installer.sh -b -p "/opt/miniconda" -rm conda-installer.sh -/opt/miniconda/bin/conda install unzip --yes -curl -L -o source.zip https://github.com/AdvancedPhotonSource/ptychodus/archive/{{ pkg_version }}.zip -/opt/miniconda/bin/unzip source.zip -rm source.zip -cd ptychodus* -CONDA_OVERRIDE_CUDA={{ cuda_version }} /opt/miniconda/bin/conda install cuda-version={{ cuda_version }} --file requirements.txt -c conda-forge --yes -/opt/miniconda/bin/pip install . --no-deps --no-build-isolation -/opt/miniconda/bin/pip check -cd .. -rm ptychodus* -rf -/opt/miniconda/bin/conda clean --all --yes - -%runscript -/opt/miniconda/bin/python -m ptychodus "$@" diff --git a/doc/dist.rst b/doc/dist.rst new file mode 100644 index 00000000..f4a5d017 --- /dev/null +++ b/doc/dist.rst @@ -0,0 +1,47 @@ +Distribution Instructions +========================= + +Python Package Index (PyPI) +--------------------------- + +From the ptychodus directory, create wheel in ./dist/ + +.. code-block:: shell + + $ python -m build . + +Upload to PyPI + +.. code-block:: shell + + $ twine upload dist/* + +Docker +------ + +Build Docker image + +.. code-block:: shell + + $ time docker build -t python-ptychodus . + +Run container + +.. code-block:: shell + + $ xhost +local:docker + $ docker run -it --rm -e "DISPLAY=$DISPLAY" -v "$HOME/.Xauthority:/root/.Xauthority:ro" --network host \ + --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 python-ptychodus + $ xhost -local:docker + +Check container status + +.. code-block:: shell + + $ docker ps -a + +Clean up images + +.. code-block:: shell + + $ sudo docker system prune -a diff --git a/polaris.yaml b/polaris.yaml new file mode 100644 index 00000000..ba18891f --- /dev/null +++ b/polaris.yaml @@ -0,0 +1,40 @@ +engine: + type: HighThroughputEngine + max_workers_per_node: 1 + + # Un-comment to give each worker exclusive access to a single GPU + # available_accelerators: 4 + + strategy: + type: SimpleStrategy + max_idletime: 3600 + + address: + type: address_by_interface + ifname: bond0 + + provider: + type: PBSProProvider + + launcher: + type: MpiExecLauncher + # Ensures 1 manger per node, work on all 64 cores + bind_cmd: --cpu-bind + overrides: --depth=64 --ppn 1 + + account: APSDataAnalysis + queue: preemptable + cpus_per_node: 32 + select_options: ngpus=4 + + # e.g., "#PBS -l filesystems=home:grand:eagle\n#PBS -k doe" + scheduler_options: "#PBS -l filesystems=home:grand:eagle" + + # Node setup: activate necessary conda environment and such + worker_init: "source ~/miniconda3/etc/profile.d/conda.sh; conda activate ptychodus", + + walltime: 01:00:00 + nodes_per_block: 1 + init_blocks: 0 + min_blocks: 0 + max_blocks: 2 diff --git a/pyproject.toml b/pyproject.toml index aae521b1..016f28f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=64", "setuptools_scm>=8"] +requires = ["setuptools>=64", "setuptools_scm[toml]>=8"] build-backend = "setuptools.build_meta" [project] @@ -7,14 +7,16 @@ name = "ptychodus" description = "Ptychodus is a ptychography data analysis application." readme = "README.rst" requires-python = ">=3.10" -license = {file = "LICENSE.txt"} +license = {file = "LICENSE"} dependencies = [ - "h5py", + "h5py>=3", + "hdf5plugin", "matplotlib", "numpy", "psutil", "scikit-image", "scipy", + "tables", "tifffile", "watchdog", ] @@ -24,11 +26,17 @@ dynamic = ["version"] ptychodus = "ptychodus.__main__:main" [project.optional-dependencies] -globus = ["gladier", "gladier-tools"] +globus = ["gladier", "gladier-tools>=0.5.4"] gui = ["PyQt5"] ptychonn = ["ptychonn==0.3.*,>=0.3.7"] tike = ["tike==0.25.*,>=0.25.3"] +[tool.setuptools.package-data] +"ptychodus" = ["py.typed"] + +[tool.setuptools.packages.find] +where = ["src"] + [tool.setuptools_scm] [tool.mypy] @@ -44,6 +52,7 @@ module = [ "hdf5plugin", "lightning.*", "parsl.*", + "ptychi.*", "ptychonn.*", "pvaccess", "pvapy.*", @@ -60,10 +69,3 @@ target-version = "py310" [tool.ruff.format] quote-style = "single" - -[tool.setuptools.package-data] -"ptychodus" = ["py.typed"] - -[tool.setuptools.packages.find] -where = ["src"] - diff --git a/requirements-dev.txt b/requirements-dev.txt index c42305e2..3113269b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,5 @@ -h5py +build +h5py>=3 hdf5plugin matplotlib mypy @@ -6,6 +7,7 @@ numpy psutil pyqt pyqt-stubs +pytables pytest python>=3.10 ruff @@ -16,4 +18,3 @@ setuptools_scm>=8 tifffile toml watchdog -wheel diff --git a/src/ptychodus/__main__.py b/src/ptychodus/__main__.py index 5fb3c484..53f6d78c 100644 --- a/src/ptychodus/__main__.py +++ b/src/ptychodus/__main__.py @@ -2,11 +2,14 @@ from pathlib import Path import argparse +import logging import sys from ptychodus.model import ModelCore import ptychodus +logger = logging.getLogger(__name__) + def versionString() -> str: return f'{ptychodus.__name__.title()} ({ptychodus.__version__})' @@ -120,6 +123,7 @@ def main() -> int: try: from PyQt5.QtWidgets import QApplication except ModuleNotFoundError: + logger.warning('PyQt5 not found.') return 0 # QApplication expects the first argument to be the program name diff --git a/src/ptychodus/api/constants.py b/src/ptychodus/api/constants.py deleted file mode 100644 index a377bc38..00000000 --- a/src/ptychodus/api/constants.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Final - -# Source: https://physics.nist.gov/cuu/Constants/index.html -ELECTRON_VOLT_J: Final[float] = 1.602176634e-19 -LIGHT_SPEED_M_PER_S: Final[float] = 299792458 -PLANCK_CONSTANT_J_PER_HZ: Final[float] = 6.62607015e-34 diff --git a/src/ptychodus/api/geometry.py b/src/ptychodus/api/geometry.py index 324c6aab..5ce22364 100644 --- a/src/ptychodus/api/geometry.py +++ b/src/ptychodus/api/geometry.py @@ -11,6 +11,12 @@ class PixelGeometry: widthInMeters: float heightInMeters: float + def copy(self) -> PixelGeometry: + return PixelGeometry( + widthInMeters=float(self.widthInMeters), + heightInMeters=float(self.heightInMeters), + ) + def __repr__(self) -> str: return f'{type(self).__name__}({self.widthInMeters}, {self.heightInMeters})' diff --git a/src/ptychodus/api/object.py b/src/ptychodus/api/object.py index 60c2cfc0..9edf4102 100644 --- a/src/ptychodus/api/object.py +++ b/src/ptychodus/api/object.py @@ -14,6 +14,21 @@ ObjectArrayType: TypeAlias = numpy.typing.NDArray[numpy.complexfloating[Any, Any]] +@dataclass(frozen=True) +class ObjectCenter: + positionXInMeters: float + positionYInMeters: float + + def copy(self) -> ObjectCenter: + return ObjectCenter( + positionXInMeters=float(self.positionXInMeters), + positionYInMeters=float(self.positionYInMeters), + ) + + def __repr__(self) -> str: + return f'{type(self).__name__}({self.positionXInMeters}, {self.positionYInMeters})' + + @dataclass(frozen=True) class ObjectPoint: index: int @@ -52,6 +67,12 @@ def getPixelGeometry(self) -> PixelGeometry: heightInMeters=self.pixelHeightInMeters, ) + def getCenter(self) -> ObjectCenter: + return ObjectCenter( + positionXInMeters=self.centerXInMeters, + positionYInMeters=self.centerYInMeters, + ) + def mapObjectPointToScanPoint(self, point: ObjectPoint) -> ScanPoint: rx_px = self.widthInPixels / 2 ry_px = self.heightInPixels / 2 @@ -91,65 +112,49 @@ def getObjectGeometry(self) -> ObjectGeometry: class Object: def __init__( self, - array: ObjectArrayType | None = None, - layerDistanceInMeters: Sequence[float] | None = None, - *, - pixelWidthInMeters: float = 0.0, - pixelHeightInMeters: float = 0.0, - centerXInMeters: float = 0.0, - centerYInMeters: float = 0.0, + array: ObjectArrayType | None, + pixelGeometry: PixelGeometry | None, + center: ObjectCenter | None, + layerDistanceInMeters: Sequence[float] = [], ) -> None: if array is None: - self._array = numpy.zeros((1, 0, 0), dtype=complex) - else: - if numpy.iscomplexobj(array): - if array.ndim == 2: - self._array = array[numpy.newaxis, :, :] - elif array.ndim == 3: + self._array: ObjectArrayType = numpy.zeros((1, 0, 0), dtype=complex) + elif numpy.iscomplexobj(array): + match array.ndim: + case 2: + self._array = array[numpy.newaxis, ...] + case 3: self._array = array - else: + case _: raise ValueError('Object must be 2- or 3-dimensional ndarray.') - else: - raise TypeError('Object must be a complex-valued ndarray') - - if layerDistanceInMeters is None: - self._layerDistanceInMeters: Sequence[float] = [numpy.inf] else: - self._layerDistanceInMeters = layerDistanceInMeters + raise TypeError('Object must be a complex-valued ndarray') - expectedLayers = self.numberOfLayers - actualLayers = len(self._layerDistanceInMeters) + self._pixelGeometry = pixelGeometry + self._center = center + self._layerDistanceInMeters = layerDistanceInMeters - if actualLayers < expectedLayers: - raise ValueError(f'Expected {expectedLayers} layer distances; got {actualLayers}!') + expectedSpaces = self._array.shape[-3] - 1 + actualSpaces = len(layerDistanceInMeters) - self._pixelWidthInMeters = pixelWidthInMeters - self._pixelHeightInMeters = pixelHeightInMeters - self._centerXInMeters = centerXInMeters - self._centerYInMeters = centerYInMeters + if actualSpaces != expectedSpaces: + raise ValueError(f'Expected {expectedSpaces} layer distances; got {actualSpaces}!') def copy(self) -> Object: return Object( - array=numpy.array(self._array), + array=self._array.copy(), + pixelGeometry=None if self._pixelGeometry is None else self._pixelGeometry.copy(), + center=None if self._center is None else self._center.copy(), layerDistanceInMeters=list(self._layerDistanceInMeters), - pixelWidthInMeters=float(self._pixelWidthInMeters), - pixelHeightInMeters=float(self._pixelHeightInMeters), - centerXInMeters=float(self._centerXInMeters), - centerYInMeters=float(self._centerYInMeters), ) - @property - def array(self) -> ObjectArrayType: + def getArray(self) -> ObjectArrayType: return self._array @property def dataType(self) -> numpy.dtype: return self._array.dtype - @property - def numberOfLayers(self) -> int: - return self._array.shape[-3] - @property def sizeInBytes(self) -> int: return self._array.nbytes @@ -163,35 +168,37 @@ def heightInPixels(self) -> int: return self._array.shape[-2] @property - def pixelWidthInMeters(self) -> float: - return self._pixelWidthInMeters - - @property - def pixelHeightInMeters(self) -> float: - return self._pixelHeightInMeters + def numberOfLayers(self) -> int: + return self._array.shape[-3] - @property - def centerXInMeters(self) -> float: - return self._centerXInMeters + def getPixelGeometry(self) -> PixelGeometry | None: + return self._pixelGeometry - @property - def centerYInMeters(self) -> float: - return self._centerYInMeters + def getCenter(self) -> ObjectCenter | None: + return self._center def getGeometry(self) -> ObjectGeometry: + pixelWidthInMeters = 0.0 + pixelHeightInMeters = 0.0 + + if self._pixelGeometry is not None: + pixelWidthInMeters = self._pixelGeometry.widthInMeters + pixelHeightInMeters = self._pixelGeometry.heightInMeters + + centerXInMeters = 0.0 + centerYInMeters = 0.0 + + if self._center is not None: + centerXInMeters = self._center.positionXInMeters + centerYInMeters = self._center.positionYInMeters + return ObjectGeometry( widthInPixels=self.widthInPixels, heightInPixels=self.heightInPixels, - pixelWidthInMeters=self._pixelWidthInMeters, - pixelHeightInMeters=self._pixelHeightInMeters, - centerXInMeters=self._centerXInMeters, - centerYInMeters=self._centerYInMeters, - ) - - def getPixelGeometry(self) -> PixelGeometry: - return PixelGeometry( - widthInMeters=self._pixelWidthInMeters, - heightInMeters=self._pixelHeightInMeters, + pixelWidthInMeters=pixelWidthInMeters, + pixelHeightInMeters=pixelHeightInMeters, + centerXInMeters=centerXInMeters, + centerYInMeters=centerYInMeters, ) def getLayer(self, number: int) -> ObjectArrayType: @@ -204,16 +211,13 @@ def getLayersFlattened(self) -> ObjectArrayType: def layerDistanceInMeters(self) -> Sequence[float]: return self._layerDistanceInMeters - def getLayerDistanceInMeters(self, number: int) -> float: - return self._layerDistanceInMeters[number] - def getTotalLayerDistanceInMeters(self) -> float: - return sum(self._layerDistanceInMeters[:-1]) + return sum(self._layerDistanceInMeters) class ObjectInterpolator(ABC): @abstractmethod - def getPatch(self, patchCenter: ScanPoint, patchExtent: ImageExtent) -> Object: + def get_patch(self, patch_center: ScanPoint, patch_extent: ImageExtent) -> Object: """returns an interpolated patch from the object array""" pass diff --git a/src/ptychodus/api/parametric.py b/src/ptychodus/api/parametric.py index 6a00d4f2..26215df9 100644 --- a/src/ptychodus/api/parametric.py +++ b/src/ptychodus/api/parametric.py @@ -220,8 +220,57 @@ def copy(self) -> RealParameter: ) -class RealArrayParameter(ParameterBase[MutableSequence[float]]): - def __init__(self, value: Sequence[float], parent: RealArrayParameter | None) -> None: +class IntegerSequenceParameter(ParameterBase[MutableSequence[int]]): + def __init__(self, value: Sequence[int], parent: IntegerSequenceParameter | None) -> None: + super().__init__(list(value), parent) + + def __iter__(self) -> Iterator[int]: + return iter(self._value) + + def __getitem__(self, index: int) -> int: + return self._value[index] + + def __setitem__(self, index: int, value: int) -> None: + if self._value[index] != value: + self._value[index] = value + self.notifyObservers() + + def __delitem__(self, index: int) -> None: + del self._value[index] + self.notifyObservers() + + def insert(self, index: int, value: int) -> None: + self._value.insert(index, value) + self.notifyObservers() + + def __len__(self) -> int: + return len(self._value) + + def setValue(self, value: Sequence[int], *, notify: bool = True) -> None: + if self._value != value: + self._value = list(value) + + if notify: + self.notifyObservers() + + def getValueAsString(self) -> str: + return ','.join(repr(value) for value in self) + + def setValueFromString(self, value: str) -> None: + newValue: list[int] = list() + + for xstr in value.split(','): + if xstr: + newValue.append(int(xstr)) + + self.setValue(newValue) + + def copy(self) -> IntegerSequenceParameter: + return IntegerSequenceParameter(self.getValue(), self) + + +class RealSequenceParameter(ParameterBase[MutableSequence[float]]): + def __init__(self, value: Sequence[float], parent: RealSequenceParameter | None) -> None: super().__init__(list(value), parent) def __iter__(self) -> Iterator[float]: @@ -260,21 +309,22 @@ def setValueFromString(self, value: str) -> None: tmp: list[float] = list() for xstr in value.split(','): - try: - x = float(xstr) - except ValueError: - x = float('nan') + if xstr: + try: + x = float(xstr) + except ValueError: + x = float('nan') - tmp.append(x) + tmp.append(x) self.setValue(tmp) - def copy(self) -> RealArrayParameter: - return RealArrayParameter(self.getValue(), self) + def copy(self) -> RealSequenceParameter: + return RealSequenceParameter(self.getValue(), self) -class ComplexArrayParameter(ParameterBase[MutableSequence[complex]]): - def __init__(self, value: Sequence[complex], parent: ComplexArrayParameter | None) -> None: +class ComplexSequenceParameter(ParameterBase[MutableSequence[complex]]): + def __init__(self, value: Sequence[complex], parent: ComplexSequenceParameter | None) -> None: super().__init__(list(value), parent) def __iter__(self) -> Iterator[complex]: @@ -313,17 +363,18 @@ def setValueFromString(self, value: str) -> None: tmp: list[complex] = list() for xstr in value.split(','): - try: - x = complex(xstr) - except ValueError: - x = float('nan') * 1j + if xstr: + try: + x = complex(xstr) + except ValueError: + x = float('nan') * 1j - tmp.append(x) + tmp.append(x) self.setValue(tmp) - def copy(self) -> ComplexArrayParameter: - return ComplexArrayParameter(self.getValue(), self) + def copy(self) -> ComplexSequenceParameter: + return ComplexSequenceParameter(self.getValue(), self) class ParameterGroup(Observable, Observer): @@ -368,6 +419,13 @@ def createIntegerParameter( self._addParameter(name, parameter) return parameter + def createIntegerSequenceParameter( + self, name: str, value: Sequence[int] + ) -> IntegerSequenceParameter: + parameter = IntegerSequenceParameter(value, parent=None) + self._addParameter(name, parameter) + return parameter + def createRealParameter( self, name: str, value: float, *, minimum: float | None = None, maximum: float | None = None ) -> RealParameter: @@ -375,15 +433,17 @@ def createRealParameter( self._addParameter(name, parameter) return parameter - def createRealArrayParameter(self, name: str, value: Sequence[float]) -> RealArrayParameter: - parameter = RealArrayParameter(value, parent=None) + def createRealSequenceParameter( + self, name: str, value: Sequence[float] + ) -> RealSequenceParameter: + parameter = RealSequenceParameter(value, parent=None) self._addParameter(name, parameter) return parameter - def createComplexArrayParameter( + def createComplexSequenceParameter( self, name: str, value: Sequence[complex] - ) -> ComplexArrayParameter: - parameter = ComplexArrayParameter(value, parent=None) + ) -> ComplexSequenceParameter: + parameter = ComplexSequenceParameter(value, parent=None) self._addParameter(name, parameter) return parameter diff --git a/src/ptychodus/api/probe.py b/src/ptychodus/api/probe.py index 6e46cdc3..3726596f 100644 --- a/src/ptychodus/api/probe.py +++ b/src/ptychodus/api/probe.py @@ -1,12 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path import numpy -from .geometry import ImageExtent, PixelGeometry +from .geometry import PixelGeometry from .propagator import WavefieldArrayType, intensity from .typing import RealArrayType @@ -40,20 +39,12 @@ def widthInMeters(self) -> float: def heightInMeters(self) -> float: return self.heightInPixels * self.pixelHeightInMeters - def _asTuple(self) -> tuple[int, int, float, float]: - return ( - self.widthInPixels, - self.heightInPixels, - self.pixelWidthInMeters, - self.pixelHeightInMeters, + def getPixelGeometry(self) -> PixelGeometry: + return PixelGeometry( + widthInMeters=self.pixelWidthInMeters, + heightInMeters=self.pixelHeightInMeters, ) - def __eq__(self, other: object) -> bool: - if isinstance(other, ProbeGeometry): - return self._asTuple() == other._asTuple() - - return False - class ProbeGeometryProvider(ABC): @property @@ -61,6 +52,11 @@ class ProbeGeometryProvider(ABC): def detectorDistanceInMeters(self) -> float: pass + @property + @abstractmethod + def probePhotonCount(self) -> float: + pass + @property @abstractmethod def probeWavelengthInMeters(self) -> float: @@ -77,59 +73,49 @@ def getProbeGeometry(self) -> ProbeGeometry: class Probe: - @staticmethod - def _calculateModeRelativePower(array: WavefieldArrayType) -> Sequence[float]: - power = numpy.sum(intensity(array), axis=(-2, -1)) - powersum = numpy.sum(power) - - if powersum > 0.0: - power /= powersum - - return power.tolist() - def __init__( self, - array: WavefieldArrayType | None = None, - *, - pixelWidthInMeters: float = 0.0, - pixelHeightInMeters: float = 0.0, + array: WavefieldArrayType | None, + pixelGeometry: PixelGeometry | None, ) -> None: if array is None: - self._array = numpy.zeros((1, 0, 0), dtype=complex) - else: - if numpy.iscomplexobj(array): - if array.ndim == 2: - self._array = array[numpy.newaxis, :, :] - elif array.ndim == 3: + self._array: WavefieldArrayType = numpy.zeros((1, 1, 0, 0), dtype=complex) + elif numpy.iscomplexobj(array): + match array.ndim: + case 2: + self._array = array[numpy.newaxis, numpy.newaxis, ...] + case 3: + self._array = array[numpy.newaxis, ...] + case 4: self._array = array - else: - raise ValueError('Probe must be 2- or 3-dimensional ndarray.') - else: - raise TypeError('Probe must be a complex-valued ndarray') + case _: + raise ValueError('Probe must be 2-, 3-, or 4-dimensional ndarray.') + else: + raise TypeError('Probe must be a complex-valued ndarray') - self._modeRelativePower = Probe._calculateModeRelativePower(self._array) - self._pixelWidthInMeters = pixelWidthInMeters - self._pixelHeightInMeters = pixelHeightInMeters + self._pixelGeometry = pixelGeometry + + power = numpy.sum(intensity(self._array[0]), axis=(-2, -1)) + powersum = numpy.sum(power) + + if powersum > 0.0: + power /= powersum + + self._modeRelativePower = power.tolist() def copy(self) -> Probe: return Probe( - array=numpy.array(self._array), - pixelWidthInMeters=float(self._pixelWidthInMeters), - pixelHeightInMeters=float(self._pixelHeightInMeters), + array=self._array.copy(), + pixelGeometry=None if self._pixelGeometry is None else self._pixelGeometry.copy(), ) - @property - def array(self) -> WavefieldArrayType: + def getArray(self) -> WavefieldArrayType: return self._array @property def dataType(self) -> numpy.dtype: return self._array.dtype - @property - def numberOfModes(self) -> int: - return self._array.shape[-3] - @property def sizeInBytes(self) -> int: return self._array.nbytes @@ -143,50 +129,50 @@ def heightInPixels(self) -> int: return self._array.shape[-2] @property - def pixelWidthInMeters(self) -> float: - return self._pixelWidthInMeters + def numberOfIncoherentModes(self) -> int: + return self._array.shape[-3] @property - def pixelHeightInMeters(self) -> float: - return self._pixelHeightInMeters + def numberOfCoherentModes(self) -> int: + return self._array.shape[-4] + + def getPixelGeometry(self) -> PixelGeometry | None: + return self._pixelGeometry def getGeometry(self) -> ProbeGeometry: - return ProbeGeometry( - widthInPixels=self.widthInPixels, - heightInPixels=self.heightInPixels, - pixelWidthInMeters=self._pixelWidthInMeters, - pixelHeightInMeters=self._pixelHeightInMeters, - ) + pixelWidthInMeters = 0.0 + pixelHeightInMeters = 0.0 - def getPixelGeometry(self) -> PixelGeometry: - return PixelGeometry( - widthInMeters=self._pixelWidthInMeters, - heightInMeters=self._pixelHeightInMeters, - ) + if self._pixelGeometry is not None: + pixelWidthInMeters = self._pixelGeometry.widthInMeters + pixelHeightInMeters = self._pixelGeometry.heightInMeters - def getExtent(self) -> ImageExtent: - return ImageExtent( + return ProbeGeometry( widthInPixels=self.widthInPixels, heightInPixels=self.heightInPixels, + pixelWidthInMeters=pixelWidthInMeters, + pixelHeightInMeters=pixelHeightInMeters, ) - def getMode(self, number: int) -> WavefieldArrayType: - return self._array[number, :, :] + def getIncoherentMode(self, number: int) -> WavefieldArrayType: + return self._array[0, number, :, :] - def getModesFlattened(self) -> WavefieldArrayType: - if self._array.size > 0: - return self._array.transpose((1, 0, 2)).reshape(self._array.shape[-2], -1) - else: - return self._array + def getIncoherentModesFlattened(self) -> WavefieldArrayType: + modes = self._array[0] + return modes.transpose((1, 0, 2)).reshape(modes.shape[-2], -1) - def getModeRelativePower(self, number: int) -> float: + def getIncoherentModeRelativePower(self, number: int) -> float: return self._modeRelativePower[number] def getCoherence(self) -> float: return numpy.sqrt(numpy.sum(numpy.square(self._modeRelativePower))) + def getCoherentMode(self, number: int) -> WavefieldArrayType: + return self._array[number, 0, :, :] + def getIntensity(self) -> RealArrayType: - return numpy.sum(intensity(self._array), axis=-3) + array_no_opr = self._array[0] # TODO OPR + return numpy.sum(intensity(array_no_opr), axis=-3) class ProbeFileReader(ABC): diff --git a/src/ptychodus/api/product.py b/src/ptychodus/api/product.py index 73b660d4..85da5d4a 100644 --- a/src/ptychodus/api/product.py +++ b/src/ptychodus/api/product.py @@ -1,14 +1,19 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from typing import Final from dataclasses import dataclass from pathlib import Path from sys import getsizeof -from .constants import ELECTRON_VOLT_J, LIGHT_SPEED_M_PER_S, PLANCK_CONSTANT_J_PER_HZ from .object import Object from .probe import Probe from .scan import Scan +# Source: https://physics.nist.gov/cuu/Constants/index.html +ELECTRON_VOLT_J: Final[float] = 1.602176634e-19 +LIGHT_SPEED_M_PER_S: Final[float] = 299792458 +PLANCK_CONSTANT_J_PER_HZ: Final[float] = 6.62607015e-34 + @dataclass(frozen=True) class ProductMetadata: @@ -16,7 +21,7 @@ class ProductMetadata: comments: str detectorDistanceInMeters: float probeEnergyInElectronVolts: float - probePhotonsPerSecond: float + probePhotonCount: float exposureTimeInSeconds: float @property @@ -38,7 +43,7 @@ def sizeInBytes(self) -> int: sz += getsizeof(self.comments) sz += getsizeof(self.detectorDistanceInMeters) sz += getsizeof(self.probeEnergyInElectronVolts) - sz += getsizeof(self.probePhotonsPerSecond) + sz += getsizeof(self.probePhotonCount) sz += getsizeof(self.exposureTimeInSeconds) return sz diff --git a/src/ptychodus/api/reconstructor.py b/src/ptychodus/api/reconstructor.py index ad5eb67e..c62b475a 100644 --- a/src/ptychodus/api/reconstructor.py +++ b/src/ptychodus/api/reconstructor.py @@ -41,63 +41,27 @@ class TrainOutput: class TrainableReconstructor(Reconstructor): @abstractmethod - def ingestTrainingData(self, parameters: ReconstructInput) -> None: + def getModelFileFilter(self) -> str: pass @abstractmethod - def getOpenTrainingDataFileFilterList(self) -> Sequence[str]: - pass - - @abstractmethod - def getOpenTrainingDataFileFilter(self) -> str: - pass - - @abstractmethod - def openTrainingData(self, filePath: Path) -> None: - pass - - @abstractmethod - def getSaveTrainingDataFileFilterList(self) -> Sequence[str]: - pass - - @abstractmethod - def getSaveTrainingDataFileFilter(self) -> str: - pass - - @abstractmethod - def saveTrainingData(self, filePath: Path) -> None: - pass - - @abstractmethod - def train(self) -> TrainOutput: - pass - - @abstractmethod - def clearTrainingData(self) -> None: - pass - - @abstractmethod - def getOpenModelFileFilterList(self) -> Sequence[str]: - pass - - @abstractmethod - def getOpenModelFileFilter(self) -> str: + def openModel(self, filePath: Path) -> None: pass @abstractmethod - def openModel(self, filePath: Path) -> None: + def saveModel(self, filePath: Path) -> None: pass @abstractmethod - def getSaveModelFileFilterList(self) -> Sequence[str]: + def getTrainingDataFileFilter(self) -> str: pass @abstractmethod - def getSaveModelFileFilter(self) -> str: + def exportTrainingData(self, filePath: Path, parameters: ReconstructInput) -> None: pass @abstractmethod - def saveModel(self, filePath: Path) -> None: + def train(self, dataPath: Path) -> TrainOutput: pass @@ -112,50 +76,23 @@ def name(self) -> str: def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: return ReconstructOutput(parameters.product, 0) - def ingestTrainingData(self, parameters: ReconstructInput) -> None: - pass - - def getOpenTrainingDataFileFilterList(self) -> Sequence[str]: - return list() - - def getOpenTrainingDataFileFilter(self) -> str: - return str() - - def openTrainingData(self, filePath: Path) -> None: - pass - - def getSaveTrainingDataFileFilterList(self) -> Sequence[str]: - return list() - - def getSaveTrainingDataFileFilter(self) -> str: + def getModelFileFilter(self) -> str: return str() - def saveTrainingData(self, filePath: Path) -> None: + def openModel(self, filePath: Path) -> None: pass - def train(self) -> TrainOutput: - return TrainOutput([], [], 0) - - def clearTrainingData(self) -> None: + def saveModel(self, filePath: Path) -> None: pass - def getOpenModelFileFilterList(self) -> Sequence[str]: - return list() - - def getOpenModelFileFilter(self) -> str: + def getTrainingDataFileFilter(self) -> str: return str() - def openModel(self, filePath: Path) -> None: + def exportTrainingData(self, filePath: Path, parameters: ReconstructInput) -> None: pass - def getSaveModelFileFilterList(self) -> Sequence[str]: - return list() - - def getSaveModelFileFilter(self) -> str: - return str() - - def saveModel(self, filePath: Path) -> None: - pass + def train(self, dataPath: Path) -> TrainOutput: + return TrainOutput([], [], 0) class ReconstructorLibrary(Iterable[Reconstructor]): @@ -163,3 +100,8 @@ class ReconstructorLibrary(Iterable[Reconstructor]): @abstractmethod def name(self) -> str: pass + + @property + @abstractmethod + def logger_name(self) -> str: + pass diff --git a/src/ptychodus/api/workflow.py b/src/ptychodus/api/workflow.py index 54e331b8..92d3ba68 100644 --- a/src/ptychodus/api/workflow.py +++ b/src/ptychodus/api/workflow.py @@ -41,7 +41,7 @@ def buildObject( pass @abstractmethod - def reconstructLocal(self, outputProductName: str) -> WorkflowProductAPI: + def reconstructLocal(self) -> WorkflowProductAPI: pass @abstractmethod @@ -89,7 +89,7 @@ def createProduct( comments: str = '', detectorDistanceInMeters: float | None = None, probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, + probePhotonCount: float | None = None, exposureTimeInSeconds: float | None = None, ) -> WorkflowProductAPI: """creates a new product""" diff --git a/src/ptychodus/controller/core.py b/src/ptychodus/controller/core.py index e709cbd8..e695110b 100644 --- a/src/ptychodus/controller/core.py +++ b/src/ptychodus/controller/core.py @@ -13,6 +13,7 @@ from .patterns import PatternsController from .probe import ProbeController from .product import ProductController +from .ptychi import PtyChiViewControllerFactory from .ptychonn import PtychoNNViewControllerFactory from .reconstructor import ReconstructorController from .scan import ScanController @@ -27,6 +28,9 @@ def __init__(self, model: ModelCore, view: ViewCore) -> None: self._memoryController = MemoryController(model.memoryPresenter, view.memoryProgressBar) self._fileDialogFactory = FileDialogFactory() + self._ptyChiViewControllerFactory = PtyChiViewControllerFactory( + model.ptyChiReconstructorLibrary + ) self._ptychonnViewControllerFactory = PtychoNNViewControllerFactory( model.ptychonnReconstructorLibrary, self._fileDialogFactory ) @@ -43,7 +47,7 @@ def __init__(self, model: ModelCore, view: ViewCore) -> None: view.statusBar(), self._fileDialogFactory, ) - self._patternsController = PatternsController.createInstance( + self._patternsController = PatternsController( model.detector, model.diffractionDatasetInputOutputPresenter, model.diffractionMetadataPresenter, @@ -103,19 +107,20 @@ def __init__(self, model: ModelCore, view: ViewCore) -> None: view.objectView, self._fileDialogFactory, ) - self._reconstructorParametersController = ReconstructorController.createInstance( + self._reconstructorController = ReconstructorController( model.reconstructorPresenter, model.productRepository, - view.reconstructorParametersView, + view.reconstructorView, view.reconstructorPlotView, - self._fileDialogFactory, self._productController.tableModel, + self._fileDialogFactory, [ + self._ptyChiViewControllerFactory, self._ptychonnViewControllerFactory, self._tikeViewControllerFactory, ], ) - self._workflowController = WorkflowController.createInstance( + self._workflowController = WorkflowController( model.workflowParametersPresenter, model.workflowAuthorizationPresenter, model.workflowStatusPresenter, diff --git a/src/ptychodus/controller/object/core.py b/src/ptychodus/controller/object/core.py index 268a3da0..0b0d5b48 100644 --- a/src/ptychodus/controller/object/core.py +++ b/src/ptychodus/controller/object/core.py @@ -240,7 +240,12 @@ def _updateView(self, current: QModelIndex, previous: QModelIndex) -> None: if current.parent().isValid() else object_.getLayersFlattened() ) - self._imageController.setArray(array, object_.getPixelGeometry()) + pixelGeometry = object_.getPixelGeometry() + + if pixelGeometry is None: + logger.warning('Missing object pixel geometry!') + else: + self._imageController.setArray(array, pixelGeometry) def handleItemInserted(self, index: int, item: ObjectRepositoryItem) -> None: self._treeModel.insertItem(index, item) diff --git a/src/ptychodus/controller/object/treeModel.py b/src/ptychodus/controller/object/treeModel.py index 24cc29c7..b1c1b0c9 100644 --- a/src/ptychodus/controller/object/treeModel.py +++ b/src/ptychodus/controller/object/treeModel.py @@ -150,7 +150,7 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A try: return item.layerDistanceInMeters[index.row()] except IndexError: - return float('NaN') + return float('inf') else: item = self._repository[index.row()] object_ = item.getObject() diff --git a/src/ptychodus/controller/object/xmcd.py b/src/ptychodus/controller/object/xmcd.py index d8317972..2a26a36a 100644 --- a/src/ptychodus/controller/object/xmcd.py +++ b/src/ptychodus/controller/object/xmcd.py @@ -66,16 +66,21 @@ def _analyze(self) -> None: return self._result = result - self._differenceVisualizationWidgetController.setArray( - result.polar_difference[0, :, :], result.pixel_geometry - ) - self._sumVisualizationWidgetController.setArray( - result.polar_sum[0, :, :], result.pixel_geometry - ) - # TODO support multi-layer objects - self._ratioVisualizationWidgetController.setArray( - result.polar_ratio[0, :, :], result.pixel_geometry - ) + pixel_geometry = result.pixel_geometry + + if pixel_geometry is None: + logger.warning('Missing XMCD pixel geometry!') + else: + self._differenceVisualizationWidgetController.setArray( + result.polar_difference[0, :, :], pixel_geometry + ) + self._sumVisualizationWidgetController.setArray( + result.polar_sum[0, :, :], pixel_geometry + ) + # TODO support multi-layer objects + self._ratioVisualizationWidgetController.setArray( + result.polar_ratio[0, :, :], pixel_geometry + ) def analyze(self, lcircItemIndex: int, rcircItemIndex: int) -> None: self._dialog.parametersView.lcircComboBox.setCurrentIndex(lcircItemIndex) diff --git a/src/ptychodus/controller/parametric.py b/src/ptychodus/controller/parametric.py index 9cdd9686..6df4edfb 100644 --- a/src/ptychodus/controller/parametric.py +++ b/src/ptychodus/controller/parametric.py @@ -1,6 +1,8 @@ +from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from decimal import Decimal +from pathlib import Path from typing import Final import logging @@ -14,7 +16,9 @@ QDialogButtonBox, QFormLayout, QGroupBox, + QHBoxLayout, QLineEdit, + QPushButton, QSpinBox, QVBoxLayout, QWidget, @@ -25,18 +29,16 @@ from ptychodus.api.parametric import ( BooleanParameter, IntegerParameter, + PathParameter, RealParameter, StringParameter, ) from ..view.widgets import AngleWidget, DecimalLineEdit, DecimalSlider, LengthWidget +from .data import FileDialogFactory logger = logging.getLogger(__name__) -__all__ = [ - 'ParameterViewBuilder', -] - class ParameterViewController(ABC): @abstractmethod @@ -44,6 +46,31 @@ def getWidget(self) -> QWidget: pass +class CheckableGroupBoxParameterViewController(ParameterViewController, Observer): + def __init__(self, parameter: BooleanParameter, title: str, *, tool_tip: str = '') -> None: + super().__init__() + self._parameter = parameter + self._widget = QGroupBox(title) + self._widget.setCheckable(True) + + if tool_tip: + self._widget.setToolTip(tool_tip) + + self._syncModelToView() + self._widget.toggled.connect(parameter.setValue) + self._parameter.addObserver(self) + + def getWidget(self) -> QWidget: + return self._widget + + def _syncModelToView(self) -> None: + self._widget.setChecked(self._parameter.getValue()) + + def update(self, observable: Observable) -> None: + if observable is self._parameter: + self._syncModelToView() + + class CheckBoxParameterViewController(ParameterViewController, Observer): def __init__(self, parameter: BooleanParameter, text: str, *, tool_tip: str = '') -> None: super().__init__() @@ -91,7 +118,7 @@ def _syncModelToView(self) -> None: maximum = self._parameter.getMaximum() if minimum is None: - logger.error('Minimum not provided!') + raise ValueError('Minimum not provided!') else: self._widget.blockSignals(True) @@ -110,7 +137,7 @@ def update(self, observable: Observable) -> None: class ComboBoxParameterViewController(ParameterViewController, Observer): def __init__( - self, parameter: StringParameter, items: Sequence[str], *, tool_tip: str = '' + self, parameter: StringParameter, items: Iterable[str], *, tool_tip: str = '' ) -> None: super().__init__() self._parameter = parameter @@ -169,6 +196,157 @@ def update(self, observable: Observable) -> None: self._syncModelToView() +class PathParameterViewController(ParameterViewController, Observer): + def __init__( + self, + parameter: PathParameter, + fileDialogFactory: FileDialogFactory, + *, + caption: str, + nameFilters: Sequence[str] | None, + mimeTypeFilters: Sequence[str] | None, + selectedNameFilter: str | None, + tool_tip: str, + ) -> None: + super().__init__() + self._parameter = parameter + self._fileDialogFactory = fileDialogFactory + self._caption = caption + self._nameFilters = nameFilters + self._mimeTypeFilters = mimeTypeFilters + self._selectedNameFilter = selectedNameFilter + self._lineEdit = QLineEdit() + self._browseButton = QPushButton('Browse') + self._widget = QWidget() + + if tool_tip: + self._lineEdit.setToolTip(tool_tip) + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._lineEdit) + layout.addWidget(self._browseButton) + self._widget.setLayout(layout) + + self._syncModelToView() + parameter.addObserver(self) + self._lineEdit.editingFinished.connect(self._syncPathToModel) + + @classmethod + def createFileOpener( + cls, + parameter: PathParameter, + fileDialogFactory: FileDialogFactory, + *, + caption: str = 'Open File', + nameFilters: Sequence[str] | None = None, + mimeTypeFilters: Sequence[str] | None = None, + selectedNameFilter: str | None = None, + tool_tip: str = '', + ) -> PathParameterViewController: + viewController = cls( + parameter, + fileDialogFactory, + caption=caption, + nameFilters=nameFilters, + mimeTypeFilters=mimeTypeFilters, + selectedNameFilter=selectedNameFilter, + tool_tip=tool_tip, + ) + viewController._browseButton.clicked.connect(viewController._chooseFileToOpen) + return viewController + + @classmethod + def createFileSaver( + cls, + parameter: PathParameter, + fileDialogFactory: FileDialogFactory, + *, + caption: str = 'Save File', + nameFilters: Sequence[str] | None = None, + mimeTypeFilters: Sequence[str] | None = None, + selectedNameFilter: str | None = None, + tool_tip: str = '', + ) -> PathParameterViewController: + viewController = cls( + parameter, + fileDialogFactory, + caption=caption, + nameFilters=nameFilters, + mimeTypeFilters=mimeTypeFilters, + selectedNameFilter=selectedNameFilter, + tool_tip=tool_tip, + ) + viewController._browseButton.clicked.connect(viewController._chooseFileToSave) + return viewController + + @classmethod + def createDirectoryChooser( + cls, + parameter: PathParameter, + fileDialogFactory: FileDialogFactory, + *, + caption: str = 'Choose Directory', + tool_tip: str = '', + ) -> PathParameterViewController: + viewController = cls( + parameter, + fileDialogFactory, + caption=caption, + nameFilters=None, + mimeTypeFilters=None, + selectedNameFilter=None, + tool_tip=tool_tip, + ) + viewController._browseButton.clicked.connect(viewController._chooseDirectory) + return viewController + + def getWidget(self) -> QWidget: + return self._widget + + def _syncPathToModel(self) -> None: + path = Path(self._lineEdit.text()) + self._parameter.setValue(path) + + def _chooseFileToOpen(self) -> None: + path, _ = self._fileDialogFactory.getOpenFilePath( + self._widget, + self._caption, + self._nameFilters, + self._mimeTypeFilters, + self._selectedNameFilter, + ) + + if path: + self._parameter.setValue(path) + + def _chooseFileToSave(self) -> None: + path, _ = self._fileDialogFactory.getSaveFilePath( + self._widget, + self._caption, + self._nameFilters, + self._mimeTypeFilters, + self._selectedNameFilter, + ) + + if path: + self._parameter.setValue(path) + + def _chooseDirectory(self) -> None: + path = self._fileDialogFactory.getExistingDirectoryPath(self._widget, self._caption) + + if path: + self._parameter.setValue(path) + + def _syncModelToView(self) -> None: + path = self._parameter.getValue() + self._lineEdit.setText(str(path)) + + def update(self, observable: Observable) -> None: + if observable is self._parameter: + self._syncModelToView() + + class IntegerLineEditParameterViewController(ParameterViewController, Observer): def __init__(self, parameter: IntegerParameter, *, tool_tip: str = '') -> None: super().__init__() @@ -268,7 +446,7 @@ def _syncModelToView(self) -> None: maximum = self._parameter.getMaximum() if minimum is None or maximum is None: - logger.error('Range not provided!') + raise ValueError('Range not provided!') else: value = Decimal(repr(self._parameter.getValue())) range_ = Interval[Decimal](Decimal(repr(minimum)), Decimal(repr(maximum))) @@ -280,11 +458,16 @@ def update(self, observable: Observable) -> None: class LengthWidgetParameterViewController(ParameterViewController, Observer): - def __init__(self, parameter: RealParameter, *, is_signed: bool = False) -> None: + def __init__( + self, parameter: RealParameter, *, is_signed: bool = False, tool_tip: str = '' + ) -> None: super().__init__() self._parameter = parameter self._widget = LengthWidget.createInstance(isSigned=is_signed) + if tool_tip: + self._widget.setToolTip(tool_tip) + self._syncModelToView() self._widget.lengthChanged.connect(self._syncViewToModel) parameter.addObserver(self) @@ -304,11 +487,14 @@ def update(self, observable: Observable) -> None: class AngleWidgetParameterViewController(ParameterViewController, Observer): - def __init__(self, parameter: RealParameter) -> None: + def __init__(self, parameter: RealParameter, tool_tip: str = '') -> None: super().__init__() self._parameter = parameter self._widget = AngleWidget.createInstance() + if tool_tip: + self._widget.setToolTip(tool_tip) + self._syncModelToView() self._widget.angleChanged.connect(self._syncViewToModel) parameter.addObserver(self) @@ -359,7 +545,8 @@ def _handleButtonBoxClicked(self, button: QAbstractButton) -> None: class ParameterViewBuilder: - def __init__(self) -> None: + def __init__(self, fileDialogFactory: FileDialogFactory | None = None) -> None: + self._fileDialogFactory = fileDialogFactory self._viewControllersTop: list[ParameterViewController] = list() self._viewControllers: dict[tuple[str, str], ParameterViewController] = dict() self._viewControllersBottom: list[ParameterViewController] = list() @@ -372,8 +559,85 @@ def addCheckBox( tool_tip: str = '', group: str = '', ) -> None: - viewController = CheckBoxParameterViewController(parameter, '') - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + viewController = CheckBoxParameterViewController(parameter, '', tool_tip=tool_tip) + self.addViewController(viewController, label, group=group) + + def addComboBox( + self, + parameter: StringParameter, + items: Iterable[str], + label: str, + *, + tool_tip: str = '', + group: str = '', + ) -> None: + viewController = ComboBoxParameterViewController(parameter, items, tool_tip=tool_tip) + self.addViewController(viewController, label, group=group) + + def addFileOpener( + self, + parameter: PathParameter, + label: str, + *, + caption: str = 'Open File', + nameFilters: Sequence[str] | None = None, + mimeTypeFilters: Sequence[str] | None = None, + selectedNameFilter: str | None = None, + tool_tip: str = '', + group: str = '', + ) -> None: + if self._fileDialogFactory is None: + raise ValueError('Cannot add file chooser without FileDialogFactory!') + else: + viewController = PathParameterViewController.createFileOpener( + parameter, + self._fileDialogFactory, + caption=caption, + nameFilters=nameFilters, + mimeTypeFilters=mimeTypeFilters, + selectedNameFilter=selectedNameFilter, + tool_tip=tool_tip, + ) + self.addViewController(viewController, label, group=group) + + def addFileSaver( + self, + parameter: PathParameter, + label: str, + *, + caption: str = 'Save File', + nameFilters: Sequence[str] | None = None, + mimeTypeFilters: Sequence[str] | None = None, + selectedNameFilter: str | None = None, + tool_tip: str = '', + group: str = '', + ) -> None: + if self._fileDialogFactory is None: + raise ValueError('Cannot add file chooser without FileDialogFactory!') + else: + viewController = PathParameterViewController.createFileSaver( + parameter, + self._fileDialogFactory, + caption=caption, + nameFilters=nameFilters, + mimeTypeFilters=mimeTypeFilters, + selectedNameFilter=selectedNameFilter, + tool_tip=tool_tip, + ) + self.addViewController(viewController, label, group=group) + + def addDirectoryChooser( + self, parameter: PathParameter, label: str, *, tool_tip: str = '', group: str = '' + ) -> None: + if self._fileDialogFactory is None: + raise ValueError('Cannot add directory chooser without FileDialogFactory!') + else: + viewController = PathParameterViewController.createDirectoryChooser( + parameter, + self._fileDialogFactory, + tool_tip=tool_tip, + ) + self.addViewController(viewController, label, group=group) def addSpinBox( self, @@ -383,8 +647,19 @@ def addSpinBox( tool_tip: str = '', group: str = '', ) -> None: - viewController = SpinBoxParameterViewController(parameter) - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + viewController = SpinBoxParameterViewController(parameter, tool_tip=tool_tip) + self.addViewController(viewController, label, group=group) + + def addIntegerLineEdit( + self, + parameter: IntegerParameter, + label: str, + *, + tool_tip: str = '', + group: str = '', + ) -> None: + viewController = IntegerLineEditParameterViewController(parameter, tool_tip=tool_tip) + self.addViewController(viewController, label, group=group) def addDecimalLineEdit( self, @@ -394,8 +669,8 @@ def addDecimalLineEdit( tool_tip: str = '', group: str = '', ) -> None: - viewController = DecimalLineEditParameterViewController(parameter) - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + viewController = DecimalLineEditParameterViewController(parameter, tool_tip=tool_tip) + self.addViewController(viewController, label, group=group) def addDecimalSlider( self, @@ -405,8 +680,8 @@ def addDecimalSlider( tool_tip: str = '', group: str = '', ) -> None: - viewController = DecimalSliderParameterViewController(parameter) - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + viewController = DecimalSliderParameterViewController(parameter, tool_tip=tool_tip) + self.addViewController(viewController, label, group=group) def addLengthWidget( self, @@ -416,8 +691,8 @@ def addLengthWidget( tool_tip: str = '', group: str = '', ) -> None: - viewController = LengthWidgetParameterViewController(parameter) - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + viewController = LengthWidgetParameterViewController(parameter, tool_tip=tool_tip) + self.addViewController(viewController, label, group=group) def addAngleWidget( self, @@ -427,8 +702,8 @@ def addAngleWidget( tool_tip: str = '', group: str = '', ) -> None: - viewController = AngleWidgetParameterViewController(parameter) - self.addViewController(viewController, label, tool_tip=tool_tip, group=group) + viewController = AngleWidgetParameterViewController(parameter, tool_tip=tool_tip) + self.addViewController(viewController, label, group=group) def addViewControllerToTop(self, viewController: ParameterViewController) -> None: self._viewControllersTop.append(viewController) @@ -438,7 +713,6 @@ def addViewController( viewController: ParameterViewController, label: str, *, - tool_tip: str = '', group: str = '', ) -> None: self._viewControllers[group, label] = viewController diff --git a/src/ptychodus/controller/patterns/core.py b/src/ptychodus/controller/patterns/core.py index e110112d..e5c50358 100644 --- a/src/ptychodus/controller/patterns/core.py +++ b/src/ptychodus/controller/patterns/core.py @@ -1,8 +1,7 @@ -from __future__ import annotations import logging from PyQt5.QtCore import QModelIndex -from PyQt5.QtWidgets import QAbstractItemView, QMessageBox +from PyQt5.QtWidgets import QAbstractItemView, QFormLayout, QMessageBox from ptychodus.api.observer import Observable, Observer @@ -13,18 +12,39 @@ DiffractionMetadataPresenter, DiffractionPatternPresenter, ) -from ...view.patterns import PatternsView +from ...view.patterns import DetectorView, PatternsView from ...view.widgets import ExceptionDialog from ..data import FileDialogFactory from ..image import ImageController -from .detector import DetectorController +from ..parametric import LengthWidgetParameterViewController, SpinBoxParameterViewController +from .dataset import DatasetTreeModel, DatasetTreeNode from .info import PatternsInfoViewController -from .treeModel import DatasetTreeModel, DatasetTreeNode from .wizard import OpenDatasetWizardController logger = logging.getLogger(__name__) +class DetectorController: + def __init__(self, detector: Detector, view: DetectorView) -> None: + self._widthInPixelsViewController = SpinBoxParameterViewController(detector.widthInPixels) + self._heightInPixelsViewController = SpinBoxParameterViewController(detector.heightInPixels) + self._pixelWidthViewController = LengthWidgetParameterViewController( + detector.pixelWidthInMeters + ) + self._pixelHeightViewController = LengthWidgetParameterViewController( + detector.pixelHeightInMeters + ) + self._bitDepthViewController = SpinBoxParameterViewController(detector.bitDepth) + + layout = QFormLayout() + layout.addRow('Detector Width [px]:', self._widthInPixelsViewController.getWidget()) + layout.addRow('Detector Height [px]:', self._heightInPixelsViewController.getWidget()) + layout.addRow('Pixel Width:', self._pixelWidthViewController.getWidget()) + layout.addRow('Pixel Height:', self._pixelHeightViewController.getWidget()) + layout.addRow('Bit Depth:', self._bitDepthViewController.getWidget()) + view.setLayout(layout) + + class PatternsController(Observer): def __init__( self, @@ -55,44 +75,19 @@ def __init__( ) self._treeModel = DatasetTreeModel() - @classmethod - def createInstance( - cls, - detector: Detector, - ioPresenter: DiffractionDatasetInputOutputPresenter, - metadataPresenter: DiffractionMetadataPresenter, - datasetPresenter: DiffractionDatasetPresenter, - patternPresenter: DiffractionPatternPresenter, - imageController: ImageController, - view: PatternsView, - fileDialogFactory: FileDialogFactory, - ) -> PatternsController: - controller = cls( - detector, - ioPresenter, - metadataPresenter, - datasetPresenter, - patternPresenter, - imageController, - view, - fileDialogFactory, - ) - - view.treeView.setModel(controller._treeModel) + view.treeView.setModel(self._treeModel) view.treeView.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) - view.treeView.selectionModel().currentChanged.connect(controller._updateView) - controller._updateView(QModelIndex(), QModelIndex()) + view.treeView.selectionModel().currentChanged.connect(self._updateView) + self._updateView(QModelIndex(), QModelIndex()) - view.buttonBox.openButton.clicked.connect(controller._wizardController.openDataset) - view.buttonBox.saveButton.clicked.connect(controller._saveDataset) - view.buttonBox.infoButton.clicked.connect(controller._openPatternsInfo) - view.buttonBox.closeButton.clicked.connect(controller._closeDataset) + view.buttonBox.openButton.clicked.connect(self._wizardController.openDataset) + view.buttonBox.saveButton.clicked.connect(self._saveDataset) + view.buttonBox.infoButton.clicked.connect(self._openPatternsInfo) + view.buttonBox.closeButton.clicked.connect(self._closeDataset) view.buttonBox.closeButton.setEnabled(False) # TODO - datasetPresenter.addObserver(controller) - - controller._syncModelToView() + datasetPresenter.addObserver(self) - return controller + self._syncModelToView() def _updateView(self, current: QModelIndex, previous: QModelIndex) -> None: if current.isValid(): diff --git a/src/ptychodus/controller/patterns/treeModel.py b/src/ptychodus/controller/patterns/dataset.py similarity index 94% rename from src/ptychodus/controller/patterns/treeModel.py rename to src/ptychodus/controller/patterns/dataset.py index f4fe525a..0633c12a 100644 --- a/src/ptychodus/controller/patterns/treeModel.py +++ b/src/ptychodus/controller/patterns/dataset.py @@ -139,12 +139,13 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A node = index.internalPointer() if role == Qt.ItemDataRole.DisplayRole: - if index.column() == 0: - return node.label - elif index.column() == 1: - return node.numberOfFrames - elif index.column() == 2: - return f'{node.sizeInBytes / (1024 * 1024):.2f}' + match index.column(): + case 0: + return node.label + case 1: + return node.numberOfFrames + case 2: + return f'{node.sizeInBytes / (1024 * 1024):.2f}' elif role == Qt.ItemDataRole.FontRole: font = QFont() font.setItalic(node.state == DiffractionPatternState.FOUND) diff --git a/src/ptychodus/controller/patterns/detector.py b/src/ptychodus/controller/patterns/detector.py deleted file mode 100644 index d63b1971..00000000 --- a/src/ptychodus/controller/patterns/detector.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -from PyQt5.QtWidgets import QFormLayout - - -from ...model.patterns import Detector -from ...view.patterns import DetectorView -from ..parametric import LengthWidgetParameterViewController, SpinBoxParameterViewController - - -class DetectorController: - def __init__(self, detector: Detector, view: DetectorView) -> None: - self._widthInPixelsViewController = SpinBoxParameterViewController(detector.widthInPixels) - self._heightInPixelsViewController = SpinBoxParameterViewController(detector.heightInPixels) - self._pixelWidthViewController = LengthWidgetParameterViewController( - detector.pixelWidthInMeters - ) - self._pixelHeightViewController = LengthWidgetParameterViewController( - detector.pixelHeightInMeters - ) - self._bitDepthViewController = SpinBoxParameterViewController(detector.bitDepth) - - layout = QFormLayout() - layout.addRow('Detector Width [px]:', self._widthInPixelsViewController.getWidget()) - layout.addRow('Detector Height [px]:', self._heightInPixelsViewController.getWidget()) - layout.addRow('Pixel Width:', self._pixelWidthViewController.getWidget()) - layout.addRow('Pixel Height:', self._pixelHeightViewController.getWidget()) - layout.addRow('Bit Depth:', self._bitDepthViewController.getWidget()) - view.setLayout(layout) diff --git a/src/ptychodus/controller/patterns/info.py b/src/ptychodus/controller/patterns/info.py index 77aa5157..eb92e9d9 100644 --- a/src/ptychodus/controller/patterns/info.py +++ b/src/ptychodus/controller/patterns/info.py @@ -1,10 +1,95 @@ +from typing import Any, overload + from PyQt5.QtWidgets import QWidget +from PyQt5.QtCore import Qt, QAbstractItemModel, QModelIndex, QObject from ptychodus.api.observer import Observable, Observer +from ptychodus.api.tree import SimpleTreeNode from ...model.patterns import DiffractionDatasetPresenter from ...view.patterns import PatternsInfoDialog -from .tree import SimpleTreeModel + + +class SimpleTreeModel(QAbstractItemModel): + def __init__(self, rootNode: SimpleTreeNode, parent: QObject | None = None) -> None: + super().__init__(parent) + self._rootNode = rootNode + + def setRootNode(self, rootNode: SimpleTreeNode) -> None: + self.beginResetModel() + self._rootNode = rootNode + self.endResetModel() + + @overload + def parent(self, child: QModelIndex) -> QModelIndex: ... + + @overload + def parent(self) -> QObject: ... + + def parent(self, child: QModelIndex | None = None) -> QModelIndex | QObject: + if child is None: + return super().parent() + else: + value = QModelIndex() + + if child.isValid(): + childItem = child.internalPointer() + parentItem = childItem.parentItem + + if parentItem is self._rootNode: + value = QModelIndex() + else: + value = self.createIndex(parentItem.row(), 0, parentItem) + + return value + + def headerData( + self, + section: int, + orientation: Qt.Orientation, + role: int = Qt.ItemDataRole.DisplayRole, + ) -> Any: + if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole: + return self._rootNode.data(section) + + def flags(self, index: QModelIndex) -> Qt.ItemFlags: + return super().flags(index) + + def index(self, row: int, column: int, parent: QModelIndex = QModelIndex()) -> QModelIndex: + value = QModelIndex() + + if self.hasIndex(row, column, parent): + parentItem = parent.internalPointer() if parent.isValid() else self._rootNode + childItem = parentItem.childItems[row] + + if childItem: + value = self.createIndex(row, column, childItem) + + return value + + def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: + if index.isValid() and role == Qt.ItemDataRole.DisplayRole: + node = index.internalPointer() + return node.data(index.column()) + + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: + if parent.column() > 0: + return 0 + + node = self._rootNode + + if parent.isValid(): + node = parent.internalPointer() + + return len(node.childItems) + + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: + node = self._rootNode + + if parent.isValid(): + node = parent.internalPointer() + + return len(node.itemData) class PatternsInfoViewController(Observer): diff --git a/src/ptychodus/controller/patterns/tree.py b/src/ptychodus/controller/patterns/tree.py deleted file mode 100644 index b5e70838..00000000 --- a/src/ptychodus/controller/patterns/tree.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Any, overload - -from PyQt5.QtCore import Qt, QAbstractItemModel, QModelIndex, QObject - -from ptychodus.api.tree import SimpleTreeNode - - -class SimpleTreeModel(QAbstractItemModel): - def __init__(self, rootNode: SimpleTreeNode, parent: QObject | None = None) -> None: - super().__init__(parent) - self._rootNode = rootNode - - def setRootNode(self, rootNode: SimpleTreeNode) -> None: - self.beginResetModel() - self._rootNode = rootNode - self.endResetModel() - - @overload - def parent(self, child: QModelIndex) -> QModelIndex: ... - - @overload - def parent(self) -> QObject: ... - - def parent(self, child: QModelIndex | None = None) -> QModelIndex | QObject: - if child is None: - return super().parent() - else: - value = QModelIndex() - - if child.isValid(): - childItem = child.internalPointer() - parentItem = childItem.parentItem - - if parentItem is self._rootNode: - value = QModelIndex() - else: - value = self.createIndex(parentItem.row(), 0, parentItem) - - return value - - def headerData( - self, - section: int, - orientation: Qt.Orientation, - role: int = Qt.ItemDataRole.DisplayRole, - ) -> Any: - if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole: - return self._rootNode.data(section) - - def flags(self, index: QModelIndex) -> Qt.ItemFlags: - return super().flags(index) - - def index(self, row: int, column: int, parent: QModelIndex = QModelIndex()) -> QModelIndex: - value = QModelIndex() - - if self.hasIndex(row, column, parent): - parentItem = parent.internalPointer() if parent.isValid() else self._rootNode - childItem = parentItem.childItems[row] - - if childItem: - value = self.createIndex(row, column, childItem) - - return value - - def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: - if index.isValid() and role == Qt.ItemDataRole.DisplayRole: - node = index.internalPointer() - return node.data(index.column()) - - def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: - if parent.column() > 0: - return 0 - - node = self._rootNode - - if parent.isValid(): - node = parent.internalPointer() - - return len(node.childItems) - - def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: - node = self._rootNode - - if parent.isValid(): - node = parent.internalPointer() - - return len(node.itemData) diff --git a/src/ptychodus/controller/probe/core.py b/src/ptychodus/controller/probe/core.py index 509b7ef9..f23fe6bd 100644 --- a/src/ptychodus/controller/probe/core.py +++ b/src/ptychodus/controller/probe/core.py @@ -294,11 +294,16 @@ def _updateView(self, current: QModelIndex, previous: QModelIndex) -> None: else: probe = item.getProbe() array = ( - probe.getMode(current.row()) + probe.getIncoherentMode(current.row()) if current.parent().isValid() - else probe.getModesFlattened() + else probe.getIncoherentModesFlattened() ) - self._imageController.setArray(array, probe.getPixelGeometry()) + pixelGeometry = probe.getPixelGeometry() + + if pixelGeometry is None: + logger.warning('Missing probe pixel geometry!') + else: + self._imageController.setArray(array, pixelGeometry) def handleItemInserted(self, index: int, item: ProbeRepositoryItem) -> None: self._treeModel.insertItem(index, item) diff --git a/src/ptychodus/controller/probe/editorFactory.py b/src/ptychodus/controller/probe/editorFactory.py index 453e3256..db06f60a 100644 --- a/src/ptychodus/controller/probe/editorFactory.py +++ b/src/ptychodus/controller/probe/editorFactory.py @@ -173,24 +173,24 @@ def _appendAdditionalModes( dialogBuilder: ParameterViewBuilder, modesBuilder: MultimodalProbeBuilder, ) -> None: - additionalModesGroup = 'Additional Modes' + additionalModesGroup = 'Additional Modes' # FIXME OPR dialogBuilder.addSpinBox( - modesBuilder.numberOfModes, + modesBuilder.numberOfIncoherentModes, 'Number of Modes:', group=additionalModesGroup, ) dialogBuilder.addCheckBox( - modesBuilder.isOrthogonalizeModesEnabled, + modesBuilder.orthogonalizeIncoherentModes, 'Orthogonalize Modes:', group=additionalModesGroup, ) dialogBuilder.addViewController( - DecayTypeParameterViewController(modesBuilder.modeDecayType), + DecayTypeParameterViewController(modesBuilder.incoherentModeDecayType), 'Decay Type:', group=additionalModesGroup, ) dialogBuilder.addDecimalSlider( - modesBuilder.modeDecayRatio, + modesBuilder.incoherentModeDecayRatio, 'Decay Ratio:', group=additionalModesGroup, ) diff --git a/src/ptychodus/controller/probe/propagator.py b/src/ptychodus/controller/probe/propagator.py index 2fb95cec..3c04f7b8 100644 --- a/src/ptychodus/controller/probe/propagator.py +++ b/src/ptychodus/controller/probe/propagator.py @@ -72,9 +72,12 @@ def _updateCurrentCoordinate(self, step: int) -> None: logger.exception(err) ExceptionDialog.showException('Update Current Coordinate', err) else: - self._xyVisualizationWidgetController.setArray( - xyProjection, self._propagator.getPixelGeometry() - ) + pixelGeometry = self._propagator.getPixelGeometry() + + if pixelGeometry is None: + logger.warning('Missing propagator pixel geometry!') + else: + self._xyVisualizationWidgetController.setArray(xyProjection, pixelGeometry) # TODO auto-units lerpValue *= 1e6 @@ -142,13 +145,18 @@ def _syncModelToView(self) -> None: self._dialog.coordinateSlider.setValue(0) self._updateCurrentCoordinate(self._dialog.coordinateSlider.value()) + pixelGeometry = self._propagator.getPixelGeometry() + + if pixelGeometry is None: + logger.warning('Missing propagator pixel geometry!') + return try: self._zxVisualizationWidgetController.setArray( - self._propagator.getZXProjection(), self._propagator.getPixelGeometry() + self._propagator.getZXProjection(), pixelGeometry ) self._zyVisualizationWidgetController.setArray( - self._propagator.getZYProjection(), self._propagator.getPixelGeometry() + self._propagator.getZYProjection(), pixelGeometry ) except ValueError: self._zxVisualizationWidgetController.clearArray() diff --git a/src/ptychodus/controller/probe/treeModel.py b/src/ptychodus/controller/probe/treeModel.py index 3c2f38ec..cede4ec2 100644 --- a/src/ptychodus/controller/probe/treeModel.py +++ b/src/ptychodus/controller/probe/treeModel.py @@ -49,9 +49,9 @@ def __init__( @staticmethod def _appendModes(node: ProbeTreeNode, item: ProbeRepositoryItem) -> None: - object_ = item.getProbe() + probe = item.getProbe() - for layer in range(object_.numberOfModes): + for layer in range(probe.numberOfIncoherentModes): node.insertNode() def insertItem(self, index: int, item: ProbeRepositoryItem) -> None: @@ -66,7 +66,7 @@ def updateItem(self, index: int, item: ProbeRepositoryItem) -> None: node = self._treeRoot.children[index] numModesOld = len(node.children) - numModesNew = item.getProbe().numberOfModes + numModesNew = item.getProbe().numberOfIncoherentModes if numModesOld < numModesNew: self.beginInsertRows(topLeft, numModesOld, numModesNew) @@ -148,7 +148,7 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A probe = item.getProbe() try: - relativePower = probe.getModeRelativePower(index.row()) + relativePower = probe.getIncoherentModeRelativePower(index.row()) except IndexError: return -1 diff --git a/src/ptychodus/controller/product/core.py b/src/ptychodus/controller/product/core.py index fae48798..f85ba658 100644 --- a/src/ptychodus/controller/product/core.py +++ b/src/ptychodus/controller/product/core.py @@ -38,7 +38,7 @@ def __init__(self, repository: ProductRepository, parent: QObject | None = None) 'Name', 'Detector-Object\nDistance [m]', 'Probe Energy\n[keV]', - 'Probe Photon\nFlux [ph/s]', + 'Probe Photon\nCount', 'Exposure\nTime [s]', 'Pixel Width\n[nm]', 'Pixel Height\n[nm]', @@ -81,7 +81,7 @@ def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> A elif index.column() == 2: return f'{metadata.probeEnergyInElectronVolts.getValue() / 1e3:.4g}' elif index.column() == 3: - return f'{metadata.probePhotonsPerSecond.getValue():.4g}' + return f'{metadata.probePhotonCount.getValue():.4g}' elif index.column() == 4: return f'{metadata.exposureTimeInSeconds.getValue():.4g}' elif index.column() == 5: @@ -123,11 +123,11 @@ def setData(self, index: QModelIndex, value: Any, role: int = Qt.ItemDataRole.Ed return True elif index.column() == 3: try: - photonsPerSecond = float(value) + photonCount = float(value) except ValueError: return False - metadata.probePhotonsPerSecond.setValue(photonsPerSecond) + metadata.probePhotonCount.setValue(photonCount) return True elif index.column() == 4: try: diff --git a/src/ptychodus/controller/product/editor.py b/src/ptychodus/controller/product/editor.py index 2a8acb07..21e73064 100644 --- a/src/ptychodus/controller/product/editor.py +++ b/src/ptychodus/controller/product/editor.py @@ -22,6 +22,9 @@ def __init__(self, product: ProductRepositoryItem, parent: QObject | None = None self._header = ['Property', 'Value'] self._properties = [ 'Probe Wavelength [nm]', + 'Probe Wavenumber [1/nm]', + 'Probe Angular Wavenumber [rad/nm]', + 'Probe Photon Flux [ph/s]', 'Probe Power [W]', 'Object Plane Pixel Width [nm]', 'Object Plane Pixel Height [nm]', @@ -39,21 +42,32 @@ def headerData( def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any: if index.isValid() and role == Qt.ItemDataRole.DisplayRole: - if index.column() == 0: - return self._properties[index.row()] - elif index.column() == 1: - geometry = self._product.getGeometry() - - if index.row() == 0: - return f'{geometry.probeWavelengthInMeters * 1e9:.4g}' - elif index.row() == 1: - return f'{geometry.probePowerInWatts:.4g}' - elif index.row() == 2: - return f'{geometry.objectPlanePixelWidthInMeters * 1e9:.4g}' - elif index.row() == 3: - return f'{geometry.objectPlanePixelHeightInMeters * 1e9:.4g}' - elif index.row() == 4: - return f'{geometry.fresnelNumber:.4g}' + match index.column(): + case 0: + return self._properties[index.row()] + case 1: + geometry = self._product.getGeometry() + + match index.row(): + case 0: + return f'{geometry.probeWavelengthInMeters * 1e9:.4g}' + case 1: + return f'{geometry.probeWavelengthsPerMeter * 1e-9:.4g}' + case 2: + return f'{geometry.probeRadiansPerMeter * 1e-9:.4g}' + case 3: + return f'{geometry.probePhotonsPerSecond:.4g}' + case 4: + return f'{geometry.probePowerInWatts:.4g}' + case 5: + return f'{geometry.objectPlanePixelWidthInMeters * 1e9:.4g}' + case 6: + return f'{geometry.objectPlanePixelHeightInMeters * 1e9:.4g}' + case 7: + try: + return f'{geometry.fresnelNumber:.4g}' + except ZeroDivisionError: + return 'inf' def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: return len(self._properties) diff --git a/src/ptychodus/controller/ptychi/__init__.py b/src/ptychodus/controller/ptychi/__init__.py new file mode 100644 index 00000000..f85a0981 --- /dev/null +++ b/src/ptychodus/controller/ptychi/__init__.py @@ -0,0 +1,5 @@ +from .core import PtyChiViewControllerFactory + +__all__ = [ + 'PtyChiViewControllerFactory', +] diff --git a/src/ptychodus/controller/ptychi/core.py b/src/ptychodus/controller/ptychi/core.py new file mode 100644 index 00000000..41fe1f0a --- /dev/null +++ b/src/ptychodus/controller/ptychi/core.py @@ -0,0 +1,95 @@ +from PyQt5.QtWidgets import QVBoxLayout, QWidget + +from ...model.ptychi import ( + PtyChiAutodiffSettings, + PtyChiDMSettings, + PtyChiLSQMLSettings, + PtyChiPIESettings, + PtyChiReconstructorLibrary, +) + +from ..reconstructor import ReconstructorViewControllerFactory +from .object import PtyChiObjectViewController +from .opr import PtyChiOPRViewController +from .positions import PtyChiProbePositionsViewController +from .probe import PtyChiProbeViewController +from .reconstructor import PtyChiReconstructorViewController + +__all__ = ['PtyChiViewControllerFactory'] + + +class PtyChiViewController(QWidget): + def __init__( + self, + model: PtyChiReconstructorLibrary, + reconstructorName: str, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + autodiffSettings: PtyChiAutodiffSettings | None = None + dmSettings: PtyChiDMSettings | None = None + lsqmlSettings: PtyChiLSQMLSettings | None = None + pieSettings: PtyChiPIESettings | None = None + + match reconstructorName: + case 'Autodiff': + autodiffSettings = model.autodiffSettings + case 'DM': + dmSettings = model.dmSettings + case 'LSQML': + lsqmlSettings = model.lsqmlSettings + case 'PIE' | 'ePIE' | 'rPIE': + pieSettings = model.pieSettings + + # FIXME verify tooltips + self._reconstructorViewController = PtyChiReconstructorViewController( + model.reconstructorSettings, + autodiffSettings, + dmSettings, + lsqmlSettings, + model.enumerators, + model.deviceRepository, + ) + self._objectViewController = PtyChiObjectViewController( + model.objectSettings, + dmSettings, + lsqmlSettings, + pieSettings, + model.reconstructorSettings.numEpochs, + model.enumerators, + ) + self._probeViewController = PtyChiProbeViewController( + model.probeSettings, + lsqmlSettings, + pieSettings, + model.reconstructorSettings.numEpochs, + model.enumerators, + ) + self._probePositionsViewController = PtyChiProbePositionsViewController( + model.probePositionSettings, model.reconstructorSettings.numEpochs, model.enumerators + ) + self._oprViewController = PtyChiOPRViewController( + model.oprSettings, model.reconstructorSettings.numEpochs, model.enumerators + ) + + layout = QVBoxLayout() + layout.addWidget(self._reconstructorViewController.getWidget()) + layout.addWidget(self._objectViewController.getWidget()) + layout.addWidget(self._probeViewController.getWidget()) + layout.addWidget(self._probePositionsViewController.getWidget()) + layout.addWidget(self._oprViewController.getWidget()) + layout.addStretch() + self.setLayout(layout) + + +class PtyChiViewControllerFactory(ReconstructorViewControllerFactory): + def __init__(self, model: PtyChiReconstructorLibrary) -> None: + super().__init__() + self._model = model + + @property + def backendName(self) -> str: + return 'pty-chi' + + def createViewController(self, reconstructorName: str) -> QWidget: + return PtyChiViewController(self._model, reconstructorName) diff --git a/src/ptychodus/controller/ptychi/object.py b/src/ptychodus/controller/ptychi/object.py new file mode 100644 index 00000000..e0873d96 --- /dev/null +++ b/src/ptychodus/controller/ptychi/object.py @@ -0,0 +1,332 @@ +from PyQt5.QtWidgets import QFormLayout + +from ptychodus.api.parametric import ( + BooleanParameter, + IntegerParameter, + RealParameter, + StringParameter, +) + +from ...model.ptychi import ( + PtyChiDMSettings, + PtyChiEnumerators, + PtyChiLSQMLSettings, + PtyChiObjectSettings, + PtyChiPIESettings, +) +from ..parametric import ( + CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, + LengthWidgetParameterViewController, + SpinBoxParameterViewController, +) +from .optimizer import PtyChiOptimizationPlanViewController, PtyChiOptimizerParameterViewController + +__all__ = ['PtyChiObjectViewController'] + + +class PtyChiConstrainL1NormViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrainL1Norm: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + weight: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrainL1Norm, + 'Constrain L\u2081 Norm', + tool_tip='Whether to constrain the L\u2081 norm.', + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._weightViewController = DecimalLineEditParameterViewController( + weight, + tool_tip='Weight of the L\u2081 norm constraint.', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + layout.addRow('Weight:', self._weightViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiConstrainSmoothnessViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrainSmoothness: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + alpha: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrainSmoothness, + 'Constrain Smoothness', + tool_tip='Whether to constrain smoothness in the magnitude (but not phase) of the object', + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._alphaViewController = DecimalSliderParameterViewController( + alpha, tool_tip='Relaxation smoothing constant.' + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + layout.addRow('Alpha:', self._alphaViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiConstrainTotalVariationViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrainTotalVariation: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + weight: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrainTotalVariation, + 'Constrain Total Variation', + tool_tip='Whether to constrain the total variation.', + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._weightViewController = DecimalLineEditParameterViewController( + weight, + tool_tip='Weight of the total variation constraint.', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + layout.addRow('Weight:', self._weightViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiRemoveGridArtifactsViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + removeGridArtifacts: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + periodXInMeters: RealParameter, + periodYInMeters: RealParameter, + windowSizeInPixels: IntegerParameter, + direction: StringParameter, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + removeGridArtifacts, + 'Remove Grid Artifacts', + tool_tip="Whether to remove grid artifacts in the object's phase at the end of an epoch.", + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._periodXViewController = LengthWidgetParameterViewController( + periodXInMeters, tool_tip='Horizontal period of grid artifacts in meters.' + ) + self._periodYViewController = LengthWidgetParameterViewController( + periodYInMeters, tool_tip='Vertical period of grid artifacts in meters.' + ) + self._windowSizeViewController = SpinBoxParameterViewController( + windowSizeInPixels, tool_tip='Window size for grid artifact removal in pixels.' + ) + self._directionViewController = ComboBoxParameterViewController( + direction, enumerators.directions(), tool_tip='Direction of grid artifact removal.' + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + layout.addRow('Period X:', self._periodXViewController.getWidget()) + layout.addRow('Period Y:', self._periodYViewController.getWidget()) + layout.addRow('Window Size [px]:', self._windowSizeViewController.getWidget()) + layout.addRow('Direction:', self._directionViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiRegularizeMultisliceViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + regularizeMultislice: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + weight: RealParameter, + unwrapPhase: BooleanParameter, + gradientMethod: StringParameter, + integrationMethod: StringParameter, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + regularizeMultislice, + 'Regularize Multislice', + tool_tip='Whether to regularize multislice objects using cross-slice smoothing.', + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._weightViewController = DecimalLineEditParameterViewController( + weight, + tool_tip='Weight for multislice regularization.', + ) + self._unwrapPhaseViewController = CheckBoxParameterViewController( + unwrapPhase, + 'Unwrap Phase', + tool_tip='Whether to unwrap the phase of the object during multislice regularization.', + ) + self._gradientMethodViewController = ComboBoxParameterViewController( + gradientMethod, + enumerators.imageGradientMethods(), + tool_tip='Method for calculating the phase gradient during phase unwrapping.', + ) + self._integrationMethodViewController = ComboBoxParameterViewController( + integrationMethod, + enumerators.imageIntegrationMethods(), + tool_tip='Method for integrating the phase gradient during phase unwrapping.', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + layout.addRow('Weight:', self._weightViewController.getWidget()) + layout.addRow(self._unwrapPhaseViewController.getWidget()) + layout.addRow('Gradient Method:', self._gradientMethodViewController.getWidget()) + layout.addRow('Integration Method:', self._integrationMethodViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiObjectViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + settings: PtyChiObjectSettings, + dmSettings: PtyChiDMSettings | None, + lsqmlSettings: PtyChiLSQMLSettings | None, + pieSettings: PtyChiPIESettings | None, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + settings.isOptimizable, 'Optimize Object', tool_tip='Whether the object is optimizable.' + ) + self._optimizationPlanViewController = PtyChiOptimizationPlanViewController( + settings.optimizationPlanStart, + settings.optimizationPlanStop, + settings.optimizationPlanStride, + num_epochs, + ) + self._optimizerViewController = PtyChiOptimizerParameterViewController( + settings.optimizer, enumerators + ) + self._stepSizeViewController = DecimalLineEditParameterViewController( + settings.stepSize, tool_tip='Optimizer step size' + ) + self._patchInterpolatorViewController = ComboBoxParameterViewController( + settings.patchInterpolator, + enumerators.patchInterpolationMethods(), + tool_tip='Interpolation method used for extracting and updating patches of the object.', + ) + self._constrainL1NormViewController = PtyChiConstrainL1NormViewController( + settings.constrainL1Norm, + settings.constrainL1NormStart, + settings.constrainL1NormStop, + settings.constrainL1NormStride, + settings.constrainL1NormWeight, + num_epochs, + ) + self._constrainSmoothnessViewController = PtyChiConstrainSmoothnessViewController( + settings.constrainSmoothness, + settings.constrainSmoothnessStart, + settings.constrainSmoothnessStop, + settings.constrainSmoothnessStride, + settings.constrainSmoothnessAlpha, + num_epochs, + ) + self._constrainTotalVariationViewController = PtyChiConstrainTotalVariationViewController( + settings.constrainTotalVariation, + settings.constrainTotalVariationStart, + settings.constrainTotalVariationStop, + settings.constrainTotalVariationStride, + settings.constrainTotalVariationWeight, + num_epochs, + ) + self._removeGridArtifactsViewController = PtyChiRemoveGridArtifactsViewController( + settings.removeGridArtifacts, + settings.removeGridArtifactsStart, + settings.removeGridArtifactsStop, + settings.removeGridArtifactsStride, + settings.removeGridArtifactsPeriodXInMeters, + settings.removeGridArtifactsPeriodYInMeters, + settings.removeGridArtifactsWindowSizeInPixels, + settings.removeGridArtifactsDirection, + num_epochs, + enumerators, + ) + self._regularizeMultisliceViewController = PtyChiRegularizeMultisliceViewController( + settings.regularizeMultislice, + settings.regularizeMultisliceStart, + settings.regularizeMultisliceStop, + settings.regularizeMultisliceStride, + settings.regularizeMultisliceWeight, + settings.regularizeMultisliceUnwrapPhase, + settings.regularizeMultisliceUnwrapPhaseImageGradientMethod, + settings.regularizeMultisliceUnwrapPhaseImageIntegrationMethod, + num_epochs, + enumerators, + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._optimizationPlanViewController.getWidget()) + layout.addRow('Optimizer:', self._optimizerViewController.getWidget()) + layout.addRow('Step Size:', self._stepSizeViewController.getWidget()) + layout.addRow('Patch Interpolator:', self._patchInterpolatorViewController.getWidget()) + layout.addRow(self._constrainL1NormViewController.getWidget()) + layout.addRow(self._constrainSmoothnessViewController.getWidget()) + layout.addRow(self._constrainTotalVariationViewController.getWidget()) + layout.addRow(self._removeGridArtifactsViewController.getWidget()) + layout.addRow(self._regularizeMultisliceViewController.getWidget()) + + if dmSettings is not None: + self._amplitudeClampLimitViewController = DecimalLineEditParameterViewController( + dmSettings.objectAmplitudeClampLimit, + tool_tip='Maximum amplitude value for the object.', + ) + layout.addRow( + 'Amplitude Clamp Limit:', self._amplitudeClampLimitViewController.getWidget() + ) + + if lsqmlSettings is not None: + self._objectOptimalStepSizeScalerViewController = ( + DecimalLineEditParameterViewController(lsqmlSettings.objectOptimalStepSizeScaler) + ) + layout.addRow( + 'Optimal Step Size Scaler:', + self._objectOptimalStepSizeScalerViewController.getWidget(), + ) + + self._objectMultimodalUpdateViewController = CheckBoxParameterViewController( + lsqmlSettings.objectMultimodalUpdate, + 'Multimodal Update', + ) + layout.addRow(self._objectMultimodalUpdateViewController.getWidget()) + + if pieSettings is not None: + self._alphaViewController = DecimalSliderParameterViewController( + pieSettings.objectAlpha + ) + layout.addRow('Alpha:', self._alphaViewController.getWidget()) + + self.getWidget().setLayout(layout) diff --git a/src/ptychodus/controller/ptychi/opr.py b/src/ptychodus/controller/ptychi/opr.py new file mode 100644 index 00000000..4d53cee3 --- /dev/null +++ b/src/ptychodus/controller/ptychi/opr.py @@ -0,0 +1,112 @@ +from PyQt5.QtWidgets import QFormLayout + +from ptychodus.api.parametric import BooleanParameter, IntegerParameter, StringParameter + +from ...model.ptychi import PtyChiEnumerators, PtyChiOPRSettings +from ..parametric import ( + CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, + SpinBoxParameterViewController, +) +from .optimizer import PtyChiOptimizationPlanViewController, PtyChiOptimizerParameterViewController + + +class PtyChiSmoothOPRModeWeightsViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + smoothModeWeights: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + smoothingMethod: StringParameter, + polynomialSmoothingDegree: IntegerParameter, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + smoothModeWeights, + 'Smooth OPR Mode Weights', + tool_tip='Smooth the OPR mode weights.', + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._smoothingMethodViewController = ComboBoxParameterViewController( + smoothingMethod, + enumerators.oprWeightSmoothingMethods(), + tool_tip='The method for smoothing OPR mode weights.', + ) + self._polynomialSmoothingDegreeViewController = SpinBoxParameterViewController( + polynomialSmoothingDegree, + tool_tip='The degree of the polynomial used for smoothing OPR mode weights.', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + layout.addRow('Smoothing Method:', self._smoothingMethodViewController.getWidget()) + layout.addRow( + 'Polynomial Degree:', self._polynomialSmoothingDegreeViewController.getWidget() + ) + self.getWidget().setLayout(layout) + + +class PtyChiOPRViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + settings: PtyChiOPRSettings, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + settings.isOptimizable, + 'Orthogonal Probe Relaxation', + tool_tip='Whether OPR modes are optimizable.', + ) + self._optimizationPlanViewController = PtyChiOptimizationPlanViewController( + settings.optimizationPlanStart, + settings.optimizationPlanStop, + settings.optimizationPlanStride, + num_epochs, + ) + self._optimizerViewController = PtyChiOptimizerParameterViewController( + settings.optimizer, enumerators + ) + self._stepSizeViewController = DecimalLineEditParameterViewController( + settings.stepSize, tool_tip='Optimizer step size' + ) + self._optimizeIntensitiesViewController = CheckBoxParameterViewController( + settings.optimizeIntensities, + 'Optimize Intensities', + tool_tip='Whether to optimize intensity variation (i.e., the weight of the first OPR mode).', + ) + self._optimizeEigenmodeWeightsViewController = CheckBoxParameterViewController( + settings.optimizeEigenmodeWeights, + 'Optimize Eigenmode Weights', + tool_tip='Whether to optimize eigenmode weights (i.e., the weights of the second and following OPR modes).', + ) + self._smoothModeWeightsViewController = PtyChiSmoothOPRModeWeightsViewController( + settings.smoothModeWeights, + settings.smoothModeWeightsStart, + settings.smoothModeWeightsStop, + settings.smoothModeWeightsStride, + settings.smoothingMethod, + settings.polynomialSmoothingDegree, + num_epochs, + enumerators, + ) + self._relaxUpdateViewController = DecimalSliderParameterViewController( + settings.relaxUpdate, + tool_tip='Whether to relax the update of the OPR mode weights.', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._optimizationPlanViewController.getWidget()) + layout.addRow('Optimizer:', self._optimizerViewController.getWidget()) + layout.addRow('Step Size:', self._stepSizeViewController.getWidget()) + layout.addRow(self._optimizeIntensitiesViewController.getWidget()) + layout.addRow(self._optimizeEigenmodeWeightsViewController.getWidget()) + layout.addRow(self._smoothModeWeightsViewController.getWidget()) + self.getWidget().setLayout(layout) diff --git a/src/ptychodus/controller/ptychi/optimizer.py b/src/ptychodus/controller/ptychi/optimizer.py new file mode 100644 index 00000000..86365e81 --- /dev/null +++ b/src/ptychodus/controller/ptychi/optimizer.py @@ -0,0 +1,90 @@ +from PyQt5.QtWidgets import QHBoxLayout, QSpinBox, QWidget + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import IntegerParameter, StringParameter + +from ...model.ptychi import PtyChiEnumerators +from ..parametric import ( + ComboBoxParameterViewController, + SpinBoxParameterViewController, + ParameterViewController, +) + +__all__ = [ + 'PtyChiOptimizationPlanViewController', + 'PtyChiOptimizerParameterViewController', +] + + +class PtyChiStopSpinBoxParameterViewController(ParameterViewController, Observer): + def __init__( + self, stop: IntegerParameter, num_epochs: IntegerParameter, *, tool_tip: str = '' + ) -> None: + super().__init__() + self._stop = stop + self._num_epochs = num_epochs + self._widget = QSpinBox() + + if tool_tip: + self._widget.setToolTip(tool_tip) + + self._syncModelToView() + self._widget.valueChanged.connect(self._syncViewToModel) + stop.addObserver(self) + num_epochs.addObserver(self) + + def getWidget(self) -> QWidget: + return self._widget + + def _syncViewToModel(self, value: int) -> None: + num_epochs = self._num_epochs.getValue() + self._stop.setValue(value if value < num_epochs else -1) + + def _syncModelToView(self) -> None: + num_epochs = self._num_epochs.getValue() + stop = self._stop.getValue() + + self._widget.blockSignals(True) + self._widget.setRange(0, num_epochs) + self._widget.setValue(num_epochs if stop < 0 else stop) + self._widget.blockSignals(False) + + def update(self, observable: Observable) -> None: + if observable in (self._stop, self._num_epochs): + self._syncModelToView() + + +class PtyChiOptimizationPlanViewController(ParameterViewController): + def __init__( + self, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__() + self._startViewController = SpinBoxParameterViewController( + start, tool_tip='Iteration to start optimizing' + ) + self._stopViewController = PtyChiStopSpinBoxParameterViewController( + stop, num_epochs, tool_tip='Iteration to stop optimizing' + ) + self._strideViewController = SpinBoxParameterViewController( + stride, tool_tip='Number of iterations between updates' + ) + self._widget = QWidget() + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._startViewController.getWidget()) + layout.addWidget(self._stopViewController.getWidget()) + layout.addWidget(self._strideViewController.getWidget()) + self._widget.setLayout(layout) + + def getWidget(self) -> QWidget: + return self._widget + + +class PtyChiOptimizerParameterViewController(ComboBoxParameterViewController): + def __init__(self, parameter: StringParameter, enumerators: PtyChiEnumerators) -> None: + super().__init__(parameter, enumerators.optimizers(), tool_tip='Name of the optimizer.') diff --git a/src/ptychodus/controller/ptychi/positions.py b/src/ptychodus/controller/ptychi/positions.py new file mode 100644 index 00000000..ba7bec7d --- /dev/null +++ b/src/ptychodus/controller/ptychi/positions.py @@ -0,0 +1,154 @@ +from PyQt5.QtWidgets import QFormLayout, QFrame, QWidget + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import ( + BooleanParameter, + IntegerParameter, + RealParameter, + StringParameter, +) + +from ...model.ptychi import PtyChiEnumerators, PtyChiProbePositionSettings +from ..parametric import ( + CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, + ParameterViewController, + SpinBoxParameterViewController, +) +from .optimizer import PtyChiOptimizationPlanViewController, PtyChiOptimizerParameterViewController + +__all__ = ['PtyChiProbePositionsViewController'] + + +class PtyChiCrossCorrelationViewController(ParameterViewController, Observer): + def __init__( + self, + algorithm: StringParameter, + scale: IntegerParameter, + realSpaceWidth: RealParameter, + probeThreshold: RealParameter, + ) -> None: + super().__init__() + self._algorithm = algorithm + self._scaleViewController = SpinBoxParameterViewController( + scale, tool_tip='Upsampling factor of the cross-correlation in real space.' + ) + self._realSpaceWidthViewController = DecimalLineEditParameterViewController( + realSpaceWidth, tool_tip='Width of the cross-correlation in real-space' + ) + self._probeThresholdViewController = DecimalSliderParameterViewController( + probeThreshold, tool_tip='Probe intensity threshold used to calculate the probe mask.' + ) + self._widget = QFrame() + self._widget.setFrameShape(QFrame.StyledPanel) + + layout = QFormLayout() + layout.addRow('Scale:', self._scaleViewController.getWidget()) + layout.addRow('Real Space Width:', self._realSpaceWidthViewController.getWidget()) + layout.addRow('Probe Threshold:', self._probeThresholdViewController.getWidget()) + self._widget.setLayout(layout) + + algorithm.addObserver(self) + self._syncModelToView() + + def getWidget(self) -> QWidget: + return self._widget + + def _syncModelToView(self) -> None: + self._widget.setVisible(self._algorithm.getValue().upper() == 'CROSS_CORRELATION') + + def update(self, observable: Observable) -> None: + if observable is self._algorithm: + self._syncModelToView() + + +class PtyChiUpdateMagnitudeLimitViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + limitMagnitudeUpdate: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + magnitudeUpdateLimit: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + limitMagnitudeUpdate, + 'Limit Update Magnitude', + tool_tip='Limit the magnitude of the probe update.', + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._viewController = DecimalLineEditParameterViewController( + magnitudeUpdateLimit, + tool_tip='Magnitude limit of the probe update.', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + layout.addRow('Limit:', self._viewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiProbePositionsViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + settings: PtyChiProbePositionSettings, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + settings.isOptimizable, + 'Optimize Probe Positions', + tool_tip='Whether the probe positions are optimizable.', + ) + self._optimizationPlanViewController = PtyChiOptimizationPlanViewController( + settings.optimizationPlanStart, + settings.optimizationPlanStop, + settings.optimizationPlanStride, + num_epochs, + ) + self._optimizerViewController = PtyChiOptimizerParameterViewController( + settings.optimizer, enumerators + ) + self._stepSizeViewController = DecimalLineEditParameterViewController( + settings.stepSize, tool_tip='Optimizer step size' + ) + self._algorithmViewController = ComboBoxParameterViewController( + settings.positionCorrectionType, + enumerators.positionCorrectionTypes(), + tool_tip='Algorithm used to calculate the position correction update.', + ) + self._crossCorrelationViewController = PtyChiCrossCorrelationViewController( + settings.positionCorrectionType, + settings.crossCorrelationScale, + settings.crossCorrelationRealSpaceWidth, + settings.crossCorrelationProbeThreshold, + ) + self._magnitudeUpdateLimitViewController = PtyChiUpdateMagnitudeLimitViewController( + settings.limitMagnitudeUpdate, + settings.limitMagnitudeUpdateStart, + settings.limitMagnitudeUpdateStop, + settings.limitMagnitudeUpdateStride, + settings.magnitudeUpdateLimit, + num_epochs, + ) + self._constrainCentroidViewController = CheckBoxParameterViewController( + settings.constrainCentroid, + 'Constrain Centroid', + tool_tip='Whether to subtract the mean from positions after updating positions.', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._optimizationPlanViewController.getWidget()) + layout.addRow('Optimizer:', self._optimizerViewController.getWidget()) + layout.addRow('Step Size:', self._stepSizeViewController.getWidget()) + layout.addRow('Algorithm:', self._algorithmViewController.getWidget()) + layout.addRow(self._crossCorrelationViewController.getWidget()) + layout.addRow(self._magnitudeUpdateLimitViewController.getWidget()) + layout.addRow(self._constrainCentroidViewController.getWidget()) + self.getWidget().setLayout(layout) diff --git a/src/ptychodus/controller/ptychi/probe.py b/src/ptychodus/controller/ptychi/probe.py new file mode 100644 index 00000000..e9383ad5 --- /dev/null +++ b/src/ptychodus/controller/ptychi/probe.py @@ -0,0 +1,248 @@ +from PyQt5.QtWidgets import QFormLayout + +from ptychodus.api.parametric import ( + BooleanParameter, + IntegerParameter, + RealParameter, + StringParameter, +) + +from ...model.ptychi import ( + PtyChiEnumerators, + PtyChiLSQMLSettings, + PtyChiPIESettings, + PtyChiProbeSettings, +) +from ..parametric import ( + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, +) +from .optimizer import PtyChiOptimizationPlanViewController, PtyChiOptimizerParameterViewController + +__all__ = ['PtyChiProbeViewController'] + + +class PtyChiConstrainProbePowerViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrainPower: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrainPower, 'Constrain Power', tool_tip='Whether to constrain probe power.' + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiOrthogonalizeIncoherentModesViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + orthogonalizeModes: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + method: StringParameter, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + orthogonalizeModes, + 'Orthogonalize Incoherent Modes', + tool_tip='Whether to orthogonalize incoherent probe modes.', + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._methodViewController = ComboBoxParameterViewController( + method, + enumerators.orthogonalizationMethods(), + tool_tip='Method to use for incoherent mode orthogonalization.', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + layout.addRow('Method:', self._methodViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiOrthogonalizeOPRModesViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + orthogonalizeModes: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + orthogonalizeModes, + 'Orthogonalize OPR Modes', + tool_tip='Whether to orthogonalize OPR modes.', + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiConstrainSupportViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrainSupport: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + threshold: RealParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrainSupport, + 'Constrain Support', + tool_tip='When enabled, the probe will be shrinkwrapped so that small values are set to zero.', + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + self._thresholdViewController = DecimalLineEditParameterViewController( + threshold, tool_tip='Threshold for the probe support constraint.' + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + layout.addRow('Threshold:', self._thresholdViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiConstrainCenterViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + constrainCenter: BooleanParameter, + start: IntegerParameter, + stop: IntegerParameter, + stride: IntegerParameter, + num_epochs: IntegerParameter, + ) -> None: + super().__init__( + constrainCenter, + 'Constrain Center', + tool_tip='When enabled, the probe center of mass will be constrained to the center of the probe array.', + ) + self._planViewController = PtyChiOptimizationPlanViewController( + start, stop, stride, num_epochs + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._planViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiProbeViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + settings: PtyChiProbeSettings, + lsqmlSettings: PtyChiLSQMLSettings | None, + pieSettings: PtyChiPIESettings | None, + num_epochs: IntegerParameter, + enumerators: PtyChiEnumerators, + ) -> None: + super().__init__( + settings.isOptimizable, 'Optimize Probe', tool_tip='Whether the probe is optimizable.' + ) + self._optimizationPlanViewController = PtyChiOptimizationPlanViewController( + settings.optimizationPlanStart, + settings.optimizationPlanStop, + settings.optimizationPlanStride, + num_epochs, + ) + self._optimizerViewController = PtyChiOptimizerParameterViewController( + settings.optimizer, enumerators + ) + self._stepSizeViewController = DecimalLineEditParameterViewController( + settings.stepSize, tool_tip='Optimizer step size' + ) + self._constrainPowerViewController = PtyChiConstrainProbePowerViewController( + settings.constrainProbePower, + settings.constrainProbePowerStart, + settings.constrainProbePowerStop, + settings.constrainProbePowerStride, + num_epochs, + ) + self._orthogonalizeIncoherentModesViewController = ( + PtyChiOrthogonalizeIncoherentModesViewController( + settings.orthogonalizeIncoherentModes, + settings.orthogonalizeIncoherentModesStart, + settings.orthogonalizeIncoherentModesStop, + settings.orthogonalizeIncoherentModesStride, + settings.orthogonalizeIncoherentModesMethod, + num_epochs, + enumerators, + ) + ) + self._orthogonalizeOPRModesViewController = PtyChiOrthogonalizeOPRModesViewController( + settings.orthogonalizeOPRModes, + settings.orthogonalizeOPRModesStart, + settings.orthogonalizeOPRModesStop, + settings.orthogonalizeOPRModesStride, + num_epochs, + ) + self._constrainSupportViewController = PtyChiConstrainSupportViewController( + settings.constrainSupport, + settings.constrainSupportStart, + settings.constrainSupportStop, + settings.constrainSupportStride, + settings.constrainSupportThreshold, + num_epochs, + ) + self._constrainCenterViewController = PtyChiConstrainCenterViewController( + settings.constrainCenter, + settings.constrainCenterStart, + settings.constrainCenterStop, + settings.constrainCenterStride, + num_epochs, + ) + self._relaxEigenmodeUpdateViewController = DecimalSliderParameterViewController( + settings.relaxEigenmodeUpdate, + tool_tip='Relaxation factor for the eigenmode update.', + ) + + layout = QFormLayout() + layout.addRow('Plan:', self._optimizationPlanViewController.getWidget()) + layout.addRow('Optimizer:', self._optimizerViewController.getWidget()) + layout.addRow('Step Size:', self._stepSizeViewController.getWidget()) + layout.addRow(self._constrainPowerViewController.getWidget()) + layout.addRow(self._orthogonalizeIncoherentModesViewController.getWidget()) + layout.addRow(self._orthogonalizeOPRModesViewController.getWidget()) + layout.addRow(self._constrainSupportViewController.getWidget()) + layout.addRow(self._constrainCenterViewController.getWidget()) + layout.addRow( + 'Relax Eigenmode Update:', self._relaxEigenmodeUpdateViewController.getWidget() + ) + + if lsqmlSettings is not None: + self._probeOptimalStepSizeScalerViewController = DecimalLineEditParameterViewController( + lsqmlSettings.probeOptimalStepSizeScaler + ) + layout.addRow( + 'Optimal Step Size Scaler:', + self._probeOptimalStepSizeScalerViewController.getWidget(), + ) + + if pieSettings is not None: + self._probeAlpha = DecimalSliderParameterViewController(pieSettings.probeAlpha) + layout.addRow('Alpha:', self._probeAlpha.getWidget()) + + self.getWidget().setLayout(layout) diff --git a/src/ptychodus/controller/ptychi/reconstructor.py b/src/ptychodus/controller/ptychi/reconstructor.py new file mode 100644 index 00000000..4edaea96 --- /dev/null +++ b/src/ptychodus/controller/ptychi/reconstructor.py @@ -0,0 +1,249 @@ +from PyQt5.QtWidgets import ( + QButtonGroup, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QRadioButton, + QVBoxLayout, + QWidget, +) + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.parametric import BooleanParameter, RealParameter + +from ...model.ptychi import ( + PtyChiAutodiffSettings, + PtyChiDMSettings, + PtyChiDeviceRepository, + PtyChiEnumerators, + PtyChiLSQMLSettings, + PtyChiReconstructorSettings, +) +from ..parametric import ( + CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, + ComboBoxParameterViewController, + DecimalLineEditParameterViewController, + DecimalSliderParameterViewController, + ParameterViewController, + SpinBoxParameterViewController, +) + +__all__ = ['PtyChiReconstructorViewController'] + + +class PtyChiDeviceViewController(CheckableGroupBoxParameterViewController): + def __init__( + self, + useDevices: BooleanParameter, + repository: PtyChiDeviceRepository, + *, + tool_tip: str = '', + ) -> None: + super().__init__(useDevices, 'Use Devices', tool_tip=tool_tip) + layout = QVBoxLayout() + + for device in repository: + deviceLabel = QLabel(device) + layout.addWidget(deviceLabel) + + self.getWidget().setLayout(layout) + + +class PtyChiPrecisionParameterViewController(ParameterViewController, Observer): + def __init__(self, useDoublePrecision: BooleanParameter, *, tool_tip: str = '') -> None: + super().__init__() + self._useDoublePrecision = useDoublePrecision + self._singlePrecisionButton = QRadioButton('Single') + self._doublePrecisionButton = QRadioButton('Double') + self._buttonGroup = QButtonGroup() + self._widget = QWidget() + + self._singlePrecisionButton.setToolTip('Compute using single precision.') + self._doublePrecisionButton.setToolTip('Compute using double precision.') + + if tool_tip: + self._widget.setToolTip(tool_tip) + + self._buttonGroup.addButton(self._singlePrecisionButton, 1) + self._buttonGroup.addButton(self._doublePrecisionButton, 2) + self._buttonGroup.setExclusive(True) + self._buttonGroup.idToggled.connect(self._syncViewToModel) + + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._singlePrecisionButton) + layout.addWidget(self._doublePrecisionButton) + layout.addStretch() + self._widget.setLayout(layout) + + self._syncModelToView() + useDoublePrecision.addObserver(self) + + def getWidget(self) -> QWidget: + return self._widget + + def _syncViewToModel(self, toolId: int, checked: bool) -> None: + if toolId == 2: + self._useDoublePrecision.setValue(checked) + + def _syncModelToView(self) -> None: + button = self._buttonGroup.button(2 if self._useDoublePrecision.getValue() else 1) + button.setChecked(True) + + def update(self, observable: Observable) -> None: + if observable is self._useDoublePrecision: + self._syncModelToView() + + +class PtyChiMomentumAccelerationGradientMixingFactorViewController( + CheckableGroupBoxParameterViewController +): + def __init__( + self, + useGradientMixingFactor: BooleanParameter, + gradientMixingFactor: RealParameter, + ) -> None: + super().__init__( + useGradientMixingFactor, + 'Use Gradient Mixing Factor', + tool_tip='Controls how the current gradient is mixed with the accumulated velocity in LSQML momentum acceleration.', + ) + self._gradientMixingFactorViewController = DecimalLineEditParameterViewController( + gradientMixingFactor + ) + + layout = QVBoxLayout() + layout.addWidget(self._gradientMixingFactorViewController.getWidget()) + self.getWidget().setLayout(layout) + + +class PtyChiReconstructorViewController(ParameterViewController): + def __init__( + self, + settings: PtyChiReconstructorSettings, + autodiffSettings: PtyChiAutodiffSettings | None, + dmSettings: PtyChiDMSettings | None, + lsqmlSettings: PtyChiLSQMLSettings | None, + enumerators: PtyChiEnumerators, + repository: PtyChiDeviceRepository, + ) -> None: + super().__init__() + self._numEpochsViewController = SpinBoxParameterViewController( + settings.numEpochs, tool_tip='Number of epochs to run.' + ) + self._batchSizeViewController = SpinBoxParameterViewController( + settings.batchSize, tool_tip='Number of data to process in each minibatch.' + ) + self._batchingModeViewController = ComboBoxParameterViewController( + settings.batchingMode, enumerators.batchingModes(), tool_tip='Batching mode to use.' + ) + self._batchStride = SpinBoxParameterViewController( + settings.batchStride, tool_tip='Number of epochs between updating clusters.' + ) + self._precisionViewController = PtyChiPrecisionParameterViewController( + settings.useDoublePrecision, + tool_tip='Floating point precision to use for computation.', + ) + self._deviceViewController = PtyChiDeviceViewController( + settings.useDevices, repository, tool_tip='Default device to use for computation.' + ) + self._useLowMemoryViewController = CheckBoxParameterViewController( + settings.useLowMemoryForwardModel, + 'Use Low Memory Forward Model', + tool_tip='When checked, forward propagation of ptychography will be done using less vectorized code. This reduces the speed, but also lowers memory usage.', + ) + self._saveDataOnDeviceViewController = CheckBoxParameterViewController( + settings.saveDataOnDevice, + 'Save Data on Device', + tool_tip='When checked, diffraction data will be saved on the device.', + ) + self._widget = QGroupBox('Reconstructor') + + layout = QFormLayout() + layout.addRow('Number of Epochs:', self._numEpochsViewController.getWidget()) + layout.addRow('Batch Size:', self._batchSizeViewController.getWidget()) + layout.addRow('Batch Mode:', self._batchingModeViewController.getWidget()) + layout.addRow('Batch Stride:', self._batchStride.getWidget()) + + if repository: + layout.addRow(self._deviceViewController.getWidget()) + + layout.addRow('Precision:', self._precisionViewController.getWidget()) + layout.addRow(self._useLowMemoryViewController.getWidget()) + + if autodiffSettings is not None: + self._lossFunctionViewController = ComboBoxParameterViewController( + autodiffSettings.lossFunction, enumerators.lossFunctions() + ) + layout.addRow('Loss Function:', self._lossFunctionViewController.getWidget()) + + self._forwardModelClassViewController = ComboBoxParameterViewController( + autodiffSettings.forwardModelClass, enumerators.forwardModels() + ) + layout.addRow('Forward Model:', self._forwardModelClassViewController.getWidget()) + + if dmSettings is not None: + self._exitWaveUpdateRelaxationViewController = DecimalSliderParameterViewController( + dmSettings.exitWaveUpdateRelaxation + ) + layout.addRow( + 'Exit Wave Update Relaxation:', + self._exitWaveUpdateRelaxationViewController.getWidget(), + ) + + self._chunkLengthViewController = SpinBoxParameterViewController(dmSettings.chunkLength) + layout.addRow('Chunk Length:', self._chunkLengthViewController.getWidget()) + + if lsqmlSettings is not None: + self._noiseModelViewController = ComboBoxParameterViewController( + lsqmlSettings.noiseModel, enumerators.noiseModels() + ) + layout.addRow('Noise Model:', self._noiseModelViewController.getWidget()) + + self._gaussianNoiseDeviationViewController = DecimalLineEditParameterViewController( + lsqmlSettings.gaussianNoiseDeviation + ) + layout.addRow( + 'Gaussian Noise Deviation:', self._gaussianNoiseDeviationViewController.getWidget() + ) + + self._solveObjectProbeStepSizeJointlyForFirstSliceInMultisliceViewController = ( + CheckBoxParameterViewController( + lsqmlSettings.solveObjectProbeStepSizeJointlyForFirstSliceInMultislice, + 'SolveObjectProbeStepSizeJointlyForFirstSliceInMultislice', + ) + ) + layout.addRow( + self._solveObjectProbeStepSizeJointlyForFirstSliceInMultisliceViewController.getWidget() + ) + + self._solveStepSizesOnlyUsingFirstProbeModeViewController = ( + CheckBoxParameterViewController( + lsqmlSettings.solveStepSizesOnlyUsingFirstProbeMode, + 'SolveStepSizesOnlyUsingFirstProbeMode', + ) + ) + layout.addRow(self._solveStepSizesOnlyUsingFirstProbeModeViewController.getWidget()) + + self._momentumAccelerationGainViewController = DecimalLineEditParameterViewController( + lsqmlSettings.momentumAccelerationGain + ) + layout.addRow( + 'Momentum Acceleration Gain:', + self._momentumAccelerationGainViewController.getWidget(), + ) + + self._momentumAccelerationGradientMixingFactorViewController = ( + PtyChiMomentumAccelerationGradientMixingFactorViewController( + lsqmlSettings.useMomentumAccelerationGradientMixingFactor, + lsqmlSettings.momentumAccelerationGradientMixingFactor, + ) + ) + layout.addRow(self._momentumAccelerationGradientMixingFactorViewController.getWidget()) + + self._widget.setLayout(layout) + + def getWidget(self) -> QWidget: + return self._widget diff --git a/src/ptychodus/controller/reconstructor.py b/src/ptychodus/controller/reconstructor.py index a9ca3d7c..859ad199 100644 --- a/src/ptychodus/controller/reconstructor.py +++ b/src/ptychodus/controller/reconstructor.py @@ -3,8 +3,8 @@ from collections.abc import Iterable, Sequence import logging -from PyQt5.QtCore import Qt, QAbstractItemModel -from PyQt5.QtWidgets import QLabel, QWidget +from PyQt5.QtCore import Qt, QAbstractItemModel, QTimer +from PyQt5.QtWidgets import QActionGroup, QLabel, QWidget from ptychodus.api.observer import Observable, Observer @@ -18,7 +18,7 @@ from ..model.product.probe import ProbeRepositoryItem from ..model.product.scan import ScanRepositoryItem from ..model.reconstructor import ReconstructorPresenter -from ..view.reconstructor import ReconstructorParametersView, ReconstructorPlotView +from ..view.reconstructor import ReconstructorView, ReconstructorPlotView from ..view.widgets import ExceptionDialog from .data import FileDialogFactory @@ -41,8 +41,9 @@ def __init__( self, presenter: ReconstructorPresenter, productRepository: ProductRepository, - view: ReconstructorParametersView, + view: ReconstructorView, plotView: ReconstructorPlotView, + productTableModel: QAbstractItemModel, fileDialogFactory: FileDialogFactory, viewControllerFactoryList: Iterable[ReconstructorViewControllerFactory], ) -> None: @@ -56,79 +57,70 @@ def __init__( vcf.backendName: vcf for vcf in viewControllerFactoryList } - @classmethod - def createInstance( - cls, - presenter: ReconstructorPresenter, - productRepository: ProductRepository, - view: ReconstructorParametersView, - plotView: ReconstructorPlotView, - fileDialogFactory: FileDialogFactory, - productTableModel: QAbstractItemModel, - viewControllerFactoryList: list[ReconstructorViewControllerFactory], - ) -> ReconstructorController: - controller = cls( - presenter, - productRepository, - view, - plotView, - fileDialogFactory, - viewControllerFactoryList, - ) - presenter.addObserver(controller) - productRepository.addObserver(controller) - for name in presenter.getReconstructorList(): - controller._addReconstructor(name) + self._addReconstructor(name) - view.reconstructorView.algorithmComboBox.textActivated.connect(presenter.setReconstructor) - view.reconstructorView.algorithmComboBox.currentIndexChanged.connect( + view.parametersView.algorithmComboBox.textActivated.connect(presenter.setReconstructor) + view.parametersView.algorithmComboBox.currentIndexChanged.connect( view.stackedWidget.setCurrentIndex ) - view.reconstructorView.productComboBox.textActivated.connect(controller._redrawPlot) - view.reconstructorView.productComboBox.setModel(productTableModel) + view.parametersView.productComboBox.textActivated.connect(self._redrawPlot) + view.parametersView.productComboBox.setModel(productTableModel) - openModelAction = view.reconstructorView.modelMenu.addAction('Open...') - openModelAction.triggered.connect(controller._openModel) - saveModelAction = view.reconstructorView.modelMenu.addAction('Save...') - saveModelAction.triggered.connect(controller._saveModel) + self._progressTimer = QTimer() + self._progressTimer.timeout.connect(self._updateProgress) + self._progressTimer.start(5 * 1000) # TODO customize (in milliseconds) - openTrainingDataAction = view.reconstructorView.trainerMenu.addAction( - 'Open Training Data...' - ) - openTrainingDataAction.triggered.connect(controller._openTrainingData) - saveTrainingDataAction = view.reconstructorView.trainerMenu.addAction( - 'Save Training Data...' - ) - saveTrainingDataAction.triggered.connect(controller._saveTrainingData) - ingestTrainingDataAction = view.reconstructorView.trainerMenu.addAction( - 'Ingest Training Data' - ) - ingestTrainingDataAction.triggered.connect(controller._ingestTrainingData) - clearTrainingDataAction = view.reconstructorView.trainerMenu.addAction( - 'Clear Training Data' - ) - clearTrainingDataAction.triggered.connect(controller._clearTrainingData) - view.reconstructorView.trainerMenu.addSeparator() - trainAction = view.reconstructorView.trainerMenu.addAction('Train') - trainAction.triggered.connect(controller._train) + view.progressDialog.setModal(True) + view.progressDialog.setWindowModality(Qt.ApplicationModal) + view.progressDialog.setWindowFlags(Qt.Window | Qt.WindowTitleHint | Qt.CustomizeWindowHint) + view.progressDialog.textEdit.setReadOnly(True) + + openModelAction = view.parametersView.reconstructorMenu.addAction('Open Model...') + openModelAction.triggered.connect(self._openModel) + saveModelAction = view.parametersView.reconstructorMenu.addAction('Save Model...') + saveModelAction.triggered.connect(self._saveModel) - reconstructSplitAction = view.reconstructorView.reconstructorMenu.addAction( + self._modelActionGroup = QActionGroup(view.parametersView.reconstructorMenu) + self._modelActionGroup.setExclusive(False) + self._modelActionGroup.addAction(openModelAction) + self._modelActionGroup.addAction(saveModelAction) + self._modelActionGroup.addAction(view.parametersView.reconstructorMenu.addSeparator()) + + reconstructSplitAction = view.parametersView.reconstructorMenu.addAction( 'Reconstruct Odd/Even Split' ) - reconstructSplitAction.triggered.connect(controller._reconstructSplit) - reconstructAction = view.reconstructorView.reconstructorMenu.addAction('Reconstruct') - reconstructAction.triggered.connect(controller._reconstruct) + reconstructSplitAction.triggered.connect(self._reconstructSplit) + reconstructAction = view.parametersView.reconstructorMenu.addAction('Reconstruct') + reconstructAction.triggered.connect(self._reconstruct) + + exportTrainingDataAction = view.parametersView.trainerMenu.addAction( + 'Export Training Data...' + ) + exportTrainingDataAction.triggered.connect(self._exportTrainingData) + trainAction = view.parametersView.trainerMenu.addAction('Train') + trainAction.triggered.connect(self._train) + + presenter.addObserver(self) + productRepository.addObserver(self) + self._syncModelToView() - controller._syncAlgorithmToView() + def _updateProgress(self) -> None: + isReconstructing = self._presenter.isReconstructing - return controller + for button in self._view.progressDialog.buttonBox.buttons(): + button.setEnabled(not isReconstructing) + + for text in self._presenter.flushLog(): + self._view.progressDialog.textEdit.appendPlainText(text) + + self._presenter.processResults(block=False) def _addReconstructor(self, name: str) -> None: backendName, reconstructorName = name.split('/') # TODO REDO - self._view.reconstructorView.algorithmComboBox.addItem( - name, self._view.reconstructorView.algorithmComboBox.count() + self._view.parametersView.algorithmComboBox.addItem( + name, self._view.parametersView.algorithmComboBox.count() ) if backendName in self._viewControllerFactoryDict: @@ -141,37 +133,37 @@ def _addReconstructor(self, name: str) -> None: self._view.stackedWidget.addWidget(widget) def _reconstruct(self) -> None: - outputProductName = self._presenter.getReconstructor() - inputProductIndex = self._view.reconstructorView.productComboBox.currentIndex() + inputProductIndex = self._view.parametersView.productComboBox.currentIndex() if inputProductIndex < 0: return try: - self._presenter.reconstruct(inputProductIndex, outputProductName) + self._presenter.reconstruct(inputProductIndex) except Exception as err: logger.exception(err) ExceptionDialog.showException('Reconstructor', err) + self._view.progressDialog.show() + def _reconstructSplit(self) -> None: - outputProductName = self._presenter.getReconstructor() - inputProductIndex = self._view.reconstructorView.productComboBox.currentIndex() + inputProductIndex = self._view.parametersView.productComboBox.currentIndex() if inputProductIndex < 0: return try: - self._presenter.reconstructSplit(inputProductIndex, outputProductName) + self._presenter.reconstructSplit(inputProductIndex) except Exception as err: logger.exception(err) ExceptionDialog.showException('Split Reconstructor', err) + self._view.progressDialog.show() + def _openModel(self) -> None: + nameFilter = self._presenter.getModelFileFilter() filePath, nameFilter = self._fileDialogFactory.getOpenFilePath( - self._view, - 'Open Model', - nameFilters=self._presenter.getOpenModelFileFilterList(), - selectedNameFilter=self._presenter.getOpenModelFileFilter(), + self._view, 'Open Model', nameFilters=[nameFilter], selectedNameFilter=nameFilter ) if filePath: @@ -182,11 +174,9 @@ def _openModel(self) -> None: ExceptionDialog.showException('Model Reader', err) def _saveModel(self) -> None: + nameFilter = self._presenter.getModelFileFilter() filePath, _ = self._fileDialogFactory.getSaveFilePath( - self._view, - 'Save Model', - nameFilters=self._presenter.getSaveModelFileFilterList(), - selectedNameFilter=self._presenter.getSaveModelFileFilter(), + self._view, 'Save Model', nameFilters=[nameFilter], selectedNameFilter=nameFilter ) if filePath: @@ -196,64 +186,41 @@ def _saveModel(self) -> None: logger.exception(err) ExceptionDialog.showException('Model Writer', err) - def _openTrainingData(self) -> None: - filePath, nameFilter = self._fileDialogFactory.getOpenFilePath( - self._view, - 'Open Training Data', - nameFilters=self._presenter.getOpenTrainingDataFileFilterList(), - selectedNameFilter=self._presenter.getOpenTrainingDataFileFilter(), - ) + def _exportTrainingData(self) -> None: + inputProductIndex = self._view.parametersView.productComboBox.currentIndex() - if filePath: - try: - self._presenter.openTrainingData(filePath) - except Exception as err: - logger.exception(err) - ExceptionDialog.showException('Training Data Reader', err) + if inputProductIndex < 0: + return - def _saveTrainingData(self) -> None: + nameFilter = self._presenter.getTrainingDataFileFilter() filePath, _ = self._fileDialogFactory.getSaveFilePath( self._view, - 'Save Training Data', - nameFilters=self._presenter.getSaveTrainingDataFileFilterList(), - selectedNameFilter=self._presenter.getSaveTrainingDataFileFilter(), + 'Export Training Data', + nameFilters=[nameFilter], + selectedNameFilter=nameFilter, ) if filePath: try: - self._presenter.saveTrainingData(filePath) + self._presenter.exportTrainingData(filePath, inputProductIndex) except Exception as err: logger.exception(err) ExceptionDialog.showException('Training Data Writer', err) - def _ingestTrainingData(self) -> None: - inputProductIndex = self._view.reconstructorView.productComboBox.currentIndex() - - if inputProductIndex < 0: - return - - try: - self._presenter.ingestTrainingData(inputProductIndex) - except Exception as err: - logger.exception(err) - ExceptionDialog.showException('Ingester', err) - - def _clearTrainingData(self) -> None: - try: - self._presenter.clearTrainingData() - except Exception as err: - logger.exception(err) - ExceptionDialog.showException('Clear', err) - def _train(self) -> None: - try: - self._presenter.train() - except Exception as err: - logger.exception(err) - ExceptionDialog.showException('Trainer', err) + dataPath = self._fileDialogFactory.getExistingDirectoryPath( + self._view, 'Choose Training Data Directory' + ) + + if dataPath: + try: + self._presenter.train(dataPath) + except Exception as err: + logger.exception(err) + ExceptionDialog.showException('Trainer', err) def _redrawPlot(self) -> None: - productIndex = self._view.reconstructorView.productComboBox.currentIndex() + productIndex = self._view.parametersView.productComboBox.currentIndex() if productIndex < 0: self._plotView.axes.clear() @@ -273,14 +240,14 @@ def _redrawPlot(self) -> None: ax.plot(item.getCosts(), '.-', label='Cost', linewidth=1.5) self._plotView.figureCanvas.draw() - def _syncAlgorithmToView(self) -> None: - self._view.reconstructorView.algorithmComboBox.setCurrentText( + def _syncModelToView(self) -> None: + self._view.parametersView.algorithmComboBox.setCurrentText( self._presenter.getReconstructor() ) isTrainable = self._presenter.isTrainable - self._view.reconstructorView.modelButton.setVisible(isTrainable) - self._view.reconstructorView.trainerButton.setVisible(isTrainable) + self._modelActionGroup.setVisible(isTrainable) + self._view.parametersView.trainerButton.setVisible(isTrainable) self._redrawPlot() @@ -300,7 +267,7 @@ def handleObjectChanged(self, index: int, item: ObjectRepositoryItem) -> None: pass def handleCostsChanged(self, index: int, costs: Sequence[float]) -> None: - currentIndex = self._view.reconstructorView.productComboBox.currentIndex() + currentIndex = self._view.parametersView.productComboBox.currentIndex() if index == currentIndex: self._redrawPlot() @@ -310,4 +277,4 @@ def handleItemRemoved(self, index: int, item: ProductRepositoryItem) -> None: def update(self, observable: Observable) -> None: if observable is self._presenter: - self._syncAlgorithmToView() + self._syncModelToView() diff --git a/src/ptychodus/controller/scan/editorFactory.py b/src/ptychodus/controller/scan/editorFactory.py index 39e3b6ef..9a24fbba 100644 --- a/src/ptychodus/controller/scan/editorFactory.py +++ b/src/ptychodus/controller/scan/editorFactory.py @@ -5,13 +5,11 @@ QDialog, QFormLayout, QGridLayout, - QGroupBox, QLabel, QMessageBox, QWidget, ) -from ptychodus.api.observer import Observable, Observer from ...model.product.scan import ( CartesianScanBuilder, @@ -24,6 +22,7 @@ SpiralScanBuilder, ) from ..parametric import ( + CheckableGroupBoxParameterViewController, DecimalLineEditParameterViewController, LengthWidgetParameterViewController, ParameterViewBuilder, @@ -102,13 +101,9 @@ def getWidget(self) -> QWidget: return self._widget -class ScanBoundingBoxViewController(ParameterViewController, Observer): +class ScanBoundingBoxViewController(CheckableGroupBoxParameterViewController): def __init__(self, item: ScanRepositoryItem) -> None: - super().__init__() - self._parameter = item.expandBoundingBox - self._widget = QGroupBox('Expand Bounding Box') - self._widget.setCheckable(True) - + super().__init__(item.expandBoundingBox, 'Expand Bounding Box') self._minimumXController = LengthWidgetParameterViewController( item.expandedBoundingBoxMinimumXInMeters, is_signed=True ) @@ -127,21 +122,7 @@ def __init__(self, item: ScanRepositoryItem) -> None: layout.addRow('Maximum X:', self._maximumXController.getWidget()) layout.addRow('Minimum Y:', self._minimumYController.getWidget()) layout.addRow('Maximum Y:', self._maximumYController.getWidget()) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(self._parameter.setValue) - self._parameter.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._parameter.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._parameter: - self._syncModelToView() + self.getWidget().setLayout(layout) class ScanEditorViewControllerFactory: diff --git a/src/ptychodus/controller/tike/core.py b/src/ptychodus/controller/tike/core.py index fdaf9cb0..8bc115fc 100644 --- a/src/ptychodus/controller/tike/core.py +++ b/src/ptychodus/controller/tike/core.py @@ -1,9 +1,4 @@ -from __future__ import annotations - -from PyQt5.QtWidgets import ( - QVBoxLayout, - QWidget, -) +from PyQt5.QtWidgets import QVBoxLayout, QWidget from ...model.tike import TikeReconstructorLibrary from ..reconstructor import ReconstructorViewControllerFactory @@ -49,7 +44,6 @@ class TikeViewControllerFactory(ReconstructorViewControllerFactory): def __init__(self, model: TikeReconstructorLibrary) -> None: super().__init__() self._model = model - self._controllerList: list[TikeViewController] = list() @property def backendName(self) -> str: @@ -61,5 +55,4 @@ def createViewController(self, reconstructorName: str) -> QWidget: else: viewController = TikeViewController(self._model, showAlpha=False) - self._controllerList.append(viewController) return viewController diff --git a/src/ptychodus/controller/tike/viewControllers.py b/src/ptychodus/controller/tike/viewControllers.py index fc71fa0a..a3072618 100644 --- a/src/ptychodus/controller/tike/viewControllers.py +++ b/src/ptychodus/controller/tike/viewControllers.py @@ -14,6 +14,7 @@ ) from ..parametric import ( CheckBoxParameterViewController, + CheckableGroupBoxParameterViewController, ComboBoxParameterViewController, DecimalLineEditParameterViewController, DecimalSliderParameterViewController, @@ -100,41 +101,25 @@ def update(self, observable: Observable) -> None: self._syncModelToView() -class TikeMultigridViewController(ParameterViewController, Observer): +class TikeMultigridViewController(CheckableGroupBoxParameterViewController): def __init__(self, settings: TikeMultigridSettings) -> None: - super().__init__() + super().__init__(settings.useMultigrid, 'Multigrid') self._useMultigrid = settings.useMultigrid self._numLevelsController = SpinBoxParameterViewController( settings.numLevels, tool_tip='The number of times to reduce the problem by a factor of two.', ) - self._widget = QGroupBox('Multigrid') - self._widget.setCheckable(True) layout = QFormLayout() layout.addRow('Number of Levels:', self._numLevelsController.getWidget()) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(settings.useMultigrid.setValue) - self._useMultigrid.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._useMultigrid.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._useMultigrid: - self._syncModelToView() + self.getWidget().setLayout(layout) -class TikeAdaptiveMomentViewController(ParameterViewController, Observer): +class TikeAdaptiveMomentViewController(CheckableGroupBoxParameterViewController): def __init__( self, useAdaptiveMoment: BooleanParameter, mdecay: RealParameter, vdecay: RealParameter ) -> None: - super().__init__() + super().__init__(useAdaptiveMoment, 'Adaptive Moment') self._useAdaptiveMoment = useAdaptiveMoment self._mdecayViewController = DecimalSliderParameterViewController( mdecay, tool_tip='The proportion of the first moment that is previous first moments.' @@ -142,33 +127,16 @@ def __init__( self._vdecayViewController = DecimalSliderParameterViewController( vdecay, tool_tip='The proportion of the second moment that is previous second moments.' ) - self._widget = QGroupBox('Adaptive Moment') - self._widget.setCheckable(True) layout = QFormLayout() layout.addRow('M Decay:', self._mdecayViewController.getWidget()) layout.addRow('V Decay:', self._vdecayViewController.getWidget()) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(useAdaptiveMoment.setValue) - self._useAdaptiveMoment.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._useAdaptiveMoment.getValue()) + self.getWidget().setLayout(layout) - def update(self, observable: Observable) -> None: - if observable is self._useAdaptiveMoment: - self._syncModelToView() - -class TikeObjectCorrectionViewController(ParameterViewController, Observer): +class TikeObjectCorrectionViewController(CheckableGroupBoxParameterViewController): def __init__(self, settings: TikeObjectCorrectionSettings) -> None: - super().__init__() - self._useObjectCorrection = settings.useObjectCorrection + super().__init__(settings.useObjectCorrection, 'Object Correction') self._positivityConstraintViewController = DecimalSliderParameterViewController( settings.positivityConstraint ) @@ -184,9 +152,6 @@ def __init__(self, settings: TikeObjectCorrectionSettings) -> None: tool_tip='Forces the object magnitude to be <= 1.', ) - self._widget = QGroupBox('Object Correction') - self._widget.setCheckable(True) - layout = QFormLayout() layout.addRow( 'Positivity Constraint:', self._positivityConstraintViewController.getWidget() @@ -196,27 +161,12 @@ def __init__(self, settings: TikeObjectCorrectionSettings) -> None: ) layout.addRow(self._adaptiveMomentViewController.getWidget()) layout.addRow(self._useMagnitudeClippingViewController.getWidget()) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(settings.useObjectCorrection.setValue) - self._useObjectCorrection.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._useObjectCorrection.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._useObjectCorrection: - self._syncModelToView() + self.getWidget().setLayout(layout) -class TikeProbeSupportViewController(ParameterViewController, Observer): +class TikeProbeSupportViewController(CheckableGroupBoxParameterViewController): def __init__(self, settings: TikeProbeCorrectionSettings) -> None: - super().__init__() - self._useFiniteProbeSupport = settings.useFiniteProbeSupport + super().__init__(settings.useFiniteProbeSupport, 'Finite Probe Support') self._weightViewController = DecimalLineEditParameterViewController( settings.probeSupportWeight, tool_tip='Weight of the finite probe constraint.' ) @@ -228,34 +178,17 @@ def __init__(self, settings: TikeProbeCorrectionSettings) -> None: settings.probeSupportDegree, tool_tip='Degree of the supergaussian defining the probe support.', ) - self._widget = QGroupBox('Finite Probe Support') - self._widget.setCheckable(True) layout = QFormLayout() layout.addRow('Weight:', self._weightViewController.getWidget()) layout.addRow('Radius:', self._radiusViewController.getWidget()) layout.addRow('Degree:', self._degreeViewController.getWidget()) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(settings.useFiniteProbeSupport.setValue) - self._useFiniteProbeSupport.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._useFiniteProbeSupport.getValue()) + self.getWidget().setLayout(layout) - def update(self, observable: Observable) -> None: - if observable is self._useFiniteProbeSupport: - self._syncModelToView() - -class TikeProbeCorrectionViewController(ParameterViewController, Observer): +class TikeProbeCorrectionViewController(CheckableGroupBoxParameterViewController): def __init__(self, settings: TikeProbeCorrectionSettings) -> None: - super().__init__() - self._useProbeCorrection = settings.useProbeCorrection + super().__init__(settings.useProbeCorrection, 'Probe Correction') self._forceSparsityViewController = DecimalSliderParameterViewController( settings.forceSparsity, tool_tip='Forces this proportion of zero elements.' ) @@ -277,8 +210,6 @@ def __init__(self, settings: TikeProbeCorrectionSettings) -> None: settings.additionalProbePenalty, tool_tip='Penalty applied to the last probe for existing.', ) - self._widget = QGroupBox('Probe Correction') - self._widget.setCheckable(True) layout = QFormLayout() layout.addRow('Force Sparsity:', self._forceSparsityViewController.getWidget()) @@ -289,26 +220,12 @@ def __init__(self, settings: TikeProbeCorrectionSettings) -> None: layout.addRow( 'Additional Probe Penalty:', self._additionalProbePenaltyViewController.getWidget() ) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(settings.useProbeCorrection.setValue) - self._useProbeCorrection.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget + self.getWidget().setLayout(layout) - def _syncModelToView(self) -> None: - self._widget.setChecked(self._useProbeCorrection.getValue()) - def update(self, observable: Observable) -> None: - if observable is self._useProbeCorrection: - self._syncModelToView() - - -class TikePositionCorrectionViewController(ParameterViewController, Observer): +class TikePositionCorrectionViewController(CheckableGroupBoxParameterViewController): def __init__(self, settings: TikePositionCorrectionSettings) -> None: - self._usePositionCorrection = settings.usePositionCorrection + super().__init__(settings.usePositionCorrection, 'Position Correction') self._usePositionRegularizationViewController = CheckBoxParameterViewController( settings.usePositionRegularization, 'Use Regularization', @@ -321,8 +238,6 @@ def __init__(self, settings: TikePositionCorrectionSettings) -> None: settings.updateMagnitudeLimit, tool_tip='When set to a positive number, x and y update magnitudes are clipped (limited) to this value.', ) - self._widget = QGroupBox('Position Correction') - self._widget.setCheckable(True) layout = QFormLayout() layout.addRow(self._usePositionRegularizationViewController.getWidget()) @@ -330,18 +245,4 @@ def __init__(self, settings: TikePositionCorrectionSettings) -> None: layout.addRow( 'Update Magnitude Limit:', self._updateMagnitudeLimitViewController.getWidget() ) - self._widget.setLayout(layout) - - self._syncModelToView() - self._widget.toggled.connect(settings.usePositionCorrection.setValue) - self._usePositionCorrection.addObserver(self) - - def getWidget(self) -> QWidget: - return self._widget - - def _syncModelToView(self) -> None: - self._widget.setChecked(self._usePositionCorrection.getValue()) - - def update(self, observable: Observable) -> None: - if observable is self._usePositionCorrection: - self._syncModelToView() + self.getWidget().setLayout(layout) diff --git a/src/ptychodus/controller/visualization/controller.py b/src/ptychodus/controller/visualization/controller.py index dbc7be15..a5b6ff37 100644 --- a/src/ptychodus/controller/visualization/controller.py +++ b/src/ptychodus/controller/visualization/controller.py @@ -77,15 +77,19 @@ def setArray( *, autoscaleColorAxis: bool = False, ) -> None: - try: - product = self._engine.render( - array, pixelGeometry, autoscaleColorAxis=autoscaleColorAxis - ) - except ValueError as err: - logger.exception(err) - ExceptionDialog.showException('Renderer', err) + if numpy.all(numpy.isfinite(array)): + try: + product = self._engine.render( + array, pixelGeometry, autoscaleColorAxis=autoscaleColorAxis + ) + except ValueError as err: + logger.exception(err) + ExceptionDialog.showException('Renderer', err) + else: + self._item.setProduct(product) else: - self._item.setProduct(product) + logger.warning('Array contains infinite or NaN values!') + self._item.clearProduct() def clearArray(self) -> None: self._item.clearProduct() diff --git a/src/ptychodus/controller/workflow/controller.py b/src/ptychodus/controller/workflow/controller.py index 6f757127..bf7313ae 100644 --- a/src/ptychodus/controller/workflow/controller.py +++ b/src/ptychodus/controller/workflow/controller.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from PyQt5.QtCore import QAbstractItemModel, QTimer from PyQt5.QtWidgets import QTableView @@ -33,7 +31,7 @@ def __init__( self._authorizationController = WorkflowAuthorizationController.createInstance( authorizationPresenter, parametersView.authorizationDialog ) - self._statusController = WorkflowStatusController.createInstance( + self._statusController = WorkflowStatusController( statusPresenter, parametersView.statusView, tableView ) self._executionController = WorkflowExecutionController.createInstance( @@ -43,32 +41,8 @@ def __init__( productItemModel, ) self._timer = QTimer() - - @classmethod - def createInstance( - cls, - parametersPresenter: WorkflowParametersPresenter, - authorizationPresenter: WorkflowAuthorizationPresenter, - statusPresenter: WorkflowStatusPresenter, - executionPresenter: WorkflowExecutionPresenter, - parametersView: WorkflowParametersView, - tableView: QTableView, - productItemModel: QAbstractItemModel, - ) -> WorkflowController: - controller = cls( - parametersPresenter, - authorizationPresenter, - statusPresenter, - executionPresenter, - parametersView, - tableView, - productItemModel, - ) - - controller._timer.timeout.connect(controller._processEvents) - controller._timer.start(1000) # TODO customize - - return controller + self._timer.timeout.connect(self._processEvents) + self._timer.start(5 * 1000) # TODO customize def _processEvents(self) -> None: self._authorizationController.startAuthorizationIfNeeded() diff --git a/src/ptychodus/controller/workflow/status.py b/src/ptychodus/controller/workflow/status.py index eb095415..602aa625 100644 --- a/src/ptychodus/controller/workflow/status.py +++ b/src/ptychodus/controller/workflow/status.py @@ -1,10 +1,11 @@ -from __future__ import annotations import logging from PyQt5.QtCore import Qt, QModelIndex, QSortFilterProxyModel, QTimer from PyQt5.QtWidgets import QAbstractItemView, QTableView from PyQt5.QtGui import QDesktopServices +from ptychodus.api.observer import Observable, Observer + from ...model.workflow import WorkflowStatusPresenter from ...view.workflow import WorkflowStatusView from .tableModel import WorkflowTableModel @@ -12,48 +13,39 @@ logger = logging.getLogger(__name__) -class WorkflowStatusController: +class WorkflowStatusController(Observer): def __init__( self, presenter: WorkflowStatusPresenter, view: WorkflowStatusView, tableView: QTableView, ) -> None: + super().__init__() self._presenter = presenter self._view = view self._tableView = tableView self._tableModel = WorkflowTableModel(presenter) self._proxyModel = QSortFilterProxyModel() + self._proxyModel.setSourceModel(self._tableModel) self._timer = QTimer() + self._timer.timeout.connect(presenter.refreshStatus) - @classmethod - def createInstance( - cls, - presenter: WorkflowStatusPresenter, - view: WorkflowStatusView, - tableView: QTableView, - ) -> WorkflowStatusController: - controller = cls(presenter, view, tableView) - - controller._proxyModel.setSourceModel(controller._tableModel) - tableView.setModel(controller._proxyModel) + tableView.setModel(self._proxyModel) tableView.setSortingEnabled(True) tableView.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) - tableView.clicked.connect(controller._handleTableViewClick) + tableView.clicked.connect(self._handleTableViewClick) - controller._timer.timeout.connect(presenter.refreshStatus) - view.autoRefreshCheckBox.toggled.connect(controller._autoRefreshStatus) + view.autoRefreshCheckBox.toggled.connect(self._autoRefreshStatus) view.autoRefreshSpinBox.valueChanged.connect(presenter.setRefreshIntervalInSeconds) view.refreshButton.clicked.connect(presenter.refreshStatus) - controller._syncModelToView() - - return controller + self._syncModelToView() + presenter.addObserver(self) def _handleTableViewClick(self, index: QModelIndex) -> None: if index.column() == 5: url = index.data(Qt.ItemDataRole.UserRole) - logger.debug(f'Opening URL: "{url.toString()}"') + logger.info(f'Opening URL: "{url.toString()}"') QDesktopServices.openUrl(url) def _autoRefreshStatus(self) -> None: @@ -80,3 +72,7 @@ def _syncModelToView(self) -> None: ) self._view.autoRefreshSpinBox.setValue(self._presenter.getRefreshIntervalInSeconds()) self._view.autoRefreshSpinBox.blockSignals(False) + + def update(self, observable: Observable) -> None: + if observable is self._presenter: + self._syncModelToView() diff --git a/src/ptychodus/model/analysis/exposure.py b/src/ptychodus/model/analysis/exposure.py index c884c403..baae2708 100644 --- a/src/ptychodus/model/analysis/exposure.py +++ b/src/ptychodus/model/analysis/exposure.py @@ -2,11 +2,13 @@ from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path +from typing import Any import logging import numpy from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import ObjectCenter from ptychodus.api.visualization import RealArrayType from ..product import ProductRepository @@ -16,16 +18,10 @@ @dataclass(frozen=True) class ExposureMap: - pixel_width_m: float - pixel_height_m: float - center_x_m: float - center_y_m: float + pixel_geometry: PixelGeometry | None + center: ObjectCenter | None counts: RealArrayType - @property - def pixel_geometry(self) -> PixelGeometry: - return PixelGeometry(self.pixel_width_m, self.pixel_height_m) - class ExposureAnalyzer: def __init__(self, repository: ProductRepository) -> None: @@ -36,13 +32,11 @@ def analyze(self, itemIndex: int) -> ExposureMap: objectItem = item.getObject() object_ = objectItem.getObject() - counts = numpy.zeros_like(object_.array, dtype=float) # FIXME + counts = numpy.zeros_like(object_.getArray(), dtype=float) # FIXME return ExposureMap( - pixel_width_m=object_.pixelWidthInMeters, - pixel_height_m=object_.pixelHeightInMeters, - center_x_m=object_.centerXInMeters, - center_y_m=object_.centerYInMeters, + pixel_geometry=object_.getPixelGeometry(), + center=object_.getCenter(), counts=counts, ) @@ -52,17 +46,21 @@ def getSaveFileFilterList(self) -> Sequence[str]: def getSaveFileFilter(self) -> str: return 'NumPy Zipped Archive (*.npz)' - def saveResult(self, filePath: Path, result: ExposureMap) -> None: - numpy.savez( - filePath, - 'pixel_height_m', - result.pixel_height_m, - 'pixel_width_m', - result.pixel_width_m, - 'center_x_m', - result.center_x_m, - 'center_y_m', - result.center_y_m, - 'counts', - result.counts, - ) + def saveResult(self, file_path: Path, result: ExposureMap) -> None: + contents: dict[str, Any] = { + 'counts': result.counts, + } + + pixel_geometry = result.pixel_geometry + + if pixel_geometry is not None: + contents['pixel_height_m'] = pixel_geometry.heightInMeters + contents['pixel_width_m'] = pixel_geometry.widthInMeters + + center = result.center + + if center is not None: + contents['center_x_m'] = center.positionXInMeters + contents['center_y_m'] = center.positionYInMeters + + numpy.savez(file_path, **contents) diff --git a/src/ptychodus/model/analysis/frc.py b/src/ptychodus/model/analysis/frc.py index a9735ab2..2a3e855f 100644 --- a/src/ptychodus/model/analysis/frc.py +++ b/src/ptychodus/model/analysis/frc.py @@ -76,6 +76,9 @@ def correlate(self, itemIndex1: int, itemIndex2: int) -> FourierRingCorrelation: # TODO verify compatible pixel geometry pixelGeometry = object2.getPixelGeometry() + if pixelGeometry is None: + raise ValueError('No pixel geometry!') + # TODO subpixel image registration: skimage.registration.phase_cross_correlation # TODO remove phase offset and ramp # TODO apply soft-edged mask diff --git a/src/ptychodus/model/analysis/objectInterpolator.py b/src/ptychodus/model/analysis/objectInterpolator.py index 544f2734..14be8575 100644 --- a/src/ptychodus/model/analysis/objectInterpolator.py +++ b/src/ptychodus/model/analysis/objectInterpolator.py @@ -7,12 +7,12 @@ class ObjectLinearInterpolator(ObjectInterpolator): def __init__(self, object_: Object) -> None: self._object = object_ - def getPatch(self, patchCenter: ScanPoint, patchExtent: ImageExtent) -> Object: + def get_patch(self, patch_center: ScanPoint, patch_extent: ImageExtent) -> Object: geometry = self._object.getGeometry() - patchWidth = patchExtent.widthInPixels + patchWidth = patch_extent.widthInPixels patchRadiusXInMeters = geometry.pixelWidthInMeters * patchWidth / 2 - patchMinimumXInMeters = patchCenter.positionXInMeters - patchRadiusXInMeters + patchMinimumXInMeters = patch_center.positionXInMeters - patchRadiusXInMeters ixBeginF, xi = divmod( patchMinimumXInMeters - geometry.minimumXInMeters, geometry.pixelWidthInMeters, @@ -22,9 +22,9 @@ def getPatch(self, patchCenter: ScanPoint, patchExtent: ImageExtent) -> Object: ixSlice0 = slice(ixBegin, ixEnd) ixSlice1 = slice(ixBegin + 1, ixEnd + 1) - patchHeight = patchExtent.heightInPixels + patchHeight = patch_extent.heightInPixels patchRadiusYInMeters = geometry.pixelHeightInMeters * patchHeight / 2 - patchMinimumYInMeters = patchCenter.positionYInMeters - patchRadiusYInMeters + patchMinimumYInMeters = patch_center.positionYInMeters - patchRadiusYInMeters iyBeginF, eta = divmod( patchMinimumYInMeters - geometry.minimumYInMeters, geometry.pixelHeightInMeters, @@ -42,16 +42,15 @@ def getPatch(self, patchCenter: ScanPoint, patchExtent: ImageExtent) -> Object: w10 = xiC * eta w11 = xi * eta - patch = w00 * self._object.array[:, iySlice0, ixSlice0] - patch += w01 * self._object.array[:, iySlice0, ixSlice1] - patch += w10 * self._object.array[:, iySlice1, ixSlice0] - patch += w11 * self._object.array[:, iySlice1, ixSlice1] + objectArray = self._object.getArray() + patch = w00 * objectArray[:, iySlice0, ixSlice0] + patch += w01 * objectArray[:, iySlice0, ixSlice1] + patch += w10 * objectArray[:, iySlice1, ixSlice0] + patch += w11 * objectArray[:, iySlice1, ixSlice1] - return Object( + return Object( # FIXME multilayer objects array=patch, layerDistanceInMeters=self._object.layerDistanceInMeters, - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, - centerXInMeters=geometry.centerXInMeters, - centerYInMeters=geometry.centerYInMeters, + pixelGeometry=geometry.getPixelGeometry(), + center=geometry.getCenter(), ) diff --git a/src/ptychodus/model/analysis/objectStitcher.py b/src/ptychodus/model/analysis/objectStitcher.py index 61c1bb12..abd44315 100644 --- a/src/ptychodus/model/analysis/objectStitcher.py +++ b/src/ptychodus/model/analysis/objectStitcher.py @@ -52,11 +52,9 @@ def addPatch(self, patchCenter: ScanPoint, patchArray: ObjectArrayType) -> None: self._addPatchPart(ixSlice0, iySlice1, xiC * eta, patchArray) self._addPatchPart(ixSlice1, iySlice1, xi * eta, patchArray) - def build(self) -> Object: + def build(self) -> Object: # FIXME multilayer objects? return Object( array=self._array, - pixelWidthInMeters=self._geometry.pixelWidthInMeters, - pixelHeightInMeters=self._geometry.pixelHeightInMeters, - centerXInMeters=self._geometry.centerXInMeters, - centerYInMeters=self._geometry.centerYInMeters, + pixelGeometry=self._geometry.getPixelGeometry(), + center=self._geometry.getCenter(), ) diff --git a/src/ptychodus/model/analysis/phaseUnwrapper.py b/src/ptychodus/model/analysis/phaseUnwrapper.py new file mode 100644 index 00000000..1502f2bf --- /dev/null +++ b/src/ptychodus/model/analysis/phaseUnwrapper.py @@ -0,0 +1,459 @@ +import numpy +from numpy.typing import NDArray +from scipy import signal, ndimage +from typing import Literal, Optional, Tuple + + +class PhaseUnwrapper: + def __init__( + self, + fourier_shift_step: float = 0.5, + image_grad_method: Literal[ + 'fourier_shift', 'fourier_differentiation', 'nearest' + ] = 'fourier_differentiation', + image_integration_method: Literal['fourier', 'discrete', 'deconvolution'] = 'fourier', + weight_map: Optional[NDArray] = None, + eps: float = 1e-9, + ) -> None: + """Get the unwrapped phase of a complex 2D image. + + Parameters + ---------- + fourier_shift_step : float + The finite-difference step size used to calculate the gradient, + if the Fourier shift method is used. + image_grad_method : str + The method used to calculate the phase gradient. + - "fourier_shift": Use Fourier shift to perform shift. + - "nearest": Use nearest neighbor to perform shift. + - "fourier_differentiation": Use Fourier differentiation. + image_integration_method : str + The method used to integrate the image back from gradients. + - "fourier": Use Fourier integration as implemented in PtychoShelves. + - "deconvolution": Deconvolve ramp filter. + - "discrete": Use cumulative sum. + weight_map : Optional[NDArray] + A weight map multiplied to the input image. + eps : float + A small number to avoid division by zero. + """ + self.fourier_shift_step = fourier_shift_step + self.image_grad_method = image_grad_method + self.image_integration_method = image_integration_method + self.weight_map = weight_map + self.eps = eps + + def unwrap(self, img: NDArray) -> NDArray: + """Run unwrapping. + + Parameters + ---------- + img : NDArray + A 2D complex array giving the image to be unwrapped. + + Returns + ------- + NDArray + A 2D real array giving the unwrapped phase of the input image. + """ + if not numpy.iscomplexobj(img): + raise ValueError('Input array must be complex.') + + if self.weight_map is not None: + weight_map = float(numpy.clip(self.weight_map, 0.0, 1.0)) + else: + weight_map = 1.0 + + img = weight_map * img / (numpy.abs(img) + self.eps) + bc_center = numpy.angle(img[img.shape[0] // 2, img.shape[1] // 2]) + + # Pad image to avoid FFT boundary artifacts. + padding = [64, 64] + if any(numpy.array(padding) > 0): + img = numpy.pad( + img, ((padding[0], padding[0]), (padding[1], padding[1])), mode='reflect' + ) + img = vignett(img, margin=10, sigma=2.5) + + gy, gx = get_phase_gradient( + img, + fourier_shift_step=self.fourier_shift_step, + image_grad_method=self.image_grad_method, + ) + + if self.image_integration_method == 'discrete' and any(numpy.array(padding) > 0): + gy = gy[padding[0] : -padding[0], padding[1] : -padding[1]] + gx = gx[padding[0] : -padding[0], padding[1] : -padding[1]] + if self.image_integration_method == 'discrete': + phase = numpy.real(integrate_image_2d(gy, gx, bc_center=bc_center)) + elif self.image_integration_method == 'fourier': + phase = numpy.real(integrate_image_2d_fourier(gy, gx)) + elif self.image_integration_method == 'deconvolution': + phase = numpy.real(integrate_image_2d_deconvolution(gy, gx, bc_center=bc_center)) + else: + raise ValueError(f'Unknown integration method: {self.image_integration_method}') + + if self.image_integration_method != 'discrete' and any(numpy.array(padding) > 0): + gy = gy[padding[0] : -padding[0], padding[1] : -padding[1]] + gx = gx[padding[0] : -padding[0], padding[1] : -padding[1]] + phase = phase[padding[0] : -padding[0], padding[1] : -padding[1]] + + return phase + + +def vignett(img: NDArray, margin: int = 20, sigma: float = 1.0) -> NDArray: + """Vignett an image so that it gradually decays near the boundary. + For each dimension of the image, a mask with a width of `2 * margin` + and with half of it filled with 0s and half with 1s is + generated and convolved with a Gaussian kernel of size + `margin` and standard deviation `sigma`. The blurred mask is cropped and + multiplied to the near-edge regions of the image. + + Parameters + ---------- + img : Tensor + The input image. + margin : int + The margin of image where the decay takes place. + sigma : float + The standard deviation of the Gaussian kernel. + """ + img = img.copy() + for i_dim in range(img.ndim): + if img.shape[i_dim] <= 2 * margin: + continue + + mask_shape = ( + [img.shape[i] for i in range(i_dim)] + + [2 * margin] + + [img.shape[i] for i in range(i_dim + 1, img.ndim)] + ) + mask = numpy.zeros(mask_shape) + mask_slicer = [slice(None)] * i_dim + [slice(margin, None)] + mask[tuple(mask_slicer)] = 1.0 + + gauss_win = signal.windows.gaussian(margin // 2, std=sigma) + gauss_win = gauss_win / numpy.sum(gauss_win) + mask = ndimage.convolve1d(mask, gauss_win, axis=i_dim, mode='constant') + mask_final_slicer = [slice(None)] * i_dim + [slice(len(gauss_win), len(gauss_win) + margin)] + mask = mask[*mask_final_slicer] + mask = numpy.where(mask < 1e-3, 0, mask) + + slicer = tuple([slice(None)] * i_dim + [slice(0, margin)]) + img[slicer] = img[slicer] * mask + + slicer = tuple([slice(None)] * i_dim + [slice(-margin, None)]) + img[slicer] = img[slicer] * numpy.flip(mask, axis=i_dim) + return img + + +def nearest_neighbor_gradient( + image: NDArray, direction: Literal['forward', 'backward'], dim: Tuple[int, ...] = (0, 1) +) -> Tuple[NDArray, NDArray]: + """ + Calculate the nearest neighbor gradient of a 2D image. + + Parameters + ---------- + image : NDArray + a (... H, W) tensor of images. + direction : str + 'forward' or 'backward'. + dim : tuple of int, optional + Dimensions to calculate gradient. Default is (0, 1). + + Returns + ------- + tuple of NDArray + a tuple of 2 images with the gradient in y and x directions. + """ + if not hasattr(dim, '__len__'): + dim = (dim,) + grad_x = None + grad_y = None + if direction == 'forward': + if 1 in dim: + grad_x = numpy.concatenate([image[:, 1:], image[:, -1:]], axis=1) - image + if 0 in dim: + grad_y = numpy.concatenate([image[1:, :], image[-1:, :]], axis=0) - image + elif direction == 'backward': + if 1 in dim: + grad_x = image - numpy.concatenate([image[:, :1], image[:, :-1]], axis=1) + if 0 in dim: + grad_y = image - numpy.concatenate([image[:1, :], image[:-1, :]], axis=0) + else: + raise ValueError("direction must be 'forward' or 'backward'") + return grad_y, grad_x + + +def gaussian_gradient(image: NDArray, sigma: float = 1.0, kernel_size=5) -> Tuple[NDArray, NDArray]: + """ + Calculate the gradient of a 2D image with a Gaussian-derivative kernel. + + Parameters + ---------- + image : NDArray + A (... H, W) tensor of images. + sigma : float + Sigma of the Gaussian. + + Returns + ------- + tuple of NDArray + A tuple of 2 images with the gradient in y and x directions. + """ + r = numpy.arange(kernel_size) - (kernel_size - 1) / 2.0 + kernel = -r / (numpy.sqrt(2 * numpy.pi) * sigma**3) * numpy.exp(-(r**2) / (2 * sigma**2)) + grad_y = ndimage.convolve(image, kernel.reshape(-1, 1), mode='nearest') + grad_x = ndimage.convolve(image, kernel.reshape(1, -1), mode='nearest') + + # Gate the gradients + grads = [grad_y, grad_x] + for i, g in enumerate(grads): + m = numpy.logical_and(numpy.abs(grad_y) < 1e-6, numpy.abs(grad_y) != 0) + if numpy.count_nonzero(m) > 0: + print('Gradient magnitudes between 0 and 1e-6 are set to 0.') + g = g * numpy.logical_not(m) + grads[i] = g + grad_y, grad_x = grads + return grad_y, grad_x + + +def fourier_gradient(image: NDArray) -> Tuple[NDArray, NDArray]: + """Calculate gradient using NumPy FFT operations""" + u = numpy.fft.fftfreq(image.shape[0]) + v = numpy.fft.fftfreq(image.shape[1]) + u, v = numpy.meshgrid(u, v, indexing='ij') + + grad_y = numpy.fft.ifft(numpy.fft.fft(image, axis=-2) * (2j * numpy.pi * u), axis=-2) + grad_x = numpy.fft.ifft(numpy.fft.fft(image, axis=-1) * (2j * numpy.pi * v), axis=-1) + + return grad_y, grad_x + + +def get_phase_gradient( + img: NDArray, + fourier_shift_step: float = 0, + image_grad_method: Literal[ + 'fourier_shift', 'fourier_differentiation', 'nearest' + ] = 'fourier_shift', + eps: float = 1e-6, +) -> Tuple[NDArray, NDArray]: + """ + Get the gradient of the phase of a complex 2D image by first calculating + the spatial gradient of the complex image, then taking the phase of the + complex gradient -- i.e., it takes the phase of the gradient rather than + the gradient of the phase. This avoids the sharp gradients due to phase + wrapping when directly taking the gradient of the phase. + + Parameters + ---------- + img : NDArray + A [N, H, W] or [H, W] tensor giving a batch of images or a single image. + step : float + The finite-difference step size used to calculate the gradient, if + the Fourier shift method is used. + finite_diff_method : enums.ImageGradientMethods + The method used to calculate the phase gradient. + - "fourier_shift": Use Fourier shift to perform shift. + - "nearest": Use nearest neighbor to perform shift. + - "fourier_differentiation": Use Fourier differentiation. + eps : float + A stablizing constant. + + Returns + ------- + Tuple[NDArray, NDArray] + A tuple of 2 images with the gradient in y and x directions. + """ + if fourier_shift_step <= 0 and image_grad_method == 'fourier_shift': + raise ValueError('Step must be positive.') + + if image_grad_method == 'fourier_differentiation': + gy, gx = fourier_gradient(img) + gy = numpy.imag(numpy.conj(img) * gy) + gx = numpy.imag(numpy.conj(img) * gx) + else: + # Use finite difference. + if img.ndim == 2: + img = img[None, ...] + pad = int(numpy.ceil(fourier_shift_step)) + 1 + img = numpy.pad(img, ((0, 0), (pad, pad), (pad, pad)), mode='reflect') + + sy1 = numpy.array([[-fourier_shift_step, 0]]).repeat(img.shape[0], axis=0) + sy2 = numpy.array([[fourier_shift_step, 0]]).repeat(img.shape[0], axis=0) + if image_grad_method == 'fourier_shift': + # If the image contains zero-valued pixels, Fourier shift can result in small + # non-zero values that dangles around 0. This can cause the phase + # of the shifted image to dangle between pi and -pi. In that case, use + # `finite_diff_method="nearest" instead`, or use `step=1`. + complex_prod = fourier_shift(img, sy1) * fourier_shift(img, sy2).conj() + elif image_grad_method == 'nearest': + complex_prod = img * numpy.concatenate([img[:, :1, :], img[:, :-1, :]], axis=1).conj() + else: + raise ValueError(f'Unknown finite-difference method: {image_grad_method}') + complex_prod = numpy.where( + numpy.abs(complex_prod) < numpy.abs(complex_prod).max() * 1e-6, 0, complex_prod + ) + gy = numpy.angle(complex_prod) / (2 * fourier_shift_step) + gy = gy[0, pad:-pad, pad:-pad] + + sx1 = numpy.array([[0, -fourier_shift_step]]).repeat(img.shape[0], axis=0) + sx2 = numpy.array([[0, fourier_shift_step]]).repeat(img.shape[0], axis=0) + if image_grad_method == 'fourier_shift': + complex_prod = fourier_shift(img, sx1) * fourier_shift(img, sx2).conj() + elif image_grad_method == 'nearest': + complex_prod = img * numpy.concatenate([img[:, :, :1], img[:, :, :-1]], axis=2).conj() + complex_prod = numpy.where( + numpy.abs(complex_prod) < numpy.abs(complex_prod).max() * 1e-6, 0, complex_prod + ) + gx = numpy.angle(complex_prod) / (2 * fourier_shift_step) + gx = gx[0, pad:-pad, pad:-pad] + return gy, gx + + +def integrate_image_2d_fourier(grad_y: NDArray, grad_x: NDArray) -> NDArray: + """ + Integrate an image with the gradient in y and x directions using Fourier + differentiation. + + Parameters + ---------- + grad_y, grad_x: NDArray + A (H, W) tensor of gradients in y or x directions. + + Returns + ------- + NDArray + The integrated image. + """ + shape = grad_y.shape + f = numpy.fft.fft2(grad_x + 1j * grad_y) + y, x = numpy.fft.fftfreq(shape[0]), numpy.fft.fftfreq(shape[1]) + + r = numpy.exp(2j * numpy.pi * (x + y[:, None])) + r = r / (2j * numpy.pi * (x + 1j * y[:, None])) + r[0, 0] = 0 + integrated_image = f * r + integrated_image = numpy.fft.ifft2(integrated_image) + if not numpy.iscomplexobj(grad_x): + integrated_image = integrated_image.real + return integrated_image + + +def integrate_image_2d_deconvolution( + grad_y: NDArray, + grad_x: NDArray, + tf_y: Optional[NDArray] = None, + tf_x: Optional[NDArray] = None, + bc_center: float = 0, +) -> NDArray: + """ + Integrate an image with the gradient in y and x directions by deconvolving + the differentiation kernel, whose transfer function is assumed to be a + ramp function. + + Adapted from Tripathi, A., McNulty, I., Munson, T., & Wild, S. M. (2016). + Single-view phase retrieval of an extended sample by exploiting edge detection + and sparsity. Optics Express, 24(21), 24719–24738. doi:10.1364/OE.24.024719 + + Parameters + ---------- + grad_y, grad_x: NDArray + A (H, W) tensor of gradients in y or x directions. + tf_y, tf_x: NDArray + A (H, W) tensor of transfer functions in y or x directions. If not + provided, they are assumed to be 2i * pi * u (or v), which are the + effective transfer functions in Fourier differentiation. + bc_center: float + The value of the boundary condition at the center of the image. + + Returns + ------- + NDArray + The integrated image. + """ + u, v = numpy.fft.fftfreq(grad_x.shape[0]), numpy.fft.fftfreq(grad_x.shape[1]) + u, v = numpy.meshgrid(u, v, indexing='ij') + if tf_y is None or tf_x is None: + tf_y = 2j * numpy.pi * u + tf_x = 2j * numpy.pi * v + f_grad_y = numpy.fft.fft2(grad_y) + f_grad_x = numpy.fft.fft2(grad_x) + img = (f_grad_y * tf_y + f_grad_x * tf_x) / (numpy.abs(tf_y) ** 2 + numpy.abs(tf_x) ** 2 + 1e-5) + img = -numpy.fft.ifft2(img) + img = img + bc_center - img[img.shape[0] // 2, img.shape[1] // 2] + return img + + +def integrate_image_2d(grad_y: NDArray, grad_x: NDArray, bc_center: float = 0) -> NDArray: + """ + Integrate an image with the gradient in y and x directions. + + Parameters + ---------- + grad_y : NDArray + The gradient in y direction. + grad_x : NDArray + The gradient in x direction. + bc_center : float + The boundary condition at the center of the image, by default 0 + + Returns + ------- + NDArray + The integrated image. + """ + left_boundary = numpy.cumsum(grad_y[:, 0], axis=0) + int_img = numpy.cumsum(grad_x, axis=1) + left_boundary[:, None] + int_img = int_img + bc_center - int_img[int_img.shape[0] // 2, int_img.shape[1] // 2] + return int_img + + +def fourier_shift( + images: NDArray, shifts: NDArray, strictly_preserve_zeros: bool = False +) -> NDArray: + """ + Apply Fourier shift to a batch of images. + + Parameters + ---------- + images : NDArray + A [N, H, W] array of images. + shifts : NDArray + A [N, 2] array of shifts in pixels. + strictly_preserve_zeros : bool + If True, mask of strictly zero pixels will be generated and shifted + by the same amount. Pixels that have a non-zero value in the shifted + mask will be set to zero in the shifted image. This preserves the zero + pixels in the original image, preventing FFT from introducing small + non-zero values due to machine precision. + + Returns + ------- + NDArray + Shifted images. + """ + if strictly_preserve_zeros: + zero_mask = images == 0 + zero_mask = zero_mask.float() + zero_mask_shifted = fourier_shift(zero_mask, shifts, strictly_preserve_zeros=False) + ft_images = numpy.fft.fft2(images) + freq_y, freq_x = numpy.meshgrid( + numpy.fft.fftfreq(images.shape[-2]), numpy.fft.fftfreq(images.shape[-1]), indexing='ij' + ) + freq_x = freq_x.repeat(images.shape[0], axis=0) + freq_y = freq_y.repeat(images.shape[0], axis=0) + mult = numpy.exp( + 1j + * -2 + * numpy.pi + * (freq_x * shifts[:, 1].reshape([-1, 1, 1]) + freq_y * shifts[:, 0].reshape([-1, 1, 1])) + ) + ft_images = ft_images * mult + shifted_images = numpy.fft.ifft2(ft_images) + if not numpy.iscomplexobj(images): + shifted_images = shifted_images.real + if strictly_preserve_zeros: + shifted_images[zero_mask_shifted > 0] = 0 + return shifted_images diff --git a/src/ptychodus/model/analysis/propagator.py b/src/ptychodus/model/analysis/propagator.py index a7d7d2de..aa5ca6af 100644 --- a/src/ptychodus/model/analysis/propagator.py +++ b/src/ptychodus/model/analysis/propagator.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Sequence from pathlib import Path +from typing import Any import logging import numpy @@ -49,33 +50,38 @@ def propagate( beginCoordinateInMeters: float, endCoordinateInMeters: float, numberOfSteps: int, - ) -> None: + ) -> None: # FIXME OPR item = self._repository[self._productIndex] probe = item.getProbe().getProbe() wavelengthInMeters = item.getGeometry().probeWavelengthInMeters propagatedWavefield = numpy.zeros( - (numberOfSteps, *probe.array.shape), - dtype=probe.array.dtype, + (numberOfSteps, probe.heightInPixels, probe.widthInPixels), + dtype=probe.dataType, + ) + propagatedIntensity = numpy.zeros( + (numberOfSteps, probe.heightInPixels, probe.widthInPixels) ) - propagatedIntensity = numpy.zeros((numberOfSteps, *probe.array.shape[-2:])) distanceInMeters = numpy.linspace( beginCoordinateInMeters, endCoordinateInMeters, numberOfSteps ) pixelGeometry = probe.getPixelGeometry() + if pixelGeometry is None: + raise ValueError('No pixel geometry!') + for idx, zInMeters in enumerate(distanceInMeters): propagatorParameters = PropagatorParameters( wavelength_m=wavelengthInMeters, - width_px=probe.array.shape[-1], - height_px=probe.array.shape[-2], + width_px=probe.widthInPixels, + height_px=probe.heightInPixels, pixel_width_m=pixelGeometry.widthInMeters, pixel_height_m=pixelGeometry.heightInMeters, - propagation_distance_m=zInMeters, + propagation_distance_m=float(zInMeters), ) propagator = AngularSpectrumPropagator(propagatorParameters) - for mode in range(probe.array.shape[-3]): - wf = propagator.propagate(probe.array[mode]) + for mode in range(probe.numberOfIncoherentModes): + wf = propagator.propagate(probe.getIncoherentMode(mode)) propagatedWavefield[idx, mode, :, :] = wf propagatedIntensity[idx, :, :] += intensity(wf) @@ -95,9 +101,13 @@ def _getProbe(self) -> Probe: item = self._repository[self._productIndex] return item.getProbe().getProbe() - def getPixelGeometry(self) -> PixelGeometry: - probe = self._getProbe() - return probe.getPixelGeometry() + def getPixelGeometry(self) -> PixelGeometry | None: + try: + probe = self._getProbe() + except IndexError: + return None + else: + return probe.getPixelGeometry() def getNumberOfSteps(self) -> int: if self._propagatedIntensity is None: @@ -139,19 +149,17 @@ def savePropagatedProbe(self, filePath: Path) -> None: if self._propagatedWavefield is None or self._propagatedIntensity is None: raise ValueError('No propagated wavefield!') - pixelGeometry = self.getPixelGeometry() - numpy.savez( - filePath, - 'begin_coordinate_m', - float(self.getBeginCoordinateInMeters()), - 'end_coordinate_m', - float(self.getEndCoordinateInMeters()), - 'pixel_height_m', - pixelGeometry.heightInMeters, - 'pixel_width_m', - pixelGeometry.widthInMeters, - 'wavefield', - self._propagatedWavefield, - 'intensity', - self._propagatedIntensity, - ) + contents: dict[str, Any] = { + 'begin_coordinate_m': self.getBeginCoordinateInMeters(), + 'end_coordinate_m': self.getEndCoordinateInMeters(), + 'wavefield': self._propagatedWavefield, + 'intensity': self._propagatedIntensity, + } + + pixel_geometry = self.getPixelGeometry() + + if pixel_geometry is not None: + contents['pixel_height_m'] = pixel_geometry.heightInMeters + contents['pixel_width_m'] = pixel_geometry.widthInMeters + + numpy.savez(filePath, **contents) diff --git a/src/ptychodus/model/analysis/xmcd.py b/src/ptychodus/model/analysis/xmcd.py index 38044e30..0623ed49 100644 --- a/src/ptychodus/model/analysis/xmcd.py +++ b/src/ptychodus/model/analysis/xmcd.py @@ -2,12 +2,13 @@ from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path +from typing import Any import logging import numpy from ptychodus.api.geometry import PixelGeometry -from ptychodus.api.object import ObjectArrayType +from ptychodus.api.object import ObjectArrayType, ObjectCenter from ..product import ObjectRepository @@ -16,18 +17,12 @@ @dataclass(frozen=True) class XMCDResult: - pixel_width_m: float - pixel_height_m: float - center_x_m: float - center_y_m: float + pixel_geometry: PixelGeometry | None + center: ObjectCenter | None polar_difference: ObjectArrayType polar_sum: ObjectArrayType polar_ratio: ObjectArrayType - @property - def pixel_geometry(self) -> PixelGeometry: - return PixelGeometry(self.pixel_width_m, self.pixel_height_m) - class XMCDAnalyzer: # TODO feature request: want ability to align/add reconstructed slices @@ -40,22 +35,26 @@ def analyze(self, lcircItemIndex: int, rcircItemIndex: int) -> XMCDResult: lcircObject = self._repository[lcircItemIndex].getObject() rcircObject = self._repository[rcircItemIndex].getObject() - if lcircObject.widthInPixels != rcircObject.widthInPixels: + lcircObjectGeometry = lcircObject.getGeometry() + rcircObjectGeometry = rcircObject.getGeometry() + + if lcircObjectGeometry.widthInPixels != rcircObjectGeometry.widthInPixels: raise ValueError('Object width mismatch!') - if lcircObject.heightInPixels != rcircObject.heightInPixels: + if lcircObjectGeometry.heightInPixels != rcircObjectGeometry.heightInPixels: raise ValueError('Object height mismatch!') - if lcircObject.pixelWidthInMeters != rcircObject.pixelWidthInMeters: + if lcircObjectGeometry.pixelWidthInMeters != rcircObjectGeometry.pixelWidthInMeters: raise ValueError('Object pixel width mismatch!') - if lcircObject.pixelHeightInMeters != rcircObject.pixelHeightInMeters: + if lcircObjectGeometry.pixelHeightInMeters != rcircObjectGeometry.pixelHeightInMeters: raise ValueError('Object pixel height mismatch!') # TODO align lcircArray/rcircArray - lcircAmp = numpy.absolute(lcircObject.array) - rcircAmp = numpy.absolute(rcircObject.array) + # FIXME OPR + lcircAmp = numpy.absolute(lcircObject.getArray()) + rcircAmp = numpy.absolute(rcircObject.getArray()) ratio = numpy.divide(lcircAmp, rcircAmp) product = numpy.multiply(lcircAmp, rcircAmp) @@ -70,10 +69,8 @@ def analyze(self, lcircItemIndex: int, rcircItemIndex: int) -> XMCDResult: ) return XMCDResult( - pixel_width_m=rcircObject.pixelWidthInMeters, - pixel_height_m=rcircObject.pixelHeightInMeters, - center_x_m=rcircObject.centerXInMeters, - center_y_m=rcircObject.centerYInMeters, + pixel_geometry=rcircObject.getPixelGeometry(), + center=rcircObject.getCenter(), polar_difference=polar_difference, polar_sum=polar_sum, polar_ratio=polar_ratio, @@ -85,21 +82,23 @@ def getSaveFileFilterList(self) -> Sequence[str]: def getSaveFileFilter(self) -> str: return 'NumPy Zipped Archive (*.npz)' - def saveResult(self, filePath: Path, result: XMCDResult) -> None: - numpy.savez( - filePath, - 'pixel_height_m', - result.pixel_height_m, - 'pixel_width_m', - result.pixel_width_m, - 'center_x_m', - result.center_x_m, - 'center_y_m', - result.center_y_m, - 'polar_difference', - result.polar_difference, - 'polar_sum', - result.polar_sum, - 'polar_ratio', - result.polar_ratio, - ) + def saveResult(self, file_path: Path, result: XMCDResult) -> None: + contents: dict[str, Any] = { + 'polar_difference': result.polar_difference, + 'polar_sum': result.polar_sum, + 'polar_ratio': result.polar_ratio, + } + + pixel_geometry = result.pixel_geometry + + if pixel_geometry is not None: + contents['pixel_height_m'] = pixel_geometry.heightInMeters + contents['pixel_width_m'] = pixel_geometry.widthInMeters + + center = result.center + + if center is not None: + contents['center_x_m'] = center.positionXInMeters + contents['center_y_m'] = center.positionYInMeters + + numpy.savez(file_path, **contents) diff --git a/src/ptychodus/model/core.py b/src/ptychodus/model/core.py index e8502692..fc980b53 100644 --- a/src/ptychodus/model/core.py +++ b/src/ptychodus/model/core.py @@ -55,6 +55,7 @@ ScanAPI, ScanRepository, ) +from .ptychi import PtyChiReconstructorLibrary from .ptychonn import PtychoNNReconstructorLibrary from .reconstructor import ReconstructorCore, ReconstructorPresenter from .tike import TikeReconstructorLibrary @@ -78,7 +79,6 @@ def configureLogger(isDeveloperModeEnabled: bool) -> None: level=logging.DEBUG if isDeveloperModeEnabled else logging.INFO, ) logging.getLogger('matplotlib').setLevel(logging.WARNING) - logging.getLogger('tike').setLevel(logging.WARNING) logger.info(f'Ptychodus {version("ptychodus")}') logger.info(f'NumPy {version("numpy")}') @@ -127,6 +127,9 @@ def __init__( self.probeVisualizationEngine = VisualizationEngine(isComplex=True) self.objectVisualizationEngine = VisualizationEngine(isComplex=True) + self.ptyChiReconstructorLibrary = PtyChiReconstructorLibrary( + self.settingsRegistry, self.detector, isDeveloperModeEnabled + ) self.tikeReconstructorLibrary = TikeReconstructorLibrary.createInstance( self.settingsRegistry, isDeveloperModeEnabled ) @@ -138,6 +141,7 @@ def __init__( self._patternsCore.dataset, self._productCore.productRepository, [ + self.ptyChiReconstructorLibrary, self.tikeReconstructorLibrary, self.ptychonnReconstructorLibrary, ], @@ -176,6 +180,7 @@ def __init__( def __enter__(self) -> ModelCore: self._patternsCore.start() + self._reconstructorCore.start() self._workflowCore.start() self._automationCore.start() return self @@ -199,6 +204,7 @@ def __exit__( ) -> None: self._automationCore.stop() self._workflowCore.stop() + self._reconstructorCore.stop() self._patternsCore.stop() @property @@ -283,31 +289,37 @@ def refreshActiveDataset(self) -> None: def batchModeExecute( self, action: str, - inputFilePath: Path, - outputFilePath: Path, + inputPath: Path, + outputPath: Path, *, + productFileType: str = 'NPZ', fluorescenceInputFilePath: Path | None = None, fluorescenceOutputFilePath: Path | None = None, ) -> int: # TODO add enum for actions; implement using workflow API - inputProductIndex = self._productCore.productAPI.openProduct(inputFilePath, fileType='NPZ') + if action.lower() == 'train': + output = self._reconstructorCore.reconstructorAPI.train(inputPath) + self._reconstructorCore.reconstructorAPI.saveModel(outputPath) + return output.result + + inputProductIndex = self._productCore.productAPI.openProduct( + inputPath, fileType=productFileType + ) if inputProductIndex < 0: - logger.error(f'Failed to open product "{inputFilePath}"') + logger.error(f'Failed to open product "{inputPath}"!') return -1 if action.lower() == 'reconstruct': - outputProductName = self._productCore.productAPI.getItemName(inputProductIndex) + logger.info('Reconstructing...') outputProductIndex = self._reconstructorCore.reconstructorAPI.reconstruct( - inputProductIndex, outputProductName + inputProductIndex ) - - if outputProductIndex < 0: - logger.error(f'Failed to reconstruct product index="{inputProductIndex}"') - return -1 + self._reconstructorCore.reconstructorAPI.processResults(block=True) + logger.info('Reconstruction complete.') self._productCore.productAPI.saveProduct( - outputProductIndex, outputFilePath, fileType='NPZ' + outputProductIndex, outputPath, fileType=productFileType ) if fluorescenceInputFilePath is not None and fluorescenceOutputFilePath is not None: @@ -316,11 +328,10 @@ def batchModeExecute( fluorescenceInputFilePath, fluorescenceOutputFilePath, ) - - elif action.lower() == 'train': - self._reconstructorCore.reconstructorAPI.ingestTrainingData(inputProductIndex) - _ = self._reconstructorCore.reconstructorAPI.train() - self._reconstructorCore.reconstructorAPI.saveModel(outputFilePath) + elif action.lower() == 'prepare_training_data': + self._reconstructorCore.reconstructorAPI.exportTrainingData( + outputPath, inputProductIndex + ) else: logger.error(f'Unknown batch mode action "{action}"!') return -1 diff --git a/src/ptychodus/model/fluorescence/two_step.py b/src/ptychodus/model/fluorescence/two_step.py index 28c08e98..25e460a1 100644 --- a/src/ptychodus/model/fluorescence/two_step.py +++ b/src/ptychodus/model/fluorescence/two_step.py @@ -50,6 +50,7 @@ def __init__( reinitObservable.addObserver(self) def enhance(self, dataset: FluorescenceDataset, product: Product) -> FluorescenceDataset: + # FIXME OPR upscaler = self._upscalingStrategyChooser.currentPlugin.strategy deconvolver = self._deconvolutionStrategyChooser.currentPlugin.strategy element_maps: list[ElementMap] = list() diff --git a/src/ptychodus/model/fluorescence/vspi.py b/src/ptychodus/model/fluorescence/vspi.py index 15247c66..c1cba215 100644 --- a/src/ptychodus/model/fluorescence/vspi.py +++ b/src/ptychodus/model/fluorescence/vspi.py @@ -131,6 +131,7 @@ def __init__(self, settings: FluorescenceSettings) -> None: settings.vspiMaxIterations.addObserver(self) def enhance(self, dataset: FluorescenceDataset, product: Product) -> FluorescenceDataset: + # FIXME OPR object_geometry = product.object_.getGeometry() e_cps_shape = object_geometry.heightInPixels, object_geometry.widthInPixels element_maps: list[ElementMap] = list() diff --git a/src/ptychodus/model/patterns/builder.py b/src/ptychodus/model/patterns/builder.py index f677dc89..ff2bfcd2 100644 --- a/src/ptychodus/model/patterns/builder.py +++ b/src/ptychodus/model/patterns/builder.py @@ -34,7 +34,7 @@ def __init__(self, settings: PatternSettings, dataset: ActiveDiffractionDataset) @property def isAssembling(self) -> bool: - return len(self._workers) > 0 + return self._arrayQueue.unfinished_tasks > 0 def _getArrayAndAssemble(self) -> None: while not self._stopWorkEvent.is_set(): diff --git a/src/ptychodus/model/patterns/core.py b/src/ptychodus/model/patterns/core.py index f0d22570..e9167d3c 100644 --- a/src/ptychodus/model/patterns/core.py +++ b/src/ptychodus/model/patterns/core.py @@ -149,10 +149,8 @@ def __init__( fileWriterChooser.setCurrentPluginByName(self.patternSettings.fileType.getValue()) # TODO ^^^^^^^^^^^^^^^^ - self.patternSizer = PatternSizer.createInstance(self.patternSettings, self.detector) - self.patternPresenter = DiffractionPatternPresenter.createInstance( - self.patternSettings, self.patternSizer - ) + self.patternSizer = PatternSizer(self.patternSettings, self.detector) + self.patternPresenter = DiffractionPatternPresenter(self.patternSettings, self.patternSizer) self.dataset = ActiveDiffractionDataset(self.patternSettings, self.patternSizer) self._builder = ActiveDiffractionDatasetBuilder(self.patternSettings, self.dataset) diff --git a/src/ptychodus/model/patterns/detector.py b/src/ptychodus/model/patterns/detector.py index 662afd45..2ccefbf8 100644 --- a/src/ptychodus/model/patterns/detector.py +++ b/src/ptychodus/model/patterns/detector.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from ptychodus.api.geometry import ImageExtent, PixelGeometry from ptychodus.api.observer import Observable, Observer from ptychodus.api.settings import SettingsRegistry @@ -12,25 +10,19 @@ def __init__(self, registry: SettingsRegistry) -> None: self._settingsGroup.addObserver(self) self.widthInPixels = self._settingsGroup.createIntegerParameter( - 'WidthInPixels', 1024, minimum=0 + 'WidthInPixels', 1024, minimum=1 ) self.pixelWidthInMeters = self._settingsGroup.createRealParameter( 'PixelWidthInMeters', 75e-6, minimum=0.0 ) self.heightInPixels = self._settingsGroup.createIntegerParameter( - 'HeightInPixels', 1024, minimum=0 + 'HeightInPixels', 1024, minimum=1 ) self.pixelHeightInMeters = self._settingsGroup.createRealParameter( 'PixelHeightInMeters', 75e-6, minimum=0.0 ) self.bitDepth = self._settingsGroup.createIntegerParameter('BitDepth', 8, minimum=1) - def getImageExtent(self) -> ImageExtent: - return ImageExtent( - widthInPixels=self.widthInPixels.getValue(), - heightInPixels=self.heightInPixels.getValue(), - ) - def setImageExtent(self, imageExtent: ImageExtent) -> None: self.widthInPixels.setValue(imageExtent.widthInPixels) self.heightInPixels.setValue(imageExtent.heightInPixels) diff --git a/src/ptychodus/model/patterns/patterns.py b/src/ptychodus/model/patterns/patterns.py index 181c7d86..26b36b1e 100644 --- a/src/ptychodus/model/patterns/patterns.py +++ b/src/ptychodus/model/patterns/patterns.py @@ -16,13 +16,7 @@ def __init__(self, settings: PatternSettings, sizer: PatternSizer) -> None: self._settings = settings self._sizer = sizer - @classmethod - def createInstance( - cls, settings: PatternSettings, sizer: PatternSizer - ) -> DiffractionPatternPresenter: - presenter = cls(settings, sizer) - sizer.addObserver(presenter) - return presenter + sizer.addObserver(self) def isCropEnabled(self) -> bool: return self._sizer.isCropEnabled() diff --git a/src/ptychodus/model/patterns/settings.py b/src/ptychodus/model/patterns/settings.py index 2a7ed846..60c1d250 100644 --- a/src/ptychodus/model/patterns/settings.py +++ b/src/ptychodus/model/patterns/settings.py @@ -10,7 +10,7 @@ def __init__(self, registry: SettingsRegistry) -> None: self._settingsGroup = registry.createGroup('Patterns') self._settingsGroup.addObserver(self) - self.fileType = self._settingsGroup.createStringParameter('FileType', 'HDF5') + self.fileType = self._settingsGroup.createStringParameter('FileType', 'NeXus') self.filePath = self._settingsGroup.createPathParameter( 'FilePath', Path('/path/to/data.h5') ) @@ -19,30 +19,37 @@ def __init__(self, registry: SettingsRegistry) -> None: 'ScratchDirectory', Path.home() / '.ptychodus' ) self.numberOfDataThreads = self._settingsGroup.createIntegerParameter( - 'NumberOfDataThreads', 8 + 'NumberOfDataThreads', 8, minimum=1 ) self.cropEnabled = self._settingsGroup.createBooleanParameter('CropEnabled', True) self.cropCenterXInPixels = self._settingsGroup.createIntegerParameter( - 'CropCenterXInPixels', 32 + 'CropCenterXInPixels', 32, minimum=0 ) self.cropCenterYInPixels = self._settingsGroup.createIntegerParameter( - 'CropCenterYInPixels', 32 + 'CropCenterYInPixels', 32, minimum=0 + ) + self.cropWidthInPixels = self._settingsGroup.createIntegerParameter( + 'CropWidthInPixels', 64, minimum=1 ) - self.cropWidthInPixels = self._settingsGroup.createIntegerParameter('CropWidthInPixels', 64) self.cropHeightInPixels = self._settingsGroup.createIntegerParameter( - 'CropHeightInPixels', 64 + 'CropHeightInPixels', 64, minimum=1 ) + # TODO ExtraPaddingXY self.flipXEnabled = self._settingsGroup.createBooleanParameter('FlipXEnabled', False) self.flipYEnabled = self._settingsGroup.createBooleanParameter('FlipYEnabled', False) self.valueLowerBoundEnabled = self._settingsGroup.createBooleanParameter( 'ValueLowerBoundEnabled', False ) - self.valueLowerBound = self._settingsGroup.createIntegerParameter('ValueLowerBound', 0) + self.valueLowerBound = self._settingsGroup.createIntegerParameter( + 'ValueLowerBound', 0, minimum=0 + ) self.valueUpperBoundEnabled = self._settingsGroup.createBooleanParameter( 'ValueUpperBoundEnabled', False ) - self.valueUpperBound = self._settingsGroup.createIntegerParameter('ValueUpperBound', 65535) + self.valueUpperBound = self._settingsGroup.createIntegerParameter( + 'ValueUpperBound', 65535, minimum=0 + ) def update(self, observable: Observable) -> None: if observable is self._settingsGroup: @@ -63,8 +70,8 @@ def __init__(self, registry: SettingsRegistry) -> None: self.probeEnergyInElectronVolts = self._settingsGroup.createRealParameter( 'ProbeEnergyInElectronVolts', 10000.0, minimum=0.0 ) - self.probePhotonsPerSecond = self._settingsGroup.createRealParameter( - 'ProbePhotonsPerSecond', 0.0, minimum=0.0 + self.probePhotonCount = self._settingsGroup.createRealParameter( + 'ProbePhotonCount', 0.0, minimum=0.0 ) self.exposureTimeInSeconds = self._settingsGroup.createRealParameter( 'ExposureTimeInSeconds', 0.0, minimum=0.0 diff --git a/src/ptychodus/model/patterns/sizer.py b/src/ptychodus/model/patterns/sizer.py index 11011c06..3839748a 100644 --- a/src/ptychodus/model/patterns/sizer.py +++ b/src/ptychodus/model/patterns/sizer.py @@ -16,13 +16,9 @@ def __init__(self, settings: PatternSettings, detector: Detector) -> None: self._sliceX = slice(0) self._sliceY = slice(0) - @classmethod - def createInstance(cls, settings: PatternSettings, detector: Detector) -> PatternSizer: - sizer = cls(settings, detector) - sizer._updateSlicesAndNotifyObservers() - settings.addObserver(sizer) - detector.addObserver(sizer) - return sizer + self._updateSlicesAndNotifyObservers() + settings.addObserver(self) + detector.addObserver(self) def isCropEnabled(self) -> bool: return self._settings.cropEnabled.getValue() @@ -31,18 +27,17 @@ def setCropEnabled(self, value: bool) -> None: self._settings.cropEnabled.setValue(value) def getWidthLimitsInPixels(self) -> Interval[int]: - return Interval[int](1, self._detector.getImageExtent().widthInPixels) + return Interval[int](1, self._detector.widthInPixels.getValue()) def getWidthInPixels(self) -> int: - limitsInPixels = self.getWidthLimitsInPixels() - return ( - limitsInPixels.clamp(self._settings.cropWidthInPixels.getValue()) - if self.isCropEnabled() - else limitsInPixels.upper - ) + if self.isCropEnabled(): + limits = self.getWidthLimitsInPixels() + return limits.clamp(self._settings.cropWidthInPixels.getValue()) + + return self._detector.widthInPixels.getValue() def getCenterXLimitsInPixels(self) -> Interval[int]: - return Interval[int](0, self._detector.getImageExtent().widthInPixels) + return Interval[int](0, self._detector.widthInPixels.getValue()) def getCenterXInPixels(self) -> int: limitsInPixels = self.getCenterXLimitsInPixels() @@ -54,7 +49,7 @@ def getCenterXInPixels(self) -> int: def _getSafeCenterXInPixels(self) -> int: lower = self.getWidthInPixels() // 2 - upper = self._detector.getImageExtent().widthInPixels - 1 - lower + upper = self._detector.widthInPixels.getValue() - 1 - lower limits = Interval[int](lower, upper) return limits.clamp(self.getCenterXInPixels()) @@ -65,18 +60,17 @@ def getWidthInMeters(self) -> float: return self.getWidthInPixels() * self.getPixelWidthInMeters() def getHeightLimitsInPixels(self) -> Interval[int]: - return Interval[int](1, self._detector.getImageExtent().heightInPixels) + return Interval[int](1, self._detector.heightInPixels.getValue()) def getHeightInPixels(self) -> int: - limitsInPixels = self.getHeightLimitsInPixels() - return ( - limitsInPixels.clamp(self._settings.cropHeightInPixels.getValue()) - if self.isCropEnabled() - else limitsInPixels.upper - ) + if self.isCropEnabled(): + limits = self.getHeightLimitsInPixels() + return limits.clamp(self._settings.cropHeightInPixels.getValue()) + + return self._detector.heightInPixels.getValue() def getCenterYLimitsInPixels(self) -> Interval[int]: - return Interval[int](0, self._detector.getImageExtent().heightInPixels) + return Interval[int](0, self._detector.heightInPixels.getValue()) def getCenterYInPixels(self) -> int: limitsInPixels = self.getCenterYLimitsInPixels() @@ -88,7 +82,7 @@ def getCenterYInPixels(self) -> int: def _getSafeCenterYInPixels(self) -> int: lower = self.getHeightInPixels() // 2 - upper = self._detector.getImageExtent().heightInPixels - 1 - lower + upper = self._detector.heightInPixels.getValue() - 1 - lower limits = Interval[int](lower, upper) return limits.clamp(self.getCenterYInPixels()) diff --git a/src/ptychodus/model/product/api.py b/src/ptychodus/model/product/api.py index ee7aefe5..532b9f69 100644 --- a/src/ptychodus/model/product/api.py +++ b/src/ptychodus/model/product/api.py @@ -55,8 +55,7 @@ def buildScan( parameter = builder.parameters()[parameterName] except KeyError: logger.warning( - f'Scan builder "{builder.getName()}" does not have' - f' parameter "{parameterName}"!' + f'Scan builder "{builder.getName()}" does not have parameter "{parameterName}"!' ) else: parameter.setValue(parameterValue) @@ -365,7 +364,7 @@ def insertNewProduct( comments: str = '', detectorDistanceInMeters: float | None = None, probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, + probePhotonCount: float | None = None, exposureTimeInSeconds: float | None = None, likeIndex: int = -1, ) -> int: @@ -374,7 +373,7 @@ def insertNewProduct( comments=comments, detectorDistanceInMeters=detectorDistanceInMeters, probeEnergyInElectronVolts=probeEnergyInElectronVolts, - probePhotonsPerSecond=probePhotonsPerSecond, + probePhotonCount=probePhotonCount, exposureTimeInSeconds=exposureTimeInSeconds, likeIndex=likeIndex, ) diff --git a/src/ptychodus/model/product/metadata.py b/src/ptychodus/model/product/metadata.py index 7cd517b3..73540e43 100644 --- a/src/ptychodus/model/product/metadata.py +++ b/src/ptychodus/model/product/metadata.py @@ -26,7 +26,7 @@ def __init__( comments: str = '', detectorDistanceInMeters: float | None = None, probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, + probePhotonCount: float | None = None, exposureTimeInSeconds: float | None = None, ) -> None: super().__init__() @@ -52,12 +52,12 @@ def __init__( self._addParameter('probe_energy_eV', self.probeEnergyInElectronVolts) - self.probePhotonsPerSecond = settings.probePhotonsPerSecond.copy() + self.probePhotonCount = settings.probePhotonCount.copy() - if probePhotonsPerSecond is not None: - self.probePhotonsPerSecond.setValue(probePhotonsPerSecond) + if probePhotonCount is not None: + self.probePhotonCount.setValue(probePhotonCount) - self._addParameter('probe_photons_per_second', self.probePhotonsPerSecond) + self._addParameter('probe_photon_count', self.probePhotonCount) self.exposureTimeInSeconds = settings.exposureTimeInSeconds.copy() @@ -73,7 +73,7 @@ def assignItem(self, item: MetadataRepositoryItem, *, notify: bool = True) -> No self.comments.setValue(item.comments.getValue()) self.detectorDistanceInMeters.setValue(item.detectorDistanceInMeters.getValue()) self.probeEnergyInElectronVolts.setValue(item.probeEnergyInElectronVolts.getValue()) - self.probePhotonsPerSecond.setValue(item.probePhotonsPerSecond.getValue()) + self.probePhotonCount.setValue(item.probePhotonCount.getValue()) self.exposureTimeInSeconds.setValue(item.exposureTimeInSeconds.getValue()) def assign(self, metadata: ProductMetadata) -> None: @@ -81,7 +81,7 @@ def assign(self, metadata: ProductMetadata) -> None: self.comments.setValue(metadata.comments) self.detectorDistanceInMeters.setValue(metadata.detectorDistanceInMeters) self.probeEnergyInElectronVolts.setValue(metadata.probeEnergyInElectronVolts) - self.probePhotonsPerSecond.setValue(metadata.probePhotonsPerSecond) + self.probePhotonCount.setValue(metadata.probePhotonCount) self.exposureTimeInSeconds.setValue(metadata.exposureTimeInSeconds) def syncToSettings(self) -> None: @@ -110,6 +110,6 @@ def getMetadata(self) -> ProductMetadata: comments=self.comments.getValue(), detectorDistanceInMeters=self.detectorDistanceInMeters.getValue(), probeEnergyInElectronVolts=self.probeEnergyInElectronVolts.getValue(), - probePhotonsPerSecond=self.probePhotonsPerSecond.getValue(), + probePhotonCount=self.probePhotonCount.getValue(), exposureTimeInSeconds=self.exposureTimeInSeconds.getValue(), ) diff --git a/src/ptychodus/model/product/metadataFactory.py b/src/ptychodus/model/product/metadataFactory.py index 45e86cce..a64ae278 100644 --- a/src/ptychodus/model/product/metadataFactory.py +++ b/src/ptychodus/model/product/metadataFactory.py @@ -28,7 +28,7 @@ def create(self, metadata: ProductMetadata) -> MetadataRepositoryItem: comments=metadata.comments, detectorDistanceInMeters=metadata.detectorDistanceInMeters, probeEnergyInElectronVolts=metadata.probeEnergyInElectronVolts, - probePhotonsPerSecond=metadata.probePhotonsPerSecond, + probePhotonCount=metadata.probePhotonCount, exposureTimeInSeconds=metadata.exposureTimeInSeconds, ) @@ -39,7 +39,7 @@ def createDefault( comments: str = '', detectorDistanceInMeters: float | None = None, probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, + probePhotonCount: float | None = None, exposureTimeInSeconds: float | None = None, ) -> MetadataRepositoryItem: return MetadataRepositoryItem( @@ -49,7 +49,7 @@ def createDefault( comments=comments, detectorDistanceInMeters=detectorDistanceInMeters, probeEnergyInElectronVolts=probeEnergyInElectronVolts, - probePhotonsPerSecond=probePhotonsPerSecond, + probePhotonCount=probePhotonCount, exposureTimeInSeconds=exposureTimeInSeconds, ) diff --git a/src/ptychodus/model/product/object/builder.py b/src/ptychodus/model/product/object/builder.py index 7a63de7b..71248344 100644 --- a/src/ptychodus/model/product/object/builder.py +++ b/src/ptychodus/model/product/object/builder.py @@ -85,8 +85,24 @@ def build( logger.debug(f'Reading "{filePath}" as "{fileType}"') try: - object_ = self._fileReader.read(filePath) + objectFromFile = self._fileReader.read(filePath) except Exception as exc: raise RuntimeError(f'Failed to read "{filePath}"') from exc - return object_ + objectGeometry = geometryProvider.getObjectGeometry() + pixelGeometry = objectFromFile.getPixelGeometry() + center = objectFromFile.getCenter() + + if pixelGeometry is None: + pixelGeometry = objectGeometry.getPixelGeometry() + + if center is None: + center = objectGeometry.getCenter() + + # TODO remap object from pixelGeometryFromFile to pixelGeometryFromProvider + return Object( + objectFromFile.getArray(), + pixelGeometry, + center, + objectFromFile.layerDistanceInMeters, + ) diff --git a/src/ptychodus/model/product/object/item.py b/src/ptychodus/model/product/object/item.py index d4e02bce..306cb873 100644 --- a/src/ptychodus/model/product/object/item.py +++ b/src/ptychodus/model/product/object/item.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import numpy from ptychodus.api.object import Object, ObjectGeometryProvider from ptychodus.api.observer import Observable @@ -24,17 +23,18 @@ def __init__( self._geometryProvider = geometryProvider self._settings = settings self._builder = builder - self._object = Object() + self._object = Object(array=None, pixelGeometry=None, center=None) - self._addGroup('builder', builder, observe=True) - # TODO sync layer distance to/from settings - self.layerDistanceInMeters = self.createRealArrayParameter('layer_distance_m', [numpy.inf]) + self.layerDistanceInMeters = settings.objectLayerDistanceInMeters.copy() + self._addParameter('layer_distance_m', self.layerDistanceInMeters) + self._addGroup('builder', builder, observe=True) self._rebuild() def assignItem(self, item: ObjectRepositoryItem) -> None: self.layerDistanceInMeters.setValue(item.layerDistanceInMeters.getValue(), notify=False) self.setBuilder(item.getBuilder().copy()) + self._rebuild() def assign(self, object_: Object) -> None: builder = FromMemoryObjectBuilder(self._settings, object_) @@ -47,23 +47,23 @@ def syncToSettings(self) -> None: self._builder.syncToSettings() def getNumberOfLayers(self) -> int: - return len(self.layerDistanceInMeters) + return len(self.layerDistanceInMeters) + 1 - def setNumberOfLayers(self, number: int) -> None: - numRequested = max(1, number) + def setNumberOfLayers(self, numberOfLayers: int) -> None: + numberOfSpaces = max(0, numberOfLayers - 1) distanceInMeters = list(self.layerDistanceInMeters.getValue()) - numExisting = len(distanceInMeters) - defaultDistanceInMeters = float(self._settings.objectLayerDistanceInMeters.getValue()) - - if numExisting < 2: - distanceInMeters = [defaultDistanceInMeters] * numRequested - elif numExisting < numRequested: - distanceInMeters[-1] = distanceInMeters[-2] # overwrite inf - distanceInMeters.extend(distanceInMeters[-1:] * (numRequested - numExisting)) - elif numExisting > numRequested: - distanceInMeters = distanceInMeters[:numRequested] - - distanceInMeters[-1] = numpy.inf + + try: + defaultDistanceInMeters = distanceInMeters[-1] + except IndexError: + defaultDistanceInMeters = 0.0 + + while len(distanceInMeters) < numberOfSpaces: + distanceInMeters.append(defaultDistanceInMeters) + + if len(distanceInMeters) > numberOfSpaces: + distanceInMeters = distanceInMeters[:numberOfSpaces] + self.layerDistanceInMeters.setValue(distanceInMeters) self._rebuild() @@ -82,13 +82,10 @@ def setBuilder(self, builder: ObjectBuilder) -> None: self._rebuild() def _rebuild(self) -> None: - layerDistanceInMeters = list(self.layerDistanceInMeters.getValue()) - - if len(layerDistanceInMeters) < 1: - layerDistanceInMeters.append(numpy.inf) - try: - object_ = self._builder.build(self._geometryProvider, layerDistanceInMeters) + object_ = self._builder.build( + self._geometryProvider, self.layerDistanceInMeters.getValue() + ) except Exception as exc: logger.error(''.join(exc.args)) return diff --git a/src/ptychodus/model/product/object/random.py b/src/ptychodus/model/product/object/random.py index 0a0b4982..ef6b717e 100644 --- a/src/ptychodus/model/product/object/random.py +++ b/src/ptychodus/model/product/object/random.py @@ -4,6 +4,7 @@ import numpy from ptychodus.api.object import Object, ObjectGeometryProvider +from ptychodus.model.analysis.phaseUnwrapper import PhaseUnwrapper from .builder import ObjectBuilder from .settings import ObjectSettings @@ -44,7 +45,7 @@ def build( geometry = geometryProvider.getObjectGeometry() heightInPixels = geometry.heightInPixels + 2 * self.extraPaddingY.getValue() widthInPixels = geometry.widthInPixels + 2 * self.extraPaddingX.getValue() - objectShape = (len(layerDistanceInMeters), heightInPixels, widthInPixels) + objectShape = (1 + len(layerDistanceInMeters), heightInPixels, widthInPixels) amplitude = self._rng.normal( self.amplitudeMean.getValue(), @@ -60,8 +61,55 @@ def build( return Object( array=numpy.clip(amplitude, 0.0, 1.0) * numpy.exp(1j * phase), layerDistanceInMeters=layerDistanceInMeters, - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, - centerXInMeters=geometry.centerXInMeters, - centerYInMeters=geometry.centerYInMeters, + pixelGeometry=geometry.getPixelGeometry(), + center=geometry.getCenter(), + ) + + +class UserObjectBuilder(ObjectBuilder): # FIXME use + def __init__(self, object_: Object, settings: ObjectSettings) -> None: + """Create an object from an existing object with a potentially + different number of slices. + + If the new object is supposed to be a multislice object with a + different number of slices than the existing object, the object is + created as + `abs(o) ** (1 / nSlices) * exp(i * unwrapPhase(o) / nSlices)`. + Otherwise, the object is copied as is. + """ + super().__init__(settings, 'user') + self._existingObject = object_ + self._settings = settings + + def copy(self) -> UserObjectBuilder: + builder = UserObjectBuilder(self._existingObject, self._settings) + + for key, value in self.parameters().items(): + builder.parameters()[key].setValue(value.getValue()) + + return builder + + def build( + self, + geometryProvider: ObjectGeometryProvider, + layerDistanceInMeters: Sequence[float], + ) -> Object: + geometry = self._existingObject.getGeometry() + exitingObjectArr = self._existingObject.getArray() + nSlices = len(layerDistanceInMeters) + 1 + + if nSlices > 1 and nSlices != exitingObjectArr.shape[0]: + amplitude = numpy.abs(exitingObjectArr[0:1]) ** (1.0 / nSlices) + amplitude = amplitude.repeat(nSlices, axis=0) + phase = PhaseUnwrapper().unwrap(exitingObjectArr[0])[None, ...] / nSlices + phase = phase.repeat(nSlices, axis=0) + data = numpy.clip(amplitude, 0.0, 1.0) * numpy.exp(1j * phase) + else: + data = exitingObjectArr + + return Object( + array=data, + layerDistanceInMeters=layerDistanceInMeters, + pixelGeometry=geometry.getPixelGeometry(), + center=geometry.getCenter(), ) diff --git a/src/ptychodus/model/product/object/settings.py b/src/ptychodus/model/product/object/settings.py index 15780279..82ab93aa 100644 --- a/src/ptychodus/model/product/object/settings.py +++ b/src/ptychodus/model/product/object/settings.py @@ -18,8 +18,8 @@ def __init__(self, registry: SettingsRegistry) -> None: ) self.fileType = self._settingsGroup.createStringParameter('FileType', 'NPY') - self.objectLayerDistanceInMeters = self._settingsGroup.createRealParameter( - 'ObjectLayerDistanceInMeters', 1e-7 + self.objectLayerDistanceInMeters = self._settingsGroup.createRealSequenceParameter( + 'ObjectLayerDistanceInMeters', [] ) self.extraPaddingX = self._settingsGroup.createIntegerParameter( diff --git a/src/ptychodus/model/product/probe/averagePattern.py b/src/ptychodus/model/product/probe/averagePattern.py index 30bc42da..ea0260c0 100644 --- a/src/ptychodus/model/product/probe/averagePattern.py +++ b/src/ptychodus/model/product/probe/averagePattern.py @@ -40,6 +40,5 @@ def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: return Probe( array=self.normalize(array), - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, + pixelGeometry=geometry.getPixelGeometry(), ) diff --git a/src/ptychodus/model/product/probe/builder.py b/src/ptychodus/model/product/probe/builder.py index b8fec604..40110ce6 100644 --- a/src/ptychodus/model/product/probe/builder.py +++ b/src/ptychodus/model/product/probe/builder.py @@ -53,7 +53,7 @@ def getTransverseCoordinates(self, geometry: ProbeGeometry) -> ProbeTransverseCo ) def normalize(self, array: WavefieldArrayType) -> WavefieldArrayType: - return array / numpy.sqrt(numpy.sum(numpy.abs(array) ** 2)) + return array / numpy.sqrt(numpy.sum(numpy.square(numpy.abs(array)))) def getName(self) -> str: return self._name.getValue() @@ -109,8 +109,15 @@ def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: logger.debug(f'Reading "{filePath}" as "{fileType}"') try: - probe = self._fileReader.read(filePath) + probeFromFile = self._fileReader.read(filePath) except Exception as exc: raise RuntimeError(f'Failed to read "{filePath}"') from exc - return probe + pixelGeometryFromFile = probeFromFile.getPixelGeometry() + pixelGeometryFromProvider = geometryProvider.getProbeGeometry().getPixelGeometry() + + if pixelGeometryFromFile is None: + return Probe(probeFromFile.getArray(), pixelGeometryFromProvider) + + # TODO remap probe from pixelGeometryFromFile to pixelGeometryFromProvider + return probeFromFile diff --git a/src/ptychodus/model/product/probe/disk.py b/src/ptychodus/model/product/probe/disk.py index d4630b8f..4f44cf31 100644 --- a/src/ptychodus/model/product/probe/disk.py +++ b/src/ptychodus/model/product/probe/disk.py @@ -50,6 +50,5 @@ def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: return Probe( array=self.normalize(array), - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, + pixelGeometry=geometry.getPixelGeometry(), ) diff --git a/src/ptychodus/model/product/probe/fzp.py b/src/ptychodus/model/product/probe/fzp.py index adaf919f..70d9b2f0 100644 --- a/src/ptychodus/model/product/probe/fzp.py +++ b/src/ptychodus/model/product/probe/fzp.py @@ -100,6 +100,5 @@ def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: return Probe( array=self.normalize(array), - pixelWidthInMeters=samplePlaneGeometry.pixelWidthInMeters, - pixelHeightInMeters=samplePlaneGeometry.pixelHeightInMeters, + pixelGeometry=samplePlaneGeometry.getPixelGeometry(), ) diff --git a/src/ptychodus/model/product/probe/item.py b/src/ptychodus/model/product/probe/item.py index b3f58615..10dd05b8 100644 --- a/src/ptychodus/model/product/probe/item.py +++ b/src/ptychodus/model/product/probe/item.py @@ -25,7 +25,7 @@ def __init__( self._settings = settings self._builder = builder self._additionalModesBuilder = additionalModesBuilder - self._probe = Probe() + self._probe = Probe(array=None, pixelGeometry=None) self._addGroup('builder', builder, observe=True) self._addGroup('additional_modes', additionalModesBuilder, observe=True) @@ -65,6 +65,7 @@ def setBuilder(self, builder: ProbeBuilder) -> None: self._builder = builder self._builder.addObserver(self) self._addGroup('builder', self._builder, observe=True) + self._rebuild() def _rebuild(self) -> None: try: @@ -73,7 +74,7 @@ def _rebuild(self) -> None: logger.error(''.join(exc.args)) return - self._probe = self._additionalModesBuilder.build(probe) + self._probe = self._additionalModesBuilder.build(probe, self._geometryProvider) self.notifyObservers() def getAdditionalModesBuilder(self) -> MultimodalProbeBuilder: diff --git a/src/ptychodus/model/product/probe/multimodal.py b/src/ptychodus/model/product/probe/multimodal.py index 1563ad6e..c91f25f0 100644 --- a/src/ptychodus/model/product/probe/multimodal.py +++ b/src/ptychodus/model/product/probe/multimodal.py @@ -6,10 +6,8 @@ import numpy import scipy.linalg -from ptychodus.api.parametric import ( - ParameterGroup, -) -from ptychodus.api.probe import Probe, WavefieldArrayType +from ptychodus.api.parametric import ParameterGroup +from ptychodus.api.probe import Probe, ProbeGeometryProvider, WavefieldArrayType from .settings import ProbeSettings @@ -27,17 +25,17 @@ def __init__(self, rng: numpy.random.Generator, settings: ProbeSettings) -> None self._rng = rng self._settings = settings - self.numberOfModes = settings.numberOfModes.copy() - self._addParameter('number_of_modes', self.numberOfModes) + self.numberOfIncoherentModes = settings.numberOfIncoherentModes.copy() + self._addParameter('number_of_incoherent_modes', self.numberOfIncoherentModes) - self.modeDecayType = settings.modeDecayType.copy() - self._addParameter('mode_decay_type', self.modeDecayType) + self.incoherentModeDecayType = settings.incoherentModeDecayType.copy() + self._addParameter('incoherent_mode_decay_type', self.incoherentModeDecayType) - self.modeDecayRatio = settings.modeDecayRatio.copy() - self._addParameter('mode_decay_ratio', self.modeDecayRatio) + self.incoherentModeDecayRatio = settings.incoherentModeDecayRatio.copy() + self._addParameter('incoherent_mode_decay_ratio', self.incoherentModeDecayRatio) - self.isOrthogonalizeModesEnabled = settings.isOrthogonalizeModesEnabled.copy() - self._addParameter('orthogonalize_modes', self.isOrthogonalizeModesEnabled) + self.orthogonalizeIncoherentModes = settings.orthogonalizeIncoherentModes.copy() + self._addParameter('orthogonalize_incoherent_modes', self.orthogonalizeIncoherentModes) def syncToSettings(self) -> None: for parameter in self.parameters().values(): @@ -51,84 +49,76 @@ def copy(self) -> MultimodalProbeBuilder: return builder - def _initializeModes(self, probe: WavefieldArrayType) -> WavefieldArrayType: - modeList: list[WavefieldArrayType] = list() + def _init_modes(self, probe: WavefieldArrayType) -> WavefieldArrayType: + # TODO OPR + assert probe.ndim == 4 + array = numpy.tile(probe[0, 0, :, :], (self.numberOfIncoherentModes.getValue(), 1, 1)) + it = iter(array) # iterate incoherent modes + next(it) # preserve the first incoherent mode - if probe.ndim == 2: - modeList.append(probe) - elif probe.ndim >= 3: - probe3D = probe - - while probe3D.ndim > 3: - probe3D = probe3D[0] - - for mode in probe3D: - modeList.append(mode) - else: - raise ValueError('Probe array must contain at least two dimensions.') - - for mode in range(self.numberOfModes.getValue() - 1): - # randomly shift the first mode + for imode in it: # phase shift the rest pw = probe.shape[-1] # TODO clean up variate1 = self._rng.uniform(size=(2, 1)) - 0.5 - variate2 = (numpy.arange(0, pw) + 0.5) / pw - 0.5 + variate2 = (numpy.arange(pw) + 0.5) / pw - 0.5 ps = numpy.exp(-2j * numpy.pi * variate1 * variate2) - phaseShift = ps[0][numpy.newaxis] * ps[1][:, numpy.newaxis] - mode = modeList[0] * phaseShift - modeList.append(mode) + imode *= ps[0][numpy.newaxis] * ps[1][:, numpy.newaxis] - return numpy.stack(modeList) + return array - def _orthogonalizeModes(self, probe: WavefieldArrayType) -> WavefieldArrayType: - probeModesAsRows = probe.reshape(probe.shape[-3], -1) - probeModesAsCols = probeModesAsRows.T - probeModesAsOrthoCols = scipy.linalg.orth(probeModesAsCols) - probeModesAsOrthoRows = probeModesAsOrthoCols.T - return probeModesAsOrthoRows.reshape(*probe.shape) + def _orthogonalizeIncoherentModes(self, probe: WavefieldArrayType) -> WavefieldArrayType: + # TODO OPR + imodes_as_rows = probe.reshape(probe.shape[-3], -1) + imodes_as_cols = imodes_as_rows.T + imodes_as_ortho_cols = scipy.linalg.orth(imodes_as_cols) + imodes_as_ortho_rows = imodes_as_ortho_cols.T + return imodes_as_ortho_rows.reshape(*probe.shape) - def _getModeWeights(self, totalNumberOfModes: int) -> Sequence[float]: - modeDecayTypeText = self.modeDecayType.getValue() - modeDecayRatio = self.modeDecayRatio.getValue() + def _get_imode_weights(self, num_imodes: int) -> Sequence[float]: + imode_decay_type_text = self.incoherentModeDecayType.getValue() + imode_decay_ratio = self.incoherentModeDecayRatio.getValue() - if modeDecayRatio > 0.0: + if imode_decay_ratio > 0.0: try: - modeDecayType = ProbeModeDecayType[modeDecayTypeText.upper()] + imode_decay_type = ProbeModeDecayType[imode_decay_type_text.upper()] except KeyError: - modeDecayType = ProbeModeDecayType.POLYNOMIAL + imode_decay_type = ProbeModeDecayType.POLYNOMIAL - if modeDecayType == ProbeModeDecayType.EXPONENTIAL: - b = 1.0 + (1.0 - modeDecayRatio) / modeDecayRatio - return [b**-n for n in range(totalNumberOfModes)] + if imode_decay_type == ProbeModeDecayType.EXPONENTIAL: + b = 1.0 + (1.0 - imode_decay_ratio) / imode_decay_ratio + return [b**-n for n in range(num_imodes)] else: - b = numpy.log(modeDecayRatio) / numpy.log(2.0) - return [(n + 1) ** b for n in range(totalNumberOfModes)] + b = numpy.log(imode_decay_ratio) / numpy.log(2.0) + return [(n + 1) ** b for n in range(num_imodes)] - return [1.0] + [0.0] * (totalNumberOfModes - 1) + return [1.0] + [0.0] * (num_imodes - 1) - def _adjustRelativePower(self, probe: WavefieldArrayType) -> WavefieldArrayType: - modeWeights = self._getModeWeights(probe.shape[-3]) - power0 = numpy.sum(numpy.square(numpy.abs(probe[0, ...]))) - adjustedProbe = probe.copy() + def _adjust_power(self, probe: WavefieldArrayType, power: float) -> WavefieldArrayType: + imode_weights = self._get_imode_weights(probe.shape[-3]) + array = probe.copy() + it = iter(array) # iterate incoherent modes - for modeIndex, weight in enumerate(modeWeights): - powerN = numpy.sum(numpy.square(numpy.abs(adjustedProbe[modeIndex, ...]))) - adjustedProbe[modeIndex, ...] *= numpy.sqrt(weight * power0 / powerN) + for weight in imode_weights: + imode = next(it) + ipower = numpy.sum(numpy.square(numpy.abs(imode))) + imode *= numpy.sqrt(weight * power / ipower) - return adjustedProbe + return array - def build(self, probe: Probe) -> Probe: - if self.numberOfModes.getValue() <= 1: + def build(self, probe: Probe, geometryProvider: ProbeGeometryProvider) -> Probe: + num_requested_modes = self.numberOfIncoherentModes.getValue() + num_existing_modes = probe.numberOfIncoherentModes + + if num_requested_modes <= num_existing_modes: return probe - array = self._initializeModes(probe.array) + array = self._init_modes(probe.getArray()) + + if self.orthogonalizeIncoherentModes.getValue(): + array = self._orthogonalizeIncoherentModes(array) - if self.isOrthogonalizeModesEnabled.getValue(): - array = self._orthogonalizeModes(array) + power = probe.getIntensity().sum() - array = self._adjustRelativePower(array) + if geometryProvider.probePhotonCount > 0.0: + power = geometryProvider.probePhotonCount - return Probe( - array, - pixelWidthInMeters=probe.pixelWidthInMeters, - pixelHeightInMeters=probe.pixelHeightInMeters, - ) + return Probe(self._adjust_power(array, power), probe.getPixelGeometry()) diff --git a/src/ptychodus/model/product/probe/rect.py b/src/ptychodus/model/product/probe/rect.py index 5715e834..57baaadb 100644 --- a/src/ptychodus/model/product/probe/rect.py +++ b/src/ptychodus/model/product/probe/rect.py @@ -57,6 +57,5 @@ def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: return Probe( array=self.normalize(array), - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, + pixelGeometry=geometry.getPixelGeometry(), ) diff --git a/src/ptychodus/model/product/probe/settings.py b/src/ptychodus/model/product/probe/settings.py index fed8ee7a..e90f6ac4 100644 --- a/src/ptychodus/model/product/probe/settings.py +++ b/src/ptychodus/model/product/probe/settings.py @@ -16,17 +16,20 @@ def __init__(self, registry: SettingsRegistry) -> None: ) self.fileType = self._settingsGroup.createStringParameter('FileType', 'NPY') - self.numberOfModes = self._settingsGroup.createIntegerParameter( - 'NumberOfModes', 1, minimum=1 + self.numberOfCoherentModes = self._settingsGroup.createIntegerParameter( + 'NumberOfCoherentModes', 1, minimum=1 ) - self.isOrthogonalizeModesEnabled = self._settingsGroup.createBooleanParameter( - 'OrthogonalizeModesEnabled', True + self.numberOfIncoherentModes = self._settingsGroup.createIntegerParameter( + 'NumberOfIncoherentModes', 1, minimum=1 ) - self.modeDecayType = self._settingsGroup.createStringParameter( - 'ModeDecayType', 'Polynomial' + self.orthogonalizeIncoherentModes = self._settingsGroup.createBooleanParameter( + 'OrthogonalizeIncoherentModes', True ) - self.modeDecayRatio = self._settingsGroup.createRealParameter( - 'ModeDecayRatio', 1.0, minimum=0.0, maximum=1.0 + self.incoherentModeDecayType = self._settingsGroup.createStringParameter( + 'IncoherentModeDecayType', 'Polynomial' + ) + self.incoherentModeDecayRatio = self._settingsGroup.createRealParameter( + 'IncoherentModeDecayRatio', 1.0, minimum=0.0, maximum=1.0 ) self.diskDiameterInMeters = self._settingsGroup.createRealParameter( diff --git a/src/ptychodus/model/product/probe/superGaussian.py b/src/ptychodus/model/product/probe/superGaussian.py index 9df3b494..9bed2558 100644 --- a/src/ptychodus/model/product/probe/superGaussian.py +++ b/src/ptychodus/model/product/probe/superGaussian.py @@ -41,6 +41,5 @@ def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: return Probe( array=self.normalize(numpy.exp(-numpy.log(2) * ZP) + 0j), - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, + pixelGeometry=geometry.getPixelGeometry(), ) diff --git a/src/ptychodus/model/product/probe/zernike.py b/src/ptychodus/model/product/probe/zernike.py index f2c03c4c..92fc1890 100644 --- a/src/ptychodus/model/product/probe/zernike.py +++ b/src/ptychodus/model/product/probe/zernike.py @@ -84,7 +84,7 @@ def __init__(self, settings: ProbeSettings) -> None: self._addParameter('diameter_m', self.diameterInMeters) # TODO init zernike coefficients from settings - self.coefficients = self.createComplexArrayParameter('coefficients', [1 + 0j]) + self.coefficients = self.createComplexSequenceParameter('coefficients', [1 + 0j]) self.setOrder(1) @@ -150,6 +150,5 @@ def build(self, geometryProvider: ProbeGeometryProvider) -> Probe: return Probe( array=self.normalize(array), - pixelWidthInMeters=geometry.pixelWidthInMeters, - pixelHeightInMeters=geometry.pixelHeightInMeters, + pixelGeometry=geometry.getPixelGeometry(), ) diff --git a/src/ptychodus/model/product/productGeometry.py b/src/ptychodus/model/product/productGeometry.py index 04a744cb..f569d04c 100644 --- a/src/ptychodus/model/product/productGeometry.py +++ b/src/ptychodus/model/product/productGeometry.py @@ -1,14 +1,14 @@ import numpy -from ptychodus.api.constants import ( - ELECTRON_VOLT_J, - LIGHT_SPEED_M_PER_S, - PLANCK_CONSTANT_J_PER_HZ, -) from ptychodus.api.geometry import PixelGeometry from ptychodus.api.object import ObjectGeometry, ObjectGeometryProvider from ptychodus.api.observer import Observable, Observer from ptychodus.api.probe import ProbeGeometry, ProbeGeometryProvider +from ptychodus.api.product import ( + ELECTRON_VOLT_J, + LIGHT_SPEED_M_PER_S, + PLANCK_CONSTANT_J_PER_HZ, +) from ..patterns import PatternSizer from .metadata import MetadataRepositoryItem @@ -31,6 +31,10 @@ def __init__( self._metadata.addObserver(self) self._scan.addObserver(self) + @property + def probePhotonCount(self) -> float: + return self._metadata.probePhotonCount.getValue() + @property def probeEnergyInJoules(self) -> float: return self._metadata.probeEnergyInElectronVolts.getValue() * ELECTRON_VOLT_J @@ -45,12 +49,29 @@ def probeWavelengthInMeters(self) -> float: return 0.0 @property - def detectorDistanceInMeters(self) -> float: - return self._metadata.detectorDistanceInMeters.getValue() + def probeWavelengthsPerMeter(self) -> float: + """wavenumber""" + return 1.0 / self.probeWavelengthInMeters + + @property + def probeRadiansPerMeter(self) -> float: + """angular wavenumber""" + return 2.0 * numpy.pi / self.probeWavelengthInMeters + + @property + def probePhotonsPerSecond(self) -> float: + try: + return self.probePhotonCount / self._metadata.exposureTimeInSeconds.getValue() + except ZeroDivisionError: + return 0.0 @property def probePowerInWatts(self) -> float: - return self.probeEnergyInJoules * self._metadata.probePhotonsPerSecond.getValue() + return self.probeEnergyInJoules * self.probePhotonsPerSecond + + @property + def detectorDistanceInMeters(self) -> float: + return self._metadata.detectorDistanceInMeters.getValue() @property def _lambdaZInSquareMeters(self) -> float: @@ -74,8 +95,8 @@ def getPixelGeometry(self) -> PixelGeometry: def fresnelNumber(self) -> float: widthInMeters = self._patternSizer.getWidthInMeters() heightInMeters = self._patternSizer.getHeightInMeters() - sizeInMeters = max(widthInMeters, heightInMeters) - return sizeInMeters**2 / self._lambdaZInSquareMeters + areaInSquareMeters = widthInMeters * heightInMeters + return areaInSquareMeters / self._lambdaZInSquareMeters def getProbeGeometry(self) -> ProbeGeometry: extent = self._patternSizer.getImageExtent() diff --git a/src/ptychodus/model/product/productRepository.py b/src/ptychodus/model/product/productRepository.py index 71e7751c..6514d473 100644 --- a/src/ptychodus/model/product/productRepository.py +++ b/src/ptychodus/model/product/productRepository.py @@ -75,7 +75,7 @@ def insertNewProduct( comments: str = '', detectorDistanceInMeters: float | None = None, probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, + probePhotonCount: float | None = None, exposureTimeInSeconds: float | None = None, likeIndex: int, ) -> int: @@ -84,7 +84,7 @@ def insertNewProduct( comments=comments, detectorDistanceInMeters=detectorDistanceInMeters, probeEnergyInElectronVolts=probeEnergyInElectronVolts, - probePhotonsPerSecond=probePhotonsPerSecond, + probePhotonCount=probePhotonCount, exposureTimeInSeconds=exposureTimeInSeconds, ) scanItem = self._scanRepositoryItemFactory.create() @@ -103,10 +103,12 @@ def insertNewProduct( costs=list(), ) + index = self._insertProduct(item) + if likeIndex >= 0: item.assignItem(self._itemList[likeIndex], notify=False) - return self._insertProduct(item) + return index def insertProductFromSettings(self) -> int: # TODO add mechanism to sync product state to settings diff --git a/src/ptychodus/model/ptychi/__init__.py b/src/ptychodus/model/ptychi/__init__.py new file mode 100644 index 00000000..12f08af3 --- /dev/null +++ b/src/ptychodus/model/ptychi/__init__.py @@ -0,0 +1,29 @@ +from .core import PtyChiReconstructorLibrary +from .device import PtyChiDeviceRepository +from .enums import PtyChiEnumerators +from .settings import ( + PtyChiAutodiffSettings, + PtyChiDMSettings, + PtyChiLSQMLSettings, + PtyChiOPRSettings, + PtyChiObjectSettings, + PtyChiPIESettings, + PtyChiProbePositionSettings, + PtyChiProbeSettings, + PtyChiReconstructorSettings, +) + +__all__ = [ + 'PtyChiAutodiffSettings', + 'PtyChiDMSettings', + 'PtyChiDeviceRepository', + 'PtyChiEnumerators', + 'PtyChiLSQMLSettings', + 'PtyChiOPRSettings', + 'PtyChiObjectSettings', + 'PtyChiPIESettings', + 'PtyChiProbePositionSettings', + 'PtyChiProbeSettings', + 'PtyChiReconstructorLibrary', + 'PtyChiReconstructorSettings', +] diff --git a/src/ptychodus/model/ptychi/autodiff.py b/src/ptychodus/model/ptychi/autodiff.py new file mode 100644 index 00000000..ac31f8e4 --- /dev/null +++ b/src/ptychodus/model/ptychi/autodiff.py @@ -0,0 +1,190 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + AutodiffPtychographyOPRModeWeightsOptions, + AutodiffPtychographyObjectOptions, + AutodiffPtychographyOptions, + AutodiffPtychographyProbeOptions, + AutodiffPtychographyProbePositionOptions, + AutodiffPtychographyReconstructorOptions, + ForwardModels, + LossFunctions, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import Probe +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import Scan + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiAutodiffSettings + +logger = logging.getLogger(__name__) + + +class AutodiffReconstructor(Reconstructor): + def __init__( + self, options_helper: PtyChiOptionsHelper, settings: PtyChiAutodiffSettings + ) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'Autodiff' + + def _create_reconstructor_options(self) -> AutodiffPtychographyReconstructorOptions: + helper = self._options_helper.reconstructor_helper + + #### + + loss_function_str = self._settings.lossFunction.getValue() + + try: + loss_function = LossFunctions[loss_function_str.upper()] + except KeyError: + logger.warning('Failed to parse loss function "{loss_function_str}"!') + loss_function = LossFunctions.MSE_SQRT + + #### + + forward_model_class_str = self._settings.forwardModelClass.getValue() + + try: + forward_model_class = ForwardModels[forward_model_class_str.upper()] + except KeyError: + logger.warning('Failed to parse forward model class "{forward_model_class_str}"!') + forward_model_class = ForwardModels.PLANAR_PTYCHOGRAPHY + + #### + + return AutodiffPtychographyReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + use_low_memory_forward_model=helper.use_low_memory_forward_model, + loss_function=loss_function, + forward_model_class=forward_model_class, + forward_model_params=None, + ) + + def _create_object_options(self, object_: Object) -> AutodiffPtychographyObjectOptions: + helper = self._options_helper.object_helper + return AutodiffPtychographyObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + pixel_size_m=helper.get_pixel_size_m(object_), + l1_norm_constraint=helper.l1_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + ) + + def _create_probe_options( + self, probe: Probe, metadata: ProductMetadata + ) -> AutodiffPtychographyProbeOptions: + helper = self._options_helper.probe_helper + return AutodiffPtychographyProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probe), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + ) + + def _create_probe_position_options( + self, scan: Scan, object_geometry: ObjectGeometry + ) -> AutodiffPtychographyProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return AutodiffPtychographyProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + magnitude_limit=helper.magnitude_limit, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + ) + + def _create_opr_mode_weight_options(self) -> AutodiffPtychographyOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return AutodiffPtychographyOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> AutodiffPtychographyOptions: + product = parameters.product + return AutodiffPtychographyOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probe, product.metadata), + probe_position_options=self._create_probe_position_options( + product.scan, product.object_.getGeometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_mode_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/core.py b/src/ptychodus/model/ptychi/core.py new file mode 100644 index 00000000..9645b5b3 --- /dev/null +++ b/src/ptychodus/model/ptychi/core.py @@ -0,0 +1,95 @@ +from collections.abc import Iterator +from importlib.metadata import version +import logging + +from ptychodus.api.reconstructor import ( + NullReconstructor, + Reconstructor, + ReconstructorLibrary, +) +from ptychodus.api.settings import SettingsRegistry + +from ..patterns import Detector +from .device import PtyChiDeviceRepository +from .enums import PtyChiEnumerators +from .settings import ( + PtyChiAutodiffSettings, + PtyChiDMSettings, + PtyChiLSQMLSettings, + PtyChiOPRSettings, + PtyChiObjectSettings, + PtyChiPIESettings, + PtyChiProbePositionSettings, + PtyChiProbeSettings, + PtyChiReconstructorSettings, +) + +logger = logging.getLogger(__name__) + + +class PtyChiReconstructorLibrary(ReconstructorLibrary): + def __init__( + self, settingsRegistry: SettingsRegistry, detector: Detector, isDeveloperModeEnabled: bool + ) -> None: + super().__init__() + self.autodiffSettings = PtyChiAutodiffSettings(settingsRegistry) + self.dmSettings = PtyChiDMSettings(settingsRegistry) + self.lsqmlSettings = PtyChiLSQMLSettings(settingsRegistry) + self.objectSettings = PtyChiObjectSettings(settingsRegistry) + self.oprSettings = PtyChiOPRSettings(settingsRegistry) + self.pieSettings = PtyChiPIESettings(settingsRegistry) + self.probePositionSettings = PtyChiProbePositionSettings(settingsRegistry) + self.probeSettings = PtyChiProbeSettings(settingsRegistry) + self.reconstructorSettings = PtyChiReconstructorSettings(settingsRegistry) + + self.enumerators = PtyChiEnumerators() + self.deviceRepository = PtyChiDeviceRepository( + isDeveloperModeEnabled=isDeveloperModeEnabled + ) + self.reconstructor_list: list[Reconstructor] = list() + + try: + from .autodiff import AutodiffReconstructor + from .dm import DMReconstructor + from .epie import EPIEReconstructor + from .helper import PtyChiOptionsHelper + from .lsqml import LSQMLReconstructor + from .pie import PIEReconstructor + from .rpie import RPIEReconstructor + except ModuleNotFoundError: + logger.info('pty-chi not found.') + + if isDeveloperModeEnabled: + for reconstructor in ('DM', 'PIE', 'ePIE', 'rPIE', 'LSQML', 'Autodiff'): + self.reconstructor_list.append(NullReconstructor(reconstructor)) + else: + ptychiVersion = version('ptychi') + logger.info(f'Pty-Chi {ptychiVersion}') + + optionsHelper = PtyChiOptionsHelper( + self.reconstructorSettings, + self.objectSettings, + self.probeSettings, + self.probePositionSettings, + self.oprSettings, + detector, + ) + self.reconstructor_list.append(DMReconstructor(optionsHelper, self.dmSettings)) + self.reconstructor_list.append(PIEReconstructor(optionsHelper, self.pieSettings)) + self.reconstructor_list.append(EPIEReconstructor(optionsHelper, self.pieSettings)) + self.reconstructor_list.append(RPIEReconstructor(optionsHelper, self.pieSettings)) + self.reconstructor_list.append(LSQMLReconstructor(optionsHelper, self.lsqmlSettings)) + self.reconstructor_list.append( + AutodiffReconstructor(optionsHelper, self.autodiffSettings) + ) + + @property + def name(self) -> str: + return 'pty-chi' + + @property + def logger_name(self) -> str: + return 'ptychi' + + def __iter__(self) -> Iterator[Reconstructor]: + return iter(self.reconstructor_list) diff --git a/src/ptychodus/model/ptychi/device.py b/src/ptychodus/model/ptychi/device.py new file mode 100644 index 00000000..aba3ddbe --- /dev/null +++ b/src/ptychodus/model/ptychi/device.py @@ -0,0 +1,35 @@ +from collections.abc import Sequence +from typing import overload +import logging + +logger = logging.getLogger(__name__) + + +class PtyChiDeviceRepository(Sequence[str]): + def __init__(self, *, isDeveloperModeEnabled: bool) -> None: + self._devices: list[str] = list() + + try: + import ptychi + except ModuleNotFoundError: + if isDeveloperModeEnabled: + self._devices.extend(f'gpu:{n}' for n in range(4)) + else: + for device in ptychi.list_available_devices(): + logger.info(device) + self._devices.append(f'{device.name} ({device.torch_device})') + + if not self._devices: + logger.info('No devices found!') + + @overload + def __getitem__(self, index: int) -> str: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[str]: ... + + def __getitem__(self, index: int | slice) -> str | Sequence[str]: + return self._devices[index] + + def __len__(self) -> int: + return len(self._devices) diff --git a/src/ptychodus/model/ptychi/dm.py b/src/ptychodus/model/ptychi/dm.py new file mode 100644 index 00000000..2872a827 --- /dev/null +++ b/src/ptychodus/model/ptychi/dm.py @@ -0,0 +1,161 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + DMOPRModeWeightsOptions, + DMObjectOptions, + DMOptions, + DMProbeOptions, + DMProbePositionOptions, + DMReconstructorOptions, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import Probe +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import Scan + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiDMSettings + +logger = logging.getLogger(__name__) + + +class DMReconstructor(Reconstructor): + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiDMSettings) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'DM' + + def _create_reconstructor_options(self) -> DMReconstructorOptions: + helper = self._options_helper.reconstructor_helper + return DMReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + use_low_memory_forward_model=helper.use_low_memory_forward_model, + exit_wave_update_relaxation=self._settings.exitWaveUpdateRelaxation.getValue(), + chunk_length=self._settings.chunkLength.getValue(), + ) + + def _create_object_options(self, object_: Object) -> DMObjectOptions: + helper = self._options_helper.object_helper + return DMObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + pixel_size_m=helper.get_pixel_size_m(object_), + l1_norm_constraint=helper.l1_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + amplitude_clamp_limit=self._settings.objectAmplitudeClampLimit.getValue(), + ) + + def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> DMProbeOptions: + helper = self._options_helper.probe_helper + return DMProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probe), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + ) + + def _create_probe_position_options( + self, scan: Scan, object_geometry: ObjectGeometry + ) -> DMProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return DMProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + magnitude_limit=helper.magnitude_limit, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + ) + + def _create_opr_mode_weight_options(self) -> DMOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return DMOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> DMOptions: + product = parameters.product + return DMOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probe, product.metadata), + probe_position_options=self._create_probe_position_options( + product.scan, product.object_.getGeometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_mode_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/enums.py b/src/ptychodus/model/ptychi/enums.py new file mode 100644 index 00000000..5e40d3cc --- /dev/null +++ b/src/ptychodus/model/ptychi/enums.py @@ -0,0 +1,82 @@ +from collections.abc import Iterator, Sequence + + +class PtyChiEnumerators: + def __init__(self) -> None: + try: + from ptychi.api import ( + BatchingModes, + Directions, + ForwardModels, + ImageGradientMethods, + ImageIntegrationMethods, + LossFunctions, + NoiseModels, + OPRWeightSmoothingMethods, + Optimizers, + OrthogonalizationMethods, + PatchInterpolationMethods, + PositionCorrectionTypes, + ) + except ModuleNotFoundError: + self._batchingModes: Sequence[str] = list() + self._directions: Sequence[str] = list() + self._forwardModels: Sequence[str] = list() + self._imageGradientMethods: Sequence[str] = list() + self._imageIntegrationMethods: Sequence[str] = list() + self._lossFunctions: Sequence[str] = list() + self._noiseModels: Sequence[str] = list() + self._oprWeightSmoothingMethods: Sequence[str] = list() + self._optimizers: Sequence[str] = list() + self._orthogonalizationMethods: Sequence[str] = list() + self._patchInterpolationMethods: Sequence[str] = list() + self._positionCorrectionTypes: Sequence[str] = list() + else: + self._batchingModes = [member.name for member in BatchingModes] + self._directions = [member.name for member in Directions] + self._forwardModels = [member.name for member in ForwardModels] + self._imageGradientMethods = [member.name for member in ImageGradientMethods] + self._imageIntegrationMethods = [member.name for member in ImageIntegrationMethods] + self._lossFunctions = [member.name for member in LossFunctions] + self._noiseModels = [member.name for member in NoiseModels] + self._oprWeightSmoothingMethods = [member.name for member in OPRWeightSmoothingMethods] + self._optimizers = [member.name for member in Optimizers] + self._orthogonalizationMethods = [member.name for member in OrthogonalizationMethods] + self._patchInterpolationMethods = [member.name for member in PatchInterpolationMethods] + self._positionCorrectionTypes = [member.name for member in PositionCorrectionTypes] + + def batchingModes(self) -> Iterator[str]: + return iter(self._batchingModes) + + def directions(self) -> Iterator[str]: + return iter(self._directions) + + def forwardModels(self) -> Iterator[str]: + return iter(self._forwardModels) + + def imageGradientMethods(self) -> Iterator[str]: + return iter(self._imageGradientMethods) + + def imageIntegrationMethods(self) -> Iterator[str]: + return iter(self._imageIntegrationMethods) + + def lossFunctions(self) -> Iterator[str]: + return iter(self._lossFunctions) + + def noiseModels(self) -> Iterator[str]: + return iter(self._noiseModels) + + def oprWeightSmoothingMethods(self) -> Iterator[str]: + return iter(self._oprWeightSmoothingMethods) + + def optimizers(self) -> Iterator[str]: + return iter(self._optimizers) + + def orthogonalizationMethods(self) -> Iterator[str]: + return iter(self._orthogonalizationMethods) + + def patchInterpolationMethods(self) -> Iterator[str]: + return iter(self._patchInterpolationMethods) + + def positionCorrectionTypes(self) -> Iterator[str]: + return iter(self._positionCorrectionTypes) diff --git a/src/ptychodus/model/ptychi/epie.py b/src/ptychodus/model/ptychi/epie.py new file mode 100644 index 00000000..da1cfbda --- /dev/null +++ b/src/ptychodus/model/ptychi/epie.py @@ -0,0 +1,160 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + EPIEOptions, + EPIEReconstructorOptions, + PIEOPRModeWeightsOptions, + PIEObjectOptions, + PIEProbeOptions, + PIEProbePositionOptions, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import Probe +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import Scan + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiPIESettings + +logger = logging.getLogger(__name__) + + +class EPIEReconstructor(Reconstructor): + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'ePIE' + + def _create_reconstructor_options(self) -> EPIEReconstructorOptions: + helper = self._options_helper.reconstructor_helper + return EPIEReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + use_low_memory_forward_model=helper.use_low_memory_forward_model, + ) + + def _create_object_options(self, object_: Object) -> PIEObjectOptions: + helper = self._options_helper.object_helper + return PIEObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + pixel_size_m=helper.get_pixel_size_m(object_), + l1_norm_constraint=helper.l1_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + alpha=self._settings.objectAlpha.getValue(), + ) + + def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEProbeOptions: + helper = self._options_helper.probe_helper + return PIEProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probe), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + alpha=self._settings.probeAlpha.getValue(), + ) + + def _create_probe_position_options( + self, scan: Scan, object_geometry: ObjectGeometry + ) -> PIEProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return PIEProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + magnitude_limit=helper.magnitude_limit, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + ) + + def _create_opr_mode_weight_options(self) -> PIEOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return PIEOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> EPIEOptions: + product = parameters.product + return EPIEOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probe, product.metadata), + probe_position_options=self._create_probe_position_options( + product.scan, product.object_.getGeometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_mode_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/helper.py b/src/ptychodus/model/ptychi/helper.py new file mode 100644 index 00000000..da7acb14 --- /dev/null +++ b/src/ptychodus/model/ptychi/helper.py @@ -0,0 +1,601 @@ +from collections.abc import Sequence +import logging + +from torch import Tensor +import numpy + +from ptychi.api import ( + BatchingModes, + Devices, + Directions, + Dtypes, + ImageGradientMethods, + ImageIntegrationMethods, + LossFunctions, + OPRWeightSmoothingMethods, + OptimizationPlan, + Optimizers, + OrthogonalizationMethods, + PatchInterpolationMethods, + PositionCorrectionTypes, + PtychographyDataOptions, +) +from ptychi.api.options.base import ( + OPRModeWeightsSmoothingOptions, + ObjectL1NormConstraintOptions, + ObjectMultisliceRegularizationOptions, + ObjectSmoothnessConstraintOptions, + ObjectTotalVariationOptions, + PositionCorrectionOptions, + ProbeCenterConstraintOptions, + ProbeOrthogonalizeIncoherentModesOptions, + ProbeOrthogonalizeOPRModesOptions, + ProbePositionMagnitudeLimitOptions, + ProbePowerConstraintOptions, + ProbeSupportConstraintOptions, + RemoveGridArtifactsOptions, +) + +from ptychodus.api.object import Object, ObjectArrayType, ObjectGeometry, ObjectPoint +from ptychodus.api.probe import Probe, WavefieldArrayType +from ptychodus.api.product import Product, ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput +from ptychodus.api.scan import Scan, ScanPoint +from ptychodus.api.typing import RealArrayType + +from ..patterns import Detector +from .settings import ( + PtyChiOPRSettings, + PtyChiObjectSettings, + PtyChiProbePositionSettings, + PtyChiProbeSettings, + PtyChiReconstructorSettings, +) + + +__all__ = ['PtyChiOptionsHelper'] + +logger = logging.getLogger(__name__) + + +def create_optimization_plan(start: int, stop: int, stride: int) -> OptimizationPlan: + return OptimizationPlan(start, None if stop < 0 else stop, stride) + + +def parse_optimizer(text: str) -> Optimizers: + try: + optimizer = Optimizers[text.upper()] + except KeyError: + logger.warning('Failed to parse optimizer "{text}"!') + optimizer = Optimizers.SGD + + return optimizer + + +class PtyChiReconstructorOptionsHelper: + def __init__(self, settings: PtyChiReconstructorSettings) -> None: + self._settings = settings + + @property + def num_epochs(self) -> int: + return self._settings.numEpochs.getValue() + + @property + def batch_size(self) -> int: + return self._settings.batchSize.getValue() + + @property + def batching_mode(self) -> BatchingModes: + batching_mode_str = self._settings.batchingMode.getValue() + + try: + return BatchingModes[batching_mode_str.upper()] + except KeyError: + logger.warning('Failed to parse batching mode "{batching_mode_str}"!') + return BatchingModes.RANDOM + + @property + def compact_mode_update_clustering(self) -> bool: + return self._settings.batchStride.getValue() > 0 + + @property + def compact_mode_update_clustering_stride(self) -> int: + return self._settings.batchStride.getValue() + + @property + def default_device(self) -> Devices: + return Devices.GPU if self._settings.useDevices.getValue() else Devices.CPU + + @property + def default_dtype(self) -> Dtypes: + return Dtypes.FLOAT64 if self._settings.useDoublePrecision.getValue() else Dtypes.FLOAT32 + + @property + def random_seed(self) -> int | None: + return None # TODO + + @property + def displayed_loss_function(self) -> LossFunctions | None: + return LossFunctions.MSE_SQRT # TODO + + @property + def use_low_memory_forward_model(self) -> bool: + return self._settings.useLowMemoryForwardModel.getValue() + + +class PtyChiObjectOptionsHelper: + def __init__(self, settings: PtyChiObjectSettings) -> None: + self._settings = settings + + @property + def optimizable(self) -> bool: + return self._settings.isOptimizable.getValue() + + @property + def optimization_plan(self) -> OptimizationPlan: + return create_optimization_plan( + self._settings.optimizationPlanStart.getValue(), + self._settings.optimizationPlanStop.getValue(), + self._settings.optimizationPlanStride.getValue(), + ) + + @property + def optimizer(self) -> Optimizers: + return parse_optimizer(self._settings.optimizer.getValue()) + + @property + def step_size(self) -> float: + return self._settings.stepSize.getValue() + + @property + def optimizer_params(self) -> dict: # TODO + return dict() + + @property + def l1_norm_constraint(self) -> ObjectL1NormConstraintOptions: + return ObjectL1NormConstraintOptions( + enabled=self._settings.constrainL1Norm.getValue(), + optimization_plan=create_optimization_plan( + self._settings.constrainL1NormStart.getValue(), + self._settings.constrainL1NormStop.getValue(), + self._settings.constrainL1NormStride.getValue(), + ), + weight=self._settings.constrainL1NormWeight.getValue(), + ) + + @property + def smoothness_constraint(self) -> ObjectSmoothnessConstraintOptions: + return ObjectSmoothnessConstraintOptions( + enabled=self._settings.constrainSmoothness.getValue(), + optimization_plan=create_optimization_plan( + self._settings.constrainSmoothnessStart.getValue(), + self._settings.constrainSmoothnessStop.getValue(), + self._settings.constrainSmoothnessStride.getValue(), + ), + alpha=self._settings.constrainSmoothnessAlpha.getValue(), + ) + + @property + def total_variation(self) -> ObjectTotalVariationOptions: + return ObjectTotalVariationOptions( + enabled=self._settings.constrainTotalVariation.getValue(), + optimization_plan=create_optimization_plan( + self._settings.constrainTotalVariationStart.getValue(), + self._settings.constrainTotalVariationStop.getValue(), + self._settings.constrainTotalVariationStride.getValue(), + ), + weight=self._settings.constrainTotalVariationWeight.getValue(), + ) + + @property + def remove_grid_artifacts(self) -> RemoveGridArtifactsOptions: + direction_str = self._settings.removeGridArtifactsDirection.getValue() + + try: + direction = Directions[direction_str.upper()] + except KeyError: + logger.warning('Failed to parse direction "{direction_str}"!') + direction = Directions.XY + + return RemoveGridArtifactsOptions( + enabled=self._settings.removeGridArtifacts.getValue(), + optimization_plan=create_optimization_plan( + self._settings.removeGridArtifactsStart.getValue(), + self._settings.removeGridArtifactsStop.getValue(), + self._settings.removeGridArtifactsStride.getValue(), + ), + period_x_m=self._settings.removeGridArtifactsPeriodXInMeters.getValue(), + period_y_m=self._settings.removeGridArtifactsPeriodYInMeters.getValue(), + window_size=self._settings.removeGridArtifactsWindowSizeInPixels.getValue(), + direction=direction, + ) + + @property + def multislice_regularization(self) -> ObjectMultisliceRegularizationOptions: + unwrap_image_grad_method_str = ( + self._settings.regularizeMultisliceUnwrapPhaseImageGradientMethod.getValue() + ) + + try: + unwrap_image_grad_method = ImageGradientMethods[unwrap_image_grad_method_str.upper()] + except KeyError: + logger.warning( + 'Failed to parse image gradient method "{unwrap_image_grad_method_str}"!' + ) + unwrap_image_grad_method = ImageGradientMethods.FOURIER_SHIFT + + unwrap_image_integration_method_str = ( + self._settings.regularizeMultisliceUnwrapPhaseImageIntegrationMethod.getValue() + ) + + try: + unwrap_image_integration_method = ImageIntegrationMethods[ + unwrap_image_integration_method_str.upper() + ] + except KeyError: + logger.warning( + 'Failed to parse image integrationient method "{unwrap_image_integration_method_str}"!' + ) + unwrap_image_integration_method = ImageIntegrationMethods.DECONVOLUTION + + return ObjectMultisliceRegularizationOptions( + enabled=self._settings.regularizeMultislice.getValue(), + optimization_plan=create_optimization_plan( + self._settings.regularizeMultisliceStart.getValue(), + self._settings.regularizeMultisliceStop.getValue(), + self._settings.regularizeMultisliceStride.getValue(), + ), + weight=self._settings.regularizeMultisliceWeight.getValue(), + unwrap_phase=self._settings.regularizeMultisliceUnwrapPhase.getValue(), + unwrap_image_grad_method=unwrap_image_grad_method, + unwrap_image_integration_method=unwrap_image_integration_method, + ) + + @property + def patch_interpolation_method(self) -> PatchInterpolationMethods: + method_str = self._settings.patchInterpolator.getValue() + + try: + return PatchInterpolationMethods[method_str.upper()] + except KeyError: + logger.warning('Failed to parse patch interpolation method "{method_str}"!') + return PatchInterpolationMethods.FOURIER + + def get_initial_guess(self, object_: Object) -> ObjectArrayType: + return object_.getArray() + + def get_slice_spacings_m(self, object_: Object) -> RealArrayType: + return numpy.array(object_.layerDistanceInMeters[:-1]) # FIXME iff multislice + + def get_pixel_size_m(self, object_: Object) -> float: + pixel_geometry = object_.getPixelGeometry() + + if pixel_geometry is None: + logger.error('Missing object pixel geometry!') + return 1.0 + + return pixel_geometry.widthInMeters + + +class PtyChiProbeOptionsHelper: + def __init__(self, settings: PtyChiProbeSettings) -> None: + self._settings = settings + + @property + def optimizable(self) -> bool: + return self._settings.isOptimizable.getValue() + + @property + def optimization_plan(self) -> OptimizationPlan: + return create_optimization_plan( + self._settings.optimizationPlanStart.getValue(), + self._settings.optimizationPlanStop.getValue(), + self._settings.optimizationPlanStride.getValue(), + ) + + @property + def optimizer(self) -> Optimizers: + return parse_optimizer(self._settings.optimizer.getValue()) + + @property + def step_size(self) -> float: + return self._settings.stepSize.getValue() + + @property + def optimizer_params(self) -> dict: # TODO + return dict() + + @property + def orthogonalize_incoherent_modes(self) -> ProbeOrthogonalizeIncoherentModesOptions: + method_str = self._settings.orthogonalizeIncoherentModesMethod.getValue() + + try: + method = OrthogonalizationMethods[method_str.upper()] + except KeyError: + logger.warning('Failed to parse batching mode "{method_str}"!') + method = OrthogonalizationMethods.GS + + return ProbeOrthogonalizeIncoherentModesOptions( + enabled=self._settings.orthogonalizeIncoherentModes.getValue(), + optimization_plan=create_optimization_plan( + self._settings.orthogonalizeIncoherentModesStart.getValue(), + self._settings.orthogonalizeIncoherentModesStop.getValue(), + self._settings.orthogonalizeIncoherentModesStride.getValue(), + ), + method=method, + ) + + @property + def orthogonalize_opr_modes(self) -> ProbeOrthogonalizeOPRModesOptions: + return ProbeOrthogonalizeOPRModesOptions( + enabled=self._settings.orthogonalizeOPRModes.getValue(), + optimization_plan=create_optimization_plan( + self._settings.orthogonalizeOPRModesStart.getValue(), + self._settings.orthogonalizeOPRModesStop.getValue(), + self._settings.orthogonalizeOPRModesStride.getValue(), + ), + ) + + @property + def support_constraint(self) -> ProbeSupportConstraintOptions: + return ProbeSupportConstraintOptions( + enabled=self._settings.constrainSupport.getValue(), + optimization_plan=create_optimization_plan( + self._settings.constrainSupportStart.getValue(), + self._settings.constrainSupportStop.getValue(), + self._settings.constrainSupportStride.getValue(), + ), + threshold=self._settings.constrainSupportThreshold.getValue(), + ) + + @property + def center_constraint(self) -> ProbeCenterConstraintOptions: + return ProbeCenterConstraintOptions( + enabled=self._settings.constrainCenter.getValue(), + optimization_plan=create_optimization_plan( + self._settings.constrainCenterStart.getValue(), + self._settings.constrainCenterStop.getValue(), + self._settings.constrainCenterStride.getValue(), + ), + ) + + @property + def eigenmode_update_relaxation(self) -> float: + return self._settings.relaxEigenmodeUpdate.getValue() + + def get_initial_guess(self, probe: Probe) -> WavefieldArrayType: + return probe.getArray() + + def get_power_constraint(self, metadata: ProductMetadata) -> ProbePowerConstraintOptions: + return ProbePowerConstraintOptions( + enabled=self._settings.constrainProbePower.getValue(), + optimization_plan=create_optimization_plan( + self._settings.constrainProbePowerStart.getValue(), + self._settings.constrainProbePowerStop.getValue(), + self._settings.constrainProbePowerStride.getValue(), + ), + probe_power=metadata.probePhotonCount, + ) + + +class PtyChiProbePositionOptionsHelper: + def __init__(self, settings: PtyChiProbePositionSettings) -> None: + self._settings = settings + + @property + def optimizable(self) -> bool: + return self._settings.isOptimizable.getValue() + + @property + def optimization_plan(self) -> OptimizationPlan: + return create_optimization_plan( + self._settings.optimizationPlanStart.getValue(), + self._settings.optimizationPlanStop.getValue(), + self._settings.optimizationPlanStride.getValue(), + ) + + @property + def optimizer(self) -> Optimizers: + return parse_optimizer(self._settings.optimizer.getValue()) + + @property + def step_size(self) -> float: + return self._settings.stepSize.getValue() + + @property + def optimizer_params(self) -> dict: # TODO + return dict() + + @property + def magnitude_limit(self) -> ProbePositionMagnitudeLimitOptions: + return ProbePositionMagnitudeLimitOptions( + enabled=self._settings.limitMagnitudeUpdate.getValue(), + optimization_plan=create_optimization_plan( + self._settings.limitMagnitudeUpdateStart.getValue(), + self._settings.limitMagnitudeUpdateStop.getValue(), + self._settings.limitMagnitudeUpdateStride.getValue(), + ), + limit=self._settings.magnitudeUpdateLimit.getValue(), + ) + + @property + def constrain_position_mean(self) -> bool: + return self._settings.constrainCentroid.getValue() + + @property + def correction_options(self) -> PositionCorrectionOptions: + correction_type_str = self._settings.positionCorrectionType.getValue() + + try: + correction_type = PositionCorrectionTypes[correction_type_str.upper()] + except KeyError: + logger.warning('Failed to parse batching mode "{correction_type_str}"!') + correction_type = PositionCorrectionTypes.GRADIENT + + return PositionCorrectionOptions( + correction_type=correction_type, + cross_correlation_scale=self._settings.crossCorrelationScale.getValue(), + cross_correlation_real_space_width=self._settings.crossCorrelationRealSpaceWidth.getValue(), + cross_correlation_probe_threshold=self._settings.crossCorrelationProbeThreshold.getValue(), + ) + + def get_positions_px( + self, scan: Scan, object_geometry: ObjectGeometry + ) -> tuple[RealArrayType, RealArrayType]: + position_x_px: list[float] = list() + position_y_px: list[float] = list() + rx_px = object_geometry.widthInPixels / 2 + ry_px = object_geometry.heightInPixels / 2 + + for scan_point in scan: + object_point = object_geometry.mapScanPointToObjectPoint(scan_point) + position_x_px.append(object_point.positionXInPixels - rx_px) + position_y_px.append(object_point.positionYInPixels - ry_px) + + return numpy.array(position_x_px), numpy.array(position_y_px) + + +class PtyChiOPROptionsHelper: + def __init__(self, settings: PtyChiOPRSettings) -> None: + self._settings = settings + + @property + def optimizable(self) -> bool: + return self._settings.isOptimizable.getValue() + + @property + def optimization_plan(self) -> OptimizationPlan: + return create_optimization_plan( + self._settings.optimizationPlanStart.getValue(), + self._settings.optimizationPlanStop.getValue(), + self._settings.optimizationPlanStride.getValue(), + ) + + @property + def optimizer(self) -> Optimizers: + return parse_optimizer(self._settings.optimizer.getValue()) + + @property + def step_size(self) -> float: + return self._settings.stepSize.getValue() + + @property + def optimizer_params(self) -> dict: # TODO + return dict() + + @property + def smoothing(self) -> OPRModeWeightsSmoothingOptions: + method_str = self._settings.smoothingMethod.getValue() + + try: + method: OPRWeightSmoothingMethods | None = OPRWeightSmoothingMethods[method_str.upper()] + except KeyError: + method = None + logger.warning('Failed to parse OPR weight smoothing method "{method_str}"!') + + return OPRModeWeightsSmoothingOptions( + enabled=self._settings.smoothModeWeights.getValue(), + optimization_plan=create_optimization_plan( + self._settings.smoothModeWeightsStart.getValue(), + self._settings.smoothModeWeightsStop.getValue(), + self._settings.smoothModeWeightsStride.getValue(), + ), + method=method, + polynomial_degree=self._settings.polynomialSmoothingDegree.getValue(), + ) + + @property + def optimize_eigenmode_weights(self) -> bool: + return self._settings.optimizeEigenmodeWeights.getValue() + + @property + def optimize_intensity_variation(self) -> bool: + return self._settings.optimizeIntensities.getValue() + + @property + def update_relaxation(self) -> float: + return self._settings.relaxUpdate.getValue() + + def get_initial_weights(self) -> RealArrayType: + return numpy.array([0.0]) # FIXME + + +class PtyChiOptionsHelper: + def __init__( + self, + reconstructor_settings: PtyChiReconstructorSettings, + object_settings: PtyChiObjectSettings, + probe_settings: PtyChiProbeSettings, + probe_position_settings: PtyChiProbePositionSettings, + opr_settings: PtyChiOPRSettings, + detector: Detector, + ) -> None: + self._reconstructor_settings = reconstructor_settings + self._detector = detector + + self.reconstructor_helper = PtyChiReconstructorOptionsHelper(reconstructor_settings) + self.object_helper = PtyChiObjectOptionsHelper(object_settings) + self.probe_helper = PtyChiProbeOptionsHelper(probe_settings) + self.probe_position_helper = PtyChiProbePositionOptionsHelper(probe_position_settings) + self.opr_helper = PtyChiOPROptionsHelper(opr_settings) + + def create_data_options(self, parameters: ReconstructInput) -> PtychographyDataOptions: + metadata = parameters.product.metadata + return PtychographyDataOptions( + data=parameters.patterns, + free_space_propagation_distance_m=metadata.detectorDistanceInMeters, + wavelength_m=metadata.probeWavelengthInMeters, + detector_pixel_size_m=self._detector.pixelWidthInMeters.getValue(), + valid_pixel_mask=parameters.goodPixelMask, + save_data_on_device=self._reconstructor_settings.saveDataOnDevice.getValue(), + ) + + def create_product( + self, + product: Product, + position_x_px: Tensor | numpy.ndarray, + position_y_px: Tensor | numpy.ndarray, + probe_array: Tensor | numpy.ndarray, + object_array: Tensor | numpy.ndarray, + opr_mode_weights: Tensor | numpy.ndarray, + costs: Sequence[float], + ) -> Product: + object_in = product.object_ + object_out = Object( + array=numpy.array(object_array), + layerDistanceInMeters=object_in.layerDistanceInMeters, + pixelGeometry=object_in.getPixelGeometry(), + center=object_in.getCenter(), + ) + + # TODO OPR + probe_out = Probe( + array=numpy.array(probe_array[0]), + pixelGeometry=product.probe.getPixelGeometry(), + ) + + corrected_scan_points: list[ScanPoint] = list() + object_geometry = object_in.getGeometry() + rx_px = object_geometry.widthInPixels / 2 + ry_px = object_geometry.heightInPixels / 2 + + for uncorrected_point, pos_x_px, pos_y_px in zip( + product.scan, position_x_px, position_y_px + ): + object_point = ObjectPoint( + index=uncorrected_point.index, + positionXInPixels=pos_x_px + rx_px, + positionYInPixels=pos_y_px + ry_px, + ) + scan_point = object_geometry.mapObjectPointToScanPoint(object_point) + corrected_scan_points.append(scan_point) + + scan_out = Scan(corrected_scan_points) + + return Product( + metadata=product.metadata, + scan=scan_out, + probe=probe_out, + object_=object_out, + costs=costs, + ) diff --git a/src/ptychodus/model/ptychi/lsqml.py b/src/ptychodus/model/ptychi/lsqml.py new file mode 100644 index 00000000..dabf24da --- /dev/null +++ b/src/ptychodus/model/ptychi/lsqml.py @@ -0,0 +1,190 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + LSQMLOPRModeWeightsOptions, + LSQMLObjectOptions, + LSQMLOptions, + LSQMLProbeOptions, + LSQMLProbePositionOptions, + LSQMLReconstructorOptions, + NoiseModels, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import Probe +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import Scan + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiLSQMLSettings + +logger = logging.getLogger(__name__) + + +class LSQMLReconstructor(Reconstructor): + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiLSQMLSettings) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'LSQML' + + def _create_reconstructor_options(self) -> LSQMLReconstructorOptions: + helper = self._options_helper.reconstructor_helper + + #### + + noise_model_str = self._settings.noiseModel.getValue() + + try: + noise_model = NoiseModels[noise_model_str.upper()] + except KeyError: + logger.warning('Failed to parse batching mode "{noise_model_str}"!') + noise_model = NoiseModels.GAUSSIAN + + #### + + momentum_acceleration_gradient_mixing_factor: float | None = None + + if self._settings.useMomentumAccelerationGradientMixingFactor.getValue(): + momentum_acceleration_gradient_mixing_factor = ( + self._settings.momentumAccelerationGradientMixingFactor.getValue() + ) + + #### + + return LSQMLReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + use_low_memory_forward_model=helper.use_low_memory_forward_model, + noise_model=noise_model, + gaussian_noise_std=self._settings.gaussianNoiseDeviation.getValue(), + solve_obj_prb_step_size_jointly_for_first_slice_in_multislice=self._settings.solveObjectProbeStepSizeJointlyForFirstSliceInMultislice.getValue(), + solve_step_sizes_only_using_first_probe_mode=self._settings.solveStepSizesOnlyUsingFirstProbeMode.getValue(), + momentum_acceleration_gain=self._settings.momentumAccelerationGain.getValue(), + momentum_acceleration_gradient_mixing_factor=momentum_acceleration_gradient_mixing_factor, + ) + + def _create_object_options(self, object_: Object) -> LSQMLObjectOptions: + helper = self._options_helper.object_helper + return LSQMLObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + pixel_size_m=helper.get_pixel_size_m(object_), + l1_norm_constraint=helper.l1_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + optimal_step_size_scaler=self._settings.objectOptimalStepSizeScaler.getValue(), + multimodal_update=self._settings.objectMultimodalUpdate.getValue(), + ) + + def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> LSQMLProbeOptions: + helper = self._options_helper.probe_helper + return LSQMLProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probe), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + optimal_step_size_scaler=self._settings.probeOptimalStepSizeScaler.getValue(), + ) + + def _create_probe_position_options( + self, scan: Scan, object_geometry: ObjectGeometry + ) -> LSQMLProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return LSQMLProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + magnitude_limit=helper.magnitude_limit, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + ) + + def _create_opr_mode_weight_options(self) -> LSQMLOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return LSQMLOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> LSQMLOptions: + product = parameters.product + return LSQMLOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probe, product.metadata), + probe_position_options=self._create_probe_position_options( + product.scan, product.object_.getGeometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_mode_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/pie.py b/src/ptychodus/model/ptychi/pie.py new file mode 100644 index 00000000..21b9f40f --- /dev/null +++ b/src/ptychodus/model/ptychi/pie.py @@ -0,0 +1,160 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + PIEOPRModeWeightsOptions, + PIEObjectOptions, + PIEOptions, + PIEProbeOptions, + PIEProbePositionOptions, + PIEReconstructorOptions, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import Probe +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import Scan + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiPIESettings + +logger = logging.getLogger(__name__) + + +class PIEReconstructor(Reconstructor): + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'PIE' + + def _create_reconstructor_options(self) -> PIEReconstructorOptions: + helper = self._options_helper.reconstructor_helper + return PIEReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + use_low_memory_forward_model=helper.use_low_memory_forward_model, + ) + + def _create_object_options(self, object_: Object) -> PIEObjectOptions: + helper = self._options_helper.object_helper + return PIEObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + pixel_size_m=helper.get_pixel_size_m(object_), + l1_norm_constraint=helper.l1_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + alpha=self._settings.objectAlpha.getValue(), + ) + + def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEProbeOptions: + helper = self._options_helper.probe_helper + return PIEProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probe), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + alpha=self._settings.probeAlpha.getValue(), + ) + + def _create_probe_position_options( + self, scan: Scan, object_geometry: ObjectGeometry + ) -> PIEProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return PIEProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + magnitude_limit=helper.magnitude_limit, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + ) + + def _create_opr_mode_weight_options(self) -> PIEOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return PIEOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> PIEOptions: + product = parameters.product + return PIEOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probe, product.metadata), + probe_position_options=self._create_probe_position_options( + product.scan, product.object_.getGeometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_mode_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/rpie.py b/src/ptychodus/model/ptychi/rpie.py new file mode 100644 index 00000000..22970f4c --- /dev/null +++ b/src/ptychodus/model/ptychi/rpie.py @@ -0,0 +1,160 @@ +from collections.abc import Sequence +import logging + + +from ptychi.api import ( + PIEOPRModeWeightsOptions, + PIEObjectOptions, + PIEProbeOptions, + PIEProbePositionOptions, + RPIEOptions, + RPIEReconstructorOptions, +) +from ptychi.api.task import PtychographyTask + +from ptychodus.api.object import Object, ObjectGeometry +from ptychodus.api.probe import Probe +from ptychodus.api.product import ProductMetadata +from ptychodus.api.reconstructor import ReconstructInput, ReconstructOutput, Reconstructor +from ptychodus.api.scan import Scan + +from .helper import PtyChiOptionsHelper +from .settings import PtyChiPIESettings + +logger = logging.getLogger(__name__) + + +class RPIEReconstructor(Reconstructor): + def __init__(self, options_helper: PtyChiOptionsHelper, settings: PtyChiPIESettings) -> None: + super().__init__() + self._options_helper = options_helper + self._settings = settings + + @property + def name(self) -> str: + return 'rPIE' + + def _create_reconstructor_options(self) -> RPIEReconstructorOptions: + helper = self._options_helper.reconstructor_helper + return RPIEReconstructorOptions( + num_epochs=helper.num_epochs, + batch_size=helper.batch_size, + batching_mode=helper.batching_mode, + compact_mode_update_clustering=helper.compact_mode_update_clustering, + compact_mode_update_clustering_stride=helper.compact_mode_update_clustering_stride, + default_device=helper.default_device, + default_dtype=helper.default_dtype, + random_seed=helper.random_seed, + displayed_loss_function=helper.displayed_loss_function, + use_low_memory_forward_model=helper.use_low_memory_forward_model, + ) + + def _create_object_options(self, object_: Object) -> PIEObjectOptions: + helper = self._options_helper.object_helper + return PIEObjectOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(object_), + slice_spacings_m=helper.get_slice_spacings_m(object_), + pixel_size_m=helper.get_pixel_size_m(object_), + l1_norm_constraint=helper.l1_norm_constraint, + smoothness_constraint=helper.smoothness_constraint, + total_variation=helper.total_variation, + remove_grid_artifacts=helper.remove_grid_artifacts, + multislice_regularization=helper.multislice_regularization, + patch_interpolation_method=helper.patch_interpolation_method, + alpha=self._settings.objectAlpha.getValue(), + ) + + def _create_probe_options(self, probe: Probe, metadata: ProductMetadata) -> PIEProbeOptions: + helper = self._options_helper.probe_helper + return PIEProbeOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_guess=helper.get_initial_guess(probe), + power_constraint=helper.get_power_constraint(metadata), + orthogonalize_incoherent_modes=helper.orthogonalize_incoherent_modes, + orthogonalize_opr_modes=helper.orthogonalize_opr_modes, + support_constraint=helper.support_constraint, + center_constraint=helper.center_constraint, + eigenmode_update_relaxation=helper.eigenmode_update_relaxation, + alpha=self._settings.probeAlpha.getValue(), + ) + + def _create_probe_position_options( + self, scan: Scan, object_geometry: ObjectGeometry + ) -> PIEProbePositionOptions: + helper = self._options_helper.probe_position_helper + position_x_px, position_y_px = helper.get_positions_px(scan, object_geometry) + return PIEProbePositionOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + position_x_px=position_x_px, + position_y_px=position_y_px, + magnitude_limit=helper.magnitude_limit, + constrain_position_mean=helper.constrain_position_mean, + correction_options=helper.correction_options, + ) + + def _create_opr_mode_weight_options(self) -> PIEOPRModeWeightsOptions: + helper = self._options_helper.opr_helper + return PIEOPRModeWeightsOptions( + optimizable=helper.optimizable, + optimization_plan=helper.optimization_plan, + optimizer=helper.optimizer, + step_size=helper.step_size, + optimizer_params=helper.optimizer_params, + initial_weights=helper.get_initial_weights(), + optimize_eigenmode_weights=helper.optimize_eigenmode_weights, + optimize_intensity_variation=helper.optimize_intensity_variation, + smoothing=helper.smoothing, + update_relaxation=helper.update_relaxation, + ) + + def _create_task_options(self, parameters: ReconstructInput) -> RPIEOptions: + product = parameters.product + return RPIEOptions( + data_options=self._options_helper.create_data_options(parameters), + reconstructor_options=self._create_reconstructor_options(), + object_options=self._create_object_options(product.object_), + probe_options=self._create_probe_options(product.probe, product.metadata), + probe_position_options=self._create_probe_position_options( + product.scan, product.object_.getGeometry() + ), + opr_mode_weight_options=self._create_opr_mode_weight_options(), + ) + + def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: + task_options = self._create_task_options(parameters) + task = PtychographyTask(task_options) + task.run() # TODO (n_epochs: int | None = None) + + costs: Sequence[float] = list() + task_reconstructor = task.reconstructor + + if task_reconstructor is not None: + loss_tracker = task_reconstructor.loss_tracker + # TODO update api to include epoch and loss + # epoch = loss_tracker.table['epoch'].to_numpy() + loss = loss_tracker.table['loss'].to_numpy() + costs = [float(x) for x in loss.flatten()] + + product = self._options_helper.create_product( + product=parameters.product, + position_x_px=task.get_probe_positions_x(as_numpy=True), + position_y_px=task.get_probe_positions_y(as_numpy=True), + probe_array=task.get_data_to_cpu('probe', as_numpy=True), + object_array=task.get_data_to_cpu('object', as_numpy=True), + opr_mode_weights=task.get_data_to_cpu('opr_mode_weights', as_numpy=True), + costs=costs, + ) + return ReconstructOutput(product, 0) diff --git a/src/ptychodus/model/ptychi/settings.py b/src/ptychodus/model/ptychi/settings.py new file mode 100644 index 00000000..45db9ba9 --- /dev/null +++ b/src/ptychodus/model/ptychi/settings.py @@ -0,0 +1,462 @@ +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.settings import SettingsRegistry + + +class PtyChiReconstructorSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('PtyChi') + self._settingsGroup.addObserver(self) + + self.numEpochs = self._settingsGroup.createIntegerParameter('NumEpochs', 100, minimum=1) + self.batchSize = self._settingsGroup.createIntegerParameter('BatchSize', 100, minimum=1) + self.batchingMode = self._settingsGroup.createStringParameter('BatchingMode', 'random') + self.batchStride = self._settingsGroup.createIntegerParameter('BatchStride', 1, minimum=1) + self.useDoublePrecision = self._settingsGroup.createBooleanParameter( + 'UseDoublePrecision', False + ) + self.useDevices = self._settingsGroup.createBooleanParameter('UseDevices', True) + self.useLowMemoryForwardModel = self._settingsGroup.createBooleanParameter( + 'UseLowMemoryForwardModel', False + ) + self.saveDataOnDevice = self._settingsGroup.createBooleanParameter( + 'SaveDataOnDevice', False + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() + + +class PtyChiObjectSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('PtyChiObject') + self._settingsGroup.addObserver(self) + + self.isOptimizable = self._settingsGroup.createBooleanParameter('IsOptimizable', True) + self.optimizationPlanStart = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStart', 0, minimum=0 + ) + self.optimizationPlanStop = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStop', -1 + ) + self.optimizationPlanStride = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStride', 1, minimum=1 + ) + self.optimizer = self._settingsGroup.createStringParameter('Optimizer', 'SGD') + self.stepSize = self._settingsGroup.createRealParameter('StepSize', 1.0, minimum=0.0) + + self.patchInterpolator = self._settingsGroup.createStringParameter( + 'PatchInterpolator', 'FOURIER' + ) + + self.constrainL1Norm = self._settingsGroup.createBooleanParameter('ConstrainL1Norm', False) + self.constrainL1NormStart = self._settingsGroup.createIntegerParameter( + 'ConstrainL1NormStart', 0, minimum=0 + ) + self.constrainL1NormStop = self._settingsGroup.createIntegerParameter( + 'ConstrainL1NormStop', -1 + ) + self.constrainL1NormStride = self._settingsGroup.createIntegerParameter( + 'ConstrainL1NormStride', 1, minimum=1 + ) + self.constrainL1NormWeight = self._settingsGroup.createRealParameter( + 'ConstrainL1NormWeight', 0.0, minimum=0.0 + ) + + self.constrainSmoothness = self._settingsGroup.createBooleanParameter( + 'ConstrainSmoothness', False + ) + self.constrainSmoothnessStart = self._settingsGroup.createIntegerParameter( + 'ConstrainSmoothnessStart', 0, minimum=0 + ) + self.constrainSmoothnessStop = self._settingsGroup.createIntegerParameter( + 'ConstrainSmoothnessStop', -1 + ) + self.constrainSmoothnessStride = self._settingsGroup.createIntegerParameter( + 'ConstrainSmoothnessStride', 1, minimum=1 + ) + self.constrainSmoothnessAlpha = self._settingsGroup.createRealParameter( + 'ConstrainSmoothnessAlpha', 0.0, minimum=0.0, maximum=1.0 / 8 + ) + + self.constrainTotalVariation = self._settingsGroup.createBooleanParameter( + 'ConstrainTotalVariation', False + ) + self.constrainTotalVariationStart = self._settingsGroup.createIntegerParameter( + 'ConstrainTotalVariationStart', 0, minimum=0 + ) + self.constrainTotalVariationStop = self._settingsGroup.createIntegerParameter( + 'ConstrainTotalVariationStop', -1 + ) + self.constrainTotalVariationStride = self._settingsGroup.createIntegerParameter( + 'ConstrainTotalVariationStride', 1, minimum=1 + ) + self.constrainTotalVariationWeight = self._settingsGroup.createRealParameter( + 'ConstrainTotalVariationWeight', 0.0, minimum=0.0 + ) + + self.removeGridArtifacts = self._settingsGroup.createBooleanParameter( + 'RemoveGridArtifacts', False + ) + self.removeGridArtifactsStart = self._settingsGroup.createIntegerParameter( + 'RemoveGridArtifactsStart', 0, minimum=0 + ) + self.removeGridArtifactsStop = self._settingsGroup.createIntegerParameter( + 'RemoveGridArtifactsStop', -1 + ) + self.removeGridArtifactsStride = self._settingsGroup.createIntegerParameter( + 'RemoveGridArtifactsStride', 1, minimum=1 + ) + self.removeGridArtifactsPeriodXInMeters = self._settingsGroup.createRealParameter( + 'RemoveGridArtifactsPeriodXInMeters', 1e-7, minimum=0.0 + ) + self.removeGridArtifactsPeriodYInMeters = self._settingsGroup.createRealParameter( + 'RemoveGridArtifactsPeriodYInMeters', 1e-7, minimum=0.0 + ) + self.removeGridArtifactsWindowSizeInPixels = self._settingsGroup.createIntegerParameter( + 'RemoveGridArtifactsWindowSizeInPixels', + 5, + minimum=1, + ) + self.removeGridArtifactsDirection = self._settingsGroup.createStringParameter( + 'RemoveGridArtifactsDirection', 'XY' + ) + + self.regularizeMultislice = self._settingsGroup.createBooleanParameter( + 'RegularizeMultislice', False + ) + self.regularizeMultisliceStart = self._settingsGroup.createIntegerParameter( + 'RegularizeMultisliceStart', 0, minimum=0 + ) + self.regularizeMultisliceStop = self._settingsGroup.createIntegerParameter( + 'RegularizeMultisliceStop', -1 + ) + self.regularizeMultisliceStride = self._settingsGroup.createIntegerParameter( + 'RegularizeMultisliceStride', 1, minimum=1 + ) + self.regularizeMultisliceWeight = self._settingsGroup.createRealParameter( + 'RegularizeMultisliceWeight', 0.0, minimum=0.0 + ) + self.regularizeMultisliceUnwrapPhase = self._settingsGroup.createBooleanParameter( + 'RegularizeMultisliceUnwrapPhase', True + ) + self.regularizeMultisliceUnwrapPhaseImageGradientMethod = ( + self._settingsGroup.createStringParameter( + 'RegularizeMultisliceUnwrapPhaseImageGradientMethod', 'FOURIER_SHIFT' + ) + ) + self.regularizeMultisliceUnwrapPhaseImageIntegrationMethod = ( + self._settingsGroup.createStringParameter( + 'RegularizeMultisliceUnwrapPhaseImageIntegrationMethod', 'DECONVOLUTION' + ) + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() + + +class PtyChiProbeSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('PtyChiProbe') + self._settingsGroup.addObserver(self) + + self.isOptimizable = self._settingsGroup.createBooleanParameter('IsOptimizable', True) + self.optimizationPlanStart = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStart', 0, minimum=0 + ) + self.optimizationPlanStop = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStop', -1 + ) + self.optimizationPlanStride = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStride', 1, minimum=1 + ) + self.optimizer = self._settingsGroup.createStringParameter('Optimizer', 'SGD') + self.stepSize = self._settingsGroup.createRealParameter('StepSize', 1.0, minimum=0.0) + + self.constrainProbePower = self._settingsGroup.createBooleanParameter( + 'ConstrainProbePower', False + ) + self.constrainProbePowerStart = self._settingsGroup.createIntegerParameter( + 'ConstrainProbePowerStart', 0, minimum=0 + ) + self.constrainProbePowerStop = self._settingsGroup.createIntegerParameter( + 'ConstrainProbePowerStop', -1 + ) + self.constrainProbePowerStride = self._settingsGroup.createIntegerParameter( + 'ConstrainProbePowerStride', 1, minimum=1 + ) + + self.orthogonalizeIncoherentModes = self._settingsGroup.createBooleanParameter( + 'OrthogonalizeIncoherentModes', True + ) + self.orthogonalizeIncoherentModesStart = self._settingsGroup.createIntegerParameter( + 'OrthogonalizeIncoherentModesStart', 0, minimum=0 + ) + self.orthogonalizeIncoherentModesStop = self._settingsGroup.createIntegerParameter( + 'OrthogonalizeIncoherentModesStop', -1 + ) + self.orthogonalizeIncoherentModesStride = self._settingsGroup.createIntegerParameter( + 'OrthogonalizeIncoherentModesStride', 1, minimum=1 + ) + self.orthogonalizeIncoherentModesMethod = self._settingsGroup.createStringParameter( + 'OrthogonalizeIncoherentModesMethod', 'GS' + ) + + self.orthogonalizeOPRModes = self._settingsGroup.createBooleanParameter( + 'OrthogonalizeOPRModes', True + ) + self.orthogonalizeOPRModesStart = self._settingsGroup.createIntegerParameter( + 'OrthogonalizeOPRModesStart', 0, minimum=0 + ) + self.orthogonalizeOPRModesStop = self._settingsGroup.createIntegerParameter( + 'OrthogonalizeOPRModesStop', -1 + ) + self.orthogonalizeOPRModesStride = self._settingsGroup.createIntegerParameter( + 'OrthogonalizeOPRModesStride', 1, minimum=1 + ) + + self.constrainSupport = self._settingsGroup.createBooleanParameter( + 'ConstrainSupport', False + ) + self.constrainSupportStart = self._settingsGroup.createIntegerParameter( + 'ConstrainSupportStart', 0, minimum=0 + ) + self.constrainSupportStop = self._settingsGroup.createIntegerParameter( + 'ConstrainSupportStop', -1 + ) + self.constrainSupportStride = self._settingsGroup.createIntegerParameter( + 'ConstrainSupportStride', 1, minimum=1 + ) + self.constrainSupportThreshold = self._settingsGroup.createRealParameter( + 'ConstrainSupportThreshold', 0.005, minimum=0.0 + ) + + self.constrainCenter = self._settingsGroup.createBooleanParameter('ConstrainCenter', False) + self.constrainCenterStart = self._settingsGroup.createIntegerParameter( + 'ConstrainCenterStart', 0, minimum=0 + ) + self.constrainCenterStop = self._settingsGroup.createIntegerParameter( + 'ConstrainCenterStop', -1 + ) + self.constrainCenterStride = self._settingsGroup.createIntegerParameter( + 'ConstrainCenterStride', 1, minimum=1 + ) + + self.relaxEigenmodeUpdate = self._settingsGroup.createRealParameter( + 'RelaxEigenmodeUpdate', 1.0, minimum=0.0, maximum=1.0 + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() + + +class PtyChiProbePositionSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('PtyChiProbePosition') + self._settingsGroup.addObserver(self) + + self.isOptimizable = self._settingsGroup.createBooleanParameter('IsOptimizable', False) + self.optimizationPlanStart = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStart', 0, minimum=0 + ) + self.optimizationPlanStop = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStop', -1 + ) + self.optimizationPlanStride = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStride', 1, minimum=1 + ) + self.optimizer = self._settingsGroup.createStringParameter('Optimizer', 'SGD') + self.stepSize = self._settingsGroup.createRealParameter('StepSize', 1.0, minimum=0.0) + + self.positionCorrectionType = self._settingsGroup.createStringParameter( + 'PositionCorrectionType', 'Gradient' + ) + self.crossCorrelationScale = self._settingsGroup.createIntegerParameter( + 'CrossCorrelationScale', 20000, minimum=1 + ) + self.crossCorrelationRealSpaceWidth = self._settingsGroup.createRealParameter( + 'CrossCorrelationRealSpaceWidth', 0.01, minimum=0.0 + ) + self.crossCorrelationProbeThreshold = self._settingsGroup.createRealParameter( + 'CrossCorrelationProbeThreshold', 0.1, minimum=0.0, maximum=1.0 + ) + + self.limitMagnitudeUpdate = self._settingsGroup.createBooleanParameter( + 'LimitMagnitudeUpdate', False + ) + self.limitMagnitudeUpdateStart = self._settingsGroup.createIntegerParameter( + 'LimitMagnitudeUpdateStart', 0, minimum=0 + ) + self.limitMagnitudeUpdateStop = self._settingsGroup.createIntegerParameter( + 'LimitMagnitudeUpdateStop', -1 + ) + self.limitMagnitudeUpdateStride = self._settingsGroup.createIntegerParameter( + 'LimitMagnitudeUpdateStride', 1, minimum=1 + ) + self.magnitudeUpdateLimit = self._settingsGroup.createRealParameter( + 'MagnitudeUpdateLimit', 0.0, minimum=0.0 + ) + + self.constrainCentroid = self._settingsGroup.createBooleanParameter( + 'ConstrainCentroid', False + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() + + +class PtyChiOPRSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('PtyChiOPR') + self._settingsGroup.addObserver(self) + + self.isOptimizable = self._settingsGroup.createBooleanParameter('IsOptimizable', False) + self.optimizationPlanStart = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStart', 0, minimum=0 + ) + self.optimizationPlanStop = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStop', -1 + ) + self.optimizationPlanStride = self._settingsGroup.createIntegerParameter( + 'OptimizationPlanStride', 1, minimum=1 + ) + self.optimizer = self._settingsGroup.createStringParameter('Optimizer', 'SGD') + self.stepSize = self._settingsGroup.createRealParameter('StepSize', 1.0, minimum=0.0) + + self.optimizeIntensities = self._settingsGroup.createBooleanParameter( + 'OptimizeIntensities', False + ) + self.optimizeEigenmodeWeights = self._settingsGroup.createBooleanParameter( + 'OptimizeEigenmodeWeigts', True + ) + + self.smoothModeWeights = self._settingsGroup.createBooleanParameter( + 'SmoothModeWeights', False + ) + self.smoothModeWeightsStart = self._settingsGroup.createIntegerParameter( + 'SmoothModeWeightsStart', 0, minimum=0 + ) + self.smoothModeWeightsStop = self._settingsGroup.createIntegerParameter( + 'SmoothModeWeightsStop', -1 + ) + self.smoothModeWeightsStride = self._settingsGroup.createIntegerParameter( + 'SmoothModeWeightsStride', 1, minimum=1 + ) + self.smoothingMethod = self._settingsGroup.createStringParameter('SmoothingMethod', '') + self.polynomialSmoothingDegree = self._settingsGroup.createIntegerParameter( + 'PolynomialSmoothingDegree', 4, minimum=0, maximum=10 + ) + + self.relaxUpdate = self._settingsGroup.createRealParameter( + 'RelaxUpdate', 1.0, minimum=0.0, maximum=1.0 + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() + + +class PtyChiAutodiffSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('PtyChiAutodiff') + self._settingsGroup.addObserver(self) + + self.lossFunction = self._settingsGroup.createStringParameter('LossFunction', 'MSE_SQRT') + self.forwardModelClass = self._settingsGroup.createStringParameter( + 'ForwardModelClass', 'PLANAR_PTYCHOGRAPHY' + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() + + +class PtyChiDMSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('PtyChiDM') + self._settingsGroup.addObserver(self) + + self.exitWaveUpdateRelaxation = self._settingsGroup.createRealParameter( + 'ExitWaveUpdateRelaxation', 1.0, minimum=0.0, maximum=1.0 + ) + self.chunkLength = self._settingsGroup.createIntegerParameter('ChunkLength', 1, minimum=1) + self.objectAmplitudeClampLimit = self._settingsGroup.createRealParameter( + 'ObjectAmplitudeClampLimit', 1000, minimum=0.0 + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() + + +class PtyChiLSQMLSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('PtyChiLSQML') + self._settingsGroup.addObserver(self) + + self.noiseModel = self._settingsGroup.createStringParameter('NoiseModel', 'GAUSSIAN') + self.gaussianNoiseDeviation = self._settingsGroup.createRealParameter( + 'GaussianNoiseDeviation', 0.5 + ) + self.solveObjectProbeStepSizeJointlyForFirstSliceInMultislice = ( + self._settingsGroup.createBooleanParameter( + 'SolveObjectProbeStepSizeJointlyForFirstSliceInMultislice', False + ) + ) + self.solveStepSizesOnlyUsingFirstProbeMode = self._settingsGroup.createBooleanParameter( + 'SolveStepSizesOnlyUsingFirstProbeMode', False + ) + self.momentumAccelerationGain = self._settingsGroup.createRealParameter( + 'MomentumAccelerationGain', 0.0, minimum=0.0 + ) + self.useMomentumAccelerationGradientMixingFactor = ( + self._settingsGroup.createBooleanParameter( + 'UseMomentumAccelerationGradientMixingFactor', False + ) + ) + self.momentumAccelerationGradientMixingFactor = self._settingsGroup.createRealParameter( + 'MomentumAccelerationGradientMixingFactor', 1.0 + ) + + self.probeOptimalStepSizeScaler = self._settingsGroup.createRealParameter( + 'ProbeOptimalStepSizeScaler', 0.9, minimum=0.0 + ) + self.objectOptimalStepSizeScaler = self._settingsGroup.createRealParameter( + 'ObjectOptimalStepSizeScaler', 0.9, minimum=0.0 + ) + self.objectMultimodalUpdate = self._settingsGroup.createBooleanParameter( + 'ObjectMultimodalUpdate', True + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() + + +class PtyChiPIESettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('PtyChiPIE') + self._settingsGroup.addObserver(self) + + self.probeAlpha = self._settingsGroup.createRealParameter( + 'ProbeAlpha', 0.1, minimum=0.0, maximum=1.0 + ) + self.objectAlpha = self._settingsGroup.createRealParameter( + 'ObjectAlpha', 0.1, minimum=0.0, maximum=1.0 + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() diff --git a/src/ptychodus/model/ptychonn/common.py b/src/ptychodus/model/ptychonn/common.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/ptychodus/model/ptychonn/core.py b/src/ptychodus/model/ptychonn/core.py index b5f343d4..ebf1c655 100644 --- a/src/ptychodus/model/ptychonn/core.py +++ b/src/ptychodus/model/ptychonn/core.py @@ -180,5 +180,9 @@ def createInstance( def name(self) -> str: return 'PtychoNN' + @property + def logger_name(self) -> str: + return 'ptychonn' + def __iter__(self) -> Iterator[Reconstructor]: return iter(self._reconstructors) diff --git a/src/ptychodus/model/ptychonn/reconstructor.py b/src/ptychodus/model/ptychonn/reconstructor.py index 9903fcb2..eb1cfa58 100644 --- a/src/ptychodus/model/ptychonn/reconstructor.py +++ b/src/ptychodus/model/ptychonn/reconstructor.py @@ -1,4 +1,3 @@ -from collections.abc import Sequence from importlib.metadata import version from pathlib import Path from typing import Final @@ -18,7 +17,6 @@ ) from ..analysis import ObjectLinearInterpolator, ObjectStitcher -from .buffers import ObjectPatchCircularBuffer, PatternCircularBuffer from .model import PtychoNNModelProvider from .settings import PtychoNNModelSettings, PtychoNNTrainingSettings @@ -40,8 +38,6 @@ def __init__( self._modelSettings = modelSettings self._trainingSettings = trainingSettings self._modelProvider = modelProvider - self._patternBuffer = PatternCircularBuffer.createZeroSized() - self._objectPatchBuffer = ObjectPatchCircularBuffer.createZeroSized() ptychonnVersion = version('ptychonn') logger.info(f'\tPtychoNN {ptychonnVersion}') @@ -114,55 +110,49 @@ def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: return ReconstructOutput(product, 0) - def ingestTrainingData(self, parameters: ReconstructInput) -> None: - interpolator = ObjectLinearInterpolator(parameters.product.object_) - probeExtent = parameters.product.probe.getExtent() - - if self._patternBuffer.isZeroSized: - patternExtent = ImageExtent( - widthInPixels=parameters.patterns.shape[-1], - heightInPixels=parameters.patterns.shape[-2], - ) - maximumSize = max(1, self._trainingSettings.maximumTrainingDatasetSize.getValue()) - self._patternBuffer = PatternCircularBuffer(patternExtent, maximumSize) - self._objectPatchBuffer = ObjectPatchCircularBuffer( - patternExtent, self._modelProvider.getNumberOfChannels(), maximumSize - ) - - for scanPoint in parameters.product.scan: - objectPatch = interpolator.getPatch(scanPoint, probeExtent) - self._objectPatchBuffer.append(objectPatch.array) - - for pattern in parameters.patterns.astype(numpy.float32): - self._patternBuffer.append(pattern) - - def getOpenTrainingDataFileFilterList(self) -> Sequence[str]: - return [self.getOpenTrainingDataFileFilter()] - - def getOpenTrainingDataFileFilter(self) -> str: - return self.TRAINING_DATA_FILE_FILTER + def getModelFileFilter(self) -> str: + return self.MODEL_FILE_FILTER - def openTrainingData(self, filePath: Path) -> None: - logger.debug(f'Reading "{filePath}" as "NPZ"') - trainingData = numpy.load(filePath) - self._patternBuffer.setBuffer(trainingData[self.PATTERNS_KW]) - self._objectPatchBuffer.setBuffer(trainingData[self.PATCHES_KW]) + def openModel(self, filePath: Path) -> None: + self._modelProvider.openModel(filePath) - def getSaveTrainingDataFileFilterList(self) -> Sequence[str]: - return [self.getSaveTrainingDataFileFilter()] + def saveModel(self, filePath: Path) -> None: + self._modelProvider.saveModel(filePath) - def getSaveTrainingDataFileFilter(self) -> str: + def getTrainingDataFileFilter(self) -> str: return self.TRAINING_DATA_FILE_FILTER - def saveTrainingData(self, filePath: Path) -> None: + def exportTrainingData(self, filePath: Path, parameters: ReconstructInput) -> None: + interpolator = ObjectLinearInterpolator(parameters.product.object_) + num_channels = self._modelProvider.getNumberOfChannels() + probe_extent = ImageExtent( + widthInPixels=parameters.product.probe.widthInPixels, + heightInPixels=parameters.product.probe.heightInPixels, + ) + patches = numpy.zeros( + (len(parameters.product.scan), num_channels, *probe_extent.shape), dtype=numpy.float32 + ) + + for index, scan_point in enumerate(parameters.product.scan): + patch = interpolator.get_patch(scan_point, probe_extent).getArray() + patches[index, 0, :, :] = numpy.angle(patch) + + if num_channels > 1: + patches[index, 1, :, :] = numpy.absolute(patch) + logger.debug(f'Writing "{filePath}" as "NPZ"') trainingData = { - self.PATTERNS_KW: self._patternBuffer.getBuffer(), - self.PATCHES_KW: self._objectPatchBuffer.getBuffer(), + self.PATTERNS_KW: parameters.patterns.astype(numpy.float32), + self.PATCHES_KW: patches, } numpy.savez_compressed(filePath, **trainingData) - def train(self) -> TrainOutput: + def train(self, dataPath: Path) -> TrainOutput: + logger.debug(f'Reading "{dataPath}" as "NPZ"') + trainingData = numpy.load(dataPath) + + # FIXME phase centering? + model = self._modelProvider.getModel() logger.debug('Training...') trainingSetFractionalSize = ( @@ -172,8 +162,8 @@ def train(self) -> TrainOutput: model=model, batch_size=self._modelSettings.batchSize.getValue(), out_dir=None, - X_train=self._patternBuffer.getBuffer(), - Y_train=self._objectPatchBuffer.getBuffer(), + X_train=trainingData[self.PATTERNS_KW], + Y_train=trainingData[self.PATCHES_KW], epochs=self._trainingSettings.trainingEpochs.getValue(), training_fraction=float(trainingSetFractionalSize), log_frequency=self._trainingSettings.statusIntervalInEpochs.getValue(), @@ -199,25 +189,3 @@ def train(self) -> TrainOutput: validationLoss=validationLoss, result=0, ) - - def clearTrainingData(self) -> None: - self._patternBuffer = PatternCircularBuffer.createZeroSized() - self._objectPatchBuffer = ObjectPatchCircularBuffer.createZeroSized() - - def getOpenModelFileFilterList(self) -> Sequence[str]: - return [self.getOpenModelFileFilter()] - - def getOpenModelFileFilter(self) -> str: - return self.MODEL_FILE_FILTER - - def openModel(self, filePath: Path) -> None: - self._modelProvider.openModel(filePath) - - def getSaveModelFileFilterList(self) -> Sequence[str]: - return [self.getSaveModelFileFilter()] - - def getSaveModelFileFilter(self) -> str: - return self.MODEL_FILE_FILTER - - def saveModel(self, filePath: Path) -> None: - self._modelProvider.saveModel(filePath) diff --git a/src/ptychodus/model/ptychonn/settings.py b/src/ptychodus/model/ptychonn/settings.py index 3199971e..587c4944 100644 --- a/src/ptychodus/model/ptychonn/settings.py +++ b/src/ptychodus/model/ptychonn/settings.py @@ -27,9 +27,6 @@ def __init__(self, registry: SettingsRegistry) -> None: self._settingsGroup = registry.createGroup('PtychoNNTraining') self._settingsGroup.addObserver(self) - self.maximumTrainingDatasetSize = self._settingsGroup.createIntegerParameter( - 'MaximumTrainingDatasetSize', 100000 - ) self.validationSetFractionalSize = self._settingsGroup.createRealParameter( 'ValidationSetFractionalSize', 0.1 ) diff --git a/src/ptychodus/model/reconstructor/api.py b/src/ptychodus/model/reconstructor/api.py index c04fd699..5ca6bbba 100644 --- a/src/ptychodus/model/reconstructor/api.py +++ b/src/ptychodus/model/reconstructor/api.py @@ -11,6 +11,7 @@ from ..product import ProductRepository from .matcher import DiffractionPatternPositionMatcher, ScanIndexFilter +from .queue import ReconstructionQueue logger = logging.getLogger(__name__) @@ -18,142 +19,118 @@ class ReconstructorAPI: def __init__( self, + reconstructionQueue: ReconstructionQueue, dataMatcher: DiffractionPatternPositionMatcher, productRepository: ProductRepository, reconstructorChooser: PluginChooser[Reconstructor], ) -> None: + self._reconstructionQueue = reconstructionQueue self._dataMatcher = dataMatcher self._productRepository = productRepository self._reconstructorChooser = reconstructorChooser + @property + def isReconstructing(self) -> bool: + return self._reconstructionQueue.isReconstructing + + def processResults(self, *, block: bool) -> None: + self._reconstructionQueue.processResults(block=block) + def reconstruct( self, inputProductIndex: int, - outputProductName: str, + *, + outputProductSuffix: str = '', indexFilter: ScanIndexFilter = ScanIndexFilter.ALL, ) -> int: reconstructor = self._reconstructorChooser.currentPlugin.strategy parameters = self._dataMatcher.matchDiffractionPatternsWithPositions( inputProductIndex, indexFilter ) - outputProductIndex = self._productRepository.insertNewProduct(likeIndex=inputProductIndex) outputProduct = self._productRepository[outputProductIndex] - tic = time.perf_counter() - result = reconstructor.reconstruct(parameters) - toc = time.perf_counter() - logger.info(f'Reconstruction time {toc - tic:.4f} seconds. (code={result.result})') + outputProductName = ( + self._dataMatcher.getProductName(inputProductIndex) + + f'_{self._reconstructorChooser.currentPlugin.simpleName}' + ) - outputProduct.assign(result.product) + if outputProductSuffix: + outputProductName += f'_{outputProductSuffix}' + outputProduct.setName(outputProductName) + self._reconstructionQueue.put(reconstructor, parameters, outputProduct) return outputProductIndex - def reconstructSplit(self, inputProductIndex: int, outputProductName: str) -> tuple[int, int]: + def reconstructSplit(self, inputProductIndex: int) -> tuple[int, int]: outputProductIndexOdd = self.reconstruct( inputProductIndex, - f'{outputProductName}_odd', - ScanIndexFilter.ODD, + outputProductSuffix='odd', + indexFilter=ScanIndexFilter.ODD, ) outputProductIndexEven = self.reconstruct( inputProductIndex, - f'{outputProductName}_even', - ScanIndexFilter.EVEN, + outputProductSuffix='even', + indexFilter=ScanIndexFilter.EVEN, ) return outputProductIndexOdd, outputProductIndexEven - def ingestTrainingData(self, inputProductIndex: int) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Preparing input data...') - tic = time.perf_counter() - parameters = self._dataMatcher.matchDiffractionPatternsWithPositions( - inputProductIndex, ScanIndexFilter.ALL - ) - toc = time.perf_counter() - logger.info(f'Data preparation time {toc - tic:.4f} seconds.') - - logger.info('Ingesting...') - tic = time.perf_counter() - reconstructor.ingestTrainingData(parameters) - toc = time.perf_counter() - logger.info(f'Ingest time {toc - tic:.4f} seconds.') - else: - logger.warning('Reconstructor is not trainable!') - - def openTrainingData(self, filePath: Path) -> None: + def openModel(self, filePath: Path) -> None: reconstructor = self._reconstructorChooser.currentPlugin.strategy if isinstance(reconstructor, TrainableReconstructor): - logger.info('Opening training data...') + logger.info('Opening model...') tic = time.perf_counter() - reconstructor.openTrainingData(filePath) + reconstructor.openModel(filePath) toc = time.perf_counter() logger.info(f'Open time {toc - tic:.4f} seconds.') else: logger.warning('Reconstructor is not trainable!') - def saveTrainingData(self, filePath: Path) -> None: + def saveModel(self, filePath: Path) -> None: reconstructor = self._reconstructorChooser.currentPlugin.strategy if isinstance(reconstructor, TrainableReconstructor): - logger.info('Saving training data...') + logger.info('Saving model...') tic = time.perf_counter() - reconstructor.saveTrainingData(filePath) + reconstructor.saveModel(filePath) toc = time.perf_counter() logger.info(f'Save time {toc - tic:.4f} seconds.') else: logger.warning('Reconstructor is not trainable!') - def train(self) -> TrainOutput: + def exportTrainingData(self, filePath: Path, inputProductIndex: int) -> None: reconstructor = self._reconstructorChooser.currentPlugin.strategy - result = TrainOutput([], [], -1) if isinstance(reconstructor, TrainableReconstructor): - logger.info('Training...') + logger.info('Preparing input data...') tic = time.perf_counter() - result = reconstructor.train() + parameters = self._dataMatcher.matchDiffractionPatternsWithPositions( + inputProductIndex, ScanIndexFilter.ALL + ) toc = time.perf_counter() - logger.info(f'Training time {toc - tic:.4f} seconds. (code={result.result})') - else: - logger.warning('Reconstructor is not trainable!') - - return result - - def clearTrainingData(self) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy + logger.info(f'Data preparation time {toc - tic:.4f} seconds.') - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Resetting...') + logger.info('Exporting...') tic = time.perf_counter() - reconstructor.clearTrainingData() + reconstructor.exportTrainingData(filePath, parameters) toc = time.perf_counter() - logger.info(f'Reset time {toc - tic:.4f} seconds.') + logger.info(f'Export time {toc - tic:.4f} seconds.') else: logger.warning('Reconstructor is not trainable!') - def openModel(self, filePath: Path) -> None: + def train(self, dataPath: Path) -> TrainOutput: reconstructor = self._reconstructorChooser.currentPlugin.strategy + result = TrainOutput([], [], -1) if isinstance(reconstructor, TrainableReconstructor): - logger.info('Opening model...') + logger.info('Training...') tic = time.perf_counter() - reconstructor.openModel(filePath) + result = reconstructor.train(dataPath) toc = time.perf_counter() - logger.info(f'Open time {toc - tic:.4f} seconds.') + logger.info(f'Training time {toc - tic:.4f} seconds. (code={result.result})') else: logger.warning('Reconstructor is not trainable!') - def saveModel(self, filePath: Path) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Saving model...') - tic = time.perf_counter() - reconstructor.saveModel(filePath) - toc = time.perf_counter() - logger.info(f'Save time {toc - tic:.4f} seconds.') - else: - logger.warning('Reconstructor is not trainable!') + return result diff --git a/src/ptychodus/model/reconstructor/core.py b/src/ptychodus/model/reconstructor/core.py index b501efcf..6bfd88b4 100644 --- a/src/ptychodus/model/reconstructor/core.py +++ b/src/ptychodus/model/reconstructor/core.py @@ -12,12 +12,12 @@ from ..patterns import ActiveDiffractionDataset from ..product import ProductRepository from .api import ReconstructorAPI +from .log import ReconstructorLogHandler from .matcher import DiffractionPatternPositionMatcher from .presenter import ReconstructorPresenter +from .queue import ReconstructionQueue from .settings import ReconstructorSettings -logger = logging.getLogger(__name__) - class ReconstructorCore: def __init__( @@ -29,21 +29,40 @@ def __init__( ) -> None: self.settings = ReconstructorSettings(settingsRegistry) self._pluginChooser = PluginChooser[Reconstructor]() + self._logHandler = ReconstructorLogHandler() + self._logHandler.setFormatter( + logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') + ) for library in librarySeq: for reconstructor in library: self._pluginChooser.registerPlugin( reconstructor, + simpleName=f'{library.name}_{reconstructor.name}', displayName=f'{library.name}/{reconstructor.name}', ) + libraryLogger = logging.getLogger(library.logger_name) + libraryLogger.addHandler(self._logHandler) + if not self._pluginChooser: self._pluginChooser.registerPlugin(NullReconstructor('None'), displayName='None/None') + self._reconstructionQueue = ReconstructionQueue() self.dataMatcher = DiffractionPatternPositionMatcher(diffractionDataset, productRepository) self.reconstructorAPI = ReconstructorAPI( - self.dataMatcher, productRepository, self._pluginChooser + self._reconstructionQueue, self.dataMatcher, productRepository, self._pluginChooser ) self.presenter = ReconstructorPresenter( - self.settings, self._pluginChooser, self.reconstructorAPI, settingsRegistry + self.settings, + self._pluginChooser, + self._logHandler, + self.reconstructorAPI, + settingsRegistry, ) + + def start(self) -> None: + self._reconstructionQueue.start() + + def stop(self) -> None: + self._reconstructionQueue.stop() diff --git a/src/ptychodus/model/reconstructor/log.py b/src/ptychodus/model/reconstructor/log.py new file mode 100644 index 00000000..ee3c5d47 --- /dev/null +++ b/src/ptychodus/model/reconstructor/log.py @@ -0,0 +1,21 @@ +from collections.abc import Iterator +import queue +import logging + + +class ReconstructorLogHandler(logging.Handler): + def __init__(self) -> None: + super().__init__() + self._log: queue.Queue[str] = queue.Queue() + + def messages(self) -> Iterator[str]: + while True: + try: + yield self._log.get(block=False) + self._log.task_done() + except queue.Empty: + break + + def emit(self, record: logging.LogRecord) -> None: + text = self.format(record) + self._log.put(text) diff --git a/src/ptychodus/model/reconstructor/presenter.py b/src/ptychodus/model/reconstructor/presenter.py index d310c2e6..2c6b3095 100644 --- a/src/ptychodus/model/reconstructor/presenter.py +++ b/src/ptychodus/model/reconstructor/presenter.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Iterator, Sequence from pathlib import Path import logging @@ -11,7 +11,7 @@ ) from .api import ReconstructorAPI -from .matcher import ScanIndexFilter +from .log import ReconstructorLogHandler from .settings import ReconstructorSettings logger = logging.getLogger(__name__) @@ -22,12 +22,14 @@ def __init__( self, settings: ReconstructorSettings, reconstructorChooser: PluginChooser[Reconstructor], + logHandler: ReconstructorLogHandler, reconstructorAPI: ReconstructorAPI, reinitObservable: Observable, ) -> None: super().__init__() self._settings = settings self._reconstructorChooser = reconstructorChooser + self._logHandler = logHandler self._reconstructorAPI = reconstructorAPI self._reinitObservable = reinitObservable @@ -50,92 +52,33 @@ def _syncFromSettings(self) -> None: def _syncToSettings(self) -> None: self._settings.algorithm.setValue(self._reconstructorChooser.currentPlugin.simpleName) - def reconstruct( - self, - inputProductIndex: int, - outputProductName: str, - indexFilter: ScanIndexFilter = ScanIndexFilter.ALL, - ) -> int: - return self._reconstructorAPI.reconstruct(inputProductIndex, outputProductName, indexFilter) + def reconstruct(self, inputProductIndex: int) -> int: + return self._reconstructorAPI.reconstruct(inputProductIndex) - def reconstructSplit(self, inputProductIndex: int, outputProductName: str) -> tuple[int, int]: - return self._reconstructorAPI.reconstructSplit(inputProductIndex, outputProductName) + def reconstructSplit(self, inputProductIndex: int) -> tuple[int, int]: + return self._reconstructorAPI.reconstructSplit(inputProductIndex) @property - def isTrainable(self) -> bool: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - return isinstance(reconstructor, TrainableReconstructor) - - def ingestTrainingData(self, inputProductIndex: int) -> None: - return self._reconstructorAPI.ingestTrainingData(inputProductIndex) - - def getOpenTrainingDataFileFilterList(self) -> Sequence[str]: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getOpenTrainingDataFileFilterList() - else: - logger.warning('Reconstructor is not trainable!') - - return list() - - def getOpenTrainingDataFileFilter(self) -> str: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getOpenTrainingDataFileFilter() - else: - logger.warning('Reconstructor is not trainable!') - - return str() - - def openTrainingData(self, filePath: Path) -> None: - return self._reconstructorAPI.openTrainingData(filePath) - - def getSaveTrainingDataFileFilterList(self) -> Sequence[str]: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getSaveTrainingDataFileFilterList() - else: - logger.warning('Reconstructor is not trainable!') - - return list() + def isReconstructing(self) -> bool: + return self._reconstructorAPI.isReconstructing - def getSaveTrainingDataFileFilter(self) -> str: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getSaveTrainingDataFileFilter() - else: - logger.warning('Reconstructor is not trainable!') - - return str() + def flushLog(self) -> Iterator[str]: + for text in self._logHandler.messages(): + yield text - def saveTrainingData(self, filePath: Path) -> None: - return self._reconstructorAPI.saveTrainingData(filePath) + def processResults(self, *, block: bool) -> None: + self._reconstructorAPI.processResults(block=block) - def train(self) -> TrainOutput: - return self._reconstructorAPI.train() - - def clearTrainingData(self) -> None: - self._reconstructorAPI.clearTrainingData() - - def getOpenModelFileFilterList(self) -> Sequence[str]: + @property + def isTrainable(self) -> bool: reconstructor = self._reconstructorChooser.currentPlugin.strategy + return isinstance(reconstructor, TrainableReconstructor) - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getOpenModelFileFilterList() - else: - logger.warning('Reconstructor is not trainable!') - - return list() - - def getOpenModelFileFilter(self) -> str: + def getModelFileFilter(self) -> str: reconstructor = self._reconstructorChooser.currentPlugin.strategy if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getOpenModelFileFilter() + return reconstructor.getModelFileFilter() else: logger.warning('Reconstructor is not trainable!') @@ -144,28 +87,24 @@ def getOpenModelFileFilter(self) -> str: def openModel(self, filePath: Path) -> None: return self._reconstructorAPI.openModel(filePath) - def getSaveModelFileFilterList(self) -> Sequence[str]: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getSaveModelFileFilterList() - else: - logger.warning('Reconstructor is not trainable!') - - return list() + def saveModel(self, filePath: Path) -> None: + return self._reconstructorAPI.saveModel(filePath) - def getSaveModelFileFilter(self) -> str: + def getTrainingDataFileFilter(self) -> str: reconstructor = self._reconstructorChooser.currentPlugin.strategy if isinstance(reconstructor, TrainableReconstructor): - return reconstructor.getSaveModelFileFilter() + return reconstructor.getTrainingDataFileFilter() else: logger.warning('Reconstructor is not trainable!') return str() - def saveModel(self, filePath: Path) -> None: - return self._reconstructorAPI.saveModel(filePath) + def exportTrainingData(self, filePath: Path, inputProductIndex: int) -> None: + return self._reconstructorAPI.exportTrainingData(filePath, inputProductIndex) + + def train(self, dataPath: Path) -> TrainOutput: + return self._reconstructorAPI.train(dataPath) def update(self, observable: Observable) -> None: if observable is self._reconstructorChooser: diff --git a/src/ptychodus/model/reconstructor/queue.py b/src/ptychodus/model/reconstructor/queue.py new file mode 100644 index 00000000..4a50a45d --- /dev/null +++ b/src/ptychodus/model/reconstructor/queue.py @@ -0,0 +1,117 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +import logging +import queue +import threading +import time + +from ptychodus.api.reconstructor import Reconstructor, ReconstructInput, ReconstructOutput + +from ..product import ProductRepositoryItem + +logger = logging.getLogger(__name__) + +__all__ = ['ReconstructionQueue'] + + +class ReconstructionTask(ABC): + @abstractmethod + def execute(self) -> ReconstructionTask | None: + pass + + +class UpdateProductTask(ReconstructionTask): + def __init__(self, result: ReconstructOutput, product: ProductRepositoryItem) -> None: + self._result = result + self._product = product + + def execute(self) -> None: + name = self._product.getName() + self._product.assign(self._result.product) + self._product.setName(name) + + +class ExecuteReconstructorTask(ReconstructionTask): + def __init__( + self, + reconstructor: Reconstructor, + parameters: ReconstructInput, + product: ProductRepositoryItem, + ) -> None: + self._reconstructor = reconstructor + self._parameters = parameters + self._product = product + + def execute(self) -> UpdateProductTask: + tic = time.perf_counter() + result = self._reconstructor.reconstruct(self._parameters) + toc = time.perf_counter() + logger.info(f'Reconstruction time {toc - tic:.4f} seconds. (code={result.result})') + return UpdateProductTask(result, self._product) + + +class ReconstructionQueue: + def __init__(self) -> None: + self._inputQueue: queue.Queue[ExecuteReconstructorTask] = queue.Queue() + self._outputQueue: queue.Queue[UpdateProductTask] = queue.Queue() + self._stopWorkEvent = threading.Event() + self._worker = threading.Thread(target=self._reconstruct) + + @property + def isReconstructing(self) -> bool: + return self._inputQueue.unfinished_tasks > 0 + + def _reconstruct(self) -> None: + while not self._stopWorkEvent.is_set(): + try: + inputTask = self._inputQueue.get(block=True, timeout=1) + + try: + outputTask = inputTask.execute() + except Exception: + logger.exception('Reconstructor error!') + else: + self._outputQueue.put(outputTask) + finally: + self._inputQueue.task_done() + except queue.Empty: + pass + + def put( + self, + reconstructor: Reconstructor, + parameters: ReconstructInput, + product: ProductRepositoryItem, + ) -> None: + task = ExecuteReconstructorTask(reconstructor, parameters, product) + self._inputQueue.put(task) + + def processResults(self, *, block: bool) -> None: + while True: + try: + task = self._outputQueue.get(block=block) + + try: + task.execute() + finally: + self._outputQueue.task_done() + except queue.Empty: + break + + def start(self) -> None: + logger.info('Starting reconstructor...') + self._worker.start() + logger.info('Reconstructor started.') + + def stop(self) -> None: + logger.info('Finishing reconstructions...') + self._inputQueue.join() + + logger.info('Stopping reconstructor...') + self._stopWorkEvent.set() + self._worker.join() + self.processResults(block=False) + logger.info('Reconstructor stopped.') + + def __len__(self) -> int: + return self._inputQueue.qsize() diff --git a/src/ptychodus/model/tike/core.py b/src/ptychodus/model/tike/core.py index cc7c2d79..4587c948 100644 --- a/src/ptychodus/model/tike/core.py +++ b/src/ptychodus/model/tike/core.py @@ -1,5 +1,6 @@ from __future__ import annotations from collections.abc import Iterator +from importlib.metadata import version import logging from ptychodus.api.reconstructor import ( @@ -48,6 +49,9 @@ def createInstance( core.reconstructorList.append(NullReconstructor('rpie')) core.reconstructorList.append(NullReconstructor('lstsq_grad')) else: + tikeVersion = version('tike') + logger.info(f'Tike {tikeVersion}') + tikeReconstructor = TikeReconstructor( core.settings, core.multigridSettings, @@ -64,5 +68,9 @@ def createInstance( def name(self) -> str: return 'Tike' + @property + def logger_name(self) -> str: + return 'tike' + def __iter__(self) -> Iterator[Reconstructor]: return iter(self.reconstructorList) diff --git a/src/ptychodus/model/tike/reconstructor.py b/src/ptychodus/model/tike/reconstructor.py index c97469b2..e432794d 100644 --- a/src/ptychodus/model/tike/reconstructor.py +++ b/src/ptychodus/model/tike/reconstructor.py @@ -1,4 +1,3 @@ -from importlib.metadata import version from typing import Any import logging import pprint @@ -44,9 +43,6 @@ def __init__( self._probeCorrectionSettings = probeCorrectionSettings self._objectCorrectionSettings = objectCorrectionSettings - tikeVersion = version('tike') - logger.info(f'\tTike {tikeVersion}') - def getObjectOptions(self) -> tike.ptycho.ObjectOptions: settings = self._objectCorrectionSettings options = None @@ -129,11 +125,16 @@ def __call__( objectInput = parameters.product.object_ objectGeometry = objectInput.getGeometry() - # TODO change array[0] -> array when multislice is available - objectInputArray = objectInput.array[0].astype('complex64') + objectInputArray = objectInput.getArray().astype('complex64') + numberOfLayers = objectInput.numberOfLayers + + if numberOfLayers == 1: + objectInputArray = objectInputArray[0] + else: + raise ValueError(f'Tike does not support multislice (layers={numberOfLayers})!') probeInput = parameters.product.probe - probeInputArray = probeInput.array[numpy.newaxis, numpy.newaxis, ...].astype('complex64') + probeInputArray = probeInput.getArray().astype('complex64') scanInput = parameters.product.scan scanInputCoords: list[float] = list() @@ -163,8 +164,7 @@ def __call__( logger.debug(f'num_gpu={numGpus}') exitwave_options = tike.ptycho.ExitWaveOptions( - # TODO: Use a user supplied `measured_pixels` instead - measured_pixels=numpy.ones(probeInputArray.shape[-2:], dtype=numpy.bool_), + measured_pixels=parameters.goodPixelMask, noise_model=self._settings.noiseModel.getValue(), ) @@ -211,11 +211,7 @@ def __call__( scanOutput = Scan(scanOutputPoints) if self._probeCorrectionSettings.useProbeCorrection.getValue(): - probeOutput = Probe( - array=result.probe[0, 0], - pixelWidthInMeters=probeInput.pixelWidthInMeters, - pixelHeightInMeters=probeInput.pixelHeightInMeters, - ) + probeOutput = Probe(array=result.probe, pixelGeometry=probeInput.getPixelGeometry()) else: probeOutput = probeInput.copy() @@ -223,10 +219,8 @@ def __call__( objectOutput = Object( array=result.psi, layerDistanceInMeters=objectInput.layerDistanceInMeters, - pixelWidthInMeters=objectInput.pixelWidthInMeters, - pixelHeightInMeters=objectInput.pixelHeightInMeters, - centerXInMeters=objectInput.centerXInMeters, - centerYInMeters=objectInput.centerYInMeters, + pixelGeometry=objectInput.getPixelGeometry(), + center=objectInput.getCenter(), ) else: objectOutput = objectInput.copy() diff --git a/src/ptychodus/model/workflow/api.py b/src/ptychodus/model/workflow/api.py index 49f90319..29ce37e3 100644 --- a/src/ptychodus/model/workflow/api.py +++ b/src/ptychodus/model/workflow/api.py @@ -69,11 +69,10 @@ def buildObject( else: self._objectAPI.buildObject(self._productIndex, builderName, builderParameters) - def reconstructLocal(self, outputProductName: str) -> WorkflowProductAPI: + def reconstructLocal(self) -> WorkflowProductAPI: logger.debug(f'Reconstruct: index={self._productIndex}') - outputProductIndex = self._reconstructorAPI.reconstruct( - self._productIndex, outputProductName - ) + outputProductIndex = self._reconstructorAPI.reconstruct(self._productIndex) + return ConcreteWorkflowProductAPI( self._productAPI, self._scanAPI, @@ -156,7 +155,7 @@ def createProduct( comments: str = '', detectorDistanceInMeters: float | None = None, probeEnergyInElectronVolts: float | None = None, - probePhotonsPerSecond: float | None = None, + probePhotonCount: float | None = None, exposureTimeInSeconds: float | None = None, ) -> WorkflowProductAPI: productIndex = self._productAPI.insertNewProduct( @@ -164,7 +163,7 @@ def createProduct( comments=comments, detectorDistanceInMeters=detectorDistanceInMeters, probeEnergyInElectronVolts=probeEnergyInElectronVolts, - probePhotonsPerSecond=probePhotonsPerSecond, + probePhotonCount=probePhotonCount, exposureTimeInSeconds=exposureTimeInSeconds, ) return self._createProductAPI(productIndex) diff --git a/src/ptychodus/model/workflow/core.py b/src/ptychodus/model/workflow/core.py index e08eaa82..2ae2da5b 100644 --- a/src/ptychodus/model/workflow/core.py +++ b/src/ptychodus/model/workflow/core.py @@ -1,4 +1,3 @@ -from __future__ import annotations from collections.abc import Sequence from datetime import datetime from pathlib import Path @@ -38,20 +37,10 @@ def __init__( self._computeDataLocator = computeDataLocator self._outputDataLocator = outputDataLocator - @classmethod - def createInstance( - cls, - settings: WorkflowSettings, - inputDataLocator: DataLocator, - computeDataLocator: DataLocator, - outputDataLocator: OutputDataLocator, - ) -> WorkflowParametersPresenter: - presenter = cls(settings, inputDataLocator, computeDataLocator, outputDataLocator) - settings.addObserver(presenter) - inputDataLocator.addObserver(presenter) - computeDataLocator.addObserver(presenter) - outputDataLocator.addObserver(presenter) - return presenter + settings.addObserver(self) + inputDataLocator.addObserver(self) + computeDataLocator.addObserver(self) + outputDataLocator.addObserver(self) def setInputDataEndpointID(self, endpointID: UUID) -> None: self._inputDataLocator.setEndpointID(endpointID) @@ -122,11 +111,11 @@ def getOutputDataPosixPath(self) -> Path: def update(self, observable: Observable) -> None: if observable is self._settings: self.notifyObservers() - elif observable is self._inputDataLocator: - self.notifyObservers() - elif observable is self._computeDataLocator: - self.notifyObservers() - elif observable is self._outputDataLocator: + elif observable in ( + self._inputDataLocator, + self._computeDataLocator, + self._outputDataLocator, + ): self.notifyObservers() @@ -145,13 +134,16 @@ def setCodeFromAuthorizeURL(self, code: str) -> None: self._authorizer.setCodeFromAuthorizeURL(code) -class WorkflowStatusPresenter: +class WorkflowStatusPresenter(Observable, Observer): def __init__( self, settings: WorkflowSettings, statusRepository: WorkflowStatusRepository ) -> None: + super().__init__() self._settings = settings self._statusRepository = statusRepository + settings.addObserver(self) + def getRefreshIntervalLimitsInSeconds(self) -> Interval[int]: return Interval[int](10, 86400) @@ -180,6 +172,10 @@ def getStatusDateTime(self) -> datetime: def refreshStatus(self) -> None: self._statusRepository.refreshStatus() + def update(self, observable: Observable) -> None: + if observable is self._settings: + self.notifyObservers() + class WorkflowExecutionPresenter: def __init__(self, executor: WorkflowExecutor) -> None: @@ -238,7 +234,7 @@ def __init__( self._authorizer, self._statusRepository, self._executor ) - self.parametersPresenter = WorkflowParametersPresenter.createInstance( + self.parametersPresenter = WorkflowParametersPresenter( self._settings, self._inputDataLocator, self._computeDataLocator, diff --git a/src/ptychodus/model/workflow/executor.py b/src/ptychodus/model/workflow/executor.py index 594a59fb..6916aa1f 100644 --- a/src/ptychodus/model/workflow/executor.py +++ b/src/ptychodus/model/workflow/executor.py @@ -77,9 +77,9 @@ def runFlow(self, inputProductIndex: int) -> None: ) flowInput = { - 'input_data_transfer_source_endpoint_id': str(self._inputDataLocator.getEndpointID()), + 'input_data_transfer_source_endpoint': str(self._inputDataLocator.getEndpointID()), 'input_data_transfer_source_path': inputDataGlobusPath, - 'input_data_transfer_destination_endpoint_id': str( + 'input_data_transfer_destination_endpoint': str( self._computeDataLocator.getEndpointID() ), 'input_data_transfer_destination_path': computeDataGlobusPath, @@ -91,11 +91,9 @@ def runFlow(self, inputProductIndex: int) -> None: 'ptychodus_patterns_file': str(computeDataPosixPath / patternsFile), 'ptychodus_input_file': str(computeDataPosixPath / inputFile), 'ptychodus_output_file': str(computeDataPosixPath / outputFile), - 'output_data_transfer_source_endpoint_id': str( - self._computeDataLocator.getEndpointID() - ), + 'output_data_transfer_source_endpoint': str(self._computeDataLocator.getEndpointID()), 'output_data_transfer_source_path': f'{computeDataGlobusPath}/{outputFile}', - 'output_data_transfer_destination_endpoint_id': str( + 'output_data_transfer_destination_endpoint': str( self._outputDataLocator.getEndpointID() ), 'output_data_transfer_destination_path': f'{outputDataGlobusPath}/{outputFile}', diff --git a/src/ptychodus/model/workflow/globus.py b/src/ptychodus/model/workflow/globus.py index 7df6975e..59e14120 100644 --- a/src/ptychodus/model/workflow/globus.py +++ b/src/ptychodus/model/workflow/globus.py @@ -60,8 +60,6 @@ class PtychodusReconstruct(gladier.GladierBaseTool): @gladier.generate_flow_definition class PtychodusClient(gladier.GladierBaseClient): client_id = PTYCHODUS_CLIENT_ID - globus_group = '13e5512f-e761-11ec-8a9e-ff9dc0f99d56' - gladier_tools = [ 'gladier_tools.globus.transfer.Transfer:InputData', PtychodusReconstruct, diff --git a/src/ptychodus/plugins/csaxsDiffractionFile.py b/src/ptychodus/plugins/csaxsDiffractionFile.py new file mode 100644 index 00000000..af755d8e --- /dev/null +++ b/src/ptychodus/plugins/csaxsDiffractionFile.py @@ -0,0 +1,85 @@ +from pathlib import Path +from typing import Final +import logging + +import h5py + +from ptychodus.api.geometry import ImageExtent, PixelGeometry +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry + +from .h5DiffractionFile import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder + +logger = logging.getLogger(__name__) + + +class CSAXSDiffractionFileReader(DiffractionFileReader): + SIMPLE_NAME: Final[str] = 'SLS_cSAXS' + DISPLAY_NAME: Final[str] = 'SLS cSAXS Diffraction Files (*.h5 *.hdf5)' + ONE_MICRON_M: Final[float] = 1e-6 + ONE_MILLIMETER_M: Final[float] = 1e-3 + + def __init__(self) -> None: + self._dataPath = '/entry/data/data' + self._treeBuilder = H5DiffractionFileTreeBuilder() + + def read(self, filePath: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.createNullInstance(filePath) + + try: + with h5py.File(filePath, 'r') as h5File: + contentsTree = self._treeBuilder.build(h5File) + + try: + data = h5File[self._dataPath] + x_pixel_size_um = h5File['/entry/instrument/eiger_4/x_pixel_size'] + y_pixel_size_um = h5File['/entry/instrument/eiger_4/y_pixel_size'] + distance_mm = h5File['/entry/instrument/monochromator/distance'] + energy_keV = h5File['/entry/instrument/monochromator/energy'] + except KeyError: + logger.warning('Unable to load data.') + else: + numberOfPatterns, detectorHeight, detectorWidth = data.shape + detectorDistanceInMeters = float(distance_mm[()]) * self.ONE_MILLIMETER_M + detectorPixelGeometry = PixelGeometry( + widthInMeters=float(x_pixel_size_um[()]) * self.ONE_MICRON_M, + heightInMeters=float(y_pixel_size_um[()]) * self.ONE_MICRON_M, + ) + probeEnergyInElectronVolts = 1000 * float(energy_keV[()]) + + metadata = DiffractionMetadata( + numberOfPatternsPerArray=numberOfPatterns, + numberOfPatternsTotal=numberOfPatterns, + patternDataType=data.dtype, + detectorDistanceInMeters=abs(detectorDistanceInMeters), + detectorExtent=ImageExtent(detectorWidth, detectorHeight), + detectorPixelGeometry=detectorPixelGeometry, + probeEnergyInElectronVolts=probeEnergyInElectronVolts, + filePath=filePath, + ) + + array = H5DiffractionPatternArray( + label=filePath.stem, + index=0, + filePath=filePath, + dataPath=self._dataPath, + ) + + dataset = SimpleDiffractionDataset(metadata, contentsTree, [array]) + except OSError: + logger.warning(f'Unable to read file "{filePath}".') + + return dataset + + +def registerPlugins(registry: PluginRegistry) -> None: + registry.diffractionFileReaders.registerPlugin( + CSAXSDiffractionFileReader(), + simpleName=CSAXSDiffractionFileReader.SIMPLE_NAME, + displayName=CSAXSDiffractionFileReader.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/csvObjectFile.py b/src/ptychodus/plugins/csvObjectFile.py index 69064259..12ea97e1 100644 --- a/src/ptychodus/plugins/csvObjectFile.py +++ b/src/ptychodus/plugins/csvObjectFile.py @@ -9,12 +9,12 @@ class CSVObjectFileReader(ObjectFileReader): def read(self, filePath: Path) -> Object: array = numpy.genfromtxt(filePath, delimiter=',', dtype='complex') - return Object(array) + return Object(array=array, pixelGeometry=None, center=None) class CSVObjectFileWriter(ObjectFileWriter): def write(self, filePath: Path, object_: Object) -> None: - array = object_.array + array = object_.getArray() numpy.savetxt(filePath, array, delimiter=',') diff --git a/src/ptychodus/plugins/csvProbeFile.py b/src/ptychodus/plugins/csvProbeFile.py index 81120e6e..507fcd8a 100644 --- a/src/ptychodus/plugins/csvProbeFile.py +++ b/src/ptychodus/plugins/csvProbeFile.py @@ -17,12 +17,12 @@ def read(self, filePath: Path) -> Probe: if numberOfModes > 1: array = arrayFlat.reshape(numberOfModes, arrayFlat.shape[1], arrayFlat.shape[1]) - return Probe(array) + return Probe(array=array, pixelGeometry=None) class CSVProbeFileWriter(ProbeFileWriter): def write(self, filePath: Path, probe: Probe) -> None: - array = probe.array + array = probe.getArray() arrayFlat = array.reshape(-1, array.shape[-1]) numpy.savetxt(filePath, arrayFlat, delimiter=',') diff --git a/src/ptychodus/plugins/cxiDiffractionFile.py b/src/ptychodus/plugins/cxiFile.py similarity index 65% rename from src/ptychodus/plugins/cxiDiffractionFile.py rename to src/ptychodus/plugins/cxiFile.py index ab629705..7713c8d2 100644 --- a/src/ptychodus/plugins/cxiDiffractionFile.py +++ b/src/ptychodus/plugins/cxiFile.py @@ -1,9 +1,10 @@ from pathlib import Path +from typing import Final import logging import h5py -from ptychodus.api.constants import ELECTRON_VOLT_J +from .h5DiffractionFile import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder from ptychodus.api.geometry import ImageExtent, PixelGeometry from ptychodus.api.patterns import ( DiffractionDataset, @@ -12,8 +13,10 @@ SimpleDiffractionDataset, ) from ptychodus.api.plugins import PluginRegistry - -from .h5DiffractionFile import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder +from ptychodus.api.probe import Probe, ProbeFileReader +from ptychodus.api.product import ELECTRON_VOLT_J +from ptychodus.api.propagator import WavefieldArrayType +from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint logger = logging.getLogger(__name__) @@ -48,6 +51,9 @@ def read(self, filePath: Path) -> DiffractionDataset: probeEnergyInJoules = float(h5File['/entry_1/instrument_1/source_1/energy'][()]) probeEnergyInElectronVolts = probeEnergyInJoules / ELECTRON_VOLT_J + # TODO load detector mask; zeros are good pixels + # /entry_1/instrument_1/detector_1/mask Dataset {512, 512} + metadata = DiffractionMetadata( numberOfPatternsPerArray=numberOfPatterns, numberOfPatternsTotal=numberOfPatterns, @@ -73,9 +79,46 @@ def read(self, filePath: Path) -> DiffractionDataset: return dataset +class CXIScanFileReader(ScanFileReader): + def read(self, filePath: Path) -> Scan: + scanPointList: list[ScanPoint] = list() + + with h5py.File(filePath, 'r') as h5File: + xyz_m = h5File['/entry_1/data_1/translation'][()] + + for idx, (x, y, z) in enumerate(xyz_m): + point = ScanPoint(idx, x, y) + scanPointList.append(point) + + return Scan(scanPointList) + + +class CXIProbeFileReader(ProbeFileReader): + def read(self, filePath: Path) -> Probe: + array: WavefieldArrayType | None = None + + with h5py.File(filePath, 'r') as h5File: + array = h5File['/entry_1/instrument_1/source_1/illumination'][()] + + return Probe(array=array, pixelGeometry=None) + + def registerPlugins(registry: PluginRegistry) -> None: + SIMPLE_NAME: Final[str] = 'CXI' + DISPLAY_NAME: Final[str] = 'Coherent X-ray Imaging Files (*.cxi)' + registry.diffractionFileReaders.registerPlugin( CXIDiffractionFileReader(), - simpleName='CXI', - displayName='Coherent X-ray Imaging Files (*.cxi)', + simpleName=SIMPLE_NAME, + displayName=DISPLAY_NAME, + ) + registry.scanFileReaders.registerPlugin( + CXIScanFileReader(), + simpleName=SIMPLE_NAME, + displayName=DISPLAY_NAME, + ) + registry.probeFileReaders.registerPlugin( + CXIProbeFileReader(), + simpleName=SIMPLE_NAME, + displayName=DISPLAY_NAME, ) diff --git a/src/ptychodus/plugins/cxiProbeFile.py b/src/ptychodus/plugins/cxiProbeFile.py deleted file mode 100644 index b3d4fd9a..00000000 --- a/src/ptychodus/plugins/cxiProbeFile.py +++ /dev/null @@ -1,31 +0,0 @@ -from pathlib import Path -import logging - -import h5py -import numpy - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.probe import Probe, ProbeFileReader - -logger = logging.getLogger(__name__) - - -class CXIProbeFileReader(ProbeFileReader): - def read(self, filePath: Path) -> Probe: - array = numpy.zeros((0, 0, 0), dtype=complex) - - with h5py.File(filePath, 'r') as h5File: - try: - array = h5File['/entry_1/instrument_1/source_1/illumination'][()] - except KeyError: - logger.warning('Unable to load probe.') - - return Probe(array) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.probeFileReaders.registerPlugin( - CXIProbeFileReader(), - simpleName='CXI', - displayName='Coherent X-ray Imaging Files (*.cxi)', - ) diff --git a/src/ptychodus/plugins/cxiScanFile.py b/src/ptychodus/plugins/cxiScanFile.py deleted file mode 100644 index 88485c7b..00000000 --- a/src/ptychodus/plugins/cxiScanFile.py +++ /dev/null @@ -1,39 +0,0 @@ -from pathlib import Path -import logging - -import h5py - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint - -logger = logging.getLogger(__name__) - - -class CXIScanFileReader(ScanFileReader): - def read(self, filePath: Path) -> Scan: - pointList: list[ScanPoint] = list() - - with h5py.File(filePath, 'r') as h5File: - try: - xyzArray = h5File['/entry_1/data_1/translation'][()] - except KeyError: - logger.exception('Unable to load scan.') - else: - for idx, xyz in enumerate(xyzArray): - try: - x, y, z = xyz - except ValueError: - logger.exception(f'Unable to load scan point {xyz=}.') - else: - point = ScanPoint(idx, x, y) - pointList.append(point) - - return Scan(pointList) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.scanFileReaders.registerPlugin( - CXIScanFileReader(), - simpleName='CXI', - displayName='Coherent X-ray Imaging Files (*.cxi)', - ) diff --git a/src/ptychodus/plugins/h5DiffractionFile.py b/src/ptychodus/plugins/h5DiffractionFile.py index 8c9a978b..fb69dfcd 100644 --- a/src/ptychodus/plugins/h5DiffractionFile.py +++ b/src/ptychodus/plugins/h5DiffractionFile.py @@ -67,13 +67,15 @@ def _addAttributes( itemDetails = f'STRING = "{value}"' elif isinstance(value, h5py.Empty): logger.debug(f'Skipping empty attribute {name}.') + elif isinstance(value, numpy.ndarray): + itemDetails = f'ARRAY = {value}' else: stringInfo = h5py.check_string_dtype(value.dtype) - itemDetails = ( - f'STRING = "{value.decode(stringInfo.encoding)}"' - if stringInfo - else f'SCALAR {value.dtype} = {value}' - ) + + if stringInfo: + itemDetails = f'STRING = "{value.decode(stringInfo.encoding)}"' + else: + itemDetails = f'SCALAR {value.dtype} = {value}' treeNode.createChild([str(name), 'Attribute', itemDetails]) @@ -112,13 +114,18 @@ def build(self, h5File: h5py.File) -> SimpleTreeNode: if isinstance(value, bytes): itemDetails = value.decode() + elif isinstance(value, numpy.ndarray): + itemDetails = f'STRING = {h5Item.asstr()}' else: stringInfo = h5py.check_string_dtype(value.dtype) - itemDetails = ( - f'STRING = "{value.decode(stringInfo.encoding)}"' - if stringInfo - else f'SCALAR {value.dtype} = {value}' - ) + + if stringInfo: + itemDetails = f'STRING = "{value.decode(stringInfo.encoding)}"' + else: + itemDetails = f'SCALAR {value.dtype} = {value}' + elif h5Item.size == 1: + value = h5Item[()] + itemDetails = f'DATASET {value.dtype} = {value}' else: itemDetails = f'{h5Item.shape} {h5Item.dtype}' elif isinstance(h5Item, h5py.SoftLink): @@ -196,8 +203,8 @@ def registerPlugins(registry: PluginRegistry) -> None: ) registry.diffractionFileReaders.registerPlugin( H5DiffractionFileReader(dataPath='/entry/measurement/Eiger/data'), - simpleName='NanoMax', - displayName='NanoMax Diffraction Files (*.h5 *.hdf5)', + simpleName='MAX_IV_NanoMax', + displayName='MAX IV NanoMax Diffraction Files (*.h5 *.hdf5)', ) registry.diffractionFileReaders.registerPlugin( H5DiffractionFileReader(dataPath='/dp'), diff --git a/src/ptychodus/plugins/h5ProductFile.py b/src/ptychodus/plugins/h5ProductFile.py index 4337465b..9c8b6b47 100644 --- a/src/ptychodus/plugins/h5ProductFile.py +++ b/src/ptychodus/plugins/h5ProductFile.py @@ -4,7 +4,8 @@ import h5py -from ptychodus.api.object import Object +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import Object, ObjectCenter from ptychodus.api.plugins import PluginRegistry from ptychodus.api.probe import Probe from ptychodus.api.product import ( @@ -26,7 +27,7 @@ class H5ProductFileIO(ProductFileReader, ProductFileWriter): COMMENTS: Final[str] = 'comments' DETECTOR_OBJECT_DISTANCE: Final[str] = 'detector_object_distance_m' PROBE_ENERGY: Final[str] = 'probe_energy_eV' - PROBE_PHOTON_FLUX: Final[str] = 'probe_photons_per_s' + PROBE_PHOTON_COUNT: Final[str] = 'probe_photon_count' EXPOSURE_TIME: Final[str] = 'exposure_time_s' PROBE_ARRAY: Final[str] = 'probe' @@ -49,12 +50,19 @@ def read(self, filePath: Path) -> Product: scanPointList: list[ScanPoint] = list() with h5py.File(filePath, 'r') as h5File: + probePhotonCount = 0.0 + + try: + probePhotonCount = float(h5File.attrs[self.PROBE_PHOTON_COUNT]) + except KeyError: + logger.debug('Probe photon count not found.') + metadata = ProductMetadata( name=str(h5File.attrs[self.NAME]), comments=str(h5File.attrs[self.COMMENTS]), detectorDistanceInMeters=float(h5File.attrs[self.DETECTOR_OBJECT_DISTANCE]), probeEnergyInElectronVolts=float(h5File.attrs[self.PROBE_ENERGY]), - probePhotonsPerSecond=float(h5File.attrs[self.PROBE_PHOTON_FLUX]), + probePhotonCount=probePhotonCount, exposureTimeInSeconds=float(h5File.attrs[self.EXPOSURE_TIME]), ) @@ -67,21 +75,30 @@ def read(self, filePath: Path) -> Product: scanPointList.append(point) h5Probe = h5File[self.PROBE_ARRAY] + probePixelGeometry = PixelGeometry( + widthInMeters=float(h5Probe.attrs[self.PROBE_PIXEL_WIDTH]), + heightInMeters=float(h5Probe.attrs[self.PROBE_PIXEL_HEIGHT]), + ) probe = Probe( array=h5Probe[()], - pixelWidthInMeters=float(h5Probe.attrs[self.PROBE_PIXEL_WIDTH]), - pixelHeightInMeters=float(h5Probe.attrs[self.PROBE_PIXEL_HEIGHT]), + pixelGeometry=probePixelGeometry, ) h5Object = h5File[self.OBJECT_ARRAY] + objectPixelGeometry = PixelGeometry( + widthInMeters=float(h5Object.attrs[self.OBJECT_PIXEL_WIDTH]), + heightInMeters=float(h5Object.attrs[self.OBJECT_PIXEL_HEIGHT]), + ) + objectCenter = ObjectCenter( + positionXInMeters=float(h5Object.attrs[self.OBJECT_CENTER_X]), + positionYInMeters=float(h5Object.attrs[self.OBJECT_CENTER_Y]), + ) h5ObjectLayerDistance = h5File[self.OBJECT_LAYER_DISTANCE] object_ = Object( array=h5Object[()], + pixelGeometry=objectPixelGeometry, + center=objectCenter, layerDistanceInMeters=h5ObjectLayerDistance[()], - pixelWidthInMeters=float(h5Object.attrs[self.OBJECT_PIXEL_WIDTH]), - pixelHeightInMeters=float(h5Object.attrs[self.OBJECT_PIXEL_HEIGHT]), - centerXInMeters=float(h5Object.attrs[self.OBJECT_CENTER_X]), - centerYInMeters=float(h5Object.attrs[self.OBJECT_CENTER_Y]), ) h5Costs = h5File[self.COSTS_ARRAY] @@ -111,7 +128,7 @@ def write(self, filePath: Path, product: Product) -> None: h5File.attrs[self.COMMENTS] = metadata.comments h5File.attrs[self.DETECTOR_OBJECT_DISTANCE] = metadata.detectorDistanceInMeters h5File.attrs[self.PROBE_ENERGY] = metadata.probeEnergyInElectronVolts - h5File.attrs[self.PROBE_PHOTON_FLUX] = metadata.probePhotonsPerSecond + h5File.attrs[self.PROBE_PHOTON_COUNT] = metadata.probePhotonCount h5File.attrs[self.EXPOSURE_TIME] = metadata.exposureTimeInSeconds h5File.create_dataset(self.PROBE_POSITION_INDEXES, data=scanIndexes) @@ -120,13 +137,13 @@ def write(self, filePath: Path, product: Product) -> None: probe = product.probe probeGeometry = probe.getGeometry() - h5Probe = h5File.create_dataset(self.PROBE_ARRAY, data=probe.array) + h5Probe = h5File.create_dataset(self.PROBE_ARRAY, data=probe.getArray()) h5Probe.attrs[self.PROBE_PIXEL_WIDTH] = probeGeometry.pixelWidthInMeters h5Probe.attrs[self.PROBE_PIXEL_HEIGHT] = probeGeometry.pixelHeightInMeters object_ = product.object_ objectGeometry = object_.getGeometry() - h5Object = h5File.create_dataset(self.OBJECT_ARRAY, data=object_.array) + h5Object = h5File.create_dataset(self.OBJECT_ARRAY, data=object_.getArray()) h5Object.attrs[self.OBJECT_CENTER_X] = objectGeometry.centerXInMeters h5Object.attrs[self.OBJECT_CENTER_Y] = objectGeometry.centerYInMeters h5Object.attrs[self.OBJECT_PIXEL_WIDTH] = objectGeometry.pixelWidthInMeters diff --git a/src/ptychodus/plugins/lynxDiffractionFile.py b/src/ptychodus/plugins/lynxDiffractionFile.py index 477dc6d6..67627842 100644 --- a/src/ptychodus/plugins/lynxDiffractionFile.py +++ b/src/ptychodus/plugins/lynxDiffractionFile.py @@ -63,6 +63,6 @@ def read(self, filePath: Path) -> DiffractionDataset: def registerPlugins(registry: PluginRegistry) -> None: registry.diffractionFileReaders.registerPlugin( LYNXDiffractionFileReader(), - simpleName='LYNX', - displayName='LYNX Diffraction Files (*.h5 *.hdf5)', + simpleName='APS_LYNX', + displayName='APS LYNX Diffraction Files (*.h5 *.hdf5)', ) diff --git a/src/ptychodus/plugins/matObjectFile.py b/src/ptychodus/plugins/matObjectFile.py deleted file mode 100644 index aaaa9a92..00000000 --- a/src/ptychodus/plugins/matObjectFile.py +++ /dev/null @@ -1,49 +0,0 @@ -from pathlib import Path - -import numpy -import scipy.io - -from ptychodus.api.object import Object, ObjectFileReader, ObjectFileWriter -from ptychodus.api.plugins import PluginRegistry - - -class MATObjectFileReader(ObjectFileReader): - def read(self, filePath: Path) -> Object: - matDict = scipy.io.loadmat(filePath) - array = matDict['object'] - - if array.ndim == 3: - # array[width, height, num_layers] - array = array.transpose(2, 0, 1) - - try: - p = matDict['p'][0, 0] - multi_slice_param = p['multi_slice_param'][0, 0] - layerDistanceInMeters = numpy.squeeze(multi_slice_param['z_distance']) - except ValueError: - object_ = Object(array) - else: - object_ = Object(array, layerDistanceInMeters) - - return object_ - - -class MATObjectFileWriter(ObjectFileWriter): - def write(self, filePath: Path, object_: Object) -> None: - array = object_.array - matDict = {'object': array.transpose(1, 2, 0)} - # TODO layer distance to p.z_distance - scipy.io.savemat(filePath, matDict) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.objectFileReaders.registerPlugin( - MATObjectFileReader(), - simpleName='MAT', - displayName='MAT Files (*.mat)', - ) - registry.objectFileWriters.registerPlugin( - MATObjectFileWriter(), - simpleName='MAT', - displayName='MAT Files (*.mat)', - ) diff --git a/src/ptychodus/plugins/matProbeFile.py b/src/ptychodus/plugins/matProbeFile.py deleted file mode 100644 index c72613bb..00000000 --- a/src/ptychodus/plugins/matProbeFile.py +++ /dev/null @@ -1,42 +0,0 @@ -from pathlib import Path - -import scipy.io - -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.probe import Probe, ProbeFileReader, ProbeFileWriter - - -class MATProbeFileReader(ProbeFileReader): - def read(self, filePath: Path) -> Probe: - matDict = scipy.io.loadmat(filePath) - array = matDict['probe'] - - if array.ndim == 4: - # array[width, height, num_shared_modes, num_varying_modes] - array = array[..., 0] - - if array.ndim == 3: - # array[width, height, num_shared_modes] - array = array.transpose(2, 0, 1) - - return Probe(array) - - -class MATProbeFileWriter(ProbeFileWriter): - def write(self, filePath: Path, probe: Probe) -> None: - array = probe.array - matDict = {'probe': array.transpose(1, 2, 0)} - scipy.io.savemat(filePath, matDict) - - -def registerPlugins(registry: PluginRegistry) -> None: - registry.probeFileReaders.registerPlugin( - MATProbeFileReader(), - simpleName='MAT', - displayName='MAT Files (*.mat)', - ) - registry.probeFileWriters.registerPlugin( - MATProbeFileWriter(), - simpleName='MAT', - displayName='MAT Files (*.mat)', - ) diff --git a/src/ptychodus/plugins/mdaScanFile.py b/src/ptychodus/plugins/mdaScanFile.py index 8ae4728b..112cb5f6 100644 --- a/src/ptychodus/plugins/mdaScanFile.py +++ b/src/ptychodus/plugins/mdaScanFile.py @@ -460,8 +460,8 @@ def read(self, filePath: Path) -> Scan: for x in xarray: point = ScanPoint( index=len(pointList), - positionXInMeters=x * self.MICRONS_TO_METERS, - positionYInMeters=y * self.MICRONS_TO_METERS, + positionXInMeters=float(x) * self.MICRONS_TO_METERS, + positionYInMeters=float(y) * self.MICRONS_TO_METERS, ) pointList.append(point) @@ -482,8 +482,8 @@ def read(self, filePath: Path) -> Scan: for idx, (x, y) in enumerate(zip(xarray, yarray)): point = ScanPoint( index=idx, - positionXInMeters=x * self.MICRONS_TO_METERS, - positionYInMeters=y * self.MICRONS_TO_METERS, + positionXInMeters=float(x) * self.MICRONS_TO_METERS, + positionYInMeters=float(y) * self.MICRONS_TO_METERS, ) pointList.append(point) diff --git a/src/ptychodus/plugins/npyObjectFile.py b/src/ptychodus/plugins/npyObjectFile.py index 7389b277..24320ba9 100644 --- a/src/ptychodus/plugins/npyObjectFile.py +++ b/src/ptychodus/plugins/npyObjectFile.py @@ -9,12 +9,12 @@ class NPYObjectFileReader(ObjectFileReader): def read(self, filePath: Path) -> Object: array = numpy.load(filePath) - return Object(array) + return Object(array=array, pixelGeometry=None, center=None) class NPYObjectFileWriter(ObjectFileWriter): def write(self, filePath: Path, object_: Object) -> None: - array = object_.array + array = object_.getArray() numpy.save(filePath, array) diff --git a/src/ptychodus/plugins/npyProbeFile.py b/src/ptychodus/plugins/npyProbeFile.py index b5c58fe0..e5298c58 100644 --- a/src/ptychodus/plugins/npyProbeFile.py +++ b/src/ptychodus/plugins/npyProbeFile.py @@ -9,12 +9,12 @@ class NPYProbeFileReader(ProbeFileReader): def read(self, filePath: Path) -> Probe: array = numpy.load(filePath) - return Probe(array) + return Probe(array=array, pixelGeometry=None) class NPYProbeFileWriter(ProbeFileWriter): def write(self, filePath: Path, probe: Probe) -> None: - array = probe.array + array = probe.getArray() numpy.save(filePath, array) diff --git a/src/ptychodus/plugins/npzProductFile.py b/src/ptychodus/plugins/npzProductFile.py index 848a346d..615f914c 100644 --- a/src/ptychodus/plugins/npzProductFile.py +++ b/src/ptychodus/plugins/npzProductFile.py @@ -1,9 +1,11 @@ from pathlib import Path from typing import Any, Final +import logging import numpy -from ptychodus.api.object import Object, ObjectFileReader +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import Object, ObjectCenter, ObjectFileReader from ptychodus.api.plugins import PluginRegistry from ptychodus.api.probe import Probe, ProbeFileReader from ptychodus.api.product import ( @@ -14,16 +16,18 @@ ) from ptychodus.api.scan import Scan, ScanFileReader, ScanPoint +logger = logging.getLogger(__name__) + class NPZProductFileIO(ProductFileReader, ProductFileWriter): SIMPLE_NAME: Final[str] = 'NPZ' - DISPLAY_NAME: Final[str] = 'NumPy Zipped Archive (*.npz)' + DISPLAY_NAME: Final[str] = 'Ptychodus NumPy Zipped Archive (*.npz)' NAME: Final[str] = 'name' COMMENTS: Final[str] = 'comments' DETECTOR_OBJECT_DISTANCE: Final[str] = 'detector_object_distance_m' PROBE_ENERGY: Final[str] = 'probe_energy_eV' - PROBE_PHOTON_FLUX: Final[str] = 'probe_photons_per_s' + PROBE_PHOTON_COUNT: Final[str] = 'probe_photon_count' EXPOSURE_TIME: Final[str] = 'exposure_time_s' PROBE_ARRAY: Final[str] = 'probe' @@ -44,12 +48,19 @@ class NPZProductFileIO(ProductFileReader, ProductFileWriter): def read(self, filePath: Path) -> Product: with numpy.load(filePath) as npzFile: + probePhotonCount = 0.0 + + try: + probePhotonCount = float(npzFile[self.PROBE_PHOTON_COUNT]) + except KeyError: + logger.debug('Probe photon count not found.') + metadata = ProductMetadata( name=str(npzFile[self.NAME]), comments=str(npzFile[self.COMMENTS]), detectorDistanceInMeters=float(npzFile[self.DETECTOR_OBJECT_DISTANCE]), probeEnergyInElectronVolts=float(npzFile[self.PROBE_ENERGY]), - probePhotonsPerSecond=float(npzFile[self.PROBE_PHOTON_FLUX]), + probePhotonCount=probePhotonCount, exposureTimeInSeconds=float(npzFile[self.EXPOSURE_TIME]), ) @@ -57,19 +68,25 @@ def read(self, filePath: Path) -> Product: scanXInMeters = npzFile[self.PROBE_POSITION_X] scanYInMeters = npzFile[self.PROBE_POSITION_Y] - probe = Probe( - array=npzFile[self.PROBE_ARRAY], - pixelWidthInMeters=float(npzFile[self.PROBE_PIXEL_WIDTH]), - pixelHeightInMeters=float(npzFile[self.PROBE_PIXEL_HEIGHT]), + probePixelGeometry = PixelGeometry( + widthInMeters=float(npzFile[self.PROBE_PIXEL_WIDTH]), + heightInMeters=float(npzFile[self.PROBE_PIXEL_HEIGHT]), ) + probe = Probe(array=npzFile[self.PROBE_ARRAY], pixelGeometry=probePixelGeometry) + objectPixelGeometry = PixelGeometry( + widthInMeters=float(npzFile[self.OBJECT_PIXEL_WIDTH]), + heightInMeters=float(npzFile[self.OBJECT_PIXEL_HEIGHT]), + ) + objectCenter = ObjectCenter( + positionXInMeters=float(npzFile[self.OBJECT_CENTER_X]), + positionYInMeters=float(npzFile[self.OBJECT_CENTER_Y]), + ) object_ = Object( array=npzFile[self.OBJECT_ARRAY], + pixelGeometry=objectPixelGeometry, + center=objectCenter, layerDistanceInMeters=npzFile[self.OBJECT_LAYER_DISTANCE], - pixelWidthInMeters=float(npzFile[self.OBJECT_PIXEL_WIDTH]), - pixelHeightInMeters=float(npzFile[self.OBJECT_PIXEL_HEIGHT]), - centerXInMeters=float(npzFile[self.OBJECT_CENTER_X]), - centerYInMeters=float(npzFile[self.OBJECT_CENTER_Y]), ) costs = npzFile[self.COSTS_ARRAY] @@ -104,7 +121,7 @@ def write(self, filePath: Path, product: Product) -> None: contents[self.COMMENTS] = metadata.comments contents[self.DETECTOR_OBJECT_DISTANCE] = metadata.detectorDistanceInMeters contents[self.PROBE_ENERGY] = metadata.probeEnergyInElectronVolts - contents[self.PROBE_PHOTON_FLUX] = metadata.probePhotonsPerSecond + contents[self.PROBE_PHOTON_COUNT] = metadata.probePhotonCount contents[self.EXPOSURE_TIME] = metadata.exposureTimeInSeconds contents[self.PROBE_POSITION_INDEXES] = scanIndexes @@ -113,13 +130,13 @@ def write(self, filePath: Path, product: Product) -> None: probe = product.probe probeGeometry = probe.getGeometry() - contents[self.PROBE_ARRAY] = probe.array + contents[self.PROBE_ARRAY] = probe.getArray() contents[self.PROBE_PIXEL_WIDTH] = probeGeometry.pixelWidthInMeters contents[self.PROBE_PIXEL_HEIGHT] = probeGeometry.pixelHeightInMeters object_ = product.object_ objectGeometry = object_.getGeometry() - contents[self.OBJECT_ARRAY] = object_.array + contents[self.OBJECT_ARRAY] = object_.getArray() contents[self.OBJECT_CENTER_X] = objectGeometry.centerXInMeters contents[self.OBJECT_CENTER_Y] = objectGeometry.centerYInMeters contents[self.OBJECT_PIXEL_WIDTH] = objectGeometry.pixelWidthInMeters @@ -150,23 +167,29 @@ def read(self, filePath: Path) -> Scan: class NPZProbeFileReader(ProbeFileReader): def read(self, filePath: Path) -> Probe: with numpy.load(filePath) as npzFile: - return Probe( - array=npzFile[NPZProductFileIO.PROBE_ARRAY], - pixelWidthInMeters=float(npzFile[NPZProductFileIO.PROBE_PIXEL_WIDTH]), - pixelHeightInMeters=float(npzFile[NPZProductFileIO.PROBE_PIXEL_HEIGHT]), + pixelGeometry = PixelGeometry( + widthInMeters=float(npzFile[NPZProductFileIO.PROBE_PIXEL_WIDTH]), + heightInMeters=float(npzFile[NPZProductFileIO.PROBE_PIXEL_HEIGHT]), ) + return Probe(array=npzFile[NPZProductFileIO.PROBE_ARRAY], pixelGeometry=pixelGeometry) class NPZObjectFileReader(ObjectFileReader): def read(self, filePath: Path) -> Object: with numpy.load(filePath) as npzFile: + pixelGeometry = PixelGeometry( + widthInMeters=float(npzFile[NPZProductFileIO.OBJECT_PIXEL_WIDTH]), + heightInMeters=float(npzFile[NPZProductFileIO.OBJECT_PIXEL_HEIGHT]), + ) + center = ObjectCenter( + positionXInMeters=float(npzFile[NPZProductFileIO.OBJECT_CENTER_X]), + positionYInMeters=float(npzFile[NPZProductFileIO.OBJECT_CENTER_Y]), + ) return Object( array=npzFile[NPZProductFileIO.OBJECT_ARRAY], + pixelGeometry=pixelGeometry, + center=center, layerDistanceInMeters=npzFile[NPZProductFileIO.OBJECT_LAYER_DISTANCE], - pixelWidthInMeters=float(npzFile[NPZProductFileIO.OBJECT_PIXEL_WIDTH]), - pixelHeightInMeters=float(npzFile[NPZProductFileIO.OBJECT_PIXEL_HEIGHT]), - centerXInMeters=float(npzFile[NPZProductFileIO.OBJECT_CENTER_X]), - centerYInMeters=float(npzFile[NPZProductFileIO.OBJECT_CENTER_Y]), ) diff --git a/src/ptychodus/plugins/nslsIIDiffractionFile.py b/src/ptychodus/plugins/nslsIIDiffractionFile.py new file mode 100644 index 00000000..80df4913 --- /dev/null +++ b/src/ptychodus/plugins/nslsIIDiffractionFile.py @@ -0,0 +1,77 @@ +from pathlib import Path +from typing import Final +import logging + +import h5py +import numpy + +from ptychodus.api.geometry import ImageExtent, PixelGeometry +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + SimpleDiffractionDataset, +) +from ptychodus.api.plugins import PluginRegistry + +from .h5DiffractionFile import H5DiffractionPatternArray, H5DiffractionFileTreeBuilder + +logger = logging.getLogger(__name__) + + +class NSLSIIDiffractionFileReader(DiffractionFileReader): + SIMPLE_NAME: Final[str] = 'NSLS-II' + DISPLAY_NAME: Final[str] = 'NSLS-II Diffraction Files (*.mat)' + ONE_MICRON_M: Final[float] = 1e-6 + + def __init__(self) -> None: + self._dataPath = 'det_data' + self._treeBuilder = H5DiffractionFileTreeBuilder() + + def read(self, filePath: Path) -> DiffractionDataset: + dataset = SimpleDiffractionDataset.createNullInstance(filePath) + + try: + with h5py.File(filePath, 'r') as h5File: + contentsTree = self._treeBuilder.build(h5File) + + try: + data = h5File[self._dataPath] + pixelSizeInMicrons = h5File['det_pixel_size'] + except KeyError: + logger.warning('Unable to load data.') + else: + numberOfPatterns, detectorHeight, detectorWidth = data.shape + pixelSizeInMeters = ( + float(numpy.squeeze(pixelSizeInMicrons[()])) * self.ONE_MICRON_M + ) + + metadata = DiffractionMetadata( + numberOfPatternsPerArray=numberOfPatterns, + numberOfPatternsTotal=numberOfPatterns, + patternDataType=data.dtype, + detectorExtent=ImageExtent(detectorWidth, detectorHeight), + detectorPixelGeometry=PixelGeometry(pixelSizeInMeters, pixelSizeInMeters), + filePath=filePath, + ) + + array = H5DiffractionPatternArray( + label=filePath.stem, + index=0, + filePath=filePath, + dataPath=self._dataPath, + ) + + dataset = SimpleDiffractionDataset(metadata, contentsTree, [array]) + except OSError: + logger.warning(f'Unable to read file "{filePath}".') + + return dataset + + +def registerPlugins(registry: PluginRegistry) -> None: + registry.diffractionFileReaders.registerPlugin( + NSLSIIDiffractionFileReader(), + simpleName=NSLSIIDiffractionFileReader.SIMPLE_NAME, + displayName=NSLSIIDiffractionFileReader.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/nslsIIProductFile.py b/src/ptychodus/plugins/nslsIIProductFile.py new file mode 100644 index 00000000..4ed551bf --- /dev/null +++ b/src/ptychodus/plugins/nslsIIProductFile.py @@ -0,0 +1,75 @@ +from pathlib import Path +from typing import Final, Sequence + +import h5py + +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import Object +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import Probe +from ptychodus.api.product import Product, ProductFileReader, ProductMetadata +from ptychodus.api.scan import Scan, ScanPoint + + +class NSLSIIProductFileReader(ProductFileReader): + SIMPLE_NAME: Final[str] = 'NSLS-II' + DISPLAY_NAME: Final[str] = 'NSLS-II Product Files (*.mat)' + ONE_MICRON_M: Final[float] = 1e-6 + + def read(self, filePath: Path) -> Product: + point_list: list[ScanPoint] = list() + + with h5py.File(filePath, 'r') as h5File: + detector_distance_m = float(h5File['det_dist'][()]) * self.ONE_MICRON_M + probe_energy_eV = 1000.0 * float(h5File['energy'][()]) + + metadata = ProductMetadata( + name=filePath.stem, + comments='', + detectorDistanceInMeters=detector_distance_m, + probeEnergyInElectronVolts=probe_energy_eV, + probePhotonCount=0.0, # not included in file + exposureTimeInSeconds=0.0, # not included in file + ) + + pixel_width_m = h5File['img_pixel_size_x'][()] + pixel_height_m = h5File['img_pixel_size_y'][()] + pixel_geometry = PixelGeometry( + widthInMeters=pixel_width_m, heightInMeters=pixel_height_m + ) + positions_m = h5File['pos_xy'][()].T * self.ONE_MICRON_M + + for index, _xy in enumerate(positions_m): + point = ScanPoint( + index=index, + positionXInMeters=_xy[1], + positionYInMeters=_xy[2], + ) + point_list.append(point) + + probe_array = h5File['prb'][()].astype(complex) + probe = Probe(array=probe_array, pixelGeometry=pixel_geometry) + + object_array = h5File['obj'][()].astype(complex) + object_ = Object( + array=object_array, + pixelGeometry=pixel_geometry, + center=None, + ) + costs: Sequence[float] = list() + + return Product( + metadata=metadata, + scan=Scan(point_list), + probe=probe, + object_=object_, + costs=costs, + ) + + +def registerPlugins(registry: PluginRegistry) -> None: + registry.productFileReaders.registerPlugin( + NSLSIIProductFileReader(), + simpleName=NSLSIIProductFileReader.SIMPLE_NAME, + displayName=NSLSIIProductFileReader.DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/ptychoShelvesProductFile.py b/src/ptychodus/plugins/ptychoShelvesProductFile.py index c3f08036..a2d4c5e8 100644 --- a/src/ptychodus/plugins/ptychoShelvesProductFile.py +++ b/src/ptychodus/plugins/ptychoShelvesProductFile.py @@ -1,44 +1,28 @@ from pathlib import Path from typing import Final, Sequence +import numpy import scipy.io -from ptychodus.api.constants import ( +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import Object +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import Probe +from ptychodus.api.product import ( ELECTRON_VOLT_J, LIGHT_SPEED_M_PER_S, PLANCK_CONSTANT_J_PER_HZ, + Product, + ProductFileReader, + ProductMetadata, ) -from ptychodus.api.object import Object, ObjectArrayType, ObjectFileWriter -from ptychodus.api.plugins import PluginRegistry -from ptychodus.api.probe import Probe, ProbeFileWriter -from ptychodus.api.product import Product, ProductFileReader, ProductMetadata -from ptychodus.api.propagator import WavefieldArrayType from ptychodus.api.scan import Scan, ScanPoint -class MATProductFileReader(ProductFileReader): +class PtychoShelvesProductFileReader(ProductFileReader): SIMPLE_NAME: Final[str] = 'PtychoShelves' DISPLAY_NAME: Final[str] = 'PtychoShelves Files (*.mat)' - def _load_probe_array(self, probeMatrix: WavefieldArrayType) -> WavefieldArrayType: - if probeMatrix.ndim == 4: - # probeMatrix[width, height, num_shared_modes, num_varying_modes] - # TODO support spatially varying probe modes - probeMatrix = probeMatrix[..., 0] - - if probeMatrix.ndim == 3: - # probeMatrix[width, height, num_shared_modes] - probeMatrix = probeMatrix - - return probeMatrix.transpose(2, 0, 1) - - def _load_object_array(self, objectMatrix: ObjectArrayType) -> ObjectArrayType: - if objectMatrix.ndim == 3: - # objectMatrix[width, height, num_layers] - objectMatrix = objectMatrix.transpose(2, 0, 1) - - return objectMatrix - def read(self, filePath: Path) -> Product: scanPointList: list[ScanPoint] = list() @@ -52,13 +36,14 @@ def read(self, filePath: Path) -> Product: comments='', detectorDistanceInMeters=0.0, # not included in file probeEnergyInElectronVolts=probe_energy_eV, - probePhotonsPerSecond=0.0, # not included in file + probePhotonCount=0.0, # not included in file exposureTimeInSeconds=0.0, # not included in file ) dx_spec = p_struct['dx_spec'] pixel_width_m = dx_spec[0] pixel_height_m = dx_spec[1] + pixel_geometry = PixelGeometry(widthInMeters=pixel_width_m, heightInMeters=pixel_height_m) outputs_struct = matDict['outputs'] probe_positions = outputs_struct['probe_positions'] @@ -71,31 +56,39 @@ def read(self, filePath: Path) -> Product: ) scanPointList.append(point) - probe = Probe( - self._load_probe_array(matDict['probe']), - pixelWidthInMeters=pixel_width_m, - pixelHeightInMeters=pixel_height_m, - ) + probe_array = matDict['probe'] + + if probe_array.ndim == 3: + # probe_array[height, width, num_shared_modes] + probe_array = probe_array.transpose(2, 0, 1) + elif probe_array.ndim == 4: + # probe_array[height, width, num_shared_modes, num_varying_modes] + probe_array = probe_array.transpose(3, 2, 0, 1) - layer_distance_m: Sequence[float] | None = None + probe = Probe(array=probe_array, pixelGeometry=pixel_geometry) + + object_array = matDict['object'] + + if object_array.ndim == 3: + # object_array[height, width, num_layers] + object_array = object_array.transpose(2, 0, 1) + + layer_distance_m: Sequence[float] = list() try: multi_slice_param = p_struct['multi_slice_param'] + z_distance = multi_slice_param['z_distance'] except KeyError: pass else: - try: - z_distance = multi_slice_param['z_distance'] - except KeyError: - pass - else: - layer_distance_m = z_distance.tolist() + num_spaces = object_array.shape[-3] - 1 + layer_distance_m = numpy.squeeze(z_distance)[:num_spaces] object_ = Object( - self._load_object_array(matDict['object']), - layer_distance_m, - pixelWidthInMeters=pixel_width_m, - pixelHeightInMeters=pixel_height_m, + array=object_array, + pixelGeometry=pixel_geometry, + center=None, + layerDistanceInMeters=layer_distance_m, ) costs = outputs_struct['fourier_error_out'] @@ -108,34 +101,9 @@ def read(self, filePath: Path) -> Product: ) -class MATObjectFileWriter(ObjectFileWriter): - def write(self, filePath: Path, object_: Object) -> None: - array = object_.array - matDict = {'object': array.transpose(1, 2, 0)} - # TODO layer distance to p.z_distance - scipy.io.savemat(filePath, matDict) - - -class MATProbeFileWriter(ProbeFileWriter): - def write(self, filePath: Path, probe: Probe) -> None: - array = probe.array - matDict = {'probe': array.transpose(1, 2, 0)} - scipy.io.savemat(filePath, matDict) - - def registerPlugins(registry: PluginRegistry) -> None: registry.productFileReaders.registerPlugin( - MATProductFileReader(), - simpleName=MATProductFileReader.SIMPLE_NAME, - displayName=MATProductFileReader.DISPLAY_NAME, - ) - registry.probeFileWriters.registerPlugin( - MATProbeFileWriter(), - simpleName=MATProductFileReader.SIMPLE_NAME, - displayName=MATProductFileReader.DISPLAY_NAME, - ) - registry.objectFileWriters.registerPlugin( - MATObjectFileWriter(), - simpleName=MATProductFileReader.SIMPLE_NAME, - displayName=MATProductFileReader.DISPLAY_NAME, + PtychoShelvesProductFileReader(), + simpleName=PtychoShelvesProductFileReader.SIMPLE_NAME, + displayName=PtychoShelvesProductFileReader.DISPLAY_NAME, ) diff --git a/src/ptychodus/plugins/slacNPZFile.py b/src/ptychodus/plugins/slacNPZFile.py new file mode 100644 index 00000000..e3e3beb3 --- /dev/null +++ b/src/ptychodus/plugins/slacNPZFile.py @@ -0,0 +1,103 @@ +from pathlib import Path +from typing import Final, Sequence +import logging + +import numpy + +from ptychodus.api.geometry import ImageExtent +from ptychodus.api.object import Object +from ptychodus.api.patterns import ( + DiffractionDataset, + DiffractionFileReader, + DiffractionMetadata, + DiffractionPatternState, + SimpleDiffractionDataset, + SimpleDiffractionPatternArray, +) +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import Probe +from ptychodus.api.product import Product, ProductFileReader, ProductMetadata +from ptychodus.api.scan import Scan, ScanPoint +from ptychodus.api.tree import SimpleTreeNode + +logger = logging.getLogger(__name__) + + +class SLACDiffractionFileReader(DiffractionFileReader): + def read(self, filePath: Path) -> DiffractionDataset: + with numpy.load(filePath) as npzFile: + patterns = numpy.transpose(npzFile['diffraction'], [2, 0, 1]) + + numberOfPatterns, detectorHeight, detectorWidth = patterns.shape + + metadata = DiffractionMetadata( + numberOfPatternsPerArray=numberOfPatterns, + numberOfPatternsTotal=numberOfPatterns, + patternDataType=patterns.dtype, + detectorExtent=ImageExtent(detectorWidth, detectorHeight), + filePath=filePath, + ) + + contentsTree = SimpleTreeNode.createRoot(['Name', 'Type', 'Details']) + contentsTree.createChild( + [filePath.stem, type(patterns).__name__, f'{patterns.dtype}{patterns.shape}'] + ) + + array = SimpleDiffractionPatternArray( + label=filePath.stem, + index=0, + data=patterns, + state=DiffractionPatternState.FOUND, + ) + + return SimpleDiffractionDataset(metadata, contentsTree, [array]) + + +class SLACProductFileReader(ProductFileReader): + def read(self, filePath: Path) -> Product: + with numpy.load(filePath) as npzFile: + scanXInMeters = npzFile['xcoords_start'] + scanYInMeters = npzFile['ycoords_start'] + probeArray = npzFile['probeGuess'] + objectArray = npzFile['objectGuess'] + + metadata = ProductMetadata( + name=filePath.stem, + comments='', + detectorDistanceInMeters=0.0, # not included in file + probeEnergyInElectronVolts=0.0, # not included in file + probePhotonCount=0.0, # not included in file + exposureTimeInSeconds=0.0, # not included in file + ) + + scanPointList: list[ScanPoint] = list() + + for idx, (x_m, y_m) in enumerate(zip(scanXInMeters, scanYInMeters)): + point = ScanPoint(idx, x_m, y_m) + scanPointList.append(point) + + costs: Sequence[float] = list() # not included in file + + return Product( + metadata=metadata, + scan=Scan(scanPointList), + probe=Probe(array=probeArray, pixelGeometry=None), + object_=Object(array=objectArray, pixelGeometry=None, center=None), + costs=costs, + ) + + +def registerPlugins(registry: PluginRegistry) -> None: + SIMPLE_NAME: Final[str] = 'SLAC' + DISPLAY_NAME: Final[str] = 'SLAC NumPy Zipped Archive (*.npz)' + + registry.diffractionFileReaders.registerPlugin( + SLACDiffractionFileReader(), + simpleName=SIMPLE_NAME, + displayName=DISPLAY_NAME, + ) + registry.productFileReaders.registerPlugin( + SLACProductFileReader(), + simpleName=SIMPLE_NAME, + displayName=DISPLAY_NAME, + ) diff --git a/src/ptychodus/plugins/workflow.py b/src/ptychodus/plugins/workflow.py index 6e20e0d8..214d0366 100644 --- a/src/ptychodus/plugins/workflow.py +++ b/src/ptychodus/plugins/workflow.py @@ -148,7 +148,7 @@ def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: inputProductAPI.buildProbe() inputProductAPI.buildObject() # TODO would prefer to write instructions and submit to queue - outputProductAPI = inputProductAPI.reconstructLocal(f'{productName}_out') + outputProductAPI = inputProductAPI.reconstructLocal() outputProductAPI.saveProduct( experimentDir / 'ptychodus' / f'{productName}.h5', fileType='HDF5' ) diff --git a/src/ptychodus/ptychodus_bdp.py b/src/ptychodus/ptychodus_bdp.py index b08eb0d5..1896d53b 100755 --- a/src/ptychodus/ptychodus_bdp.py +++ b/src/ptychodus/ptychodus_bdp.py @@ -125,9 +125,9 @@ def main() -> int: type=float, ) parser.add_argument( - '--probe_photon_flux_Hz', - metavar='FLUX', - help='probe number of photons per second', + '--probe_photon_count', + metavar='NUMBER', + help='probe number of photons', type=float, ) parser.add_argument( @@ -175,7 +175,7 @@ def main() -> int: replacementPathPrefix=args.remote_path_prefix, ) elif bool(args.local_path_prefix) ^ bool(args.remote_path_prefix): - parser.error('--local_path_prefix and --remote_path_prefix' 'must be given together.') + parser.error('--local_path_prefix and --remote_path_prefix must be given together.') if args.crop_center_x_px is not None and args.crop_center_y_px is not None: cropCenter = CropCenter( @@ -211,7 +211,7 @@ def main() -> int: comments=args.comment, detectorDistanceInMeters=args.detector_distance_m, probeEnergyInElectronVolts=args.probe_energy_eV, - probePhotonsPerSecond=args.probe_photon_flux_Hz, + probePhotonCount=args.probe_photon_count, exposureTimeInSeconds=args.exposure_time_s, ) workflowProductAPI.openScan(Path(args.scan_file_path.name)) diff --git a/src/ptychodus/view/core.py b/src/ptychodus/view/core.py index 3a57fb33..5f628346 100644 --- a/src/ptychodus/view/core.py +++ b/src/ptychodus/view/core.py @@ -21,7 +21,7 @@ from .image import ImageView from .patterns import PatternsView from .product import ProductView -from .reconstructor import ReconstructorParametersView, ReconstructorPlotView +from .reconstructor import ReconstructorView, ReconstructorPlotView from .repository import RepositoryTableView, RepositoryTreeView from .scan import ScanPlotView from .settings import SettingsView @@ -58,7 +58,7 @@ def __init__(self, parent: QWidget | None) -> None: self.productView = ProductView() self.productDiagramView = QWidget() - self.scanAction = self.navigationToolBar.addAction(QIcon(':/icons/scan'), 'Scan') + self.scanAction = self.navigationToolBar.addAction(QIcon(':/icons/scan'), 'Positions') self.scanView = RepositoryTableView() self.scanPlotView = ScanPlotView.createInstance() @@ -73,8 +73,8 @@ def __init__(self, parent: QWidget | None) -> None: self.reconstructorAction = self.navigationToolBar.addAction( QIcon(':/icons/reconstructor'), 'Reconstructor' ) - self.reconstructorParametersView = ReconstructorParametersView.createInstance() - self.reconstructorPlotView = ReconstructorPlotView.createInstance() + self.reconstructorView = ReconstructorView() + self.reconstructorPlotView = ReconstructorPlotView() self.workflowAction = self.navigationToolBar.addAction( QIcon(':/icons/workflow'), 'Workflow' @@ -115,7 +115,7 @@ def createInstance( view.parametersWidget.addWidget(view.scanView) view.parametersWidget.addWidget(view.probeView) view.parametersWidget.addWidget(view.objectView) - view.parametersWidget.addWidget(view.reconstructorParametersView) + view.parametersWidget.addWidget(view.reconstructorView) view.parametersWidget.addWidget(view.workflowParametersView) view.parametersWidget.addWidget(view.automationView) view.parametersWidget.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Minimum) diff --git a/src/ptychodus/view/reconstructor.py b/src/ptychodus/view/reconstructor.py index 1f482ab0..2d5b02bc 100644 --- a/src/ptychodus/view/reconstructor.py +++ b/src/ptychodus/view/reconstructor.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from PyQt5.QtWidgets import ( QComboBox, QDialog, @@ -22,106 +20,81 @@ from matplotlib.figure import Figure -class ReconstructorView(QGroupBox): - def __init__(self, parent: QWidget | None) -> None: +class ReconstructorParametersView(QGroupBox): + def __init__(self, parent: QWidget | None = None) -> None: super().__init__('Parameters', parent) self.algorithmComboBox = QComboBox() self.productComboBox = QComboBox() - self.modelButton = QPushButton('Model') - self.modelMenu = QMenu() - self.trainerButton = QPushButton('Trainer') - self.trainerMenu = QMenu() - self.reconstructorButton = QPushButton('Reconstructor') - self.reconstructorMenu = QMenu() - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ReconstructorView: - view = cls(parent) + self.reconstructorMenu = QMenu() + self.reconstructorButton = QPushButton('Reconstructor') + self.reconstructorButton.setMenu(self.reconstructorMenu) - view.modelButton.setMenu(view.modelMenu) - view.trainerButton.setMenu(view.trainerMenu) - view.reconstructorButton.setMenu(view.reconstructorMenu) + self.trainerMenu = QMenu() + self.trainerButton = QPushButton('Trainer') + self.trainerButton.setMenu(self.trainerMenu) actionLayout = QHBoxLayout() actionLayout.setContentsMargins(0, 0, 0, 0) - actionLayout.addWidget(view.modelButton) - actionLayout.addWidget(view.trainerButton) - actionLayout.addWidget(view.reconstructorButton) + actionLayout.addWidget(self.reconstructorButton) + actionLayout.addWidget(self.trainerButton) layout = QFormLayout() - layout.addRow('Algorithm:', view.algorithmComboBox) - layout.addRow('Product:', view.productComboBox) + layout.addRow('Algorithm:', self.algorithmComboBox) + layout.addRow('Product:', self.productComboBox) layout.addRow('Action:', actionLayout) - view.setLayout(layout) - - return view + self.setLayout(layout) class ReconstructorProgressDialog(QDialog): - def __init__(self, parent: QWidget | None) -> None: + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) self.textEdit = QPlainTextEdit() self.progressBar = QProgressBar() self.buttonBox = QDialogButtonBox() - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ReconstructorProgressDialog: - dialog = cls(parent) - dialog.setWindowTitle('Reconstruction Progress') - dialog.buttonBox.addButton(QDialogButtonBox.Ok) - dialog.buttonBox.accepted.connect(dialog.accept) - dialog.buttonBox.addButton(QDialogButtonBox.Cancel) - dialog.buttonBox.rejected.connect(dialog.reject) + self.setWindowTitle('Reconstruction Progress') + self.buttonBox.addButton(QDialogButtonBox.Ok) + self.buttonBox.accepted.connect(self.accept) + self.buttonBox.addButton(QDialogButtonBox.Cancel) + self.buttonBox.rejected.connect(self.reject) layout = QVBoxLayout() - layout.addWidget(dialog.textEdit) - layout.addWidget(dialog.progressBar) - layout.addWidget(dialog.buttonBox) - dialog.setLayout(layout) - - return dialog + layout.addWidget(self.textEdit) + layout.addWidget(self.progressBar) + layout.addWidget(self.buttonBox) + self.setLayout(layout) -class ReconstructorParametersView(QWidget): - def __init__(self, parent: QWidget | None) -> None: +class ReconstructorView(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.reconstructorView = ReconstructorView.createInstance() - self.stackedWidget = QStackedWidget() - self.scrollArea = QScrollArea() - self.progressDialog = ReconstructorProgressDialog.createInstance() # TODO use this - - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ReconstructorParametersView: - view = cls(parent) + self.parametersView = ReconstructorParametersView() - view.scrollArea.setWidgetResizable(True) - view.scrollArea.setWidget(view.stackedWidget) + self.stackedWidget = QStackedWidget() + self.stackedWidget.layout().setContentsMargins(0, 0, 0, 0) - view.stackedWidget.layout().setContentsMargins(0, 0, 0, 0) + self.scrollArea = QScrollArea() + self.scrollArea.setWidgetResizable(True) + self.scrollArea.setWidget(self.stackedWidget) layout = QVBoxLayout() - layout.addWidget(view.reconstructorView) - layout.addWidget(view.scrollArea) - view.setLayout(layout) + layout.addWidget(self.parametersView) + layout.addWidget(self.scrollArea) + self.setLayout(layout) - return view + self.progressDialog = ReconstructorProgressDialog() class ReconstructorPlotView(QWidget): - def __init__(self, parent: QWidget | None) -> None: + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) self.figure = Figure() self.figureCanvas = FigureCanvasQTAgg(self.figure) self.navigationToolbar = NavigationToolbar(self.figureCanvas, self) self.axes = self.figure.add_subplot(111) - @classmethod - def createInstance(cls, parent: QWidget | None = None) -> ReconstructorPlotView: - view = cls(parent) - layout = QVBoxLayout() - layout.addWidget(view.navigationToolbar) - layout.addWidget(view.figureCanvas) - view.setLayout(layout) - - return view + layout.addWidget(self.navigationToolbar) + layout.addWidget(self.figureCanvas) + self.setLayout(layout) diff --git a/tests/test_phase_unwrap.py b/tests/test_phase_unwrap.py new file mode 100644 index 00000000..3661b149 --- /dev/null +++ b/tests/test_phase_unwrap.py @@ -0,0 +1,25 @@ +import os + +import numpy as np +import matplotlib.pyplot as plt + +import ptychodus.model.analysis.phaseUnwrapper as pu + + +def test_phase_unwrap() -> None: + phase_unwrapper = pu.PhaseUnwrapper( + image_grad_method='fourier_differentiation', + image_integration_method='fourier', + ) + img = np.load(os.path.join("data", "phase_unwrap", "recon_20241220_epoch_400.npy")) + img = img[0] + + phase = phase_unwrapper.unwrap(img) + + plt.figure() + plt.imshow(phase) + plt.show() + + +if __name__ == "__main__": + test_phase_unwrap() diff --git a/tests/test_propagation.py b/tests/test_propagation.py index be0d2913..4fe15b31 100644 --- a/tests/test_propagation.py +++ b/tests/test_propagation.py @@ -1,5 +1,5 @@ -if __name__ == "__main__": +if __name__ == '__main__': import matplotlib - matplotlib.use("Agg") + matplotlib.use('Agg') import matplotlib.pyplot as plt diff --git a/tests/test_zernike.py b/tests/test_zernike.py index f93fcf36..b7e01ea6 100644 --- a/tests/test_zernike.py +++ b/tests/test_zernike.py @@ -2,11 +2,11 @@ def test_indexing() -> None: idx = 0 for n in range(10): - print("") + print('') for m in range(-n, n + 1, 2): idx_calc = (n * (n + 2) + m) // 2 - print(f"{n=} {m=:+d} {idx=} {idx_calc=}") + print(f'{n=} {m=:+d} {idx=} {idx_calc=}') assert idx == idx_calc idx += 1 @@ -15,7 +15,7 @@ def test_pyramid() -> None: import numpy import matplotlib - matplotlib.use("Agg") + matplotlib.use('Agg') import matplotlib.colors import matplotlib.pyplot as plt @@ -47,10 +47,10 @@ def test_pyramid() -> None: col = max_radial_degree + angular_frequency ax = fig.add_subplot(gs[row : row + 1, col : col + 2]) - ax.pcolormesh(X, Y, Z, norm=matplotlib.colors.CenteredNorm(), cmap="seismic") - ax.set_aspect("equal") + ax.pcolormesh(X, Y, Z, norm=matplotlib.colors.CenteredNorm(), cmap='seismic') + ax.set_aspect('equal') ax.set_title(str(polynomial)) - ax.axis("off") + ax.axis('off') - plt.savefig("zernike_pyramid.png", bbox_inches="tight", dpi=my_dpi) + plt.savefig('zernike_pyramid.png', bbox_inches='tight', dpi=my_dpi) plt.close(fig)