File size: 771 Bytes
ab4488b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from autograd.extend import primitive, defvjp, vspace
from autograd.builtins import tuple
from autograd import make_vjp

@primitive
def fixed_point(f, a, x0, distance, tol):
    _f = f(a)
    x, x_prev = _f(x0), x0
    while distance(x, x_prev) > tol:
        x, x_prev = _f(x), x
    return x

def fixed_point_vjp(ans, f, a, x0, distance, tol):
    def rev_iter(params):
        a, x_star, x_star_bar = params
        vjp_x, _ = make_vjp(f(a))(x_star)
        vs = vspace(x_star)
        return lambda g: vs.add(vjp_x(g), x_star_bar)
    vjp_a, _ = make_vjp(lambda x, y: f(x)(y))(a, ans)
    return lambda g: vjp_a(fixed_point(rev_iter, tuple((a, ans, g)),
                           vspace(x0).zeros(), distance, tol))

defvjp(fixed_point, None, fixed_point_vjp, None)