Source code for emg3d.io

"""

:mod:`io` -- I/O utilities
==========================

Utility functions for writing and reading data.

"""
# Copyright 2018-2020 The emg3d Developers.
#
# This file is part of emg3d.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License.  You may obtain a copy
# of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
# License for the specific language governing permissions and limitations under
# the License.


import os
import shelve
import warnings
import numpy as np
from datetime import datetime

from emg3d import fields, models, utils, meshes

try:
    import h5py
except ImportError:
    h5py = ("\n* ERROR   :: '.h5'-files require `h5py`."
            "\n             Install it via `pip install h5py` or"
            "\n             `conda install -c conda-forge h5py`.\n")


__all__ = ['save', 'load']

# Known classes to serialize and de-serialize.
KNOWN_CLASSES = {
    'Model': models.Model,
    'Field': fields.Field,
    'SourceField': fields.SourceField,
    'TensorMesh': meshes.TensorMesh,
}


def data_write(fname, keys, values, path='data', exists=0):
    """DEPRECATED; USE :func:`save`.


    Parameters
    ----------
    fname : str
        File name.

    keys : str or list of str
        Name(s) of the values to store in file.

    values : anything
        Values to store with keys in file.

    path : str, optional
        Absolute or relative path where to store. Default is 'data'.

    exists : int, optional
        Flag how to act if a shelve with the given name already exists:

        - < 0: Delete existing shelve.
        - 0 (default): Do nothing (print that it exists).
        - > 0: Append to existing shelve.

    """
    # Issue warning
    mesg = ("\n    The use of `data_write` and `data_read` is deprecated.\n"
            "    These function will be removed before v1.0.\n"
            "    Use `emg3d.save` and `emg3d.load` instead.")
    warnings.warn(mesg, DeprecationWarning)

    # Get absolute path, create if it doesn't exist.
    path = os.path.abspath(path)
    os.makedirs(path, exist_ok=True)
    full_path = os.path.join(path, fname)

    # Check if shelve exists.
    bak_exists = os.path.isfile(full_path+".bak")
    dat_exists = os.path.isfile(full_path+".dat")
    dir_exists = os.path.isfile(full_path+".dir")
    if any([bak_exists, dat_exists, dir_exists]):
        print("   > File exists, ", end="")
        if exists == 0:
            print("NOT SAVING THE DATA.")
            return
        elif exists > 0:
            print("appending to it", end='')
        else:
            print("overwriting it.")
            for extension in ["dat", "bak", "dir"]:
                try:
                    os.remove(full_path+"."+extension)
                except FileNotFoundError:
                    pass

    # Cast into list.
    if not isinstance(keys, (list, tuple)):
        keys = [keys, ]
        values = [values, ]

    # Shelve it.
    with shelve.open(full_path) as db:

        # If appending, print the keys which will be overwritten.
        if exists > 0:
            over = [j for j in keys if any(i == j for i in list(db.keys()))]
            if len(over) > 0:
                print(" (overwriting existing key(s) "+f"{over}"[1:-1]+").")
            else:
                print(".")

        # Writing it to the shelve.
        for i, key in enumerate(keys):

            # If the parameter is a TensorMesh instance, we set the volume
            # None. This saves space, and it will simply be reconstructed if
            # required.
            if type(values[i]).__name__ == 'TensorMesh':
                if hasattr(values[i], '_vol'):
                    delattr(values[i], '_vol')

            db[key] = values[i]


def data_read(fname, keys=None, path="data"):
    """DEPRECATED; USE :func:`load`.


    Parameters
    ----------
    fname : str
        File name.

    keys : str, list of str, or None; optional
        Name(s) of the values to get from file. If None, returns everything as
        a dict. Default is None.

    path : str, optional
        Absolute or relative path where fname is stored. Default is 'data'.


    Returns
    -------
    out : values or dict
        Requested value(s) or dict containing everything if keys=None.

    """
    # Issue warning
    mesg = ("\n    The use of `data_write` and `data_read` is deprecated.\n"
            "    These functions will be removed before v1.0.\n"
            "    Use `save` and `load` instead.")
    warnings.warn(mesg, DeprecationWarning)

    # Get absolute path.
    path = os.path.abspath(path)
    full_path = os.path.join(path, fname)

    # Check if shelve exists.
    for extension in [".dat", ".bak", ".dir"]:
        if not os.path.isfile(full_path+extension):
            print(f"   > File <{full_path+extension}> does not exist.")
            if isinstance(keys, (list, tuple)):
                return len(keys)*(None, )
            else:
                return None

    # Get it from shelve.
    with shelve.open(path+"/"+fname) as db:
        if keys is None:                           # None
            out = dict()
            for key, item in db.items():
                out[key] = item
            return out

        elif not isinstance(keys, (list, tuple)):  # single parameter
            return db[keys]

        else:                                      # lists/tuples of parameters
            out = []
            for key in keys:
                out.append(db[key])
            return out


