File size: 5,459 Bytes
ab4488b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from __future__ import absolute_import
import types
import warnings
from autograd.extend import primitive, notrace_primitive
import numpy as _np
import autograd.builtins as builtins
from numpy.core.einsumfunc import _parse_einsum_input
notrace_functions = [
_np.ndim, _np.shape, _np.iscomplexobj, _np.result_type
]
def wrap_intdtype(cls):
class IntdtypeSubclass(cls):
__new__ = notrace_primitive(cls.__new__)
return IntdtypeSubclass
def wrap_namespace(old, new):
unchanged_types = {float, int, type(None), type}
int_types = {_np.int8, _np.int16, _np.int32, _np.int64, _np.integer}
for name, obj in old.items():
if obj in notrace_functions:
new[name] = notrace_primitive(obj)
elif callable(obj) and type(obj) is not type:
new[name] = primitive(obj)
elif type(obj) is type and obj in int_types:
new[name] = wrap_intdtype(obj)
elif type(obj) in unchanged_types:
new[name] = obj
wrap_namespace(_np.__dict__, globals())
# ----- Special treatment of list-input functions -----
@primitive
def concatenate_args(axis, *args):
return _np.concatenate(args, axis).view(ndarray)
concatenate = lambda arr_list, axis=0 : concatenate_args(axis, *arr_list)
vstack = row_stack = lambda tup: concatenate([atleast_2d(_m) for _m in tup], axis=0)
def hstack(tup):
arrs = [atleast_1d(_m) for _m in tup]
if arrs[0].ndim == 1:
return concatenate(arrs, 0)
return concatenate(arrs, 1)
def column_stack(tup):
arrays = []
for v in tup:
arr = array(v)
if arr.ndim < 2:
arr = array(arr, ndmin=2).T
arrays.append(arr)
return concatenate(arrays, 1)
def array(A, *args, **kwargs):
t = builtins.type(A)
if t in (list, tuple):
return array_from_args(args, kwargs, *map(array, A))
else:
return _array_from_scalar_or_array(args, kwargs, A)
def wrap_if_boxes_inside(raw_array, slow_op_name=None):
if raw_array.dtype is _np.dtype('O'):
if slow_op_name:
warnings.warn("{0} is slow for array inputs. "
"np.concatenate() is faster.".format(slow_op_name))
return array_from_args((), {}, *raw_array.ravel()).reshape(raw_array.shape)
else:
return raw_array
@primitive
def _array_from_scalar_or_array(array_args, array_kwargs, scalar):
return _np.array(scalar, *array_args, **array_kwargs)
@primitive
def array_from_args(array_args, array_kwargs, *args):
return _np.array(args, *array_args, **array_kwargs)
def select(condlist, choicelist, default=0):
raw_array = _np.select(list(condlist), list(choicelist), default=default)
return array(list(raw_array.ravel())).reshape(raw_array.shape)
def stack(arrays, axis=0):
# this code is basically copied from numpy/core/shape_base.py's stack
# we need it here because we want to re-implement stack in terms of the
# primitives defined in this file
arrays = [array(arr) for arr in arrays]
if not arrays:
raise ValueError('need at least one array to stack')
shapes = set(arr.shape for arr in arrays)
if len(shapes) != 1:
raise ValueError('all input arrays must have the same shape')
result_ndim = arrays[0].ndim + 1
if not -result_ndim <= axis < result_ndim:
raise IndexError('axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim))
if axis < 0:
axis += result_ndim
sl = (slice(None),) * axis + (None,)
return concatenate([arr[sl] for arr in arrays], axis=axis)
def append(arr, values, axis=None):
# this code is basically copied from numpy/lib/function_base.py's append
arr = array(arr)
if axis is None:
if ndim(arr) != 1:
arr = ravel(arr)
values = ravel(array(values))
axis = ndim(arr) - 1
return concatenate((arr, values), axis=axis)
# ----- Enable functions called using [] ----
class r_class():
def __getitem__(self, args):
raw_array = _np.r_[args]
return wrap_if_boxes_inside(raw_array, slow_op_name = "r_")
r_ = r_class()
class c_class():
def __getitem__(self, args):
raw_array = _np.c_[args]
return wrap_if_boxes_inside(raw_array, slow_op_name = "c_")
c_ = c_class()
# ----- misc -----
@primitive
def make_diagonal(D, offset=0, axis1=0, axis2=1):
# Numpy doesn't offer a complement to np.diagonal: a function to create new
# diagonal arrays with extra dimensions. We need such a function for the
# gradient of np.diagonal and it's also quite handy to have. So here it is.
if not (offset==0 and axis1==-1 and axis2==-2):
raise NotImplementedError("Currently make_diagonal only supports offset=0, axis1=-1, axis2=-2")
# We use a trick: calling np.diagonal returns a view on the original array,
# so we can modify it in-place. (only valid for numpy version >= 1.10.)
new_array = _np.zeros(D.shape + (D.shape[-1],))
new_array_diag = _np.diagonal(new_array, offset=0, axis1=-1, axis2=-2)
new_array_diag.flags.writeable = True
new_array_diag[:] = D
return new_array
@notrace_primitive
def metadata(A):
return _np.shape(A), _np.ndim(A), _np.result_type(A), _np.iscomplexobj(A)
@notrace_primitive
def parse_einsum_input(*args):
return _parse_einsum_input(args)
@primitive
def _astype(A, dtype, order='K', casting='unsafe', subok=True, copy=True):
return A.astype(dtype, order, casting, subok, copy)
|