Source code for numpyro.contrib.render

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import itertools
from pathlib import Path

import jax

from numpyro import handlers
import numpyro.distributions as dist


[docs]def get_model_relations(model, model_args=None, model_kwargs=None, num_tries=10): """ Infer relations of RVs and plates from given model and optionally data. See https://github.com/pyro-ppl/numpyro/issues/949 for more details. This returns a dictionary with keys: - "sample_sample" map each downstream sample site to a list of the upstream sample sites on which it depend; - "sample_dist" maps each sample site to the name of the distribution at that site; - "plate_sample" maps each plate name to a lists of the sample sites within that plate; and - "observe" is a list of observed sample sites. For example for the model:: def model(data): m = numpyro.sample('m', dist.Normal(0, 1)) sd = numpyro.sample('sd', dist.LogNormal(m, 1)) with numpyro.plate('N', len(data)): numpyro.sample('obs', dist.Normal(m, sd), obs=data) the relation is:: {'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']}, 'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'}, 'plate_sample': {'N': ['obs']}, 'observed': ['obs']} :param callable model: A model to inspect. :param model_args: Optional tuple of model args. :param model_kwargs: Optional dict of model kwargs. :param int num_tries: Optional number times to trace model to detect discrete -> continuous dependency. :rtype: dict """ model_args = model_args or () model_kwargs = model_kwargs or {} trace = handlers.trace(handlers.seed(model, 0)).get_trace( *model_args, **model_kwargs ) obs_sites = [ name for name, site in trace.items() if site["type"] == "sample" and site["is_observed"] ] def _get_dist_name(fn): if isinstance( fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution) ): return _get_dist_name(fn.base_dist) return type(fn).__name__ sample_dist = { name: _get_dist_name(site["fn"]) for name, site in trace.items() if site["type"] == "sample" } sample_plates = { name: [frame.name for frame in site["cond_indep_stack"]] for name, site in trace.items() if site["type"] == "sample" } plate_samples = { k: {name for name, plates in sample_plates.items() if k in plates} for k in trace if trace[k]["type"] == "plate" } def _resolve_plate_samples(plate_samples): for p, pv in plate_samples.items(): for q, qv in plate_samples.items(): if len(pv & qv) > 0 and len(pv - qv) > 0 and len(qv - pv) > 0: plate_samples_ = plate_samples.copy() plate_samples_[q] = pv & qv plate_samples_[q + "__CLONE"] = qv - pv return _resolve_plate_samples(plate_samples_) return plate_samples plate_samples = _resolve_plate_samples(plate_samples) # convert set to list to keep order of variables plate_samples = { k: [name for name in trace if name in v] for k, v in plate_samples.items() } def get_log_probs(sample, seed=0): with handlers.trace() as tr, handlers.seed(model, seed), handlers.substitute( data=sample ): model(*model_args, **model_kwargs) return { name: site["fn"].log_prob(site["value"]) for name, site in tr.items() if site["type"] == "sample" } samples = { name: site["value"] for name, site in trace.items() if site["type"] == "sample" and not site["is_observed"] and not site["fn"].is_discrete } if samples: log_prob_grads = jax.jacobian(get_log_probs)(samples) else: log_prob_grads = {k: {} for k in get_log_probs(samples)} sample_deps = {} for name, grads in log_prob_grads.items(): sample_deps[name] = {n for n in grads if n != name and (grads[n] != 0).any()} # find discrete -> continuous dependency samples = { name: site["value"] for name, site in trace.items() if site["type"] == "sample" } discrete_sites = [ name for name, site in trace.items() if site["type"] == "sample" and not site["is_observed"] and site["fn"].is_discrete ] log_probs_prototype = get_log_probs(samples) for name in discrete_sites: samples_ = samples.copy() samples_.pop(name) for i in range(num_tries): log_probs = get_log_probs(samples_, seed=i + 1) for var in samples: if var == name: continue if (log_probs[var] != log_probs_prototype[var]).any(): sample_deps[var] |= {name} sample_sample = {} for name in samples: sample_sample[name] = [var for var in samples if var in sample_deps[name]] return { "sample_sample": sample_sample, "sample_dist": sample_dist, "plate_sample": plate_samples, "observed": obs_sites, }
def generate_graph_specification(model_relations): """ Convert model relations into data structure which can be readily converted into a network. """ # group nodes by plate plate_groups = dict(model_relations["plate_sample"]) plate_rvs = {rv for rvs in plate_groups.values() for rv in rvs} plate_groups[None] = [ rv for rv in model_relations["sample_sample"] if rv not in plate_rvs ] # RVs which are in no plate # retain node metadata node_data = {} for rv in model_relations["sample_sample"]: node_data[rv] = { "is_observed": rv in model_relations["observed"], "distribution": model_relations["sample_dist"][rv], } # infer plate structure # (when the order of plates cannot be determined from subset relations, # it follows the order in which plates appear in trace) plate_data = {} for plate1, plate2 in list(itertools.combinations(plate_groups, 2)): if plate1 is None or plate2 is None: continue if set(plate_groups[plate1]) < set(plate_groups[plate2]): plate_data[plate1] = {"parent": plate2} elif set(plate_groups[plate1]) >= set(plate_groups[plate2]): plate_data[plate2] = {"parent": plate1} for plate in plate_groups: if plate is None: continue if plate not in plate_data: plate_data[plate] = {"parent": None} # infer RV edges edge_list = [] for target, source_list in model_relations["sample_sample"].items(): edge_list.extend([(source, target) for source in source_list]) return { "plate_groups": plate_groups, "plate_data": plate_data, "node_data": node_data, "edge_list": edge_list, } def render_graph(graph_specification, render_distributions=False): """ Create a graphviz object given a graph specification. :param bool render_distributions: Show distribution of each RV in plot. """ try: import graphviz # noqa: F401 except ImportError as e: raise ImportError( "Looks like you want to use graphviz (https://graphviz.org/) " "to render your model. " "You need to install `graphviz` to be able to use this feature. " "It can be installed with `pip install graphviz`." ) from e plate_groups = graph_specification["plate_groups"] plate_data = graph_specification["plate_data"] node_data = graph_specification["node_data"] edge_list = graph_specification["edge_list"] graph = graphviz.Digraph() # add plates plate_graph_dict = { plate: graphviz.Digraph(name=f"cluster_{plate}") for plate in plate_groups if plate is not None } for plate, plate_graph in plate_graph_dict.items(): plate_graph.attr(label=plate.split("__CLONE")[0], labeljust="r", labelloc="b") plate_graph_dict[None] = graph # add nodes for plate, rv_list in plate_groups.items(): cur_graph = plate_graph_dict[plate] for rv in rv_list: color = "grey" if node_data[rv]["is_observed"] else "white" cur_graph.node( rv, label=rv, shape="ellipse", style="filled", fillcolor=color ) # add leaf nodes first while len(plate_data) >= 1: for plate, data in plate_data.items(): parent_plate = data["parent"] is_leaf = True for plate2, data2 in plate_data.items(): if plate == data2["parent"]: is_leaf = False break if is_leaf: plate_graph_dict[parent_plate].subgraph(plate_graph_dict[plate]) plate_data.pop(plate) break # add edges for source, target in edge_list: graph.edge(source, target) # render distributions if requested if render_distributions: dist_label = "" for rv, data in node_data.items(): rv_dist = data["distribution"] dist_label += rf"{rv} ~ {rv_dist}\l" graph.node("distribution_description_node", label=dist_label, shape="plaintext") # return whole graph return graph
[docs]def render_model( model, model_args=None, model_kwargs=None, filename=None, render_distributions=False, num_tries=10, ): """ Wrap all functions needed to automatically render a model. .. warning:: This utility does not support the :func:`~numpyro.contrib.control_flow.scan` primitive yet. .. warning:: Currently, this utility uses a heuristic approach, which will work for most cases, to detect dependencies in a NumPyro model. :param model: Model to render. :param model_args: Positional arguments to pass to the model. :param model_kwargs: Keyword arguments to pass to the model. :param str filename: File to save rendered model in. :param bool render_distributions: Whether to include RV distribution annotations in the plot. :param int num_tries: Times to trace model to detect discrete -> continuous dependency. """ relations = get_model_relations( model, model_args=model_args, model_kwargs=model_kwargs, num_tries=num_tries ) graph_spec = generate_graph_specification(relations) graph = render_graph(graph_spec, render_distributions=render_distributions) if filename is not None: filename = Path(filename) graph.render( filename.stem, view=False, cleanup=True, format=filename.suffix[1:] ) # remove leading period from suffix return graph