|
from itertools import repeat |
|
from autograd.wrap_util import wraps |
|
from autograd.util import subvals, toposort |
|
from autograd.tracer import trace, Node |
|
from functools import partial |
|
|
|
class ConstGraphNode(Node): |
|
__slots__ = ['parents', 'partial_fun'] |
|
def __init__(self, value, fun, args, kwargs, parent_argnums, parents): |
|
args = subvals(args, zip(parent_argnums, repeat(None))) |
|
def partial_fun(partial_args): |
|
return fun(*subvals(args, zip(parent_argnums, partial_args)), **kwargs) |
|
|
|
self.parents = parents |
|
self.partial_fun = partial_fun |
|
|
|
def initialize_root(self): |
|
self.parents = [] |
|
|
|
def const_graph_unary(fun): |
|
graph = [] |
|
_fun = [fun] |
|
def maybe_cached_fun(x): |
|
if graph: |
|
_graph = graph[0] |
|
vals = {_graph[0] : x} |
|
for node in _graph[1:]: |
|
vals[node] = node.partial_fun([vals[p] for p in node.parents]) |
|
return vals[node] |
|
else: |
|
start_node = ConstGraphNode.new_root() |
|
end_value, end_node = trace(start_node, _fun.pop(), x) |
|
if end_node is None: |
|
raise Exception("Output is independent of input") |
|
graph.append(list(toposort(end_node))[::-1]) |
|
return end_value |
|
return maybe_cached_fun |
|
|
|
def const_graph(fun, *args, **kwargs): |
|
partial_fun = partial(fun, *args, **kwargs) |
|
unary_fun = lambda args: partial_fun(*args) |
|
maybe_cached_unary_fun = const_graph_unary(unary_fun) |
|
@wraps(fun) |
|
def _fun(*args): return maybe_cached_unary_fun(args) |
|
return _fun |
|
|
|
class FullGraphNode(Node): |
|
__slots__ = ['value', 'recipe'] |
|
def __init__(self, value, fun, args, kwargs, parent_argnums, parents): |
|
self.value = value |
|
self.recipe = (fun, args, kwargs, zip(parent_argnums, parents)) |
|
|
|
def initialize_root(self): |
|
self.value = None |
|
self.recipe = (lambda x: x, (), {}, []) |
|
|
|
def full_graph(fun, *args, **kwargs): |
|
unary_fun = lambda args: fun(*args, **kwargs) |
|
start_node = FullGraphNode.new_root() |
|
end_value, end_node = trace(start_node, unary_fun, args) |
|
return end_node |
|
|