Source code for pycbc.inference.sampler.base

# Copyright (C) 2016  Christopher M. Biwer, Collin Capano
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
"""
Defines the base sampler class to be inherited by all samplers.
"""

from __future__ import absolute_import

from abc import ABCMeta, abstractmethod, abstractproperty
import os
import shutil
import logging

from six import add_metaclass

from pycbc import distributions
from pycbc.inference.io import validate_checkpoint_files

#
# =============================================================================
#
#                           Base Sampler definition
#
# =============================================================================
#


[docs]@add_metaclass(ABCMeta) class BaseSampler(object): """Abstract base class for all inference samplers. All sampler classes must inherit from this class and implement its abstract methods. Parameters ---------- model : Model An instance of a model from ``pycbc.inference.models``. """ name = None def __init__(self, model): self.model = model # @classmethod <--uncomment when we move to python 3.3
[docs] @abstractmethod def from_config(cls, cp, model, nprocesses=1, use_mpi=False): """This should initialize the sampler given a config file. """ pass
@property def variable_params(self): """Returns the parameters varied in the model. """ return self.model.variable_params @property def sampling_params(self): """Returns the sampling params used by the model. """ return self.model.sampling_params @property def static_params(self): """Returns the model's fixed parameters. """ return self.model.static_params @abstractproperty def samples(self): """A dict mapping variable_params to arrays of samples currently in memory. The dictionary may also contain sampling_params. The sample arrays may have any shape, and may or may not be thinned. """ pass @abstractproperty def model_stats(self): """A dict mapping model's metadata fields to arrays of values for each sample in ``raw_samples``. The arrays may have any shape, and may or may not be thinned. """ pass
[docs] @abstractmethod def run(self): """This function should run the sampler. Any checkpointing should be done internally in this function. """ pass
@abstractproperty def io(self): """A class that inherits from ``BaseInferenceFile`` to handle IO with an hdf file. This should be a class, not an instance of class, so that the sampler can initialize it when needed. """ pass
[docs] @abstractmethod def set_initial_conditions(self, initial_distribution=None, samples_file=None): """Sets up the starting point for the sampler. Should also set the sampler's random state. """ pass
[docs] @abstractmethod def checkpoint(self): """The sampler must have a checkpoint method for dumping raw samples and stats to the file type defined by ``io``. """ pass
[docs] @abstractmethod def finalize(self): """Do any finalization to the samples file before exiting.""" pass
[docs] def setup_output(self, output_file, force=False, injection_file=None): """Sets up the sampler's checkpoint and output files. The checkpoint file has the same name as the output file, but with ``.checkpoint`` appended to the name. A backup file will also be created. If the output file already exists, an ``OSError`` will be raised. This can be overridden by setting ``force`` to ``True``. Parameters ---------- sampler : sampler instance Sampler output_file : str Name of the output file. force : bool, optional If the output file already exists, overwrite it. injection_file : str, optional If an injection was added to the data, write its information. """ # check that the output file doesn't already exist if os.path.exists(output_file): if force: os.remove(output_file) else: raise OSError("output-file already exists; use force if you " "wish to overwrite it.") # check for backup file(s) checkpoint_file = output_file + '.checkpoint' backup_file = output_file + '.bkup' # check if we have a good checkpoint and/or backup file logging.info("Looking for checkpoint file") checkpoint_valid = validate_checkpoint_files(checkpoint_file, backup_file) # Create a new file if the checkpoint doesn't exist, or if it is # corrupted self.new_checkpoint = False # keeps track if this is a new file or not if not checkpoint_valid: logging.info("Checkpoint not found or not valid") create_new_output_file(self, checkpoint_file, injection_file=injection_file) # now the checkpoint is valid self.new_checkpoint = True # copy to backup shutil.copy(checkpoint_file, backup_file) # write the command line, startup for fn in [checkpoint_file, backup_file]: with self.io(fn, "a") as fp: fp.write_command_line() fp.write_resume_point() # store self.checkpoint_file = checkpoint_file self.backup_file = backup_file self.checkpoint_valid = checkpoint_valid
# # ============================================================================= # # Convenience functions # # ============================================================================= #
[docs]def create_new_output_file(sampler, filename, injection_file=None, **kwargs): """Creates a new output file. If the output file already exists, an ``OSError`` will be raised. This can be overridden by setting ``force`` to ``True``. Parameters ---------- sampler : sampler instance Sampler filename : str Name of the file to create. force : bool, optional Create the file even if it already exists. Default is False. injection_file : str, optional If an injection was added to the data, write its information. \**kwargs : All other keyword arguments are passed through to the file's ``write_metadata`` function. """ logging.info("Creating file {}".format(filename)) with sampler.io(filename, "w") as fp: # create the samples group and sampler info group fp.create_group(fp.samples_group) fp.create_group(fp.sampler_group) # save the sampler's metadata fp.write_sampler_metadata(sampler) # save injection parameters if injection_file is not None: logging.info("Writing injection file to output") # just use the first one fp.write_injections(injection_file)
[docs]def initial_dist_from_config(cp, variable_params): r"""Loads a distribution for the sampler start from the given config file. A distribution will only be loaded if the config file has a [initial-\*] section(s). Parameters ---------- cp : Config parser The config parser to try to load from. variable_params : list of str The variable parameters for the distribution. Returns ------- JointDistribution or None : The initial distribution. If no [initial-\*] section found in the config file, will just return None. """ if len(cp.get_subsections("initial")): logging.info("Using a different distribution for the starting points " "than the prior.") initial_dists = distributions.read_distributions_from_config( cp, section="initial") constraints = distributions.read_constraints_from_config( cp, constraint_section="initial_constraint") init_dist = distributions.JointDistribution( variable_params, *initial_dists, **{"constraints": constraints}) else: init_dist = None return init_dist