# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import itertools
from pathlib import Path
import jax
from numpyro import handlers
import numpyro.distributions as dist
[docs]def get_model_relations(model, model_args=None, model_kwargs=None, num_tries=10):
"""
Infer relations of RVs and plates from given model and optionally data.
See https://github.com/pyro-ppl/numpyro/issues/949 for more details.
This returns a dictionary with keys:
- "sample_sample" map each downstream sample site to a list of the upstream
sample sites on which it depend;
- "sample_dist" maps each sample site to the name of the distribution at
that site;
- "plate_sample" maps each plate name to a lists of the sample sites
within that plate; and
- "observe" is a list of observed sample sites.
For example for the model::
def model(data):
m = numpyro.sample('m', dist.Normal(0, 1))
sd = numpyro.sample('sd', dist.LogNormal(m, 1))
with numpyro.plate('N', len(data)):
numpyro.sample('obs', dist.Normal(m, sd), obs=data)
the relation is::
{'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']},
'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
'plate_sample': {'N': ['obs']},
'observed': ['obs']}
:param callable model: A model to inspect.
:param model_args: Optional tuple of model args.
:param model_kwargs: Optional dict of model kwargs.
:param int num_tries: Optional number times to trace model to detect
discrete -> continuous dependency.
:rtype: dict
"""
model_args = model_args or ()
model_kwargs = model_kwargs or {}
trace = handlers.trace(handlers.seed(model, 0)).get_trace(
*model_args, **model_kwargs
)
obs_sites = [
name
for name, site in trace.items()
if site["type"] == "sample" and site["is_observed"]
]
def _get_dist_name(fn):
if isinstance(
fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)
):
return _get_dist_name(fn.base_dist)
return type(fn).__name__
sample_dist = {
name: _get_dist_name(site["fn"])
for name, site in trace.items()
if site["type"] == "sample"
}
sample_plates = {
name: [frame.name for frame in site["cond_indep_stack"]]
for name, site in trace.items()
if site["type"] == "sample"
}
plate_samples = {
k: {name for name, plates in sample_plates.items() if k in plates}
for k in trace
if trace[k]["type"] == "plate"
}
def _resolve_plate_samples(plate_samples):
for p, pv in plate_samples.items():
for q, qv in plate_samples.items():
if len(pv & qv) > 0 and len(pv - qv) > 0 and len(qv - pv) > 0:
plate_samples_ = plate_samples.copy()
plate_samples_[q] = pv & qv
plate_samples_[q + "__CLONE"] = qv - pv
return _resolve_plate_samples(plate_samples_)
return plate_samples
plate_samples = _resolve_plate_samples(plate_samples)
# convert set to list to keep order of variables
plate_samples = {
k: [name for name in trace if name in v] for k, v in plate_samples.items()
}
def get_log_probs(sample, seed=0):
with handlers.trace() as tr, handlers.seed(model, seed), handlers.substitute(
data=sample
):
model(*model_args, **model_kwargs)
return {
name: site["fn"].log_prob(site["value"])
for name, site in tr.items()
if site["type"] == "sample"
}
samples = {
name: site["value"]
for name, site in trace.items()
if site["type"] == "sample"
and not site["is_observed"]
and not site["fn"].is_discrete
}
if samples:
log_prob_grads = jax.jacobian(get_log_probs)(samples)
else:
log_prob_grads = {k: {} for k in get_log_probs(samples)}
sample_deps = {}
for name, grads in log_prob_grads.items():
sample_deps[name] = {n for n in grads if n != name and (grads[n] != 0).any()}
# find discrete -> continuous dependency
samples = {
name: site["value"] for name, site in trace.items() if site["type"] == "sample"
}
discrete_sites = [
name
for name, site in trace.items()
if site["type"] == "sample"
and not site["is_observed"]
and site["fn"].is_discrete
]
log_probs_prototype = get_log_probs(samples)
for name in discrete_sites:
samples_ = samples.copy()
samples_.pop(name)
for i in range(num_tries):
log_probs = get_log_probs(samples_, seed=i + 1)
for var in samples:
if var == name:
continue
if (log_probs[var] != log_probs_prototype[var]).any():
sample_deps[var] |= {name}
sample_sample = {}
for name in samples:
sample_sample[name] = [var for var in samples if var in sample_deps[name]]
return {
"sample_sample": sample_sample,
"sample_dist": sample_dist,
"plate_sample": plate_samples,
"observed": obs_sites,
}
def generate_graph_specification(model_relations):
"""
Convert model relations into data structure which can be readily
converted into a network.
"""
# group nodes by plate
plate_groups = dict(model_relations["plate_sample"])
plate_rvs = {rv for rvs in plate_groups.values() for rv in rvs}
plate_groups[None] = [
rv for rv in model_relations["sample_sample"] if rv not in plate_rvs
] # RVs which are in no plate
# retain node metadata
node_data = {}
for rv in model_relations["sample_sample"]:
node_data[rv] = {
"is_observed": rv in model_relations["observed"],
"distribution": model_relations["sample_dist"][rv],
}
# infer plate structure
# (when the order of plates cannot be determined from subset relations,
# it follows the order in which plates appear in trace)
plate_data = {}
for plate1, plate2 in list(itertools.combinations(plate_groups, 2)):
if plate1 is None or plate2 is None:
continue
if set(plate_groups[plate1]) < set(plate_groups[plate2]):
plate_data[plate1] = {"parent": plate2}
elif set(plate_groups[plate1]) >= set(plate_groups[plate2]):
plate_data[plate2] = {"parent": plate1}
for plate in plate_groups:
if plate is None:
continue
if plate not in plate_data:
plate_data[plate] = {"parent": None}
# infer RV edges
edge_list = []
for target, source_list in model_relations["sample_sample"].items():
edge_list.extend([(source, target) for source in source_list])
return {
"plate_groups": plate_groups,
"plate_data": plate_data,
"node_data": node_data,
"edge_list": edge_list,
}
def render_graph(graph_specification, render_distributions=False):
"""
Create a graphviz object given a graph specification.
:param bool render_distributions: Show distribution of each RV in plot.
"""
try:
import graphviz # noqa: F401
except ImportError as e:
raise ImportError(
"Looks like you want to use graphviz (https://graphviz.org/) "
"to render your model. "
"You need to install `graphviz` to be able to use this feature. "
"It can be installed with `pip install graphviz`."
) from e
plate_groups = graph_specification["plate_groups"]
plate_data = graph_specification["plate_data"]
node_data = graph_specification["node_data"]
edge_list = graph_specification["edge_list"]
graph = graphviz.Digraph()
# add plates
plate_graph_dict = {
plate: graphviz.Digraph(name=f"cluster_{plate}")
for plate in plate_groups
if plate is not None
}
for plate, plate_graph in plate_graph_dict.items():
plate_graph.attr(label=plate.split("__CLONE")[0], labeljust="r", labelloc="b")
plate_graph_dict[None] = graph
# add nodes
for plate, rv_list in plate_groups.items():
cur_graph = plate_graph_dict[plate]
for rv in rv_list:
color = "grey" if node_data[rv]["is_observed"] else "white"
cur_graph.node(
rv, label=rv, shape="ellipse", style="filled", fillcolor=color
)
# add leaf nodes first
while len(plate_data) >= 1:
for plate, data in plate_data.items():
parent_plate = data["parent"]
is_leaf = True
for plate2, data2 in plate_data.items():
if plate == data2["parent"]:
is_leaf = False
break
if is_leaf:
plate_graph_dict[parent_plate].subgraph(plate_graph_dict[plate])
plate_data.pop(plate)
break
# add edges
for source, target in edge_list:
graph.edge(source, target)
# render distributions if requested
if render_distributions:
dist_label = ""
for rv, data in node_data.items():
rv_dist = data["distribution"]
dist_label += rf"{rv} ~ {rv_dist}\l"
graph.node("distribution_description_node", label=dist_label, shape="plaintext")
# return whole graph
return graph
[docs]def render_model(
model,
model_args=None,
model_kwargs=None,
filename=None,
render_distributions=False,
num_tries=10,
):
"""
Wrap all functions needed to automatically render a model.
.. warning:: This utility does not support the
:func:`~numpyro.contrib.control_flow.scan` primitive yet.
.. warning:: Currently, this utility uses a heuristic approach,
which will work for most cases, to detect dependencies in a NumPyro model.
:param model: Model to render.
:param model_args: Positional arguments to pass to the model.
:param model_kwargs: Keyword arguments to pass to the model.
:param str filename: File to save rendered model in.
:param bool render_distributions: Whether to include RV distribution annotations in the plot.
:param int num_tries: Times to trace model to detect discrete -> continuous dependency.
"""
relations = get_model_relations(
model, model_args=model_args, model_kwargs=model_kwargs, num_tries=num_tries
)
graph_spec = generate_graph_specification(relations)
graph = render_graph(graph_spec, render_distributions=render_distributions)
if filename is not None:
filename = Path(filename)
graph.render(
filename.stem, view=False, cleanup=True, format=filename.suffix[1:]
) # remove leading period from suffix
return graph