Source code for numpyro.infer.initialization
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
import warnings
import jax.numpy as jnp
import numpyro.distributions as dist
from numpyro.distributions import biject_to
from numpyro.util import find_stack_level
[docs]
def init_to_median(site=None, num_samples=15):
"""
Initialize to the prior median. For priors with no `.sample` method implemented,
we defer to the :func:`init_to_uniform` strategy.
:param int num_samples: number of prior points to calculate median.
"""
if site is None:
return partial(init_to_median, num_samples=num_samples)
if (
site["type"] == "sample"
and not site["is_observed"]
and not site["fn"].support.is_discrete
):
if site["value"] is not None:
warnings.warn(
f"init_to_median() skipping initialization of site '{site['name']}'"
" which already stores a value.",
stacklevel=find_stack_level(),
)
return site["value"]
rng_key = site["kwargs"].get("rng_key")
sample_shape = site["kwargs"].get("sample_shape")
try:
samples = site["fn"](
sample_shape=(num_samples,) + sample_shape, rng_key=rng_key
)
return jnp.median(samples, axis=0)
except NotImplementedError:
return init_to_uniform(site)
[docs]
def init_to_mean(site=None):
"""
Initialize to the prior mean. For priors with no `.mean` property implemented,
we defer to the :func:`init_to_median` strategy.
"""
if site is None:
return partial(init_to_mean)
if (
site["type"] == "sample"
and not site["is_observed"]
and not site["fn"].support.is_discrete
):
if site["value"] is not None:
warnings.warn(
f"init_to_mean() skipping initialization of site '{site['name']}'"
" which already stores a value.",
stacklevel=find_stack_level(),
)
return site["value"]
try:
# Try .mean property.
value = site["fn"].mean
sample_shape = site["kwargs"].get("sample_shape")
if sample_shape:
value = jnp.broadcast_to(value, sample_shape + jnp.shape(value))
except (NotImplementedError, ValueError):
return init_to_median(site)
[docs]
def init_to_sample(site=None):
"""
Initialize to a prior sample. For priors with no `.sample` method implemented,
we defer to the :func:`init_to_uniform` strategy.
"""
return init_to_median(site, num_samples=1)
[docs]
def init_to_uniform(site=None, radius=2):
"""
Initialize to a random point in the area `(-radius, radius)` of unconstrained domain.
:param float radius: specifies the range to draw an initial point in the unconstrained domain.
"""
if site is None:
return partial(init_to_uniform, radius=radius)
if (
site["type"] == "sample"
and not site["is_observed"]
and not site["fn"].support.is_discrete
):
if site["value"] is not None:
warnings.warn(
f"init_to_uniform() skipping initialization of site '{site['name']}'"
" which already stores a value.",
stacklevel=find_stack_level(),
)
return site["value"]
# XXX: we import here to avoid circular import
from numpyro.infer.util import helpful_support_errors
rng_key = site["kwargs"].get("rng_key")
sample_shape = site["kwargs"].get("sample_shape")
with helpful_support_errors(site):
transform = biject_to(site["fn"].support)
unconstrained_shape = transform.inverse_shape(site["fn"].shape())
unconstrained_samples = dist.Uniform(-radius, radius)(
rng_key=rng_key, sample_shape=sample_shape + unconstrained_shape
)
return transform(unconstrained_samples)
[docs]
def init_to_feasible(site=None):
"""
Initialize to an arbitrary feasible point, ignoring distribution
parameters.
"""
return init_to_uniform(site, radius=0)
[docs]
def init_to_value(site=None, values={}):
"""
Initialize to the value specified in `values`. We defer to
:func:`init_to_uniform` strategy for sites which do not appear in `values`.
:param dict values: dictionary of initial values keyed by site name.
"""
if site is None:
return partial(init_to_value, values=values)
if site["type"] == "sample" and not site["is_observed"]:
if site["name"] in values:
return values[site["name"]]
else: # defer to default strategy
return init_to_uniform(site)