[docs]def save(fname, backend="h5py", compression="gzip", **kwargs): """Save meshes, models, fields, and other data to disk. Serialize and save :class:`emg3d.meshes.TensorMesh`, :class:`emg3d.fields.Field`, and :class:`emg3d.models.Model` instances and add arbitrary other data, where instances of the same type are grouped together. The serialized instances will be de-serialized if loaded with :func:`load`. Parameters ---------- fname : str File name. backend : str, optional Backend to use. Implemented are currently: - `h5py` (default): Uses `h5py` to store inputs to a hierarchical, compressed binary hdf5 file with the extension '.h5'. Recommended and default backend, but requires the module `h5py`. Use `numpy` if you don't want to install `h5py`. - `numpy`: Uses `numpy` to store inputs to a flat, compressed binary file with the extension '.npz'. compression : int or str, optional Passed through to h5py, default is 'gzip'. kwargs : Keyword arguments, optional Data to save using its key as name. The following instances will be properly serialized: :class:`emg3d.meshes.TensorMesh`, :class:`emg3d.fields.Field`, and :class:`emg3d.models.Model` and serialized again if loaded with :func:`load`. These instances are collected in their own group if h5py is used. """ # Get absolute path. full_path = os.path.abspath(fname) # Add meta-data to kwargs kwargs['_date'] = datetime.today().isoformat() kwargs['_version'] = 'emg3d v' + utils.__version__ kwargs['_format'] = '0.10.0' # File format; version of emg3d when changed. # Get hierarchical dictionary with serialized and # sorted TensorMesh, Field, and Model instances. data = _dict_serialize(kwargs) # Save data depending on the backend. if backend == "numpy": # Add .npz if necessary. if not full_path.endswith('.npz'): full_path += '.npz' # Store flattened data. np.savez_compressed(full_path, **_dict_flatten(data)) elif backend == "h5py": # Add .h5 if necessary. if not full_path.endswith('.h5'): full_path += '.h5' # Check if h5py is installed. if isinstance(h5py, str): print(h5py) raise ImportError("backend='h5py'") # Store data. with h5py.File(full_path, "w") as h5file: _hdf5_add_to(data, h5file, compression) else: raise NotImplementedError(f"Backend '{backend}' is not implemented.")
[docs]def load(fname, **kwargs): """Load meshes, models, fields, and other data from disk. Load and de-serialize :class:`emg3d.meshes.TensorMesh`, :class:`emg3d.fields.Field`, and :class:`emg3d.models.Model` instances and add arbitrary other data that were saved with :func:`save`. Parameters ---------- fname : str File name including extension. Used backend depends on the file extensions: - '.npz': numpy-binary - '.h5': h5py-binary (needs `h5py`) verb : int If 1 (default) verbose, if 0 silent. Returns ------- out : dict A dictionary containing the data stored in fname; :class:`emg3d.meshes.TensorMesh`, :class:`emg3d.fields.Field`, and :class:`emg3d.models.Model` instances are de-serialized and returned as instances. """ # Get kwargs. verb = kwargs.pop('verb', 1) # allow_pickle is undocumented, but kept, just in case... allow_pickle = kwargs.pop('allow_pickle', False) # Ensure no kwargs left. if kwargs: raise TypeError(f"Unexpected **kwargs: {list(kwargs.keys())}") # Get absolute path. full_path = os.path.abspath(fname) # Load data depending on the file extension. if fname.endswith('npz'): # Load .npz into a flat dict. with np.load(full_path, allow_pickle=allow_pickle) as dat: data = {key: dat[key] for key in dat.files} # Un-flatten data. data = _dict_unflatten(data) elif fname.endswith('h5'): # Check if h5py is installed. if isinstance(h5py, str): print(h5py) raise ImportError("backend='h5py'") # Load data. with h5py.File(full_path, 'r') as h5file: data = _hdf5_get_from(h5file) else: ext = fname.split('.')[-1] raise NotImplementedError(f"Extension '.{ext}' is not implemented.") # De-serialize data. _dict_deserialize(data) # Check if file was (supposedly) created by emg3d. try: version = data['_version'] date = data['_date'] form = data['_format'] # Print file info. if verb > 0: print(f" Loaded file {full_path}") print(f" -> Stored with {version} (format {form}) on {date}") except KeyError: if verb > 0: print(f"\n* NOTE :: {full_path} was not created by emg3d.") return data
def _dict_serialize(inp, out=None, top=True): """Serialize TensorMesh, Field, and Model instances in dict. Returns a serialized dictionary <out> of <inp>, where all :class:`emg3d.meshes.TensorMesh`, :class:`emg3d.fields.Field`, and :class:`emg3d.models.Model` instances have been serialized. These instances are additionally grouped together in dictionaries, and all other stuff is put into 'Data'. There are some limitations: 1. Key names are converted to strings. 2. None values are converted to 'NoneType'. 3. TensorMesh instances from discretize will be stored as if they would be simpler emg3d-meshes. Parameters ---------- inp : dict Input dictionary to serialize. out : dict Output dictionary; created if not provided. top : bool Used for recursion. Returns ------- out : dict Serialized <inp>-dict. """ # Initiate output dictionary if not provided. if out is None: output = True out = {} else: output = False # Loop over items. for key, value in inp.items(): # Limitation 1: Cast keys to string if not isinstance(key, str): key = str(key) # Take care of the following instances (if we are in the # top-directory they get their own category): if (isinstance(value, tuple(KNOWN_CLASSES.values())) or hasattr(value, 'x0')): # Name of the instance name = value.__class__.__name__ # Workaround for discretize.TensorMesh -> stored as if TensorMesh. if hasattr(value, 'to_dict'): to_dict = value.to_dict() else: try: to_dict = {'hx': value.hx, 'hy': value.hy, 'hz': value.hz, 'x0': value.x0, '__class__': name} except AttributeError: # Gracefully fail. print(f"* WARNING :: Could not serialize <{key}>") continue # If we are in the top-directory put them in their own category. if top: value = {key: to_dict} key = name else: value = to_dict elif top: if key.startswith('_'): # Store meta-data in top-level... out[key] = value continue else: # ...rest falls into Data/. value = {key: value} key = 'Data' # Initiate if necessary. if key not in out.keys(): out[key] = {} # If value is a dict use recursion, else store. if isinstance(value, dict): _dict_serialize(value, out[key], False) else: # Limitation 2: None if value is None: out[key] = 'NoneType' else: out[key] = value # Return if it wasn't provided. if output: return out def _dict_deserialize(inp): """De-serialize TensorMesh, Field, and Model instances in dict. De-serializes in-place dictionary <inp>, where all :class:`emg3d.meshes.TensorMesh`, :class:`emg3d.fields.Field`, and :class:`emg3d.models.Model` instances have been de-serialized. It also converts back 'NoneType'-strings to None. Parameters ---------- inp : dict Input dictionary to de-serialize. """ # Loop over items. for key, value in inp.items(): # Analyze if it is a dict, else ignore (check for 'NoneType'). if isinstance(value, dict): # If it has a __class__-key, de-serialize. if '__class__' in value.keys(): for k2, v2 in value.items(): if isinstance(v2, str) and v2 == 'NoneType': value[k2] = None # De-serialize, overwriting all the existing entries. try: inst = KNOWN_CLASSES[value['__class__']] inp[key] = inst.from_dict(value) continue except (AttributeError, KeyError): # Gracefully fail. print(f"* WARNING :: Could not de-serialize <{key}>") # In no __class__-key or de-serialization fails, use recursion. _dict_deserialize(value) elif isinstance(value, str) and value == 'NoneType': inp[key] = None def _dict_flatten(data): """Return flattened dict of input dict <data>. After https://codereview.stackexchange.com/revisions/21035/3 Parameters ---------- data : dict Input dict to flatten Returns ------- fdata : dict Flattened dict. """ def expand(key, value): """Expand list.""" if isinstance(value, dict): return [(key+'>'+k, v) for k, v in _dict_flatten(value).items()] else: return [(key, value)] return dict([item for k, v in data.items() for item in expand(k, v)]) def _dict_unflatten(data): """Return un-flattened dict of input dict <data>. After https://stackoverflow.com/a/6037657 Parameters ---------- data : dict Input dict to un-flatten Returns ------- udata : dict Un-flattened dict. """ # Initialize output dict. out = {} # Loop over items. for key, value in data.items(): # Split the keys. parts = key.split(">") # Initiate tmp dict. tmp = out # Loop over key-parts. for part in parts[:-1]: # If subkey does not exist yet, initiate subdict. if part not in tmp: tmp[part] = {} # Add value to subdict. tmp = tmp[part] # Convert numpy strings to str. if '<U' in str(value.dtype): value = str(value) # Store actual value of this key. tmp[parts[-1]] = value return out def _hdf5_add_to(data, h5file, compression): """Adds dictionary entries recursively to h5. Parameters ---------- data : dict Dictionary containing the data. h5file : file Opened by h5py. compression : str or int Passed through to h5py. """ # Loop over items. for key, value in data.items(): # Use recursion if value is a dict, creating a new group. if isinstance(value, dict): _hdf5_add_to(value, h5file.create_group(key), compression) elif np.ndim(value) > 0: # Use compression where possible... h5file.create_dataset(key, data=value, compression=compression) else: # else store without compression. h5file.create_dataset(key, data=value) def _hdf5_get_from(h5file): """Return data from h5file in a dictionary. Parameters ---------- h5file : file Opened by h5py. Returns ------- data : dict Dictionary containing the data. """ # Initiate dictionary. data = {} # Loop over items. for key, value in h5file.items(): # If it is a dataset add value to key, else use recursion to dig in. if isinstance(value, h5py._hl.dataset.Dataset): data[key] = value[()] elif isinstance(value, h5py._hl.group.Group): data[key] = _hdf5_get_from(value) return data