Source code for nxtomomill.converter.hdf5.acquisition.baseacquisition

# coding: utf-8

"""
Base class for tomography acquisition (defined by Bliss)
"""

from __future__ import annotations

import os
import pint

import h5py

from nxtomo.application.nxtomo import NXtomo
from nxtomomill.utils.h5pyutils import from_data_url_to_virtual_source
from nxtomomill.utils.hdf5 import EntryReader, get_dataset_unit
from nxtomomill.utils.utils import embed_url, get_file_name  # noqa F401

try:
    import hdf5plugin  # noqa F401
except ImportError:
    pass
import logging
from collections import OrderedDict

import numpy
from silx.io.url import DataUrl
from silx.io.utils import h5py_read_dataset

from nxtomomill.io.config import TomoHDF5Config

_ureg = pint.get_application_registry()
_logger = logging.getLogger(__name__)

__all__ = ["BaseAcquisition", "get_dataset_name_from_motor"]


def _ask_for_file_removal(file_path):
    res = input(f"Overwrite {file_path} ? (y/N)").lower()
    return res == "y"


[docs]class BaseAcquisition: """ Util class to group several hdf5 group together and to write the data Nexus / NXtomo compliant """ _ENERGY_PATH = "technique/scan/energy" _DATASET_NAME_PATH = ("technique/scan/name",) _GRP_SIZE_PATH = ("technique/scan/nb_scans",) _SAMPLE_NAME_PATH = ("sample/name",) TITLE_PATHS = ( "technique/scan/sequence", "title", ) _INSTRUMENT_NAME_PATH = ( "technique/saving/beamline", "instrument/title", ) _FOV_PATH = "technique/scan/field_of_view" _NB_LOOP_PATH = ("technique/scan/nb_loop", "technique/proj/nb_loop") _NB_TOMO_PATH = ("technique/scan/nb_tomo", "technique/proj/nb_tomo") _NB_TURNS_PATH = ( "technique/proj/nb_turns", "technique/scan/nb_turns", ) _TOMO_N_PATH = ( "technique/proj/tomo_n", "technique/scan/tomo_n", "technique/proj/proj_n", "technique/scan/proj_n", ) _START_TIME_PATH = ("start_time",) _END_TIME_PATH = ("end_time",) _TECHNIQUE_MOTOR_PATHS = ("technique/scan/motor", "technique/proj/motor") _DETECTOR_ROI = ("technique/detector/roi",) _SOURCE_NAME = ("instrument/machine/name",) _SOURCE_TYPE = ("instrument/machine/type",) _FRAME_FLIP_PATHS = ( "technique/detector/{detector_name}/flipping", "technique/detector/flipping", ) _PROPAGATION_DISTANCE_PATHS = ( "technique/scan/effective_propagation_distance", "technique/scan/propagation_distance", )
[docs] def __init__( self, root_url: DataUrl | None, configuration: TomoHDF5Config, detector_sel_callback, start_index: int, ): self._root_url = root_url self._detector_sel_callback = detector_sel_callback self._registered_entries = OrderedDict() self._copy_frames = OrderedDict() # key is the entry, value his type self._entries_o_path = dict() # key is the entry, value his the entry path from the original file. # this value is different from `.name`. Don't know if this is a bug ? """user can have defined already some parameter values as energy. The idea is to avoid asking him if """ self._configuration = configuration self._start_index = start_index # set of properties when an acquisition needs to copy dark / flat (used in the case of z-series version 3) self._dark_at_start = None self._dark_at_end = None self._flat_at_start = None self._flat_at_end = None
@property def configuration(self): return self._configuration @property def start_index(self) -> int: return self._start_index
[docs] def write_as_nxtomo( self, shift_entry: int, input_file_path: str, request_input: bool, divide_into_sub_files: bool, input_callback=None, ) -> tuple: """ This function will dump the acquisition to disk as an NXtomo :param shift_entry: index of the entry to start saving new nxtomos. :param input_file_path: output file path :param request_input: if True the conversion can ask user some missing metadata :param divide_into_sub_files: if True then create one file per NXtomo :param input_callback: function to call for users to provide missing metadata """ nx_tomos = self.to_NXtomos( request_input=request_input, input_callback=input_callback, check_tomo_n=True, ) # preprocessing to define output file name possible_extensions = (".hdf5", ".h5", ".nx", ".nexus") output_file_basename = os.path.basename(self.configuration.output_file) file_extension_ = None for possible_extension in possible_extensions: if output_file_basename.endswith(possible_extension): output_file_basename.rstrip(possible_extension) file_extension_ = possible_extension def get_file_name_and_entry(index, divide_sub_files): entry = "entry" + str(index).zfill(4) if self.configuration.single_file or not divide_sub_files: en_output_file = self.configuration.output_file else: ext = file_extension_ or self.configuration.file_extension file_name = ( os.path.splitext(output_file_basename)[0] + "_" + str(index).zfill(4) + ext ) en_output_file = os.path.join( os.path.dirname(self.configuration.output_file), file_name ) if os.path.exists(en_output_file): if self.configuration.overwrite is True: _logger.warning(en_output_file + " will be removed") _logger.info("remove " + en_output_file) os.remove(en_output_file) elif _ask_for_file_removal(en_output_file) is False: raise OSError(f"unable to overwrite {en_output_file}, exit") else: os.remove(en_output_file) return en_output_file, entry result = [] for i_nx_tomo, nx_tomo in enumerate(nx_tomos): output_file, data_path = get_file_name_and_entry( (shift_entry + i_nx_tomo), divide_sub_files=divide_into_sub_files ) output_file = os.path.abspath(os.path.relpath(output_file, os.getcwd())) output_file = os.path.abspath(output_file) # For multi-tomo for example the data path is modified in order to handle with the splitting # that does not fit at 100 % with the current API. for now there is # no convenient way to handle this in a better way assert isinstance(nx_tomo, NXtomo) # embed data if requested if len(nx_tomo.instrument.detector.data) == 0: _logger.warning( f"No frame found for NXtomo number {i_nx_tomo}. Won't write any output for it" ) continue vs_0 = nx_tomo.instrument.detector.data[0] if nx_tomo.detector_data_is_defined_by_url(): new_urls = [] for url in nx_tomo.instrument.detector.data: if self._copy_frames[url.path()]: created_url = embed_url( url=url, output_file=output_file, ) new_urls.append(created_url) else: new_urls.append(url) nx_tomo.instrument.detector.data = new_urls elif nx_tomo.detector_data_is_defined_by_virtual_source(): new_vs = [] for vs in nx_tomo.instrument.detector.data: assert isinstance(vs, h5py.VirtualSource) assert isinstance(vs_0, h5py.VirtualSource) url = DataUrl(file_path=vs.path, data_path=vs.name, scheme="silx") if ( url.path() in self._copy_frames and self._copy_frames[url.path()] ): new_url = embed_url(url, output_file=output_file) n_vs, _, _ = from_data_url_to_virtual_source(new_url) new_vs.append(n_vs) else: new_vs.append(vs) nx_tomo.instrument.detector.data = new_vs # provide some extra data like origin of the input_file if input_file_path is not None: nx_tomo.bliss_original_files = (os.path.abspath(input_file_path),) # save data nx_tomo.save(file_path=output_file, data_path=data_path) # check rotation angle if nx_tomo.sample.rotation_angle is not None: unique_angles = numpy.unique( nx_tomo.sample.rotation_angle.to(_ureg.degree).magnitude ) if len(unique_angles) == 1: _logger.warning( f"NXtomo {data_path}@{output_file} seems to have a single value ({unique_angles}) for rotation angle. Seems it fails to find correct path to the rotation angle dataset" ) result.append((output_file, data_path)) return tuple(result)
def to_NXtomos(self, request_input, input_callback, check_tomo_n=True) -> tuple: raise NotImplementedError("Base class") @property def raise_error_if_issue(self): """ Should we raise an error if we encounter or an issue or should we just log an error message """ return self.configuration.raises_error @property def root_url(self): return self._root_url
[docs] def get_expected_nx_tomo(self): """ Return the expected number of nxtomo created for this acquisition. This is required to get consistent entry and file name. At lest for automation """ raise NotImplementedError("Base class")
def read_entry(self): return EntryReader(self._root_url)
[docs] def is_different_sequence(self, entry): """ Can we have several entries 1.1, 1.2, 1.3... to consider. """ raise ValueError("Base class")
def is_part_of_same_series(self, other: BaseAcquisition) -> bool: if not isinstance(other, BaseAcquisition): raise TypeError("Can only compar two z-series") if self.root_url is None or other.root_url is None: return False def read_sample_node(url: DataUrl) -> str | None: with EntryReader(url=url) as entry: sample_node = self._get_sample_node(entry) if sample_node is None: return None sample_name = sample_node.get("name", None) if sample_name is not None: return h5py_read_dataset(sample_name) return None self_sample_name = read_sample_node(self.root_url) other_sample_name = read_sample_node(other.root_url) if self_sample_name is None or other_sample_name is None: return False def filter_scan_index(sample_name: str) -> str: # for z_series the sample name is post pone with _{sample_name} that we need to filter parts = sample_name.split("_") if len(parts) > 1 and numpy.char.isnumeric(parts[-1]): return "_".join(parts[:-1]) return sample_name return filter_scan_index(self_sample_name) == filter_scan_index( other_sample_name ) @staticmethod def _get_node_values_for_frame_array( node: h5py.Group, n_frame: int | None, keys: list | tuple, info_retrieve: str, expected_unit: pint.Unit, ) -> tuple[numpy.ndarray | None, str]: assert isinstance(expected_unit, pint.Unit), "Invalid expected unit" def get_values(): # this is a two step process: first step we parse all the # the keys until we found one with the expected length # if first iteration fails then we return the first existing key for respect_length in (True, False): for possible_key in keys: if possible_key in node and isinstance( node[possible_key], h5py.Dataset ): values = h5py_read_dataset(node[possible_key]) unit = get_dataset_unit( dataset=node[possible_key], default=expected_unit, from_dataset=f"{node[possible_key].name} (reading {info_retrieve})", ) # skip values containing '*DIS*' if isinstance(values, str) and values == "*DIS*": continue if n_frame is not None and respect_length is True: if numpy.isscalar(values): length = 1 else: length = len(values) if length in (n_frame, n_frame + 1): return values, unit else: return values, unit return None, None values, unit = get_values() if values is None: raise ValueError( f"Unable to retrieve {info_retrieve} for {node.name}. Was looking for {keys} datasets" ) elif n_frame is None: return values, unit elif numpy.isscalar(values): return numpy.array([values] * n_frame), unit elif len(values) == n_frame: return values.tolist(), unit elif len(values) == (n_frame + 1): # for now we can have one extra position for rotation, # translation_x... # because saved after the last projection. It is recording the # motor position. For example in this case: 1 is the motor movement # (saved) and 2 is the acquisition # # 1 2 1 2 1 # ----- ----- # ----- ----- ----- # return values[:-1].tolist(), unit elif len(values) > n_frame: _logger.warning( f"Incoherent number of values found for {info_retrieve}. Can come from an acquisition canceled. Else please investigate." ) # in this case only get the values which have a frame return values[0:n_frame], unit elif len(values) < n_frame: _logger.warning( f"Incoherent number of values found for {info_retrieve}. Can come from an acquisition canceled. Else please investigate." ) # in this case append 0 to existing values. Maybe -1 would be better ? return list(values) + [0] * (n_frame - values), unit elif len(values) == 1: return numpy.array([values[0]] * n_frame), unit else: raise ValueError("incoherent number of angle position vs number of frame")
[docs] def register_step(self, url: DataUrl, entry_type, copy_frames) -> None: """ Add a bliss entry to the acquisition :param url: :param entry_type: """ raise NotImplementedError("Base class")
@staticmethod def _get_instrument_node(entry_node: h5py.Group) -> h5py.Group: if not isinstance(entry_node, h5py.Group): raise TypeError("entry_node: h5py.group expected") return entry_node["instrument"] @staticmethod def _get_positioners_node(entry_node: h5py.Group) -> h5py.Group | None: if not isinstance(entry_node, h5py.Group): raise TypeError("entry_node is expected to be a h5py.Group") parent_node = BaseAcquisition._get_instrument_node(entry_node) if "positioners" in parent_node: return parent_node["positioners"] else: return None @staticmethod def _get_measurement_node(entry_node: h5py.Group) -> h5py.Group | None: if not isinstance(entry_node, h5py.Group): raise TypeError("entry_node is expected to be a h5py.Group") if "measurement" in entry_node: return entry_node["measurement"] else: return None @staticmethod def _get_machine_node(entry_node: h5py.Group) -> h5py.Group | None: if not isinstance(entry_node, h5py.Group): raise TypeError("entry_node is expected to be a h5py.Group") if "instrument/machine" in entry_node: return entry_node["instrument/machine"] else: return None @staticmethod def _get_sample_node(entry_node: h5py.Group) -> h5py.Group | None: if not isinstance(entry_node, h5py.Group): raise TypeError("entry_node is expected to be a h5py.Group") if "sample" in entry_node: return entry_node["sample"] else: return None @staticmethod def _get_scan_flags(entry_node: h5py.Group) -> h5py.Group | None: if not isinstance(entry_node, h5py.Group): raise TypeError("entry_node is expected to be a h5py.Group") if "technique/scan_flags" in entry_node: return entry_node["technique/scan_flags"] else: return None def _read_rotation_motor_name(self) -> str | None: """read rotation motor from root_url/technique/scan/motor :return: name of the motor used for rotation. None if cannot find """ if self._root_url is None: _logger.warning("no root url. Unable to read rotation motor") return None else: with EntryReader(self._root_url) as entry: for motor_path in self._TECHNIQUE_MOTOR_PATHS: if motor_path in entry: try: rotation_motor = get_dataset_name_from_motor( motors=h5py_read_dataset( numpy.asarray(entry[motor_path]) ), motor_name="rotation", ) except Exception as e: _logger.error(e) else: return rotation_motor else: _logger.warning( f"{motor_path} unable to find rotation motor from {self._root_url}" ) return None def _get_machine_current(self, root_node) -> None | pint.Quantity: """retrieve electric current provide a time stamp for each of them""" if root_node is None: _logger.warning("no root url. Unable to read electric current") return None else: grps = [ root_node, ] measurement_node = self._get_measurement_node(root_node) if measurement_node is not None: grps.append(measurement_node) machine_node = self._get_machine_node(root_node) if machine_node is not None: grps.append(machine_node) for grp in grps: try: mach_current, unit = self._get_node_values_for_frame_array( node=grp, keys=self.configuration.machine_current_keys, info_retrieve="machine current", expected_unit=_ureg.mA, n_frame=None, ) except (ValueError, KeyError): pass else: # handle case where mach_current is a scalar. Cast it to list before return if numpy.isscalar(mach_current): mach_current = [ mach_current, ] return mach_current * unit else: _logger.warning( f"Unable to retrieve machine current for {root_node.name}" ) return None def _get_rotation_angle(self, root_node: h5py.Group, n_frame: int) -> pint.Quantity: """ return the list of rotation angle for each frame. If unable to find it will either: * raise an error (if 'raise_error_if_issue' is True) * return a zeros for all required angles (kwnow from n_frame) """ if not isinstance(root_node, h5py.Group): raise TypeError("root_node is expected to be a h5py.Group") for grp in ( self._get_positioners_node(root_node), root_node, self._get_measurement_node(root_node), ): try: angles, unit = self._get_node_values_for_frame_array( node=grp, n_frame=n_frame, keys=self.configuration.rotation_angle_keys, info_retrieve="rotation angle", expected_unit=_ureg.degree, ) except (ValueError, KeyError): pass else: return angles * unit mess = f"Unable to find rotation angle for {root_node.name}" if self.raise_error_if_issue: raise ValueError(mess) else: mess += "default value will be set. (0)" _logger.warning(mess) return ([0] * n_frame) * _ureg.degree def _get_sample_x(self, root_node, n_frame) -> pint.Quantity: """return the list of translation for each frame""" for grp in self._get_positioners_node(root_node), root_node: try: x_tr, unit = self._get_node_values_for_frame_array( node=grp, n_frame=n_frame, keys=self.configuration.sample_x_keys, info_retrieve="sample x translation", expected_unit=_ureg.mm, ) x_tr = numpy.asarray(x_tr) except (ValueError, KeyError): pass else: return x_tr * unit mess = f"Unable to find sample x for {root_node.name}" if self.raise_error_if_issue: raise ValueError(mess) else: mess += "default value will be set. (0)" _logger.warning(mess) return ([0] * n_frame) * _ureg.meter def _get_sample_y(self, root_node, n_frame) -> pint.Quantity: """return the list of translation for each frame""" for grp in self._get_positioners_node(root_node), root_node: try: y_tr, unit = self._get_node_values_for_frame_array( node=grp, n_frame=n_frame, keys=self.configuration.sample_y_keys, info_retrieve="sample y translation", expected_unit=_ureg.mm, ) y_tr = numpy.asarray(y_tr) except (ValueError, KeyError): pass else: return y_tr * unit mess = f"Unable to find sample y for {root_node.name}" if self.raise_error_if_issue: raise ValueError(mess) else: mess += "default value will be set. (0)" _logger.warning(mess) return ([0] * n_frame) * _ureg.meter def _get_translation_y(self, root_node, n_frame) -> pint.quantity: """return the list of translation for each frame""" for grp in self._get_positioners_node(root_node), root_node: try: y_tr, unit = self._get_node_values_for_frame_array( node=grp, n_frame=n_frame, keys=self.configuration.translation_y_keys, info_retrieve="y base translation", expected_unit=_ureg.mm, ) y_tr = numpy.asarray(y_tr) except (ValueError, KeyError): pass else: return y_tr * unit mess = f"Unable to find translation y for {root_node.name}" if self.raise_error_if_issue: raise ValueError(mess) else: mess += "default value will be set. (0)" _logger.warning(mess) return ([0] * n_frame) * _ureg.meter @staticmethod def get_translation_z_frm( root_node, n_frame: int, configuration: TomoHDF5Config ) -> pint.Quantity: for grp in BaseAcquisition._get_positioners_node(root_node), root_node: try: z_tr, unit = BaseAcquisition._get_node_values_for_frame_array( node=grp, n_frame=n_frame, keys=configuration.translation_z_keys, info_retrieve="z translation", expected_unit=_ureg.mm, ) z_tr = numpy.asarray(z_tr) except (ValueError, KeyError): pass else: return z_tr * unit mess = f"Unable to find translation z on node {root_node.name}" if configuration.raises_error: raise ValueError(mess) else: mess += "default value will be set. (0)" _logger.warning(mess) return ([0] * n_frame) * _ureg.meter def _get_translation_z(self, root_node, n_frame) -> pint.Quantity: """return the list of translation z for each frame""" return self.get_translation_z_frm( root_node=root_node, n_frame=n_frame, configuration=self.configuration, ) def _get_expo_time(self, root_node, n_frame, detector_node) -> pint.Quantity: """return expo time for each frame""" for grp in (detector_node.get("acq_parameters", None), root_node): if grp is None: continue try: expo, unit = self._get_node_values_for_frame_array( node=grp, n_frame=n_frame, keys=self.configuration.exposure_time_keys, info_retrieve="exposure time", expected_unit=_ureg.second, ) except (ValueError, KeyError): pass else: return expo * unit mess = f"Unable to find frame exposure time on entry {self.root_url.path()}" if self.raise_error_if_issue: raise ValueError(mess) else: mess += "default value will be set. (0)" _logger.warning(mess) return 0 * _ureg.second
[docs] def get_axis_scale_types(self): """ Return axis display for the detector data to be used by silx view """ return ["linear", "linear"]
def __str__(self): if self.root_url is None: return "NXTomo" else: return self.root_url.path() def get_detector_roi(self): if self._root_url is None: _logger.warning("no root url. Unable to read detector roi") return None else: with EntryReader(self._root_url) as entry: for roi_path in self._DETECTOR_ROI: if roi_path in entry: try: roi = h5py_read_dataset(numpy.asarray(entry[roi_path])) except Exception as e: _logger.error(e) else: return roi else: _logger.warning( f"{roi_path} unable to find detector roi from {self._root_url}" ) return None
[docs]def get_dataset_name_from_motor(motors, motor_name): motors = numpy.asarray(motors) indexes = numpy.where(motors == motor_name)[0] if len(indexes) == 0: return None elif len(indexes) == 1: index = indexes[0] index_dataset_id = index + 1 if index_dataset_id < len(motors): return motors[index_dataset_id] else: raise ValueError( f"{motor_name} found but unable to find dataset name from {motors}" ) else: raise ValueError(f"More than one instance of {motor_name} as been found.")