Source code for nucleon_elastic_ff.data.scripts.fft

"""Script for time fourier transforming correlator data
"""
from typing import Optional
from typing import List

import os

import h5py
import numpy as np

from nucleon_elastic_ff.utilities import set_up_logger
from nucleon_elastic_ff.utilities import find_all_files
from nucleon_elastic_ff.utilities import has_match

from nucleon_elastic_ff.data.h5io import get_dsets
from nucleon_elastic_ff.data.h5io import create_dset
from nucleon_elastic_ff.data.h5io import get_dset_chunks

from nucleon_elastic_ff.data.arraymanip import get_fft

LOGGER = set_up_logger("nucleon_elastic_ff")


[docs]def fft( # pylint: disable = R0913 root: str, name_input: str, name_output: str, max_momentum: Optional[int] = None, chunk_size: Optional[int] = None, overwrite: bool = False, cuda: bool = True, ): """Recursively scans dir for files, ffts and shifts and chops user specified momenta. Routine FFTs 4D correlation functions. If `max_momentum` is given this routine cuts the output array in a all momentum directions. .. note:: The user specifies ``max_momentum = 5``, which means, in each direction, ``x, y, z``, the momentum should is of from ``[0,1,2,3,4,5,-5,-4,-3,-2,-1]``, just like a regular FFT space, except the higher valued modes are chopped out. The input files must be h5 files (ending with ".h5") and must have `name_input` in their file name. Files which have `name_output` as name are excluded. Also, this routine ignores exporting to files which already exist. Once all files are fixed, this routine calls `slice_file` on each file. This routines transforms ``local_current`` dsets to ``momentum_current`` dsets. .. Note:: This routine explicitly assumes that the datasets to transform are of shape ``shape1 + [Nz, Ny, Nx] + [2]`` where shape1 can be anything the second shape is the to transformed shape and the last shape corresponds to real / complex. **Arguments** root: str The directory to look for files. name_input: str Files must match this pattern to be submitted for slicing. name_output: str Files must not match this pattern to be submitted for slicing. Also the sliced output files will have the input name replaced by the output name. This also includes directory names. max_momentum: int The momentum at which the FT is cutoff in each spatial dimension. chunk_size: Optional[int] = None Reads in arrays in chunks and applys fft chunkwise. This reduce the memory load. For now, only slices the zeroth-dimension. overwrite: bool = False Overwrite existing sliced files. cuda: bool = True Use `cupy` to run fft if possible. """ LOGGER.info("Starting FFT of files") LOGGER.info("Looking into `%s`", root) LOGGER.info( "Using naming convention `%s` -> `%s` (for sliced data) ", name_input, name_output, ) all_files = find_all_files( root, file_patterns=[name_input + r".*\.h5$"], exclude_file_patterns=[name_output], ) if not overwrite: all_files = [ file for file in all_files if not os.path.exists(file.replace(name_input, name_output)) ] LOGGER.info( "Found %d files which match the pattern%s", len(all_files), " " if overwrite else " (and do not exist)", ) for n_file, file_address in enumerate(all_files): LOGGER.info("--- File %d of %d ---", n_file + 1, len(all_files)) file_address_out = file_address.replace(name_input, name_output) if not os.path.exists(os.path.dirname(file_address_out)): os.makedirs(os.path.dirname(file_address_out)) fft_file( file_address, file_address_out, max_momentum=max_momentum, chunk_size=chunk_size, overwrite=overwrite, cuda=cuda, ) LOGGER.info("Done")
[docs]def fft_file( # pylint: disable = R0914, R0913, R0912 file_address_in: str, file_address_out: str, max_momentum: Optional[int] = None, dset_patterns: List[str] = ( "local_current", "x[0-9]+_y[0-9]+_z[0-9]+_t[0-9]+", "4D_correlator", ), chunk_size: Optional[int] = None, overwrite: bool = False, cuda: bool = True, ): """Reads input file and writes ffts and cuts data to output file. This methods scans all datasets within the file. If a data set has "local_current" in its name it is ffted in its spatial components. The slicing info is inferred by the argument `max_momentum`. This routines transforms `local_current` dsets to `momentum_current` dsets. Also the slicing meta info is stored in the resulting output file in the `meta` attribute of `momentum_current`. .. note:: The user specifies ``max_momentum = 5``, which means, in each direction, ``x, y, z``, the momentum should is of from ``[0,1,2,3,4,5,-5,-4,-3,-2,-1]``, just like a regular FFT space, except the higher valued modes are chopped out. .. Note:: This routine explicitly assumes that the datasets to transform are of shape ``shape1 + [Nz, Ny, Nx] + [2]`` where shape1 can be anything the second shape is the to transformed shape and the last shape corresponds to real / complex. **Arguments** file_address_in: str Address of the to be scanned and sliced HDF5 file. file_address_out: str Address of the output HDF5 file. max_momentum: Optional[int] = None The momentum at which the FT is cutoff in each spatial dimension. dset_patterns: List[str] = ( "local_current", "x[0-9]+_y[0-9]+_z[0-9]+_t[0-9]+", "4D_correlator", ), List of regex patterns data sets must match to be ffted (needs to match one). chunk_size: Optional[int] = None Reads in arrays in chunks and applys fft chunkwise. This reduce the memory load. For now, only slices the zeroth-dimension. overwrite: bool = False Overwrite existing sliced file. cuda: bool = True Use `cupy` to run fft if possible. **Raises** KeyError: If no dset was transformed. """ LOGGER.info("Sclicing\n\t `%s`\n\t->`%s`", file_address_in, file_address_out) transformed_dstes = 0 with h5py.File(file_address_in, "r") as h5f: dsets = get_dsets(h5f, load_dsets=False) LOGGER.info("Start fft for %d dsets", len(dsets)) with h5py.File(file_address_out) as h5f_out: for name, dset in dsets.items(): if has_match(name, dset_patterns, match_all=False): if "local_current" in name: name = name.replace("local_current", "momentum_current") LOGGER.debug("Start fft procedure for dset `%s`", name) shape = dset.shape if shape[-1] != 2: raise ValueError( f"Expected last shape entry of dset `{name}` to be 2 but" f" received {shape[1]}" ) if len(shape) < 4: raise ValueError( f"Expected dset `{name}` to have at least 4 dimensions but" f" only found {len(shape)}" ) if not shape[-2] == shape[-3] == shape[-4]: raise ValueError( f"Expected dset `{name}` to have same dimensions in x, y, z" f" but found {shape}" ) n1d = shape[-2] LOGGER.debug("\tAdding imag part to real part (removing last dim)") if chunk_size is None: arr = dset[()] arr = (arr.T[0] + arr.T[1] * 1j).T LOGGER.debug("\tStart fft") out = get_fft(arr, cuda=cuda, axes=(-1, -2, -3)) else: out = [] for n_chunk, chunk in enumerate( get_dset_chunks(dset, chunk_size) ): chunk = (chunk.T[0] + chunk.T[1] * 1j).T LOGGER.debug("\tStart fft of %d. chunk", n_chunk) out.append(get_fft(chunk, cuda=cuda, axes=(-1, -2, -3))) out = np.concatenate(out, axis=0) if max_momentum is not None: meta = dset.attrs.get("meta", None) meta = str(meta) + "&" if meta else "" meta += f"max_momentum=={max_momentum}&n1d_prev=={n1d}" LOGGER.debug("\tSlicing fft") slice_index = list(range(max_momentum + 1)) slice_index += [ el % n1d for el in range(-max_momentum, 0) # pylint: disable=E1130 ] for axis, key in enumerate(["x", "y", "z"]): axis = -1 * (axis + 1) LOGGER.debug( "\t\t Axis %d: %s -> %s[%s]", axis, key, key, slice_index ) out = np.take(out, slice_index, axis=axis) transformed_dstes += 1 else: meta = None out = dset[()] create_dset(h5f_out, name, out, overwrite=overwrite) if meta: h5f_out[name].attrs["meta"] = meta if transformed_dstes == 0: raise KeyError( "Could not identify any dsets to parse." "Must match one out of `%s`" % dset_patterns )
[docs]def main(): """Runs argparse for ``fft_file``. """ import argparse parser = argparse.ArgumentParser(description="Interface for `fft_file`") parser.add_argument("input", type=str, help="Name of the input hdf5 file.") parser.add_argument( "output", type=str, help="Name of the output hdf5 file." " FFT is placed in the same dataset as in the input file." " Currently only looks for `local_current` datasets", ) parser.add_argument( "--max-momentum", "-m", type=int, default=5, help="Name of the output hdf5 file." " FFT is placed in the same dataset as in the input file. [default=%(default)s]", ) parser.add_argument( "--chunk-size", "-s", type=int, default=None, help="Number of first data set dimension array entries to read in at a time." " Reduces memory loads and defaults to whole data set. [default=%(default)s]", ) parser.add_argument( "--overwrite", "-f", action="store_true", default=False, help="Overwrite hdf5 files if they already exist. [default=%(default)s]", ) args = parser.parse_args() fft_file( args.input, args.output, max_momentum=args.max_momentum, chunk_size=args.chunk_size, overwrite=args.overwrite, )
if __name__ == "__main__": main()