Source code for aemcmc.basic
from typing import Dict, Tuple
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.var import TensorVariable
import aemcmc.nuts as nuts
from aemcmc.rewriting import (
SamplerTracker,
construct_ir_fgraph,
expand_subsumptions,
sampler_rewrites_db,
)
from aemcmc.types import Sampler
[docs]def construct_sampler(
obs_rvs_to_values: Dict[TensorVariable, TensorVariable], srng: RandomStream
) -> Tuple[Sampler, Dict[TensorVariable, TensorVariable]]:
r"""Eagerly construct a sampler for a given set of observed variables and their observations.
Parameters
==========
obs_rvs_to_values
A ``dict`` of variables that maps stochastic elements
(e.g. `RandomVariable`\s) to symbolic `Variable`\s representing their
observed values.
Returns
=======
A ``dict`` that maps each random variable to its sampler step and
any updates generated by the sampler steps.
"""
fgraph, obs_rvs_to_values, memo, new_to_old_rvs = construct_ir_fgraph(
obs_rvs_to_values
)
fgraph.attach_feature(SamplerTracker(srng))
_ = sampler_rewrites_db.query("+basic").rewrite(fgraph)
random_vars = tuple(rv for rv in fgraph.outputs if rv not in obs_rvs_to_values)
discovered_samplers = fgraph.sampler_mappings.rvs_to_samplers
rvs_to_init_vals = {rv: rv.clone() for rv in random_vars}
posterior_sample_steps = rvs_to_init_vals.copy()
# Replace occurrences of observed variables with their observed values
posterior_sample_steps.update(obs_rvs_to_values)
# TODO FIXME: Get/extract `Scan`-generated updates
posterior_updates: Dict[Variable, Variable] = {}
parameters: Dict[Apply, Tuple[TensorVariable]] = {}
rvs_without_samplers = set()
for rv in fgraph.outputs:
if rv in obs_rvs_to_values:
continue
rv_steps = discovered_samplers.get(rv)
if not rv_steps:
rvs_without_samplers.add(rv)
continue
# TODO FIXME: Just choosing one for now, but we should consider them all.
step_desc, step, updates = rv_steps.pop()
# Expand subsumed `DimShuffle`d inputs to `Elemwise`s
if updates:
update_keys, update_values = zip(*updates.items())
else:
update_keys, update_values = tuple(), tuple()
sfgraph = FunctionGraph(
outputs=(step,) + tuple(update_keys) + tuple(update_values),
clone=False,
copy_inputs=False,
copy_orphans=False,
)
# Update the other sampled random variables in this step's graph
sfgraph.replace_all(list(posterior_sample_steps.items()), import_missing=True)
expand_subsumptions.rewrite(sfgraph)
step = sfgraph.outputs[0]
# Update the other sampled random variables in this step's graph
# (step,) = clone_replace([step], replace=posterior_sample_steps)
posterior_sample_steps[rv] = step
if updates:
keys_offset = len(update_keys) + 1
update_keys = sfgraph.outputs[1:keys_offset]
update_values = sfgraph.outputs[keys_offset:]
updates = dict(zip(update_keys, update_values))
posterior_updates.update(updates)
# Use the NUTS sampler for the remaining variables
if rvs_without_samplers:
to_sample_rvs = {
rv: posterior_sample_steps[rv] for rv in list(rvs_without_samplers)
}
realized_values = {
rv: vv
for rv, vv in posterior_sample_steps.items()
if rv not in to_sample_rvs
}
(nuts_sample_steps, updates, nuts_parameters) = nuts.construct_sampler(
srng, to_sample_rvs, realized_values
)
posterior_sample_steps.update(nuts_sample_steps)
posterior_updates.update(updates)
parameters.update(nuts_parameters)
sampling_steps = {
new_to_old_rvs[rv]: step
for rv, step in posterior_sample_steps.items()
if rv not in obs_rvs_to_values
}
return Sampler(sampling_steps, posterior_updates, parameters), {
new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()
}