Source code for smoofit.model

from collections import OrderedDict
from itertools import chain
from functools import partial
from typing import Dict, Callable, Optional, List, Tuple, Any, Union

from jax.config import config
config.update("jax_enable_x64", True) # enforce double precision throughout

import jax.numpy as jnp
from jax import grad, jit, vmap, hessian
from jax import random
from jax import lax
from jax.scipy.stats import poisson, multivariate_normal
from jax.scipy.linalg import inv, cholesky

import numpy as np

from scipy.optimize import minimize, root_scalar, RootResults, OptimizeResult
from scipy.optimize import NonlinearConstraint, Bounds

from .systematics import *
from .utils import *


[docs]class Variable: """ Represents any parameter in the model, be it a free parameter of interest (POI) or a (constrained) nuisance parameter. Variables have an initial (default) value, that is taken as initial value in the fits, and that can be used when generating Asimov toy datasets. Variables can either be scalar- or vector-valued. Vector-valued variables are restricted to a single dimension, i.e. they have array shape ``(n,)``. For vector-valued variables, every component can have a specific "name", which by default is just the index of the component. """
[docs] def __init__(self, name: str, init: Union[float, List[float], np.ndarray, jnp.DeviceArray] = 0., lower_bound: Optional[Union[float, List[float], np.ndarray, jnp.DeviceArray]] = None, upper_bound: Optional[Union[float, List[float], np.ndarray, jnp.DeviceArray]] = None, sub_names: Optional[List[str]] = None, nuisance: bool = False): """Constructor Lower or upper bounds for vector-valued parameters can either be specified as scalars (in which case the same bound will be used for every component) or as list/array in which case a different bound can be specified for each component). :param name: name of the parameter - should be unique among all the parameters linked to a :py:class:`Model` object. :param init: initial value of the parameter (scalar or jax/numpy 1D array) :param lower_bound: lower bound of the parameter in the fit (default is no bound) :param upper_bound: upper bound of the parameter in the fit (default is no bound) :param sub_names: names of the sub-components of vector-valued parameters (by default, will just be the indices) :param nuisance: specify if the parameter is a (Gaussian-constrained) nuisance parameter """ self.name = name self.set(init) self.dim = len(self.val) self.nuisance = nuisance self.lower_bound = lower_bound if self.lower_bound is not None: self.lower_bound = to_np(self.lower_bound) if self.lower_bound.shape[0] == 1: self.lower_bound = np.broadcast_to(self.lower_bound, self.val.shape) assert(self.lower_bound.shape == self.val.shape) self.upper_bound = upper_bound if self.upper_bound is not None: self.upper_bound = to_np(self.upper_bound) if self.upper_bound.shape[0] == 1: self.upper_bound = np.broadcast_to(self.upper_bound, self.val.shape) assert(self.upper_bound.shape == self.val.shape) if sub_names is not None: assert(len(sub_names) == self.dim) else: sub_names = list(range(self.dim)) self.sub_names = sub_names
[docs] def set(self, val: Union[float, List[float], np.ndarray, jnp.DeviceArray]): """ Set a new value for the parameter(s) The dimension of the parameter cannot be changed. :param val: the new value (scalar or jax/numpy array) """ val = to_jnp(val) assert(len(val.shape) == 1) if hasattr(self, "val"): assert(self.val.shape == val.shape) self.val = val
[docs]class ChannelContrib: """ Specify how a process contributes to a channel. Channels are any orthogonal selections of events, i.e. they can be data-taking eras, signal/control regions, lepton flavour channels, bins, etc. :py:class:`ChannelContrib` objects specify how processes contribute to a given channel - and also implicitly define which channels are to be considered in the fit. Indeed, channels are never declared explicitly, but are collected through all the :py:class:`ChannelContrib` objects of all processes registered with a model. The name of a :py:class:`ChannelContrib` is is used to collect all conributions to a given channel. If a process does not contribute to a given channel (i.e. there is no corresponding :py:class:`ChannelContrib`), templates with zero yields are automatically inserted. Systematic uncertainties are also specified through :py:class:`ChannelContrib` objects, and so are the statistical uncertainties on the templates. Once a :py:class:`ChannelContrib` object has been defined, it has to be registered with the process it belongs to, which is done using the :py:meth:`Process.add_contrib` method. """
[docs] def __init__(self, channel_name: str, yields: Union[List[float], np.ndarray], sumw2: Optional[Union[List[float], np.ndarray]] = None, sub_procs: Optional[List[str]] = None): """Constructor :param channel_name: the name of the channel - used to collect the contributions of various processes to a given channel :param yields: the nominal yields of the considered process in this channel, as 1D or 2D numpy array :param sumw2: the sum of squared weights of the considered process in this channel, used to evaluate the template statistical uncertainties (defaults to zero, i.e. infinite statistics), should have the same shape as the yields :param sub_procs: list of sub-processes of the considered process which contribute to this channel. The length of this list should match the length of the first axis of `yields`. This list should be a sub-list of the sub-processes of the :py:class:`Process` that this contribution will be attached to. You might use the :py:attr:`Process.sub_procs` attribute if all of the sub-processes contribute. """ self.name = channel_name self.yields = to_np(yields) if len(self.yields.shape) == 1: self.yields = self.yields[np.newaxis, :] if sumw2 is not None: self.register_sumw2(sumw2) else: self.sumw2 = np.zeros(self.yields.shape) self.n_bins = self.yields.shape[-1] if isinstance(sub_procs, str): sub_procs = [sub_procs] self.n_sub_procs = self.yields.shape[0] if sub_procs is not None: assert(len(sub_procs) == self.n_sub_procs) if self.n_sub_procs > 1: assert(sub_procs is not None) if self.n_sub_procs == 1 and sub_procs is None: sub_procs = [0] self.sub_procs = sub_procs self.systematics = [] self.scaled_by = []
[docs] def register_sumw2(self, sumw2: Union[List[float], np.ndarray]): """Register the sums of squared weights for bin-by-bin statistical uncertainties :param sumw2: 1D or 2D numpy array with the sums of squared weights for each bin, should have the same shape as the yields in this channel """ sumw2 = to_np(sumw2) if len(sumw2.shape) == 1: sumw2 = sumw2[np.newaxis, :] assert(sumw2.shape == self.yields.shape) self.sumw2 = sumw2
[docs] def scale_by_fn(self, fn: Callable[[jnp.DeviceArray], jnp.DeviceArray], variables: Union[Variable, Tuple[Variable]]): """Scale the yields in this channel by an arbitrary function The rescaling is restricted to this channel and to the registered sub-processes, i.e. the length of the first axis of the return value of ``fn`` should match the number of sub-processes of this channel. See :py:meth:`Process.scale_by_fn` for more information about how ``fn`` is expected to behave. :param fn: a callable returning the factors by which the yields should be scaled :param variables: a :py:class:`Variable` or a tuple of :py:class:`Variable` objects whose values will be passed to ``fn`` as positional arguments in the form of 1D ``jnp.DeviceArray`` arrays """ if isinstance(variables, Variable): variables = (variables,) self.scaled_by.append(PaddedFunction(fn, variables))
[docs] def scale_by(self, var: Variable): """Register a :py:class:`Variable` as a linear yield modifier (e.g. signal strength modifier) If the parameter is a scalar, the yields in this channel, for all sub-processes registered with this channel, will be scaled linearly by the parameter. If the parameter is vector-valued, the dimensionality of the variable should match the number of sub-processes registered with this channel. Each of those sub-processes will then be scaled linearly by a single component of the variable (restricted to this channel). :param var: a :py:class:`Variable` scaling the yields of this process """ if var.dim > 1: assert(var.dim == self.n_sub_procs) # expand for broadcasting over bins and multiplying along sub-processes self.scale_by_fn(lambda x: jnp.expand_dims(x, -1), var)
[docs] def scale_bins_by(self, var: Variable): """Register a :py:class:`Variable` as a linear yield modifier across bins If the parameter is a scalar, the yields in this channel, for all sub-processes registered with this channel, will be scaled linearly by the parameter. If the parameter is vector-valued, the dimensionality of the variable should match the number of bins of this channel. The yields in each bin will then be scaled linearly by a single component of the variable. :param var: a :py:class:`Variable` scaling the yields of this process """ if var.dim > 1: assert(var.dim == self.n_bins) self.scale_by_fn(lambda x: x, var)
[docs] def add_shape_syst(self, var: Variable, up: Union[List[float], np.ndarray], down: Union[List[float], np.ndarray], sub_vars: Optional[Union[List[str], List[int]]] = None, sub_procs: Optional[Union[List[str], List[int]]] = None): """Add a shape systematic uncertainty in this channel The uncertainty is specified through a pair of arrays corresponding to the up and down variations of the nuisance parameter (by :math:`\\pm 1 \\sigma`). The shape of the arrays should be ``([n_sources,][n_sub_procs,]n_bins)``, where: * ``n_sources`` is the number of uncertainty sources, matching either the dimension of ``var`` of the number of entries in ``sub_vars``. This axis might not be present, if only a single source is considered. * ``n_sub_procs`` is the number of sub-processes of the considered process, or the length of the ``sub_procs`` argument. This axis might not be present, if the process only has a single component. * ``n_bins`` is the only mandatory axis and should match the number of bins in this channel. In summary, the ``up`` and ``down`` arrays might be 1D, 2D or 3D. Writing :math:`H^{\\text{nom}}_i = I^{\\text{nom}} N^{\\text{nom}}_i` for the nominal template and bin ``i``, with :math:`I` the integral of the template yields, and similarly for the up/down variations :math:`H^{\\text{up/down}}_i = I^{\\text{up/down}} N^{\\text{up/down}}_i` (for any given uncertainty source), the templates are interpolated vertically using the following scheme: .. math:: H_i(\\alpha) = I(\\alpha) N_i(\\alpha) where :math:`\\alpha` is the (Gaussian-constrained) nuisance parameter, :math:`I(\\alpha)` is an inter/extrapolation of the template normalization an asymmetric log-normal uncertainty (with :math:`K^{\\text{up/down}} = I^{\\text{up/down}} / I^{\\text{nom}}`), and :math:`N_i(\\alpha)` is a morphing of the normalized template shapes defined as: .. math:: & N^{\\text{up}}_i + (\\alpha - 1) ( N^{\\text{up}}_i - N^{\\text{nom}}_i ) & \\text{ if } \\alpha > 1 \\\\ & N^{\\text{nom}}_i + \\frac{\\alpha}{2} \\left( N_{i}^{\\text{up}} - N_{i}^{\\text{down}} \\right) + \\frac{1}{16} ( 3 \\alpha^6 - 10 \\alpha^4 + 15 \\alpha^2 ) \\left( N_{i}^{\\text{up}} + N_{i}^{\\text{down}} - 2 N_{i}^{\\text{nom}} \\right) & \\text{ if } |\\alpha| \\leq 1 \\\\ & N^{\\text{down}}_i - (\\alpha + 1) ( N^{\\text{down}}_i - N^{\\text{nom}}_i ) & \\text{ if } \\alpha < -1 \\\\ The resulting interpolation satisfies :math:`H_i(1) = H^{\\text{up}}_i`, :math:`H_i(-1) = H^{\\text{down}}_i` and :math:`H_i(0) = H^{\\text{nom}}_i`, and its second derivative is continuous. Note that what is actually interpolated is a factor :math:`F_i(\\alpha) = H_i(\\alpha) / H^{\\text{nom}}_i`. Those factors for all the systematic sources of the model are multiplied together and with the (fixed) nominal yields :math:`H^{\\text{nom}}_i` to define the "predicted" yields. :param var: the :py:class:`Variable` nuisance parameter corresponding to this uncertainty :param up: the "up" variation (:math:`+1\\sigma`) :param down: the "down" variation (:math:`-1\\sigma`) :param sub_vars: list of sub-variable names in case ``var`` is vector-valued but only a subset of its components is to be used as nuisance parameters :param sub_procs: list of sub-process names in case this uncertainty only affects a subset of the sub-processes in this channel """ if sub_procs is None: sub_procs = self.sub_procs else: assert(all(c in self.sub_procs for c in sub_procs)) self.systematics.append(LazyNormedSplineShapeUnc(self.yields, up, down, var, sub_vars, sub_procs))
[docs] def add_lnN(self, var: Variable, lnN: Union[float, Tuple[float, float], np.ndarray, Tuple[np.ndarray, np.ndarray]], sub_vars: Optional[Union[List[str], List[int]]] = None, sub_procs: Optional[Union[List[str], List[int]]] = None): """Add a log-normal systematic uncertainty in this channel The log-normal uncertainties can be specified in several way depending on the shape of the ``lnN`` argument: * scalar: single uncertainty source (i.e. either ``var`` is scalar or ``sub_vars`` has length one), same uncertainty for all sub-processes (if relevant). * ``(n_sub_procs,)``: single uncertainty source (i.e. either ``var`` is scalar or ``sub_vars`` has length one), different uncertainty for each sub-process. The length should match the number of sub-processes of the process or the length of specified ``sub_procs``. * ``(n_var,)``: separate uncertainty sources, same uncertainties for all sub-processes. The length should match the dimensionality of ``var`` or the length of specified ``sub_vars``. * ``(n_var, n_sub_procs)``: separate uncertainty sources, different uncertaintiy for each sub-process (same rules as above apply). All of those above options can be passed directly as argument, in which case the uncertainty will be a symmetric log-normal, or as a tuple ``(up, down)``, in which case the uncertainty will be an asymmetric log-normal. Given the (Gaussian-constrained) nuisance parameter :math:`\\alpha` for a particular uncertainty source, the process yields are multiplied by a factor :math:`K(\\alpha)^{\\alpha}` where :math:`K(\\alpha)` is defined as: .. math:: & K_{\\text{up}} & \\text{ if } \\alpha > 1 \\\\ & \\left( \\frac{\\alpha}{4} (3 - \\alpha^2 ) (K_{\\text{up}}-K_{\\text{down}}) + \\frac12 (K_{\\text{up}}+K_{\\text{down}}) \\right) & \\text{ if } |\\alpha| \\leq 1 \\\\ & K_{\\text{down}} & \\text{ if } \\alpha < -1 This inter-/extrapolation is such that :math:`K(1) = K_{\\text{up}}` and :math:`K(-1) = K_{\\text{down}}`, and its second derivative is continuous. Note that this assumes that :math:`K_{\\text{down}} = \\text{up}/\\text{nominal}`, and :math:`K_{\\text{up}} = \\text{nominal}/\\text{up}`, so that a 3\% up and 5\% down asymmetric uncertainty is given as :math:`(K_{\\text{up}}, K_{\\text{down}}) = (1.03, 1.05)`. Conversely, a -3\% "up" and -5\% down uncertainty is given as :math:`(K_{\\text{up}}, K_{\\text{down}}) = (0.97, 0.95)`. Specifying e.g. :math:`(K_{\\text{up}}, K_{\\text{down}}) = (1.03, 0.95)` will however result in a one-sided variation, which is typically not what is desired. :param var: the :py:class:`Variable` nuisance parameter corresponding to this uncertainty :param lnN: specification of the log-normal uncertainties :param sub_vars: list of sub-variable names in case ``var`` is vector-valued but only a subset of its components is to be used as nuisance parameters :param sub_procs: list of sub-process names in case this uncertainty only affects a subset of the sub-processes in this channel """ # - n_sub_procs=1, var.dim=1 -> lnN is scalar # - n_sub_procs=1, var.dim>1 -> lnN is [var.dim] (sources) # - n_sub_procs>1, var.dim=1 -> lnN can be # - scalar (same source and lnN for all procs) --> if restrict subprocs, make it [n_sub_procs] # - [n_sub_procs] # - n_sub_procs>1, var.dim>1 -> lnN can be # - [var.dim] (each source is same lnN for all procs) --> if restrict subprocs, make it [var.dim]x[n_sub_procs] # - [var.dim]x[n_sub_procs] --> if restrict subprocs, change lnN if sub_procs is None: sub_procs = self.sub_procs else: assert(all(c in self.sub_procs for c in sub_procs)) self.systematics.append(LazyLogNormal(lnN, self.yields.shape, var, sub_vars, sub_procs))
def _insert_missing_sub_procs(self, to_list): """Insert subprocesses with zero yields into the vertical axis of `yields`. `from_list` is the list of subprocesses along the vertical axis of `yields`; `to_list` is the target list of subprocesses. """ def insert(yields): if self.sub_procs != to_list: new_yields = np.zeros((len(to_list), yields.shape[-1])) for i,sub_proc in enumerate(to_list): if sub_proc in self.sub_procs: j = from_list.index(sub_proc) new_yields[i,:] = yields[j,:] return new_yields else: return yields self.yields = insert(self.yields) self.sumw2 = insert(self.sumw2) for s in self.systematics: s.expand_sub_procs(to_list, self.yields)
[docs]class Process: """ Represents a process entering the analysis. A process is identified by its name. A process can be a "simple" process, or it can be composed of "sub-processes", when it makes physical sense to consider several processes together. For instance, when using an EFT, the sub-processes could correspond to the various contributions of the EFT expansion (SM prediction, interferences, squared/cross terms), or the points of a morphing basis. Another use case would be differential measurements (unfolding), where each generator-level bin of the measured distribution is treated as a different signal. These various signals can then be naturally represented as a single :py:class:`Process` object. Using sub-processes instead of several simple processes in those cases is not mandatory, but makes it easier (and more efficient) to specify how the process components are affected by the parameters of the model (which can be vectors!). The actual predicted yields of the process, or the systematic uncertainties that affect it, are specified through :py:class:`ChannelContrib` objects. """
[docs] def __init__(self, name: str, sub_procs: Optional[List[str]] = None): """ Constructor :param name: name of the process - should be unique among all the processes linked to a :py:class:`Model` object. :param sub_procs: names of the sub-processes - also defines how many there are (default is "simple" process, i.e. no sub-components) """ self.name = name self.sub_procs = sub_procs if sub_procs is not None else [0] self.n_sub_procs = len(self.sub_procs) self.channel_contribs = {} self.scaled_by = []
[docs] def pred(self, values: jnp.DeviceArray) -> Tuple[jnp.DeviceArray, jnp.DeviceArray]: """ Compute the yields of the process given the parameter values. .. note:: This method is only available after the model to which this process has been assigned has been prepared! :param values: 1D array with the parameter values in the order expected by the compiled model, e.g. as returned by :py:meth:`Model.values_from_dict` :returns: a tuple of two 2D ``jnp.DeviceArray`` where the entries ``(j,:)`` in the first (second) array contain the predicted yields (sumw2), across all channels, of the sub-process with index ``j`` (if the process has no sub-processes, this axis has only one entry) """ pass
[docs] def batch_pred(self, values: jnp.DeviceArray) -> Tuple[jnp.DeviceArray, jnp.DeviceArray]: """ Compute the yields of the process using a batch of parameter values. .. note:: This method is only available after the model to which this process has been assigned has been prepared! :param values: 2D array where the first axis is the batch dimension, i.e. every row specifies parameter values in the order expeced by the compiled model :returns: a tuple of two 3D ``jnp.DeviceArray`` where the entries ``(i,j,:)`` in the first (second) array contain the predicted yields (sumw2), across all channels, of the sub-process with index ``j`` given the parameter values in row ``i`` of ``values`` (if the process has no sub-processes, this axis has only one entry) """ pass
[docs] def add_contrib(self, channel_contrib: ChannelContrib): """Register a channel contribution for this process :param channel_contrib: a :py:class:`ChannelContrib` object """ if not isinstance(channel_contrib, ChannelContrib): raise ValueError("Expected a ChannelContrib object") assert(channel_contrib.name not in self.channel_contribs) assert(all(sp in self.sub_procs for sp in channel_contrib.sub_procs)) self.channel_contribs[channel_contrib.name] = channel_contrib
[docs] def scale_by_fn(self, fn: Callable[[jnp.DeviceArray], jnp.DeviceArray], variables: Union[Variable, Tuple[Variable]]): """Scale the yields of this process by an arbitrary function The input variables to the function are passed as positional arguments. Each argument is a 1D ``jnp.DeviceArray`` with the same shape as the corresponding :py:class:`Variable`, except if the latter is a scalar, in which case the argument would have shape ``(1,)``. The return value of the function should be broadcastable (with standard broadcasting rules) to a 2D shape ``(number of sub-processes, number of bins across all channels)``. This means that: * a scalar or array with shape ``(1,)`` or ``(1, 1)`` will be broadcast to all sub-processes and bins * an array with shape ``(n_sub_procs, 1)`` will scale each sub-process by a different value, but uniformly across all bins * an array with shape ``(1, n_bins)`` will scale each bin by a different value, but uniformly across all sub-processes * an array with shape ``(n_sub_procs, n_bins)`` will scale each sub-process and each bin by a different value :param fn: a callable returning the factors by which the yields should be scaled :param variables: a :py:class:`Variable` or a tuple of :py:class:`Variable` objects whose values will be passed to ``fn`` """ if isinstance(variables, Variable): variables = (variables,) self.scaled_by.append(PaddedFunction(fn, variables))
[docs] def scale_by(self, var: Variable): """Register a :py:class:`Variable` as a linear yield modifier (e.g. signal strength modifier) If the parameter is a scalar, the yields in all bins and sub-processes of this process will be scaled linearly by the parameter. If the parameter is vector-valued, the dimensionality of the variable should match the number of sub-processes of this process. Each sub-process will then be scaled linearly by a single component of the variable. :param var: a :py:class:`Variable` scaling the yields of this process """ if var.dim > 1: assert(var.dim == self.n_sub_procs) # expand for broadcasting over bins and multiplying along sub-processes self.scale_by_fn(lambda x: jnp.expand_dims(x, -1), var)
[docs] def scale_bins_by(self, var: Variable): """Register a :py:class:`Variable` as a linear yield modifier across bins If the parameter is a scalar, the yields in all bins and sub-processes of this process will be scaled linearly by the parameter. If the parameter is vector-valued, the dimensionality of the variable should match the number of bins (across all channels) of this process. The yields in each bin will then be scaled linearly by a single component of the variable. :param var: a :py:class:`Variable` scaling the yields of this process """ # caution - no check possible if var matches the total number of bins self.scale_by_fn(lambda x: x, var)
def _merge_systematics(self, syst_type): # First collect all sources (=var+subvars) syst_sources = OrderedDict() # Variable -> set(sub_vars...) for chan in self.channel_contribs: for lazy_syst in chan.systematics: if not isinstance(lazy_syst, syst_type): continue var = lazy_syst.var sub_vars = syst_sources.get(var, set()) if lazy_syst.sub_vars is not None: for subN in lazy_syst.sub_vars: sub_vars.add(subN) else: # no sub-var specified by systematic -> insert all the sub-vars of the variable for subN in var.sub_names: sub_vars.add(subN) syst_sources[var] = sub_vars # Convert to a dict of list, sort according to original order of sub-vars within the Variable for v in syst_sources.keys(): syst_sources[v] = list(syst_sources[v]) or [0] # default to [0] if scalar variable (no sub_vars) syst_sources[v].sort(key=lambda sv: v.sub_names.index(sv)) # Then expand and merge merged_lazy_systs_chans = {} for chan in self.channel_contribs: merged_sub_var_systs = {} # Variable -> lazy systematic using that Variable (merging as we go along) for lazy_syst in chan.systematics: if not isinstance(lazy_syst, syst_type): continue if lazy_syst.var not in merged_sub_var_systs: merged_sub_var_systs[lazy_syst.var] = lazy_syst else: # Merge all the systematics using the same variable (stack the sub-variables appropriately) merged_sub_var_systs[lazy_syst.var].merge_sub_vars(lazy_syst) for lazy_syst in merged_sub_var_systs.values(): # Expand all the sub-variables to the full list of sub-variables used for each variable lazy_syst.expand_sub_vars() merged_lazy_systs_chans[chan] = merged_sub_var_systs merged_lazy_systs = {} # Variable -> lazy syst merged across channels for var,sub_vars in syst_sources.items(): merged_lazy_syst = None for chan in self.channel_contribs: to_merge = merged_lazy_systs_chans[chan].get(var, None) # merge systematics along bins/channels axis, insert dummy systematic if none is present for a channel, for the current Variable merged_lazy_syst = syst_type.merge_channels(merged_lazy_syst, to_merge, chan.yields, var, sub_vars, self.sub_procs) merged_lazy_systs[var] = merged_lazy_syst return merged_lazy_systs def _prepare(self, channel_list): # Insert sub-process with zero yields for each sub-process of this Process which doesn't contribute to that ChannelContrib # (happens only if there is more than 1 registered sub-process) for c in self.channel_contribs.values(): c._insert_missing_sub_procs(self.sub_procs) # Insert channel with zero yields everywhere this process doesn't contribute prepared_channel_contribs = [] for chan in channel_list: contrib = self.channel_contribs.get(chan.name, None) if contrib is None: contrib = ChannelContrib(chan.name, np.zeros((self.n_sub_procs, chan.n_bins)), sub_procs=self.sub_procs) assert(contrib.n_bins == chan.n_bins) prepared_channel_contribs.append(contrib) self.channel_contribs = prepared_channel_contribs # now it's a list, with same order as the channel list of the model self.yields = np.concatenate([c.yields for c in self.channel_contribs], axis=-1) self.sumw2 = np.concatenate([c.sumw2 for c in self.channel_contribs], axis=-1) # Expand (across channels) and regroup (across sub-vars) uncertainties systematic_types = set(type(s) for s in c.systematics for c in self.channel_contribs if isinstance(s, LazySystematic)) self.merged_systematics = {} # syst_type -> Variable -> syst for syst_type in systematic_types: self.merged_systematics[syst_type] = self._merge_systematics(syst_type) for chan in self.channel_contribs: del chan.systematics for contrib in self.channel_contribs: chan_idxs = next(c for c in channel_list if c.name == contrib.name).sub_idxs for fn in contrib.scaled_by: fn.notify_chan_idxs(chan_idxs)
[docs]class MergedProcess:
[docs] def __init__(self, processes, channels): self.processes = processes channel_idxs = [ c.sub_idxs for c in channels ] global_proc_idx = 0 self.padded_fns = [] for p in self.processes: idxs = jnp.arange(global_proc_idx, global_proc_idx + len(p.sub_procs)) p.sub_idxs = idxs def proc_pred(values, proc): pred,sumw2 = self.pred(values) return pred[proc.sub_idxs],sumw2[proc.sub_idxs] p.pred = partial(proc_pred, proc=p) p.batch_pred = vmap(partial(proc_pred, proc=p), 0, 0) for fn in p.scaled_by: fn.notify_proc_idxs(idxs) self.padded_fns.append(fn) for chan in p.channel_contribs: for fn in chan.scaled_by: fn.notify_proc_idxs(idxs) self.padded_fns.append(fn) global_proc_idx += len(idxs) # we could create jax arrays here but it looks like memory usage is lower when they remain in numpy before jitting self.yields = np.concatenate([p.yields for p in self.processes], axis=0) self.sumw2 = np.concatenate([p.sumw2 for p in self.processes], axis=0) # First collect all the variables across all processes for a given systematic type merged_lazy_syst_vars = {} # syst type -> used Variables for p in self.processes: for syst_type,vars_systs in p.merged_systematics.items(): var_set = merged_lazy_syst_vars.get(syst_type, set()) var_set = var_set.union(vars_systs.keys()) merged_lazy_syst_vars[syst_type] = var_set for syst_type in list(merged_lazy_syst_vars.keys()): merged_lazy_syst_vars[syst_type] = list(merged_lazy_syst_vars[syst_type]) # Then merge the systematics, for a given type, across processes and variables self.systematics = [] for syst_type,variables in merged_lazy_syst_vars.items(): merged_systs = [] for v in variables: for p in self.processes: if syst_type in p.merged_systematics and v in p.merged_systematics[syst_type]: merged_systs.append(p.merged_systematics[syst_type][v]) else: # pass the channel indices here to make sure we compute yield sums separately for each channel merged_systs.append(syst_type.dummy(p.yields, v, v.sub_names, p.sub_procs, channel_idxs=channel_idxs)) concrete_syst = syst_type.build_concrete_from_list(merged_systs, self.yields, variables) self.systematics.append(concrete_syst) self.batch_pred = vmap(self.pred, 0, 0)
[docs] def pred(self, values: jnp.DeviceArray) -> Tuple[jnp.DeviceArray, jnp.DeviceArray]: """ Compute the yields of the processes given the parameter values. .. note:: This method does not include the effect of the BB-lite parameters! use :py:meth:`Model.pred` for that. :param values: 1D array with the parameter values in the order expected by the compiled model, e.g. as returned by :py:meth:`Model.values_from_dict` :returns: a tuple of two 2D ``jnp.DeviceArray`` where the entries ``(j,:)`` in the first (second) array contain the predicted yields (sumw2), across all channels, of the process with index ``j`` """ factors = jnp.ones(self.yields.shape) for fn in self.systematics: factors = fn(values) * factors for fn in self.padded_fns: factors = fn(values, factors) yields = factors * self.yields sumw2 = factors * self.sumw2 return yields,sumw2
[docs] def batch_pred(self, values: jnp.DeviceArray) -> Tuple[jnp.DeviceArray, jnp.DeviceArray]: """ Compute the yields of the processes using a batch of parameter values. .. note:: This method does not include the effect of the BB-lite parameters! use :py:meth:`Model.pred` for that. :param values: 2D array where the first axis is the batch dimension, i.e. every row specifies parameter values in the order expeced by the compiled model :returns: a tuple of two 3D ``jnp.DeviceArray`` where the entries ``(i,j,:)`` in the first (second) array contain the predicted yields (sumw2), across all channels, of the process with index ``j`` given the parameter values in row ``i`` of ``values`` """ pass
class Channel: def __init__(self, name, sub_idxs): self.name = name self.n_bins = len(sub_idxs) self.sub_idxs = sub_idxs
[docs]class Model: """Main class defining the statistical model (likelihood function), fitting etc. """
[docs] def __init__(self, do_bblite: bool = False, seed: int = 0): """Constructor :param do_bblite: Enable the computation of bin-by-bin Monte-Carlo statistical uncertainties using the Barlow-Beeston lite method. For more details see :py:meth:`Model.enable_bblite` :param seed: Initial seed for random toy generation. See `here <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers>`__ for more information about random numbers in jax """ self.procs = [] self.vars = [] self.channels = [] self.add_constraints = [] self.correlations = [] self.prepared = False self.do_bblite = do_bblite self.key = random.PRNGKey(seed) # This won't trigger JITting or grad tracing yet... self.nll_grad = jit(grad(self.nll_fast, argnums=0)) self.nll_hess = hessian(self.nll_fast, argnums=0) self.batch_pred = vmap(self.pred, 0, 0) self.batch_nll = vmap(self.nll, (0, None, None), 0)
[docs] def add_proc(self, proc: Process): """Register a :py:class:`Process` with the model Further modifying the process (e.g. adding channel contributions etc.) is still possible until :py:meth:`Model.prepare` is called. """ self.procs.append(proc)
[docs] def add_constraint(self, fn: Callable[[jnp.DeviceArray, jnp.DeviceArray, jnp.DeviceArray], jnp.DeviceArray]): """Register additional constraint function The constraint function should return a scalar, and takes as argument three jax arrays (obviously, not all of them have to be used): * the parameters of the model (concatenated as a single 1D array) * the observed yields across all channels (concatenated as a single 1D array) * the vector of global observables Sign convention: the constraint is **added** to the **log-likelihood**. Any number of constraints can be added, but it is more effective if it is possible to express a single function in array form. :param fn: the constraint function """ self.add_constraints.append(fn)
[docs] def correlate_nuisances(self, v1: Union[Variable, Tuple[Variable, int]], v2: Union[Variable, Tuple[Variable, int]], corr: float): """Specify the correlation between a pair of nuisance parameters :param v1: either a :py:class:`Variable` if the variable is a scalar, or a tuple ``(var, index)`` specifying the index of the component of ``var`` to be correlated if ``var`` is vector-valued :param v2: as ``v1`` :param corr: the correlation coefficient, should lie in the interval ``]-1, 1[`` """ if isinstance(v1, Variable): assert(v1.dim == 1) v1 = (v1, 0) elif isinstance(v1, tuple): assert(v1[1] < v1[0].dim) else: raise ValueError("not supported") if isinstance(v2, Variable): assert(v2.dim == 1) v2 = (v2, 0) elif isinstance(v2, tuple): assert(v2[1] < v2[0].dim) else: raise ValueError("not supported") assert(-1 < corr < 1) self.correlations.append((v1, v2, corr))
[docs] def enable_bblite(self): """Enable the computation of bin-by-bin Monte-Carlo statistical uncertainties using the Barlow-Beeston lite method When preparing the model, ``n`` Gaussian nuisance parameters will be added to the model, where ``n`` is the number of bins across all channels. The corresponding :py:class:`Variable` is called ``bblite`` and has ``n`` components. The (Gaussian) uncertainty in each bin is equal to ``sqrt(sumW2)``, where ``sumW2`` is the sum of the sums of squared weight entries for that bin across all processes. If a process has no registered ``sumw2`` array, it won't contribute to the uncertainty (i.e. it has infinite statistics). """ self.do_bblite = True
[docs] def prepare(self): """Prepare the model for inference This collects all the parameters, processes and channels that have been registered. A global parameter array is built, to be passed to the likelihood function as a single argument. Yield arrays for every process and channel are concatenated into a single large array, within a :py:class:`MergedProcess` object accessible though the attribute :py:attr:`Model.merged_process`. Systematic uncertainties are concatenated across sources, processes and channels. The nuisance parameter covariance matrix is also built, accessible through the attribute :py:attr:`Model.nuisance_cov_matrix`. """ # total number of bins across all channels self.n_bins = 0 # Collect all channels from processes for proc in self.procs: for name,contrib in proc.channel_contribs.items(): channel = next((c for c in self.channels if c.name == name), None) if channel is None: sub_idxs = jnp.arange(self.n_bins, self.n_bins + contrib.n_bins) channel = Channel(contrib.name, sub_idxs) self.channels.append(channel) self.n_bins += channel.n_bins assert(contrib.n_bins == channel.n_bins) self.n_channels = len(self.channels) # Now that we have the full channel list we can prepare and merge processes for proc in self.procs: proc._prepare(self.channels) self.merged_process = MergedProcess(self.procs, self.channels) # Only now can we gather all the variables for v in chain.from_iterable(fn.variables for fn in self.merged_process.padded_fns + self.merged_process.systematics): if hasattr(v, "sub_idxs") or hasattr(v, "idx"): raise RuntimeError(f"Variable {v.name} already belongs to a compiled model!") if v not in self.vars: self.vars.append(v) # Also add BB-lite nuisance parameters if self.do_bblite: bblite_names = [ f"{chan.name}_bin{binno}" for chan in self.channels for binno in range(chan.n_bins) ] self.bblite_vars = Variable("bblite", jnp.zeros(self.n_bins), sub_names=bblite_names, nuisance=True) self.vars.append(self.bblite_vars) # Build arrays of indices corresponding to each variable self.n_var = len(self.vars) self.dim_var = 0 self.var_sub_indices = [] self.nuisances = [] self.nuisance_sub_indices = [] self.nuisance_only_sub_indices = [] self.dim_nuisances = 0 for i,v in enumerate(self.vars): idxs = jnp.arange(self.dim_var, self.dim_var + v.dim) v.idx = i v.sub_idxs = idxs self.var_sub_indices.append(idxs) self.dim_var += v.dim if v.nuisance: self.nuisances.append(v) self.nuisance_sub_indices.append(idxs) idxs_nuis_only = jnp.arange(self.dim_nuisances, self.dim_nuisances + v.dim) self.nuisance_only_sub_indices.append(idxs_nuis_only) self.dim_nuisances += v.dim if self.do_bblite: self.bb_sub_indices = self.bblite_vars.sub_idxs # Save for each scaling function which indices the corresponding variables use for fn in self.merged_process.systematics + self.merged_process.padded_fns: fn.build_idxs() # Define bounds for variables with lower or upper bounds lower_bounds = - np.inf * np.ones(self.dim_var) upper_bounds = np.inf * np.ones(self.dim_var) for v in self.vars: if v.lower_bound is not None: lower_bounds[to_np(v.sub_idxs)] = v.lower_bound if v.upper_bound is not None: upper_bounds[to_np(v.sub_idxs)] = v.upper_bound self.var_bounds = Bounds(lower_bounds, upper_bounds) # Build the nuisance parameter covariance matrix if self.dim_nuisances > 0: self.flat_nuisance_sub_indices = jnp.concatenate(self.nuisance_sub_indices, axis=-1) self.n_nuisances = len(self.nuisances) nuisance_cov_matrix = np.eye(self.dim_nuisances) for (v1,i1),(v2,i2),corr in self.correlations: i1 = self.nuisance_only_sub_indices[self.nuisances.index(v1)][i1] i2 = self.nuisance_only_sub_indices[self.nuisances.index(v2)][i2] nuisance_cov_matrix[i1,i2] = corr nuisance_cov_matrix[i2,i1] = corr self.nuisance_cov_matrix = jnp.asarray(nuisance_cov_matrix) self.prepared = True
[docs] def compile(self, compile_pred: bool = False, compile_hess: bool = False): """JIT-compile the NLL function and the analytical NLL gradient .. note:: This method is only available after the model has been prepared! Calling this function is not strictly necessary - jitting and gradient tracing will happen anyway the first time the NLL, the gradient or the Hessian are called. This function merely allows to "get it done" all at once. :param compile_hess: Also compile :py:meth:`Model.pred`, which can be useful if you expect to call it very often. However using :py:meth:`Model.batch_pred` instead is probably better. :param compile_hess: Also compile the Hessian matrix. This is typically quite slow, and takes at least as much time as computing a non-JITed Hessian once. Since the Hessian is typically only computed once (on the best-fit point), JITing is not done by default and only useful when you expect to call the Hessian very often (for instance, for studying coverage properties of intervals using toys). """ default_values = jnp.concatenate([v.val for v in self.vars]) if compile_pred: print("Compiling pred") self.pred = jit(self.pred) default_pred = self.pred(default_values).block_until_ready() else: default_pred = self.pred(default_values) default_glob = self.default_glob() print("Compiling NLL") self.nll_fast(default_values, default_pred, default_glob).block_until_ready() print("Compiling NLL grad") self.nll_grad(default_values, default_pred, default_glob).block_until_ready() if compile_hess: print("Compiling Hessian") self.nll_hess = jit(self.nll_hess) self.nll_hess(default_values, default_pred, default_glob).block_until_ready()
[docs] def var_index(self, var: Variable) -> int: """Return the index of a :py:class:`Variable` within the compiled model .. note:: This method is only available after the model has been prepared! :param var: the :py:class:`Variable` :returns: the index of the variable """ return self.vars.index(var)
[docs] def var_sub_index(self, var: Variable) -> jnp.DeviceArray: """Return the sub-indices of a :py:class:`Variable` within the global parameter array of the compiled model .. note:: This method is only available after the model has been prepared! :param var: the :py:class:`Variable` :returns: the array of indices of the variable's components within the global parameter array of the compiled model """ var_idx = self.var_index(var) return self.var_sub_indices[var_idx]
[docs] def channel_sub_idxs(self, channel_name: str) -> jnp.DeviceArray: """Return the indices of a channel in the yield arrays The yield arrays of different channels of a process are concatenated. This function returns the indices of a given channel along that concatenated axis. Example: two channels are attached to the model, one (called ``SR``) with 2 bins, the other (called ``CR``) with 3 bins. Then ``model.channel_sub_idxs("SR")`` might return ``[0,1]`` or ``[3,4]``, depending on the order in which the channels have been concatenated when the model is prepared. .. note:: This method is only available after the model has been prepared! :param channel_name: the name of the channel :returns: the array of indices of that channel """ chan = next(c for c in self.channels if c.name == channel_name) return chan.sub_idxs
[docs] def values_from_dict(self, values: Dict[Variable, Union[float, List[float], np.ndarray, jnp.DeviceArray]] = None) -> jnp.DeviceArray: """Build a global parameter vector Build a 1D array with the values of all of the model's parameters in the expected order. By default, the initial values of the :py:class:`Variable` objects are used, but these can be overridden using the ``values`` argument. This will not modify the "set" values of the model's :py:class:`Variable` objects. .. note:: This method is only available after the model has been prepared! :param values: a dictionary where the keys are :py:class:`Variable` objects and the values are the values to use for the corresponding variable :returns: a 1D ``jnp.DeviceArray`` with all the parameter values in the order expected by the model """ values_by_index = [] if values is None: values = {} for var in self.vars: val = to_jnp(values.get(var, var.val)) if val.shape != var.val.shape: raise ValueError(f"Unexpected shape given for variable {var.name}: {val.shape}") values_by_index.append(val) return jnp.concatenate(values_by_index)
[docs] def default_glob(self) -> jnp.DeviceArray: """Return default vector of global observables .. note:: This method is only available after the model has been prepared! :returns: a 1D ``jnp.DeviceArray`` with the default vector of global observables corresponding to the nuisance parameters in the model (= all zeros) """ return jnp.zeros(shape=(self.dim_nuisances,))
[docs] def toy_pred(self, values: jnp.DeviceArray, nToys: int = 1) -> jnp.DeviceArray: """Generate random toys for the predicted yields .. note:: This method is only available after the model has been prepared! The toys are drawn from Poisson distributions with means given by :py:meth:`Model.pred` evaluated on ``values``. :param values: a 1D ``jnp.DeviceArray`` with the global parameter values used to generate the toy :param nToys: the number of toys to generate :returns: a 2D ``jnp.DeviceArray`` where each row is a toy, i.e. the yields in all bins across all channels """ self.key,subkey = random.split(self.key) nPred = self.pred(values) toys = random.poisson(subkey, nPred, shape=(nToys,len(nPred))) # convert to float to avoid retracings of the NLL for float/int return jnp.array(toys, dtype=jnp.float64)
[docs] def toy_glob(self, values: jnp.DeviceArray, nToys: int = 1) -> jnp.DeviceArray: """Generate random toys for the global observables associated with the model's nuisance parameters .. note:: This method is only available after the model has been prepared! The toys are drawn from a multivariate normal distribution with means given by the values of the nuisance parameters in ``values`` and with covariance matrix fixed from the model's :py:attr:`Model.nuisance_cov_matrix` attribute. :param values: a 1D ``jnp.DeviceArray`` with the parameter values used to generate the toy :param nToys: the number of toys to generate :returns: a 2D ``jnp.DeviceArray`` where each row contains a generated vector of global observables """ self.key,subkey = random.split(self.key) if self.dim_nuisances > 0: means = values[self.flat_nuisance_sub_indices] toys = random.multivariate_normal(subkey, means, cov=self.nuisance_cov_matrix, shape=(nToys,)) else: toys = jnp.zeros((nToys, 0)) # convert to float to avoid retracings of the NLL for float/int return jnp.array(toys, dtype=jnp.float64)
[docs] def concat_obs(self, nObs: Union[np.ndarray, jnp.DeviceArray, Dict[str, np.ndarray], Dict[str, jnp.DeviceArray]]) -> jnp.DeviceArray: """Concatenate observed yields across all channels in the order expected by the prepared model .. note:: This method is only available after the model has been prepared! :param nObs: either a 1D array with the observed yields in all bins and channels in the order expected by the prepared model, or a dictionary where the keys are the channel names and the values are 1D arrays with the observed yields in the corresponding channel :returns: a 1D ``jnp.DeviceArray`` with the observed yields in all bins and channels in the order expected by the prepared model """ if isinstance(nObs, dict): return jnp.concatenate([to_jnp(nObs[c.name]) for c in self.channels], axis=-1) return to_jnp(nObs)
[docs] def split_obs(self, nObs: Union[np.ndarray, jnp.DeviceArray]) -> Dict[str, jnp.DeviceArray]: """Split prediction array into channels Split a full yields array, returned e.g. by :py:meth:`model.pred`, into the different channels. .. note:: This method is only available after the model has been prepared! :param nObs: array with the observed yields in all bins and channels (1D or 2D, in the case of batch predictions) :returns: a dictionary with the channel names as keys, and the yields in the channels (as 1D or 2D ``jnp.DeviceArray``) as values """ return { c.name: np.take(nObs, self.channel_sub_idxs(c.name), axis=-1) for c in self.channels }
[docs] def get_bblite_errs(self, values: jnp.DeviceArray, sumw2: jnp.DeviceArray) -> jnp.DeviceArray: """Compute the contribution to the predicted yields due to the fluctuated bin-by-bin statistics .. note:: This method is only available after the model has been prepared! :param values: a 1D ``jnp.DeviceArray`` with the parameter values :param sumw2: a 1D ``jnp.DeviceArray`` with the sums of the sumw2 arrays in each bin :returns: a 1D ``jnp.DeviceArray`` with the contribution of the BB-lite parameters to the yields in each bin """ return values[self.bb_sub_indices] * jnp.where(sumw2 > 0., jnp.sqrt(jnp.where(sumw2 > 0., sumw2, 1.)), 0.)
[docs] def pred(self, values: jnp.DeviceArray) -> jnp.DeviceArray: """ Compute the summed yields of all processes given the parameter values, including the effect of the BB-lite parameters (if activated). .. note:: This method is only available after the model has been prepared! :param values: 1D array with the parameter values in the order expected by the prepared model, e.g. as returned by :py:meth:`Model.values_from_dict` :returns: a 1D ``jnp.DeviceArray`` containing the predicted yields across all channels and bins """ if not do_jit: values = to_jnp(values) yields,sumw2 = self.merged_process.pred(values) yields = jnp.sum(yields, axis=0) if self.do_bblite: sumw2 = jnp.sum(sumw2, axis=0) yields += self.get_bblite_errs(values, sumw2) return yields
[docs] def batch_pred(self, values: jnp.DeviceArray) -> jnp.DeviceArray: """ Compute the summed yields of all processes using a batch of parameter values, including the effect of the BB-lite parameters (if activated). .. note:: This method is only available after the model has been prepared! :param values: 2D array where the first axis is the batch dimension, i.e. every row specifies parameter values in the order expected by the prepared model :returns: a 2D ``jnp.DeviceArray`` where the entries ``(i,:)`` contain the predicted yields, across all channels and bins, given the parameter values in row ``i`` of ``values`` """ pass
[docs] def nll(self, values: Union[np.ndarray, jnp.DeviceArray], nObs: Union[np.ndarray, jnp.DeviceArray, Dict[str, np.ndarray], Dict[str, jnp.DeviceArray]], nGlob: Optional[Union[np.ndarray, jnp.DeviceArray]] = None) -> jnp.DeviceArray: """Evaluate the negative log-likelihood (NLL) of the model This is a convenience wrapper of the jitted method :py:meth:`Model.nll_fast` .. note:: This method is only available after the model has been prepared! :param values: 1D array with the parameter values in the order expected by the prepared model, e.g. as returned by :py:meth:`Model.values_from_dict` :param nObs: either a 1D array with the observed yields in all bins and channels in the order expected by the prepared model, or a dictionary where the keys are the channel names and the values are 1D arrays with the observed yields in the corresponding channel :param nGlob: a 1D array with the vector of global observables associated with the nuisance parameters. If not specified, it is generated from :py:meth:`Model.default_glob`. :returns: a scalar ``jnp.DeviceArray`` with the NLL """ nObs = self.concat_obs(nObs) if nGlob is None: nGlob = self.default_glob() else: nGlob = to_jnp(nGlob) values = to_jnp(values) return self.nll_fast(values, nObs, nGlob)
[docs] def batch_nll(self, values, nObs, nGlob): pass
[docs] def nll_grad(self, values, nObs, nGlob): pass
[docs] def nll_hess(self, values, nObs, nGlob): pass
[docs] @partial(jit, static_argnums=(0,)) def nll_fast(self, values: jnp.DeviceArray, nObs: jnp.DeviceArray, nGlob: jnp.DeviceArray) -> jnp.DeviceArray: """Evaluate the negative log-likelihood (NLL) of the model This method expects all the inputs to be properly ordered ``jnp.DeviceArray`` arrays. Use :py:meth:`Model.nll` for a more convenient interface. .. note:: This method is only available after the model has been prepared! :param values: 1D ``jnp.DeviceArray`` with the parameter values in the order expected by the prepared model :param nObs: a 1D ``jnp.DeviceArray`` with the observed yields in all bins and channels in the order expected by the prepared model :param nGlob: a 1D ``jnp.DeviceArray`` array with the vector of global observables associated with the nuisance parameters :returns: a scalar ``jnp.DeviceArray`` with the NLL """ if do_jit and debug: print(f"Tracing NLL on {values}, {nObs}, {nGlob}") if not do_jit: values = to_jnp(values) nPred = self.pred(values) log_prob = jnp.sum(poisson.logpmf(nObs, nPred), axis=-1) if self.dim_nuisances > 0: log_prob += multivariate_normal.logpdf(nGlob, values[self.flat_nuisance_sub_indices], self.nuisance_cov_matrix) for c in self.add_constraints: log_prob += c(values, nObs, nGlob) return -log_prob
def _get_updated_bounds(self, freeze_params=None): if not freeze_params: return self.var_bounds ub, lb = self.var_bounds.ub.copy(), self.var_bounds.lb.copy() for v in freeze_params.keys(): if not isinstance(freeze_params[v], tuple): val,idxs = (freeze_params[v], np.arange(0, v.dim)) else: val,idxs = freeze_params[v] val,idxs = to_np(val),to_np(idxs) assert(len(val) == len(idxs)) for sub_val, sub_idx in zip(val, idxs): ub[v.sub_idxs[sub_idx]] = sub_val lb[v.sub_idxs[sub_idx]] = sub_val return Bounds(lb, ub)
[docs] def fit(self, nObs: Union[np.ndarray, jnp.DeviceArray, Dict[str, np.ndarray], Dict[str, jnp.DeviceArray]], nGlob: Optional[Union[np.ndarray, jnp.DeviceArray]] = None, init: Union[np.ndarray, jnp.DeviceArray] = None, store_hess: bool = True, freeze_params: Dict[Variable, Union[float, np.ndarray, jnp.DeviceArray, Tuple[Union[np.ndarray, jnp.DeviceArray], Union[List[int], np.ndarray, jnp.DeviceArray]]]] = None, method: str = "trust-constr", minimizer_opts: Optional[Dict[str, Any]] = None) -> OptimizeResult: """Fit the model to the data Minimize the NLL given the observed yields (``nObs``). The default minimizer used is the ``trust-constr`` method in scipy, see the documentation for the overall minimizer options `here <https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.OptimizeResult.html#scipy.optimize.OptimizeResult>`__ and for the ``trust-constr`` method `here <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-trustconstr.html>`__. Minimization starts from the default values of the parameters, and takes the bounds on the parameters into account. This method can also be used e.g. for profiled NLL scans (in which case the POIs are fixed to the scanned values), or for nuisance parameter impacts (to obtain a new best-fit point when fixing a nuisance parameter to a specific value). Specifying the profiled parameters is done using the ``freeze_params`` argument, which is a dictionary where the keys are :py:class:`Variable` objects and the values can be either: * a number or a 1D array (with the same number of components as the variable in the key) holding the value to which the parameter should be fixed in the fit * a tuple ``(val, idxs)``, which allows to fix only specific components of vector-valued parameters (schematically, fix ``var[idxs]=val`` in the fit) .. note:: This method is only available after the model has been prepared! :param nObs: either a 1D array with the observed yields in all bins and channels in the order expected by the prepared model, or a dictionary where the keys are the channel names and the values are 1D arrays with the observed yields in the corresponding channel :param nGlob: a 1D array with the vector of global observables associated with the nuisance parameters. If not specified, it is generated from :py:meth:`Model.default_glob`. :param init: specify initial point from which to minimize (if not specified, use default values of variables) :param store_hess: compute the analytical covariance matrix (inverse Hessian) at the minimum and store it as the ``hess_inv`` attribute of the returned object. :param freeze_params: specification of the parameters to keep fixed in a profiled fit (see above for details) :param method: name of the scipy minimizer, see full list of accepted methods `here <https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html#scipy.optimize.minimize>`__ :param minimizer_opts: a dictionary of options passed to the minimizer, see the scipy documentation. In particular, the ``maxiter`` field sets the maximum number of iterations. :returns: a ``scipy.optimize.OptimizeResult`` object (see `here <https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.OptimizeResult.html>`__) containing in particular the fitted parameter values (``x``), the NLL value at the minimum (``fun``), and a success flag (``success``). See the ``trust-constr`` documentation linked above for a full list of attributes. """ nObs = self.concat_obs(nObs) if nGlob is None: nGlob = self.default_glob() else: nGlob = to_jnp(nGlob) init = init if init is not None else jnp.concatenate([v.val for v in self.vars]) bounds = self._get_updated_bounds(freeze_params) result = minimize( self.nll_fast, init, args=(nObs, nGlob), jac=self.nll_grad, bounds=bounds, method=method, options=minimizer_opts, ) result.x = to_jnp(result.x) if store_hess: print("Computing covariance matrix...") result.hess_inv = self.cov(result.x, nObs, nGlob) return result
[docs] def cov(self, values: jnp.DeviceArray, nObs: Union[np.ndarray, jnp.DeviceArray, Dict[str, np.ndarray], Dict[str, jnp.DeviceArray]], nGlob: Optional[Union[np.ndarray, jnp.DeviceArray]] = None) -> jnp.DeviceArray: """Compute the fit covariance matrix analytically The order of the entries in the matrix follows the order of the variables in the model. .. note:: This method is only available after the model has been prepared! :param values: 1D array with the parameter values in the order expected by the prepared model, e.g. as returned by :py:meth:`Model.values_from_dict` :param nObs: either a 1D array with the observed yields in all bins and channels in the order expected by the prepared model, or a dictionary where the keys are the channel names and the values are 1D arrays with the observed yields in the corresponding channel :param nGlob: a 1D array with the vector of global observables associated with the nuisance parameters. If not specified, it is generated from :py:meth:`Model.default_glob`. :returns: the analytical covariance matrix (inverse Hessian) evaluated at ``values`` (as 2D ``jnp.DeviceArray``) """ nObs = self.concat_obs(nObs) if nGlob is None: nGlob = self.default_glob() else: nGlob = to_jnp(nGlob) return inv(self.nll_hess(values, nObs, nGlob))
[docs] def sample_from_covariance(self, best_fit: OptimizeResult, nSamples: int, respect_bounds: bool = True) -> jnp.DeviceArray: """Sample from the fit covariance matrix Generate a batch of parameter values distributed as a multivariate normal whose mean is the best fit value and whose covariance matrix is the fit covariance matrix as estimated from the inverse Hessian of the NLL at the minimum. .. note:: This method is only available after the model has been prepared! :param best_fit: the fit result as returned from :py:meth:`Model.fit`, it **must** contain the covariance matrix :param nSamples: number of samples to draw :param respect_bounds: enforce the bounds on the parameters (i.e. truncate the generated values to respect the bounds) :returns: a 2D ``jnp.DeviceArray`` where each row is a sampled parameter array (with the order of variables in the model) """ self.key,subkey = random.split(self.key) values = random.multivariate_normal(subkey, best_fit.x, cov=best_fit.hess_inv, shape=(nSamples,)) if respect_bounds: values = np.minimum(values, self.var_bounds.ub) values = np.maximum(values, self.var_bounds.lb) return values
[docs] def nll_crossings(self, direction: Union[np.ndarray, jnp.DeviceArray], best_fit: OptimizeResult, nObs: Union[np.ndarray, jnp.DeviceArray, Dict[str, np.ndarray], Dict[str, jnp.DeviceArray]], nGlob: Optional[Union[np.ndarray, jnp.DeviceArray]] = None, level: float = 0.5, init: float = 0.1, bracket: Tuple[float, float] = (0, 1), root_opts: Optional[Dict[str, Any]] = None) -> RootResults: """Find the level crossing of the NLL in a specific direction This function finds the solution ``s`` to ``NLL(x(s)) = NLL0 + level``, where ``x(s) = x0 + s * direction``, ``x0`` is the best-fit point, and ``NLL0 = NLL(x0)`` is the value of the NLL at the minimum. The parameter ``s`` is constrained to the range ``[bracket[0], bracket[1]]``, and the algorithm starts with ``a=init``. The ``bracket`` and ``direction`` arguments should be chosen to make sure that ``NLL[x(bracket[1])] >= level``. Bounds on the model's parameters are taken into account. This function can be used to find statistical-only uncertainties from likelihood scans, e.g. to find the 1-sigma uncertainty on the ``i`` th parameter, set ``direction[j]`` to ``1`` if ``j=i`` and ``0`` otherwise, and ``level=0.5``. .. note:: This method is only available after the model has been prepared! :param direction: direction (in the form of a 1D array analogous to the model parameter array) in which to search the NLL crossing :param best_fit: the fit result as returned from :py:meth:`Model.fit` :param nObs: either a 1D array with the observed yields in all bins and channels in the order expected by the prepared model, or a dictionary where the keys are the channel names and the values are 1D arrays with the observed yields in the corresponding channel :param nGlob: a 1D array with the vector of global observables associated with the nuisance parameters. If not specified, it is generated from :py:meth:`Model.default_glob`. :param level: increase in the NLL from the minimum :param init: initial position of the algorithm along the specified direction :param bracket: bounds on the range of values in which the crossing is searched :param root_opts: a dictionary of options passed to the root-finding algorithm :returns: a ``scipy.optimize.RootResults`` whose attribute ``x`` is set to ``x = x0 + s * direction``, with ``s`` the solution of the equation above """ if nGlob is None: nGlob = self.default_glob() else: nGlob = to_jnp(nGlob) nObs = self.concat_obs(nObs) def crossing_fn(x): values = best_fit.x + x * direction nll = self.nll_fast(values, nObs, nGlob) return nll - best_fit.fun - level def crossing_grad(x): values = best_fit.x + x * direction grad = self.nll_grad(values, nObs, nGlob) return grad.dot(direction) lb,ub = self.var_bounds.lb, self.var_bounds.ub ub = ub - best_fit.x lb = lb - best_fit.x s_ub = np.min(np.where(direction == 0, np.inf, np.where(direction > 0, ub/direction, lb/direction))) s_lb = np.max(np.where(direction == 0, -np.inf, np.where(direction > 0, lb/direction, ub/direction))) if bracket is not None: bracket = (max(bracket[0], s_lb), min(bracket[1], s_ub)) else: bracket = (max(0., s_lb), s_ub) if bracket is not None: if crossing_fn(bracket[1]) < 0.: result = RootResults(bracket[1], 0, 0, True) result.x = best_fit.x + to_jnp(result.root) * direction return result init = min(init, bracket[1]) init = max(init, bracket[0]) result = root_scalar( crossing_fn, bracket=bracket, x0=init, fprime=crossing_grad, options=root_opts) result.x = best_fit.x + to_jnp(result.root) * direction return result
[docs] def minos_bounds(self, var: Variable, best_fit: OptimizeResult, nObs: Union[np.ndarray, jnp.DeviceArray, Dict[str, np.ndarray], Dict[str, jnp.DeviceArray]], nGlob: Optional[Union[np.ndarray, jnp.DeviceArray]] = None, sub_idx: int = 0, level: float = 0.5, init_up: Union[np.ndarray, jnp.DeviceArray] = None, init_down: Union[np.ndarray, jnp.DeviceArray] = None, freeze_params: Dict[Variable, Union[float, np.ndarray, jnp.DeviceArray, Tuple[Union[np.ndarray, jnp.DeviceArray], Union[List[int], np.ndarray, jnp.DeviceArray]]]] = None, minimizer_opts: Optional[Dict[str, Any]] = None) -> Tuple[jnp.DeviceArray, jnp.DeviceArray, OptimizeResult, OptimizeResult]: """Find the confidence interval on a parameter using the Minos method Minos intervals are obtained by finding the points where the profiled likelihood increases by a given amount w.r.t. the minimum (best fit). Any parameter (or any component of a vector-valued parameter) can be profiled. The returned ``up``, ``down`` values are such that ``up >= x0[i] >= down``, where ``x0[i]`` is the best-fit value of the chosen parameter. This function is a convenience wrapper for calling :py:meth:`Model.minos_direction`, see more details about some of the arguments there. Bounds on the model's parameters are taken into account. .. note:: This method is only available after the model has been prepared! :param Variable: the :py:class:`Variable` on which to find the bounds; for vector-valued parameters, specify the component using ``sub_idxs`` :param best_fit: the fit result as returned from :py:meth:`Model.fit` :param nObs: either a 1D array with the observed yields in all bins and channels in the order expected by the prepared model, or a dictionary where the keys are the channel names and the values are 1D arrays with the observed yields in the corresponding channel :param nGlob: a 1D array with the vector of global observables associated with the nuisance parameters. If not specified, it is generated from :py:meth:`Model.default_glob`. :param sub_idx: the index of the component of the variable to consider if the variable is vector-valued :param level: increase in the profile NLL from the minimum :param init_up: specify initial point from which to minimize for the up uncertainty :param init_down: specify initial point from which to minimize for the down uncertainty :param freeze_params: specification of parameters to keep fixed in the fit (see docs of :py:meth:`Model.fit` for details) :param minimizer_opts: a dictionary of options passed to the mimimization algorithm :returns: a tuple ``(up_bound, down_bound, up_result, down_result)`` where the first two entries are the upper and lower ends of the interval on the parameter and the latter two are the full ``scipy.optimize.OptimizeResult`` objects from the calls to :py:meth:`Model.minos_direction` to find the upper and lower bound, respectively """ var_sub_idx = var.sub_idxs[sub_idx] direction_up = np.zeros(self.dim_var) direction_down = np.zeros(self.dim_var) direction_up[var_sub_idx] = 1. direction_down[var_sub_idx] = -1. up = self.minos_direction(direction_up, best_fit, nObs, nGlob, level, init_up, freeze_params, minimizer_opts) down = self.minos_direction(direction_down, best_fit, nObs, nGlob, level, init_down, freeze_params, minimizer_opts) return up.x[var_sub_idx],down.x[var_sub_idx],up,down
[docs] def minos_direction(self, direction: Union[np.ndarray, jnp.DeviceArray], best_fit: OptimizeResult, nObs: Union[np.ndarray, jnp.DeviceArray, Dict[str, np.ndarray], Dict[str, jnp.DeviceArray]], nGlob: Optional[Union[np.ndarray, jnp.DeviceArray]] = None, level: float = 0.5, init: Union[np.ndarray, jnp.DeviceArray] = None, freeze_params: Dict[Variable, Union[float, np.ndarray, jnp.DeviceArray, Tuple[Union[np.ndarray, jnp.DeviceArray], Union[List[int], np.ndarray, jnp.DeviceArray]]]] = None, minimizer_opts: Optional[Dict[str, Any]] = None) -> OptimizeResult: """Find the level crossing of the profile NLL in a specific direction Profiling the component ``i`` of ``x`` means finding the point ``prof(x_i,i)`` that minimizes ``NLL(x)`` while fixing ``x[i]=x_i``. This functions finds the point ``x=prof(x_i,i)`` that satisfies ``NLL(prof(x_i,i)) = NLL0 + level``, but is more general in that the direction in which the profiling is done is arbitrary. This function can therefore be used to obtain the profiled, so-called Minos uncertainties on a parameter or on a linear combination of parameters. In practice, this minimizes the function ``f(x) = -dot(direction, x)`` under the constraint that ``NLL(x) <= NLL0 + level``, where NLL0 is the value of the NLL at the best fit. The minimization algorithm is scipy's ``trust-constr``, see :py:meth:`Model.fit` for more details. The initial point at which the minimization starts can be changed in case of convergence issues (by default, minimization starts from the best fit). Bounds on the model's parameters are taken into account at every step. .. note:: This method is only available after the model has been prepared! :param direction: direction (in the form of a 1D array analogous to the model parameter array) in which to search the profile NLL crossing :param best_fit: the fit result as returned from :py:meth:`Model.fit` :param nObs: either a 1D array with the observed yields in all bins and channels in the order expected by the prepared model, or a dictionary where the keys are the channel names and the values are 1D arrays with the observed yields in the corresponding channel :param nGlob: a 1D array with the vector of global observables associated with the nuisance parameters. If not specified, it is generated from :py:meth:`Model.default_glob`. :param level: increase in the profile NLL from the minimum :param init: specify initial point from which to minimize :param freeze_params: specification of parameters to keep fixed in the fit (see docs of :py:meth:`Model.fit` for details) :param minimizer_opts: a dictionary of options passed to the mimimization algorithm :returns: a ``scipy.optimize.OptimizeResult`` """ if nGlob is None: nGlob = self.default_glob() else: nGlob = to_jnp(nGlob) nObs = self.concat_obs(nObs) hess = jnp.zeros((len(best_fit.x), len(best_fit.x))) def crossing_fn(x): nll = self.nll_fast(x, nObs, nGlob) return nll - best_fit.fun - level def crossing_grad(x): return self.nll_grad(x, nObs, nGlob) init = init if init is not None else best_fit.x try: result = minimize( lambda x: -direction.dot(x), init, jac=lambda x: -direction, constraints=NonlinearConstraint(crossing_fn, -np.inf, 0., crossing_grad), bounds=self._get_updated_bounds(freeze_params), hess=lambda x: hess, method="trust-constr", options=minimizer_opts ) except ValueError as e: print(f"Failed to find the minos bound in direction {direction} starting from {init}.") raise e else: result.x = to_jnp(result.x) # print(f"Result: {result.success} - {result.nit} iteratons - DNLL is {self.nll(result.x, nObs, nGlob) - best_fit.fun}") return result