Source code for emg3d.io

"""
Utility functions for writing and reading data.
"""
# Copyright 2018-2020 The emg3d Developers.
#
# This file is part of emg3d.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License.  You may obtain a copy
# of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
# License for the specific language governing permissions and limitations under
# the License.

import os
import json
from datetime import datetime

import numpy as np

try:
    import h5py
except ImportError:
    h5py = ("'.h5'-files require `h5py`. Install it via\n"
            "`pip install h5py` or `conda install -c conda-forge h5py`.")

from emg3d import fields, maps, models, utils, meshes, surveys, simulations

__all__ = ['save', 'load']

# Known classes to serialize and de-serialize.
KNOWN_CLASSES = {
    '_Map': maps._Map,
    'Model': models.Model,
    'Field': fields.Field,
    'Survey': surveys.Survey,
    'Dipole': surveys.Dipole,
    'TensorMesh': meshes.TensorMesh,
    'SourceField': fields.SourceField,
    'Simulation': simulations.Simulation,
}


[docs]def save(fname, **kwargs): """Save surveys, meshes, models, fields, and more to disk. Serialize and save data to disk in different formats (see parameter description of `fname` for the supported file formats). The main emg3d-classes (type `emg3d.io.KNOWN_CLASSES` to get a list) can be collected in corresponding root-folders by setting `collect_classes=True`. Any other (non-emg3d) object can be added too, as long as it knows how to serialize itself. The serialized instances will be de-serialized if loaded with :func:`load`. Parameters ---------- fname : str File name inclusive ending, which defines the used data format. Implemented are currently: - `.h5`: Uses `h5py` to store inputs to a hierarchical, compressed binary hdf5 file. Recommended file format, but requires the module `h5py`. - `.npz`: Uses `numpy` to store inputs to a flat, compressed binary file. - `.json`: Uses `json` to store inputs to a hierarchical, plain text file. compression : int or str, optional Passed through to h5py, default is 'gzip'. json_indent : int or None Passed through to json, default is 2. collect_classes : bool If True, input data is collected in folders for the principal emg3d-classes (type `emg3d.io.KNOWN_CLASSES` to get a list) and everything else collected in a `Data`-folder. Defaults to False. verb : int If 1 (default) verbose, if 0 silent. kwargs : Keyword arguments, optional Data to save using its key as name. The following instances will be properly serialized: :class:`emg3d.meshes.TensorMesh`, :class:`emg3d.fields.Field`, and :class:`emg3d.models.Model` and serialized again if loaded with :func:`load`. These instances are collected in their own group if h5py is used. Note that the provided data cannot contain the before described parameters as keys. """ # Get and remove optional kwargs. compression = kwargs.pop('compression', 'gzip') json_indent = kwargs.pop('json_indent', 2) collect_classes = kwargs.pop('collect_classes', False) verb = kwargs.pop('verb', 1) # Get absolute path. full_path = os.path.abspath(fname) # Add meta-data to kwargs kwargs['_date'] = datetime.today().isoformat() kwargs['_version'] = 'emg3d v' + utils.__version__ kwargs['_format'] = '0.13.0' # File format; version of emg3d when changed. # Get hierarchical dictionary with serialized and # sorted TensorMesh, Field, and Model instances. data = _dict_serialize(kwargs, collect_classes=collect_classes) # Save data depending on the extension. if full_path.endswith('.npz'): # Convert hierarchical dict to a flat dict. data = _dict_flatten(data) # Store flattened data. np.savez_compressed(full_path, **data) elif full_path.endswith('.h5'): # Check if h5py is installed. if isinstance(h5py, str): raise ImportError(h5py) # Store data. with h5py.File(full_path, "w") as h5file: _hdf5_add_to(data, h5file, compression) elif full_path.endswith('.json'): # Move arrays to lists and decompose complex data. data = _dict_dearray_decomp(data) # Store hierarchical data. with open(full_path, "w") as f: json.dump(data, f, indent=json_indent) else: ext = full_path.split('.')[-1] raise ValueError(f"Unknown extension '.{ext}'.") # Print file info. if verb > 0: print(f"Data saved to «{full_path}»\n[{kwargs['_version']} " f"(format {kwargs['_format']}) on {kwargs['_date']}].")
[docs]def load(fname, **kwargs): """Load meshes, models, fields, and other data from disk. Load and de-serialize :class:`emg3d.meshes.TensorMesh`, :class:`emg3d.fields.Field`, and :class:`emg3d.models.Model` instances and add arbitrary other data that were saved with :func:`save`. Parameters ---------- fname : str File name including extension. Possibilities: - '.npz': numpy-binary - '.h5': h5py-binary (needs `h5py`) - '.json': json verb : int If 1 (default) verbose, if 0 silent. Returns ------- out : dict A dictionary containing the data stored in fname; :class:`emg3d.meshes.TensorMesh`, :class:`emg3d.fields.Field`, and :class:`emg3d.models.Model` instances are de-serialized and returned as instances. """ # Get kwargs. verb = kwargs.pop('verb', 1) # allow_pickle is undocumented, but kept, just in case... allow_pickle = kwargs.pop('allow_pickle', False) # Ensure no kwargs left. if kwargs: raise TypeError(f"Unexpected **kwargs: {list(kwargs.keys())}") # Get absolute path. full_path = os.path.abspath(fname) # Load data depending on the file extension. if full_path.endswith('.npz'): # Load .npz into a flat dict. with np.load(full_path, allow_pickle=allow_pickle) as dat: data = {key: dat[key] for key in dat.files} # Un-flatten data. data = _dict_unflatten(data) elif full_path.endswith('.h5'): # Check if h5py is installed. if isinstance(h5py, str): raise ImportError(h5py) # Load data. with h5py.File(full_path, 'r') as h5file: data = _hdf5_get_from(h5file) elif full_path.endswith('.json'): with open(full_path, 'r') as f: data = json.load(f) # Move lists back to arrays and compose complex data. data = _dict_array_comp(data) else: ext = full_path.split('.')[-1] raise ValueError(f"Unknown extension '.{ext}'.") # De-serialize data. _dict_deserialize(data) # Check if file was (supposedly) created by emg3d. if verb > 0: print(f"Data loaded from «{full_path}»") try: version = data['_version'] date = data['_date'] form = data['_format'] # Print file info. if verb > 0: print(f"[{version} (format {form}) on {date}].") except KeyError: if verb > 0: print("[version/format/date unknown; not created by emg3d].") return data
def _dict_serialize(inp, out=None, collect_classes=False): """Serialize emg3d-classes and other objects in inp-dict. Returns a serialized dictionary <out> of <inp>, where all members of `emg3d.io.KNOWN_CLASSES` are serialized with their respective `to_dict()` methods. These instances are additionally grouped together in dictionaries, and all other stuff is put into 'Data' if `collect_classes=True`. Any other (non-emg3d) object can be added too, as long as it knows how to serialize itself. There are some limitations: 1. Key names are converted to strings. 2. None values are converted to 'NoneType'. 3. TensorMesh instances from discretize will be stored as if they would be simpler emg3d-meshes. Parameters ---------- inp : dict Input dictionary to serialize. out : dict Output dictionary; created if not provided. collect_classes : bool If True, input data is collected in folders for the principal emg3d-classes (type `emg3d.io.KNOWN_CLASSES` to get a list) and everything else collected in a `Data`-folder. Default is False. Returns ------- out : dict Serialized <inp>-dict. """ # Initiate output dictionary if not provided. if out is None: output = True out = {} else: output = False # Loop over items. for key, value in inp.items(): # Limitation 1: Cast keys to string if not isinstance(key, str): key = str(key) # Take care of the following instances # (if we are in the root-directory they get their own category): if (isinstance(value, tuple(KNOWN_CLASSES.values())) or hasattr(value, 'x0')): # Name of the instance name = value.__class__.__name__ # Workaround for discretize.TensorMesh -> stored as if TensorMesh. if hasattr(value, 'to_dict'): to_dict = value.to_dict() else: try: to_dict = {'hx': value.hx, 'hy': value.hy, 'hz': value.hz, 'x0': value.x0, '__class__': name} except AttributeError as e: # Gracefully fail. print(f"* WARNING :: Could not serialize <{key}>.\n" f" {e}") continue # If we are in the root-directory put them in their own category. # `collect_classes` can only be True in root-directory, as it is # set to False in recursion. if collect_classes: value = {key: to_dict} key = name else: value = to_dict elif collect_classes: # `collect_classes` can only be True in root-directory, as it is # set to False in recursion. if key.startswith('_'): # Store meta-data in root-level... out[key] = value continue else: # ...rest falls into Data/. value = {key: value} key = 'Data' # Initiate if necessary. if key not in out.keys(): out[key] = {} # If value is a dict use recursion, else store. if isinstance(value, dict): _dict_serialize(value, out[key], collect_classes=False) else: # Limitation 2: None if value is None: out[key] = 'NoneType' else: out[key] = value # Return if it wasn't provided. if output: return out def _dict_deserialize(inp, first_call=True): """De-serialize emg3d-classes and other objects in inp-dict. De-serializes in-place dictionary <inp>, where all members of `emg3d.io.KNOWN_CLASSES` are de-serialized with their respective `from_dict()` methods. It also converts back `'NoneType'`-strings to `None`, and `np.bool_` to `bool`. Parameters ---------- inp : dict Input dictionary to de-serialize. """ # Recursively replace `'NoneType'` by `None` and `np.bool_` by `bool`. if first_call: _nonetype_to_none(inp) # Loop over items. for key, value in inp.items(): # If it is a dict, deserialize if KNOWN_CLASS or recursion. if isinstance(value, dict): # If it has a __class__-key, de-serialize. if '__class__' in value.keys(): # De-serialize, overwriting all the existing entries. try: inst = KNOWN_CLASSES[value['__class__']] inp[key] = inst.from_dict(value) continue except (NotImplementedError, AttributeError, KeyError) as e: # Gracefully fail. print(f"* WARNING :: Could not de-serialize <{key}>.\n" f" {e}") # In no __class__-key or de-serialization fails, use recursion. _dict_deserialize(value, False) def _nonetype_to_none(inp): """Recursively replace side-effects in inp-dict from storing to disc. Changes: - Replace `NoneType'` by `None`. - `np.bool_` are cast back to `bool` (because `bool` is converted to `np.bool_` for some file formats). """ for k, v in inp.items(): if isinstance(v, dict): _nonetype_to_none(v) elif isinstance(v, str) and v == 'NoneType': inp[k] = None elif isinstance(v, np.bool_): inp[k] = bool(v) elif isinstance(v, np.ndarray) and v.dtype == np.bool_: inp[k] = bool(np.squeeze(v)) def _dict_flatten(data): """Return flattened dict of input dict <data>. After https://codereview.stackexchange.com/revisions/21035/3 Parameters ---------- data : dict Input dict to flatten. Returns ------- fdata : dict Flattened dict. """ def expand(key, value): """Expand list.""" if isinstance(value, dict): return [(key+'>'+k, v) for k, v in _dict_flatten(value).items()] else: return [(key, value)] return dict([item for k, v in data.items() for item in expand(k, v)]) def _dict_unflatten(data): """Return un-flattened dict of input dict <data>. After https://stackoverflow.com/a/6037657 Parameters ---------- data : dict Input dict to un-flatten. Returns ------- udata : dict Un-flattened dict. """ # Initialize output dict. out = {} # Loop over items. for key, value in data.items(): # Split the keys. parts = key.split(">") # Initiate tmp dict. tmp = out # Loop over key-parts. for part in parts[:-1]: # If subkey does not exist yet, initiate subdict. if part not in tmp: tmp[part] = {} # Add value to subdict. tmp = tmp[part] # Convert numpy strings to str. if '<U' in str(np.asarray(value).dtype): value = str(value) # Store actual value of this key. tmp[parts[-1]] = value return out def _dict_dearray_decomp(data): """Return dict where arrays are replaced by lists, complex by real numbers. Parameters ---------- data : dict Input dict to decompose. Returns ------- ddata : dict As input, but arrays are moved to lists, and complex number to real numbers like [real, imag]. """ # Output dict. ddata = {} # Loop over keys. for key, value in data.items(): # Recursion. if isinstance(value, dict): value = _dict_dearray_decomp(value) # Test if complex. if np.iscomplexobj(value): key += '__complex' value = np.stack([np.asarray(value).real, np.asarray(value).imag]) # Convert to lists if no arrays wanted. if isinstance(value, np.ndarray): key += '__array-'+value.dtype.name value = value.tolist() # Store this key-value-pair. ddata[key] = value return ddata def _dict_array_comp(data): """Return dict where lists/complex are moved back to arrays. Parameters ---------- data : dict Input dict to compose. Returns ------- ddata : dict As input, but lists are again arrays and complex data are complex again. """ # Output dict. ddata = {} # Loop over keys. for key, value in data.items(): # Recursion. if isinstance(value, dict): value = _dict_array_comp(value) # Get arrays back. if '__array' in key: arraytype = key.split('__')[-1] dtype = getattr(np, arraytype[6:]) value = np.asarray(value, dtype=dtype, order='F') key = key.replace(key[-len(arraytype)-2:], '') # Compose complex numbers. if '__complex' in key: value = np.asarray(value)[0, ...] + 1j*np.asarray(value)[1, ...] key = key.replace('__complex', '') # Store this key-value-pair. ddata[key] = value return ddata def _hdf5_add_to(data, h5file, compression): """Adds dictionary entries recursively to h5. Parameters ---------- data : dict Dictionary containing the data. h5file : file Opened by h5py. compression : str or int Passed through to h5py. """ # Loop over items. for key, value in data.items(): # Use recursion if value is a dict, creating a new group. if isinstance(value, dict): _hdf5_add_to(value, h5file.create_group(key), compression) elif np.ndim(value) > 0: # Use compression where possible... h5file.create_dataset(key, data=value, compression=compression) else: # else store without compression. h5file.create_dataset(key, data=value) def _hdf5_get_from(h5file): """Return data from h5file in a dictionary. Parameters ---------- h5file : file Opened by h5py. Returns ------- data : dict Dictionary containing the data. """ # Initiate dictionary. data = {} # Loop over items. for key, value in h5file.items(): # If it is a dataset add value to key, else use recursion to dig in. if isinstance(value, h5py._hl.dataset.Dataset): data[key] = value[()] # h5py>=3.0 changed strings to byte strings. if isinstance(data[key], bytes): data[key] = data[key].decode("utf-8") elif isinstance(value, h5py._hl.group.Group): data[key] = _hdf5_get_from(value) return data def _compare_dicts(dict1, dict2, verb=False, **kwargs): """Return True if the two dicts `dict1` and `dict2` are the same. Private method, not foolproof. Useful for developing new extensions. If `verb=True`, it prints it key starting with the following legend: - True : Values are the same. - False : Values are not the same. - {1} : Key is only in dict1 present. - {2} : Key is only in dict2 present. Private keys (starting with an underscore) are ignored. Parameters ---------- dict1, dict2 : dicts Dictionaries to compare. verb : bool If True, prints all keys and if they are the same for that key. kwargs : dict For recursion. Returns ------- same : bool True if dicts are the same, False otherwise. """ # Get recursion kwargs. s = kwargs.pop('s', '') reverse = kwargs.pop('reverse', False) gsame = kwargs.pop('gsame', True) # Check if we are at the base level and in reverse mode or not. do_reverse = len(s) == 0 and reverse is False # Loop over key-value pairs. for key, value in dict1.items(): # Recursion if value is dict and present in both dicts. if isinstance(value, dict) and key in dict2.keys(): # Add current key to string. s += f"{key[:10]:11}> " # Recursion. _compare_dicts(dict1[key], dict2[key], verb=verb, s=s, reverse=reverse, gsame=gsame) # Remove current key. s = s[:-13] elif key.startswith('_'): # Ignoring private keys. pass else: # Do actual comparison. # Check if key in both dicts. if key in dict2.keys(): # If reverse, the key has already been checked. if reverse is False: # Compare. same = np.all(value == dict2[key]) # Update global bool. gsame *= same if verb: print(f"{bool(same)!s:^7}:: {s}{key}") # Clean string. s = len(s)*' ' else: # If only in one dict -> False. gsame = False if verb: print(f" {{{2 if reverse else 1}}} :: {s}{key}") # Do the same reverse, do check for keys in dict2 which are not in dict1. if do_reverse: gsame = _compare_dicts(dict2, dict1, verb, reverse=True, gsame=gsame) return gsame