|
from __future__ import absolute_import |
|
|
|
import autograd.numpy as np |
|
import scipy.stats |
|
from autograd.extend import primitive, defvjp |
|
from autograd.numpy.numpy_vjps import unbroadcast_f |
|
|
|
cdf = primitive(scipy.stats.poisson.cdf) |
|
logpmf = primitive(scipy.stats.poisson.logpmf) |
|
pmf = primitive(scipy.stats.poisson.pmf) |
|
|
|
def grad_poisson_logpmf(k, mu): |
|
return np.where(k % 1 == 0, k / mu - 1, 0) |
|
|
|
defvjp(cdf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * -pmf(np.floor(k), mu)), argnums=[1]) |
|
defvjp(logpmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * grad_poisson_logpmf(k, mu)), argnums=[1]) |
|
defvjp(pmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * ans * grad_poisson_logpmf(k, mu)), argnums=[1]) |
|
|