File size: 2,758 Bytes
ab4488b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Gradients of the normal distribution."""
from __future__ import absolute_import
import scipy.stats
import autograd.numpy as anp
from autograd.extend import primitive, defvjp
from autograd.numpy.numpy_vjps import unbroadcast_f

pdf = primitive(scipy.stats.norm.pdf)
cdf = primitive(scipy.stats.norm.cdf)
sf = primitive(scipy.stats.norm.sf)
logpdf = primitive(scipy.stats.norm.logpdf)
logcdf = primitive(scipy.stats.norm.logcdf)
logsf = primitive(scipy.stats.norm.logsf)

defvjp(pdf,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(x, lambda g: -g * ans * (x - loc) / scale**2),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(loc, lambda g: g * ans * (x - loc) / scale**2),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(scale, lambda g: g * ans * (((x - loc)/scale)**2 - 1.0)/scale))

defvjp(cdf,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(x, lambda g: g * pdf(x, loc, scale)) ,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(loc, lambda g: -g * pdf(x, loc, scale)),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(scale, lambda g: -g * pdf(x, loc, scale)*(x-loc)/scale))

defvjp(logpdf,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(x, lambda g: -g * (x - loc) / scale**2),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(loc, lambda g: g * (x - loc) / scale**2),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(scale, lambda g: g * (-1.0/scale + (x - loc)**2/scale**3)))

defvjp(logcdf,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(x, lambda g: g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(loc, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(scale, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))*(x-loc)/scale))

defvjp(logsf, 
       lambda ans, x, loc=0.0, scale=1.0: 
       unbroadcast_f(x, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))),
       lambda ans, x, loc=0.0, scale=1.0: 
       unbroadcast_f(loc, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))),
       lambda ans, x, loc=0.0, scale=1.0: 
       unbroadcast_f(scale, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale)) * (x - loc) / scale))

defvjp(sf,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(x, lambda g: -g * pdf(x, loc, scale)) ,
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(loc, lambda g: g * pdf(x, loc, scale)),
       lambda ans, x, loc=0.0, scale=1.0:
       unbroadcast_f(scale, lambda g: g * pdf(x, loc, scale)*(x-loc)/scale))