"""Gradients of the normal distribution.""" from __future__ import absolute_import import scipy.stats import autograd.numpy as anp from autograd.extend import primitive, defvjp from autograd.numpy.numpy_vjps import unbroadcast_f pdf = primitive(scipy.stats.norm.pdf) cdf = primitive(scipy.stats.norm.cdf) sf = primitive(scipy.stats.norm.sf) logpdf = primitive(scipy.stats.norm.logpdf) logcdf = primitive(scipy.stats.norm.logcdf) logsf = primitive(scipy.stats.norm.logsf) defvjp(pdf, lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * ans * (x - loc) / scale**2), lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * ans * (x - loc) / scale**2), lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(scale, lambda g: g * ans * (((x - loc)/scale)**2 - 1.0)/scale)) defvjp(cdf, lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * pdf(x, loc, scale)) , lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: -g * pdf(x, loc, scale)), lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(scale, lambda g: -g * pdf(x, loc, scale)*(x-loc)/scale)) defvjp(logpdf, lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * (x - loc) / scale**2), lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * (x - loc) / scale**2), lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(scale, lambda g: g * (-1.0/scale + (x - loc)**2/scale**3))) defvjp(logcdf, lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))), lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))), lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(scale, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))*(x-loc)/scale)) defvjp(logsf, lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))), lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))), lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(scale, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale)) * (x - loc) / scale)) defvjp(sf, lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * pdf(x, loc, scale)) , lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * pdf(x, loc, scale)), lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(scale, lambda g: g * pdf(x, loc, scale)*(x-loc)/scale))