Source code for aemcmc.ffbs

from typing import Mapping, Tuple

import aesara
import aesara.tensor as at
from aesara.tensor.random import RandomStream
from aesara.tensor.var import TensorVariable


[docs]def ffbs_step( gamma_0: TensorVariable, Gammas: TensorVariable, log_lik: TensorVariable, srng: RandomStream, lower_prec_bound: float = 1e-20, ) -> Tuple[TensorVariable, Mapping]: r"""Draw a discrete state sequence sample using forward-filtering backward-sampling (FFBS) [fs]_. FFBS draws samples according to .. math:: S_T &\sim p(S_T | y_{1:T}) \\ S_t \mid S_{t+1} &\sim p(S_{t+1} | S_t) p(S_{t+1} \mid y_{1:T}) for discrete states in the sequence :math:`S_t`, :math:`t \in \{0, \dots, T\}` and observations :math:`y_t`. The argument `gamma_0` corresponds to :math:`S_0`, `Gammas` to :math:`p(S_{t+1} | S_t)`, and `log_lik` to :math:`\log p(y_t \mid S_t)`. The latter is used to derive the forward quantities :math:`p(S_{t+1} \mid y_{1:T})`. This implementations is most similar to the one in [nr]_, which forgoes expensive log-space calculations by using fixed precision bounds and re-normalization/scaling. Parameters ---------- gamma_0 The initial state probabilities as an array of base shape ``(M,)``. Gamma The transition probability matrices. This array should take the base shape ``(N, M, M)``, where ``N`` is the state sequence length and ``M`` is the number of distinct states. If ``N`` is ``1``, the single transition matrix will broadcast across all elements of the state sequence. log_lik An array of base shape ``(M, N)`` consisting of the log-likelihood values for each state value at each point in the sequence. lower_prec_bound A number indicating when forward probabilities are too small and need to be re-normalized/scaled. Returns ------- A tuple containing the tensor representing the FFBS sampled states and the `Scan`-generated updates. References ---------- .. [fs] Sylvia Fruehwirth-Schnatter, "Markov chain Monte Carlo estimation of classical and dynamic switching and mixture models", Journal of the American Statistical Association 96 (2001): 194--209 .. [nr] Press, William H., Saul A. Teukolsky, William T. Vetterling, and Brian P. Flannery. 2007. "Numerical Recipes 3rd Edition: The Art of Scientific Computing." Cambridge university press. """ gamma_0 = at.as_tensor(gamma_0) log_lik = at.as_tensor(log_lik) # Number of observations N = log_lik.shape[-1] # Number of states # M = gamma_0.shape[-1] # Initial state probabilities gamma_0_normed = gamma_0 gamma_0_normed /= at.sum(gamma_0) # Make sure we have a transition matrix for each element in a state # sequence Gammas = at.broadcast_to(Gammas, (N,) + Gammas.shape[-2:]) def forward_step(log_lik_T_n, Gamma_n, alpha_nm1): log_lik_n = log_lik_T_n.T lik_n = at.exp(log_lik_n - log_lik_n.max()) alpha_n = at.dot(alpha_nm1, Gamma_n) * lik_n alpha_n_sum = at.sum(alpha_n, axis=-1) rescaled_alpha_n = at.switch( at.lt(alpha_n_sum, lower_prec_bound), alpha_n / lower_prec_bound, alpha_n ) return rescaled_alpha_n alphas, _ = aesara.scan( forward_step, sequences=[log_lik.T, Gammas], outputs_info=[{"initial": gamma_0_normed, "taps": [-1]}], non_sequences=[], n_steps=N, strict=True, name="forward-pass", ) alphas = alphas.T def backward_step(alphas_T_n, Gamma_np1, state_n): alphas_n = alphas_T_n.T beta_n = alphas_n * Gamma_np1[:, state_n] beta_n = beta_n / at.sum(beta_n, axis=-1) state_np1 = srng.categorical(beta_n) return state_np1 alpha_N = alphas[..., N - 1] beta_N = alpha_N / alpha_N.sum(axis=-1) state_np1 = srng.categorical(beta_N) state_samples_rev, updates = aesara.scan( backward_step, sequences=[alphas.T[: N - 1], Gammas[1:]], outputs_info=[{"initial": state_np1, "taps": [-1]}], non_sequences=[], go_backwards=True, strict=True, name="backward-pass", ) state_samples_rev = at.join( 0, at.atleast_Nd(state_np1, n=state_samples_rev.type.ndim), state_samples_rev ) return state_samples_rev[::-1], updates