"""Module provides h5 file interafaces.
"""
from typing import Dict
from typing import Optional
from typing import Union
from typing import List
from typing import Any
from typing import Iterable
import os
import numpy as np
import h5py
from nucleon_elastic_ff.utilities import set_up_logger
from nucleon_elastic_ff.utilities import has_match
LOGGER = set_up_logger("nucleon_elastic_ff")
[docs]def get_dsets(
container: Union[h5py.File, h5py.Group],
parent_name: Optional[str] = None,
load_dsets: bool = False,
ignore_containers: Optional[List[str]] = None,
) -> Dict[str, Union[h5py.Dataset, np.ndarray]]:
"""Access an HDF5 container and extracts datasets.
The method is iteratively called if the container contains further containers.
**Arguments**
container: Union[h5py.File, h5py.Group]
The HDF5 group or file to iteratetively search.
parent_name: Optional[str] = None
The name of the parent container.
load_dsets: bool = False
If False, data sets are not opened (lazy load).
If True, returns Dict with numpy arrays as values.
ignore_containers: Optional[List[str]] = None
A list of HDF5 containers to ignore when iteratively solving.
Can be regex expressions.
**Returns**
datasets: Dict[str, Union[h5py.Dataset, np.ndarray]]
A dictionary containing the full path HDF path (e.g., `groupA/subgroupB`)
to the data set as keys and the unloaded values of the set as values.
"""
if isinstance(container, h5py.File):
LOGGER.info("Locating all dsets of h5 file `%s`", container.filename)
dsets = {}
ignore_containers = [] if ignore_containers is None else ignore_containers
for key in container:
obj = container[key]
if has_match(key, ignore_containers):
continue
address = os.path.join(parent_name, key) if parent_name else key
if isinstance(obj, h5py.Dataset):
LOGGER.debug("\t`%s`", address)
dsets[address] = obj[()] if load_dsets else obj
elif isinstance(obj, h5py.Group):
dsets.update(get_dsets(obj, parent_name=address, load_dsets=load_dsets))
return dsets
[docs]def create_dset(h5f: h5py.File, key: str, data: Any, overwrite: bool = False):
"""Creates or overwrites (if requested) dataset in HDF5 file.
**Arguments**
h5f: h5py.File
The file to write to.
key: str
The name of the dataset.
data: Any
The data for the dataset.
overwrite: bool = False
Wether data shall be overwritten.
"""
LOGGER.debug("Writing dataset:`%s`", key)
if key in h5f:
if overwrite:
del h5f[key]
h5f.create_dataset(key, data=data)
else:
LOGGER.info("Skipping dataset because exists:`%s`", key)
else:
h5f.create_dataset(key, data=data)
[docs]def assert_h5files_equal( # pylint: disable=R0913
actual: str,
expected: str,
atol: float = 0.0,
rtol: float = 1.0e-7,
group_actual: Optional[str] = None,
group_expected: Optional[str] = None,
):
"""Reads to HDF5 files, compares if they have equal datasets.
Checks if for each entry `|actual - expected| < atol + rtol * |expected|`
(uses `numpy.testing.assert_allclose`).
**Arguments**
actual: str
File name for actual input data.
expected: str
File name for expected input data.
atol: float = 0.0
Absolute error tolarance. See numpy `assert_allcolse`.
rtol: float = 1.0e-7
Relative error tolarance. See numpy `assert_allcolse`.
**Raises**
AssertionError:
If datasets are different (e.g., not present or actual data is different.)
"""
with h5py.File(actual, "r") as h5f_a:
dsets_a = (
get_dsets(h5f_a, load_dsets=False)
if group_actual is None
else {group_actual: h5f_a[group_actual]}
)
with h5py.File(expected, "r") as h5f_e:
dsets_e = (
get_dsets(h5f_e, load_dsets=False)
if group_expected is None
else {group_expected: h5f_e[group_expected]}
)
actual_keys = set(dsets_a.keys())
expected_keys = set(dsets_e.keys())
if actual_keys != expected_keys:
raise AssertionError(
(
"Files have different datasets:"
"\n---Dsets in actual but not in expected---\n\t%s"
"\n---Dsets in expected but not in actual---\n\t%s"
)
% (
"\n\t".join(actual_keys.difference(expected_keys)),
"\n\t".join(expected_keys.difference(actual_keys)),
)
)
for key in actual_keys:
np.testing.assert_allclose(
dsets_a[key],
dsets_e[key],
atol=atol,
rtol=rtol,
err_msg="Dataset `%s` has unequal values." % key,
)
[docs]def get_dset_chunks(dset: h5py.Dataset, chunk_size: int) -> Iterable[np.ndarray]:
"""Returns components of data sliced in chunks determined by the chunk size.
This reduces the memory size when loading the array.
**Argumets**
dset: h5py.Dataset
Input data set to read.
chunk_size: int
Size of the chunks to load in. Slices the first dimension of the input
dataset. Must be smaller or equal to the size of the first data set
dimension.
"""
n_chunks = dset.shape[0] // chunk_size
if n_chunks < 1:
raise ValueError("Received ``chunck_size`` such that ``n_chunks < 1``.")
chunks = [
(n_chunk * chunk_size, (n_chunk + 1) * chunk_size) for n_chunk in range(n_chunks)
]
if chunks[-1][1] < dset.shape[0]:
chunks.append((chunks[-1][1], dset.shape[0]))
LOGGER.debug("Iterating `%s` in chunks `%s`", dset, chunks)
for n_start, n_end in chunks:
yield dset[n_start:n_end]