import numpy as onp from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, tensordot_adjoint_1, nograd_functions) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace) from ..util import func from .numpy_boxes import ArrayBox for fun in nograd_functions: register_notrace(JVPNode, fun) defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') defjvp_argnum(anp.array_from_args, lambda argnum, g, ans, args, kwargs: untake(g, argnum-2, vspace(ans))) defjvp(anp._array_from_scalar_or_array, None, None, lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(args, kwargs, g)) # ----- Functions that are constant w.r.t. continuous inputs ----- defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.)) # ----- Binary ufuncs (linear) ----- def_linear(anp.multiply) # ----- Binary ufuncs ----- defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans), lambda g, ans, x, y : broadcast(g, ans)) defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, ans), lambda g, ans, x, y : broadcast(-g, ans)) defjvp(anp.divide, 'same', lambda g, ans, x, y : - g * x / y**2) defjvp(anp.maximum, lambda g, ans, x, y : g * balanced_eq(x, ans, y), lambda g, ans, x, y : g * balanced_eq(y, ans, x)) defjvp(anp.minimum, lambda g, ans, x, y : g * balanced_eq(x, ans, y), lambda g, ans, x, y : g * balanced_eq(y, ans, x)) defjvp(anp.fmax, lambda g, ans, x, y : g * balanced_eq(x, ans, y), lambda g, ans, x, y : g * balanced_eq(y, ans, x)) defjvp(anp.fmin, lambda g, ans, x, y : g * balanced_eq(x, ans, y), lambda g, ans, x, y : g * balanced_eq(y, ans, x)) defjvp(anp.logaddexp, lambda g, ans, x, y : g * anp.exp(x-ans), lambda g, ans, x, y : g * anp.exp(y-ans)) defjvp(anp.logaddexp2, lambda g, ans, x, y : g * 2**(x-ans), lambda g, ans, x, y : g * 2**(y-ans)) defjvp(anp.true_divide,'same', lambda g, ans, x, y : - g * x / y**2) defjvp(anp.mod, lambda g, ans, x, y : broadcast(g, ans), lambda g, ans, x, y : -g * anp.floor(x/y)) defjvp(anp.remainder, lambda g, ans, x, y : broadcast(g, ans), lambda g, ans, x, y : -g * anp.floor(x/y)) defjvp(anp.power, lambda g, ans, x, y : g * y * x ** anp.where(y, y - 1, 1.), lambda g, ans, x, y : g * anp.log(replace_zero(x, 1.)) * ans) defjvp(anp.arctan2, lambda g, ans, x, y : g * y / (x**2 + y**2), lambda g, ans, x, y : g * -x / (x**2 + y**2)) # ----- Simple grads (linear) ----- defjvp(anp.negative, 'same') defjvp(anp.rad2deg, 'same') defjvp(anp.degrees, 'same') defjvp(anp.deg2rad, 'same') defjvp(anp.radians, 'same') defjvp(anp.reshape, 'same') defjvp(anp.roll, 'same') defjvp(anp.array_split, 'same') defjvp(anp.split, 'same') defjvp(anp.vsplit, 'same') defjvp(anp.hsplit, 'same') defjvp(anp.dsplit, 'same') defjvp(anp.ravel, 'same') defjvp(anp.expand_dims, 'same') defjvp(anp.squeeze, 'same') defjvp(anp.diag, 'same') defjvp(anp.diagonal, 'same') defjvp(anp.make_diagonal, 'same') defjvp(anp.flipud, 'same') defjvp(anp.fliplr, 'same') defjvp(anp.rot90, 'same') defjvp(anp.trace, 'same') defjvp(anp.full, 'same', argnums=(1,)) defjvp(anp.triu, 'same') defjvp(anp.tril, 'same') defjvp(anp.swapaxes, 'same') defjvp(anp.rollaxis, 'same') defjvp(anp.moveaxis, 'same') defjvp(anp.broadcast_to, 'same') def_linear(anp.cross) # ----- Simple grads ----- defjvp(anp.abs, lambda g, ans, x : anp.real(g * replace_zero(anp.conj(x), 0.)) / replace_zero(ans, 1.)) defjvp(anp.fabs, lambda g, ans, x : anp.sign(x) * g) # fabs doesn't take complex numbers. defjvp(anp.absolute, lambda g, ans, x : anp.real(g * anp.conj(x)) / ans) defjvp(anp.reciprocal, lambda g, ans, x : - g / x**2) defjvp(anp.exp, lambda g, ans, x : ans * g) defjvp(anp.exp2, lambda g, ans, x : ans * anp.log(2) * g) defjvp(anp.expm1, lambda g, ans, x : (ans + 1) * g) defjvp(anp.log, lambda g, ans, x : g / x) defjvp(anp.log2, lambda g, ans, x : g / x / anp.log(2)) defjvp(anp.log10, lambda g, ans, x : g / x / anp.log(10)) defjvp(anp.log1p, lambda g, ans, x : g / (x + 1)) defjvp(anp.sin, lambda g, ans, x : g * anp.cos(x)) defjvp(anp.cos, lambda g, ans, x : - g * anp.sin(x)) defjvp(anp.tan, lambda g, ans, x : g / anp.cos(x) **2) defjvp(anp.arcsin, lambda g, ans, x : g / anp.sqrt(1 - x**2)) defjvp(anp.arccos, lambda g, ans, x :-g / anp.sqrt(1 - x**2)) defjvp(anp.arctan, lambda g, ans, x : g / (1 + x**2)) defjvp(anp.sinh, lambda g, ans, x : g * anp.cosh(x)) defjvp(anp.cosh, lambda g, ans, x : g * anp.sinh(x)) defjvp(anp.tanh, lambda g, ans, x : g / anp.cosh(x) **2) defjvp(anp.arcsinh, lambda g, ans, x : g / anp.sqrt(x**2 + 1)) defjvp(anp.arccosh, lambda g, ans, x : g / anp.sqrt(x**2 - 1)) defjvp(anp.arctanh, lambda g, ans, x : g / (1 - x**2)) defjvp(anp.square, lambda g, ans, x : g * 2 * x) defjvp(anp.sqrt, lambda g, ans, x : g * 0.5 * x**-0.5) defjvp(anp.sinc, lambda g, ans, x : g * (anp.cos(anp.pi*x)*anp.pi*x - anp.sin(anp.pi*x))/(anp.pi*x**2)) defjvp(anp.clip, lambda g, ans, x, a_min, a_max : g * anp.logical_and(ans != a_min, ans != a_max)) defjvp(anp.real_if_close, lambda g, ans, x : match_complex(ans, g)) defjvp(anp.real, lambda g, ans, x : anp.real(g)) defjvp(anp.imag, lambda g, ans, x : match_complex(ans, -1j * g)) defjvp(anp.conj, lambda g, ans, x : anp.conj(g)) defjvp(anp.angle, lambda g, ans, x : match_complex(ans, g * anp.conj(x * 1j) / anp.abs(x)**2)) defjvp(anp.where, None, lambda g, ans, c, x=None, y=None : anp.where(c, g, anp.zeros(anp.shape(g))), lambda g, ans, c, x=None, y=None : anp.where(c, anp.zeros(g.shape), g)) # ----- Trickier grads ----- defjvp(anp.kron, 'same', 'same') defjvp(anp.diff, 'same') defjvp(anp.gradient, 'same') defjvp(anp.repeat, 'same') defjvp(anp.tile, 'same') defjvp(anp.transpose, 'same') defjvp(anp.sum, 'same') defjvp(anp.mean, 'same') defjvp(anp.prod, lambda g, ans, x, axis=None, keepdims=False: ans * anp.sum(g / x, axis=axis, keepdims=keepdims)) defjvp(anp.linspace, lambda g, ans, start, stop, *args, **kwargs: anp.linspace(g, 0, *args, **kwargs), lambda g, ans, start, stop, *args, **kwargs: anp.linspace(0, g, *args, **kwargs)) def forward_grad_np_var(g, ans, x, axis=None, ddof=0, keepdims=False): if axis is None: num_reps = anp.size(g) elif isinstance(axis, int): num_reps = anp.shape(g)[axis] elif isinstance(axis, tuple): num_reps = anp.prod(anp.array(np.shape(g))[list(axis)]) x_minus_mean = anp.conj(x - anp.mean(x, axis=axis, keepdims=True)) return (2.0 * anp.sum(anp.real(g * x_minus_mean), axis=axis, keepdims=keepdims) / (num_reps - ddof)) defjvp(anp.var, forward_grad_np_var) def forward_grad_np_std(g, ans, x, axis=None, ddof=0, keepdims=False): if axis is None: num_reps = anp.size(g) elif isinstance(axis, int): num_reps = anp.shape(g)[axis] elif isinstance(axis, tuple): num_reps = anp.prod(anp.array(anp.shape(g))[list(axis)]) if num_reps <= 1: return anp.zeros_like(ans) x_minus_mean = anp.conj(x - anp.mean(x, axis=axis, keepdims=True)) return (anp.sum(anp.real(g * x_minus_mean), axis=axis, keepdims=keepdims) / ((num_reps - ddof) * ans)) defjvp(anp.std, forward_grad_np_std) def fwd_grad_chooser(g, ans, x, axis=None, keepdims=False): if anp.isscalar(x): return g if not keepdims: if isinstance(axis, int): ans = anp.expand_dims(ans, axis) elif isinstance(axis, tuple): for ax in sorted(axis): ans = anp.expand_dims(ans, ax) chosen_locations = x == ans return (anp.sum((g * chosen_locations), axis=axis, keepdims=keepdims) / anp.sum(chosen_locations, axis=axis, keepdims=keepdims)) defjvp(anp.max, fwd_grad_chooser) defjvp(anp.min, fwd_grad_chooser) defjvp(anp.amax, fwd_grad_chooser) defjvp(anp.amin, fwd_grad_chooser) defjvp(anp.cumsum, 'same') def_linear(anp.inner) def_linear(anp.matmul) def_linear(anp.dot) def_linear(anp.tensordot) def_linear(anp.outer) def_linear(dot_adjoint_0) def_linear(dot_adjoint_1) def_linear(tensordot_adjoint_0) def_linear(tensordot_adjoint_1) def fwd_grad_concatenate_args(argnum, g, ans, axis_args, kwargs): result = [] for i in range(1, len(axis_args)): if i == argnum: result.append(g) else: result.append(anp.zeros_like(axis_args[i])) return anp.concatenate_args(axis_args[0], *result) defjvp_argnum(anp.concatenate_args, fwd_grad_concatenate_args) def fwd_grad_sort(g, ans, x, axis=-1, kind='quicksort', order=None): sort_perm = anp.argsort(x, axis, kind, order) return g[sort_perm] defjvp(anp.sort, fwd_grad_sort) if onp.lib.NumpyVersion(onp.__version__) < '2.0.0': defjvp(anp.msort, lambda g, ans, x: fwd_grad_sort(g, ans, x, axis=0)) def fwd_grad_partition(g, ans, x, kth, axis=-1, kind='introselect', order=None): partition_perm = anp.argpartition(x, kth, axis, kind, order) return g[partition_perm] defjvp(anp.partition, fwd_grad_partition) def atleast_jvpmaker(fun): def jvp(g, ans, *arys): if len(arys) > 1: raise NotImplementedError("Can't handle multiple arguments yet.") return fun(g) return jvp defjvp(anp.atleast_1d, atleast_jvpmaker(anp.atleast_1d)) defjvp(anp.atleast_2d, atleast_jvpmaker(anp.atleast_2d)) defjvp(anp.atleast_3d, atleast_jvpmaker(anp.atleast_3d)) def_linear(anp.einsum) # TODO(mattjj): can we call np.broadcast_to or a related function instead? def broadcast(x, target): target_shape, target_ndim, target_dtype, target_iscomplex = anp.metadata(target) while anp.ndim(x) < target_ndim: x = anp.expand_dims(x, 0) for axis, size in enumerate(anp.shape(x)): if size == 1: x = anp.repeat(x, target_shape[axis], axis=axis) if target_iscomplex and not anp.iscomplexobj(x): x = x + 0j # TODO(mattjj): this might promote the dtype return x defjvp(anp.pad, lambda g, ans, array, width, mode, **kwargs: anp.pad(g, width, mode))