from typing import Dict, Tuple
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.var import TensorVariable
import aemcmc.nuts as nuts
from aemcmc.rewriting import (
SamplerTracker,
construct_ir_fgraph,
expand_subsumptions,
sampler_rewrites_db,
)
from aemcmc.types import Sampler
[docs]def construct_sampler(
obs_rvs_to_values: Dict[TensorVariable, TensorVariable], srng: RandomStream
) -> Tuple[Sampler, Dict[TensorVariable, TensorVariable]]:
r"""Eagerly construct a sampler for a given set of observed variables and their observations.
Parameters
==========
obs_rvs_to_values
A ``dict`` of variables that maps stochastic elements
(e.g. `RandomVariable`\s) to symbolic `Variable`\s representing their
observed values.
Returns
=======
A ``dict`` that maps each random variable to its sampler step and
any updates generated by the sampler steps.
"""
fgraph, obs_rvs_to_values, memo, new_to_old_rvs = construct_ir_fgraph(
obs_rvs_to_values
)
fgraph.attach_feature(SamplerTracker(srng))
_ = sampler_rewrites_db.query("+basic").rewrite(fgraph)
random_vars = tuple(rv for rv in fgraph.outputs if rv not in obs_rvs_to_values)
discovered_samplers = fgraph.sampler_mappings.rvs_to_samplers
rvs_to_init_vals = {rv: rv.clone() for rv in random_vars}
posterior_sample_steps = rvs_to_init_vals.copy()
# Replace occurrences of observed variables with their observed values
posterior_sample_steps.update(obs_rvs_to_values)
# TODO FIXME: Get/extract `Scan`-generated updates
posterior_updates: Dict[Variable, Variable] = {}
parameters: Dict[Apply, Tuple[TensorVariable]] = {}
rvs_without_samplers = set()
for rv in fgraph.outputs:
if rv in obs_rvs_to_values:
continue
rv_steps = discovered_samplers.get(rv)
if not rv_steps:
rvs_without_samplers.add(rv)
continue
# TODO FIXME: Just choosing one for now, but we should consider them all.
step_desc, step, updates = rv_steps.pop()
# Expand subsumed `DimShuffle`d inputs to `Elemwise`s
if updates:
update_keys, update_values = zip(*updates.items())
else:
update_keys, update_values = tuple(), tuple()
sfgraph = FunctionGraph(
outputs=(step,) + tuple(update_keys) + tuple(update_values),
clone=False,
copy_inputs=False,
copy_orphans=False,
)
# Update the other sampled random variables in this step's graph
sfgraph.replace_all(list(posterior_sample_steps.items()), import_missing=True)
expand_subsumptions.rewrite(sfgraph)
step = sfgraph.outputs[0]
# Update the other sampled random variables in this step's graph
# (step,) = clone_replace([step], replace=posterior_sample_steps)
posterior_sample_steps[rv] = step
if updates:
keys_offset = len(update_keys) + 1
update_keys = sfgraph.outputs[1:keys_offset]
update_values = sfgraph.outputs[keys_offset:]
updates = dict(zip(update_keys, update_values))
posterior_updates.update(updates)
# Use the NUTS sampler for the remaining variables
if rvs_without_samplers:
to_sample_rvs = {
rv: posterior_sample_steps[rv] for rv in list(rvs_without_samplers)
}
realized_values = {
rv: vv
for rv, vv in posterior_sample_steps.items()
if rv not in to_sample_rvs
}
(nuts_sample_steps, updates, nuts_parameters) = nuts.construct_sampler(
srng, to_sample_rvs, realized_values
)
posterior_sample_steps.update(nuts_sample_steps)
posterior_updates.update(updates)
parameters.update(nuts_parameters)
sampling_steps = {
new_to_old_rvs[rv]: step
for rv, step in posterior_sample_steps.items()
if rv not in obs_rvs_to_values
}
return Sampler(sampling_steps, posterior_updates, parameters), {
new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()
}