import pyro.distributions as dist
from pyro import sample
import torch
Purpose
Following the blog post https://willcrichton.net/notes/probabilistic-programming-under-the-hood/
# A fair coin flip
= sample("coinflip", dist.Bernoulli(probs=0.5))
coinflip print(f'coinflip - {coinflip}')
# Noisy sample
= sample("noisy_sample", dist.Normal(loc=0, scale=1))
noisy_sample print(f'Noisy Sample - {noisy_sample}')
coinflip - 1.0
Noisy Sample - -0.21278521418571472
def sleep_model():
# very likely to feel lazy
= sample('feeling_lazy', dist.Bernoulli(0.9))
feeling_lazy if feeling_lazy:
# only going to possibly ignore alarm if I am feeling lazy
= sample('ignore_alarm', dist.Bernoulli(0.8))
ignore_alarm # will sleep more if Ignore alarm
= sample(f'amount_slept', dist.Normal(8 + 2*ignore_alarm, 1))
amount_slept else:
= sample('amount_slept', dist.Normal(6, 1))
amount_slept return amount_slept
print(sleep_model())
print(sleep_model())
print(sleep_model())
tensor(9.6567)
tensor(8.2041)
tensor(11.7966)
Traces and conditioning
On the unconditional sleep model, we could ask a few questions, like:
- Joint probability of a sample: what is the probability that
feeling_lazy
= 1,ignore_alarm
= 0, amount_slept = 10? - Joint probability distribution: what is the probability for any possible assignment to all variables?
- Marginal probability of a sample: what is the probability that
feeling_lazy
is true? - Marginal probability distribution: what is the probability over all values of
amount_slept
?
First, we need the ability to evaluate the probability of a joint assignment to each variable.
from pyro.poutine import trace
from pprint import pprint
# Runs the sleep model and collects a trace
= trace(sleep_model).get_trace()
tr
pprint({
name: {'value': props['value'],
'prob': props['fn'].log_prob(props['value']).exp()
}for (name, props) in tr.nodes.items()
if props['type'] == 'sample'
})
{'alarm_slept': {'prob': tensor(0.3937), 'value': tensor(7.8368)},
'feeling_lazy': {'prob': tensor(0.9000), 'value': tensor(1.)},
'ignore_alarm': {'prob': tensor(0.2000), 'value': tensor(0.)}}
conditional probabilities
from pyro import condition
= condition(sleep_model, {
cond_model 'feeling_lazy': torch.tensor(1.0),
'ignore_alarm': torch.tensor(0.0),
'amount_slept': 10.0
})
trace(cond_model).get_trace().log_prob_sum().exp()
tensor(0.0303)
Now we can produce an approximate answer to any of our questions by sampling from distribution enough times. For example, we can look at the marginal distribution over each variable
import pandas as pd
import matplotlib.pyplot as plt
= []
traces for _ in range(1000):
= trace(sleep_model).get_trace()
tr = {
values 'value'].item()
name: props[for (name, props) in tr.nodes.items()
if props['type'] == 'sample'
}
traces.append(values)
= pd.DataFrame(traces)
df df.hist()
array([[<AxesSubplot:title={'center':'feeling_lazy'}>,
<AxesSubplot:title={'center':'ignore_alarm'}>],
[<AxesSubplot:title={'center':'amount_slept'}>, <AxesSubplot:>]],
dtype=object)
2) df.head(
feeling_lazy | ignore_alarm | amount_slept | |
---|---|---|---|
0 | 1.0 | 1.0 | 10.285290 |
1 | 1.0 | 1.0 | 11.648011 |
Sampling conditional distributions
- Given I slept 6 hours, what is the probability I was feeling lazy?
- What is the probability of me sleeping exactly 7.65 hours
What is the problem with this? First, as the number of marginalized variables grows, we have an exponential increase in summation terms.
But the second issue is that for continuous variables, computing this marginal probability can quickly become intractable. For example, if feeling_lazy
was a real-valued laziness score between 0 and 1 (presumably a more realistic model), then marginalizing that variable requires an integral instead of a sum. In general, producing an exact estimate of a conditional probability for a complex probabilistic program is not computationally feasible.
Approximate Inference
The main idea is that instead of exactly computing the conditional probability distribution (or “posterior”) of interest, we can approximate it using a variety of techniques. Generally, these fall into two camps: sampling methods and variational methods. The CS 228 (Probabilistic Graphical Models at Stanford) course notes go in depth on both (sampling, Variational)
Essentially, for sampling methods, you use algorithms that continually draw samples from a changing probability distribution until eventually they converge on the true posterior of interest. The time to convergence is not known ahead of time. For variational methods, you use a simpler function that can be optimized to match the true posterior using standard optimization techniques like gradient descent.
Where to use what? Please have a look at the original blog post.
Variational inference 1: autodifferentiation
= dist.Normal(0, 1)
norm = sample('x', norm)
x x
tensor(-0.8995)
However, let’s say I know the value of x = 5 and I want to find a mean μ to the normal distribution that maximizes the probability of seeing that x. For that, we can use a parameter:
from pyro import param
= param("mu", torch.tensor(0.0))
mu = dist.Normal(mu, 1)
norm = sample('x', norm) x
Our goal is to update mu such that the probability of the value 5 under the normal distribution norm is maximized.
import torch.nn as nn
import torch.distributions as dist
from torch.optim import Adam
class NormalDistModel(nn.Module):
def __init__(self, mu):
super().__init__()
self.mu = nn.Parameter(torch.tensor(mu, dtype=torch.float32))
self.normal = dist.Normal(self.mu, 1)
def forward(self):
return self.normal.log_prob(torch.tensor(5.0))
= NormalDistModel(1)
model model.forward()
tensor(-8.9189, grad_fn=<SubBackward0>)
# lets now optimize the model to maximize the probabiliy of 5
= Adam(model.parameters(), lr=0.01)
optimizer
= []
mus = []
losses for _ in range (1000):
# instead of maximizing the probablity, we minize the negative log of the probability
= -model.forward()
loss
loss.backward()
# update parameters
optimizer.step()
# zero all parameter gradients so that they do not accumulate
optimizer.zero_grad()
# record the mu and the and the loss
losses.append(loss.detach().item())
mus.append(model.mu.detach().item())
= pd.DataFrame({"mu": mus, "loss": losses})
df =True) df.plot(subplots
array([<AxesSubplot:>, <AxesSubplot:>], dtype=object)