324 lines
12 KiB
Python
324 lines
12 KiB
Python
from numba import typeof
|
|
from numba.core import types
|
|
from numba.np.ufunc.ufuncbuilder import GUFuncBuilder
|
|
from numba.np.ufunc.sigparse import parse_signature
|
|
from numba.np.ufunc.ufunc_base import UfuncBase, UfuncLowererBase
|
|
from numba.np.numpy_support import ufunc_find_matching_loop
|
|
from numba.core import serialize, errors
|
|
from numba.core.typing import npydecl
|
|
from numba.core.typing.templates import signature, AbstractTemplate
|
|
import functools
|
|
|
|
|
|
def make_gufunc_kernel(_dufunc):
|
|
from numba.np import npyimpl
|
|
|
|
class GUFuncKernel(npyimpl._Kernel):
|
|
"""
|
|
npyimpl._Kernel subclass responsible for lowering a gufunc kernel
|
|
(element-wise function) inside a broadcast loop (which is
|
|
generated by npyimpl.numpy_gufunc_kernel()).
|
|
"""
|
|
dufunc = _dufunc
|
|
|
|
def __init__(self, context, builder, outer_sig):
|
|
super().__init__(context, builder, outer_sig)
|
|
ewise_types = self.dufunc._get_ewise_dtypes(outer_sig.args)
|
|
self.inner_sig, self.cres = self.dufunc.find_ewise_function(
|
|
ewise_types)
|
|
|
|
def cast(self, val, fromty, toty):
|
|
# Handle the case where "fromty" is an array and "toty" a scalar
|
|
if isinstance(fromty, types.Array) and not \
|
|
isinstance(toty, types.Array):
|
|
return super().cast(val, fromty.dtype, toty)
|
|
return super().cast(val, fromty, toty)
|
|
|
|
def generate(self, *args):
|
|
if self.cres.objectmode:
|
|
msg = ('Calling a guvectorize function in object mode is not '
|
|
'supported yet.')
|
|
raise errors.NumbaRuntimeError(msg)
|
|
self.context.add_linking_libs((self.cres.library,))
|
|
return super().generate(*args)
|
|
|
|
GUFuncKernel.__name__ += _dufunc.__name__
|
|
return GUFuncKernel
|
|
|
|
|
|
class GUFuncLowerer(UfuncLowererBase):
|
|
'''Callable class responsible for lowering calls to a specific gufunc.
|
|
'''
|
|
def __init__(self, gufunc):
|
|
from numba.np import npyimpl
|
|
super().__init__(gufunc,
|
|
make_gufunc_kernel,
|
|
npyimpl.numpy_gufunc_kernel)
|
|
|
|
|
|
class GUFunc(serialize.ReduceMixin, UfuncBase):
|
|
"""
|
|
Dynamic generalized universal function (GUFunc)
|
|
intended to act like a normal Numpy gufunc, but capable
|
|
of call-time (just-in-time) compilation of fast loops
|
|
specialized to inputs.
|
|
"""
|
|
|
|
def __init__(self, py_func, signature, identity=None, cache=None,
|
|
is_dynamic=False, targetoptions={}, writable_args=()):
|
|
self.ufunc = None
|
|
self._frozen = False
|
|
self._is_dynamic = is_dynamic
|
|
self._identity = identity
|
|
|
|
# GUFunc cannot inherit from GUFuncBuilder because "identity"
|
|
# is a property of GUFunc. Thus, we hold a reference to a GUFuncBuilder
|
|
# object here
|
|
self.gufunc_builder = GUFuncBuilder(
|
|
py_func, signature, identity, cache, targetoptions, writable_args)
|
|
|
|
self.__name__ = self.gufunc_builder.py_func.__name__
|
|
self.__doc__ = self.gufunc_builder.py_func.__doc__
|
|
self._dispatcher = self.gufunc_builder.nb_func
|
|
self._initialize(self._dispatcher)
|
|
functools.update_wrapper(self, py_func)
|
|
|
|
def _initialize(self, dispatcher):
|
|
self.build_ufunc()
|
|
self._install_type()
|
|
self._lower_me = GUFuncLowerer(self)
|
|
self._install_cg()
|
|
|
|
def _reduce_states(self):
|
|
gb = self.gufunc_builder
|
|
dct = dict(
|
|
py_func=gb.py_func,
|
|
signature=gb.signature,
|
|
identity=self._identity,
|
|
cache=gb.cache,
|
|
is_dynamic=self._is_dynamic,
|
|
targetoptions=gb.targetoptions,
|
|
writable_args=gb.writable_args,
|
|
typesigs=gb._sigs,
|
|
frozen=self._frozen,
|
|
)
|
|
return dct
|
|
|
|
@classmethod
|
|
def _rebuild(cls, py_func, signature, identity, cache, is_dynamic,
|
|
targetoptions, writable_args, typesigs, frozen):
|
|
self = cls(py_func=py_func, signature=signature, identity=identity,
|
|
cache=cache, is_dynamic=is_dynamic,
|
|
targetoptions=targetoptions, writable_args=writable_args)
|
|
for sig in typesigs:
|
|
self.add(sig)
|
|
self.build_ufunc()
|
|
self._frozen = frozen
|
|
return self
|
|
|
|
def __repr__(self):
|
|
return f"<numba._GUFunc '{self.__name__}'>"
|
|
|
|
def _install_type(self, typingctx=None):
|
|
"""Constructs and installs a typing class for a gufunc object in the
|
|
input typing context. If no typing context is given, then
|
|
_install_type() installs into the typing context of the
|
|
dispatcher object (should be same default context used by
|
|
jit() and njit()).
|
|
"""
|
|
if typingctx is None:
|
|
typingctx = self._dispatcher.targetdescr.typing_context
|
|
_ty_cls = type('GUFuncTyping_' + self.__name__,
|
|
(AbstractTemplate,),
|
|
dict(key=self, generic=self._type_me))
|
|
typingctx.insert_user_function(self, _ty_cls)
|
|
|
|
def add(self, fty):
|
|
self.gufunc_builder.add(fty)
|
|
|
|
def build_ufunc(self):
|
|
self.ufunc = self.gufunc_builder.build_ufunc()
|
|
return self
|
|
|
|
def expected_ndims(self):
|
|
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
|
return (tuple(map(len, parsed_sig[0])), tuple(map(len, parsed_sig[1])))
|
|
|
|
def _type_me(self, argtys, kws):
|
|
"""
|
|
Implement AbstractTemplate.generic() for the typing class
|
|
built by gufunc._install_type().
|
|
|
|
Return the call-site signature after either validating the
|
|
element-wise signature or compiling for it.
|
|
"""
|
|
assert not kws
|
|
ufunc = self.ufunc
|
|
sig = self.gufunc_builder.signature
|
|
inp_ndims, out_ndims = self.expected_ndims()
|
|
ndims = inp_ndims + out_ndims
|
|
|
|
assert len(argtys), len(ndims)
|
|
for idx, arg in enumerate(argtys):
|
|
if isinstance(arg, types.Array) and arg.ndim < ndims[idx]:
|
|
kind = "Input" if idx < len(inp_ndims) else "Output"
|
|
i = idx if idx < len(inp_ndims) else idx - len(inp_ndims)
|
|
msg = (
|
|
f"{self.__name__}: {kind} operand {i} does not have "
|
|
f"enough dimensions (has {arg.ndim}, gufunc core with "
|
|
f"signature {sig} requires {ndims[idx]})")
|
|
raise errors.TypingError(msg)
|
|
|
|
_handle_inputs_result = npydecl.Numpy_rules_ufunc._handle_inputs(
|
|
ufunc, argtys, kws)
|
|
ewise_types, _, _, _ = _handle_inputs_result
|
|
sig, _ = self.find_ewise_function(ewise_types)
|
|
|
|
if sig is None:
|
|
# Matching element-wise signature was not found; must
|
|
# compile.
|
|
if self._frozen:
|
|
msg = f"cannot call {self} with types {argtys}"
|
|
raise errors.TypingError(msg)
|
|
self._compile_for_argtys(ewise_types)
|
|
# double check to ensure there is a match
|
|
sig, _ = self.find_ewise_function(ewise_types)
|
|
if sig == (None, None):
|
|
msg = f"Fail to compile {self.__name__} with types {argtys}"
|
|
raise errors.TypingError(msg)
|
|
|
|
assert sig is not None
|
|
|
|
return signature(types.none, *argtys)
|
|
|
|
def _compile_for_argtys(self, argtys, return_type=None):
|
|
# Compile a new guvectorize function! Use the gufunc signature
|
|
# i.e. (n,m),(m)->(n)
|
|
# plus ewise_types to build a numba function type
|
|
fnty = self._get_function_type(*argtys)
|
|
self.gufunc_builder.add(fnty)
|
|
|
|
def match_signature(self, ewise_types, sig):
|
|
dtypes = self._get_ewise_dtypes(sig.args)
|
|
return tuple(dtypes) == tuple(ewise_types)
|
|
|
|
@property
|
|
def is_dynamic(self):
|
|
return self._is_dynamic
|
|
|
|
def _get_ewise_dtypes(self, args):
|
|
argtys = map(lambda arg: arg if isinstance(arg, types.Type) else
|
|
typeof(arg), args)
|
|
tys = []
|
|
for argty in argtys:
|
|
if isinstance(argty, types.Array):
|
|
tys.append(argty.dtype)
|
|
else:
|
|
tys.append(argty)
|
|
return tys
|
|
|
|
def _num_args_match(self, *args):
|
|
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
|
return len(args) == len(parsed_sig[0]) + len(parsed_sig[1])
|
|
|
|
def _get_function_type(self, *args):
|
|
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
|
# ewise_types is a list of [int32, int32, int32, ...]
|
|
ewise_types = self._get_ewise_dtypes(args)
|
|
|
|
# first time calling the gufunc
|
|
# generate a signature based on input arguments
|
|
l = []
|
|
for idx, sig_dim in enumerate(parsed_sig[0]):
|
|
ndim = len(sig_dim)
|
|
if ndim == 0: # append scalar
|
|
l.append(ewise_types[idx])
|
|
else:
|
|
l.append(types.Array(ewise_types[idx], ndim, 'A'))
|
|
|
|
offset = len(parsed_sig[0])
|
|
# add return type to signature
|
|
for idx, sig_dim in enumerate(parsed_sig[1]):
|
|
retty = ewise_types[idx + offset]
|
|
ret_ndim = len(sig_dim) or 1 # small hack to return scalars
|
|
l.append(types.Array(retty, ret_ndim, 'A'))
|
|
|
|
return types.none(*l)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
# If compilation is disabled OR it is NOT a dynamic gufunc
|
|
# call the underlying gufunc
|
|
if self._frozen or not self.is_dynamic:
|
|
# Do not unwrap the ufunc if the argument is a wrapper that will
|
|
# potentially pickle the ufunc after it receives it in
|
|
# __array_ufunc__. The same logic in theory should be replicated
|
|
# for reduce(), outer(), etc., but they're not implemented in dask.
|
|
if args and _is_array_wrapper(args[0]):
|
|
return args[0].__array_ufunc__(
|
|
self, "__call__", *args, **kwargs
|
|
)
|
|
else:
|
|
return self.ufunc(*args, **kwargs)
|
|
elif "out" in kwargs:
|
|
# If "out" argument is supplied
|
|
args += (kwargs.pop("out"),)
|
|
|
|
if self._num_args_match(*args) is False:
|
|
# It is not allowed to call a dynamic gufunc without
|
|
# providing all the arguments
|
|
# see: https://github.com/numba/numba/pull/5938#discussion_r506429392 # noqa: E501
|
|
msg = (
|
|
f"Too few arguments for function '{self.__name__}'. "
|
|
"Note that the pattern `out = gufunc(Arg1, Arg2, ..., ArgN)` "
|
|
"is not allowed. Use `gufunc(Arg1, Arg2, ..., ArgN, out) "
|
|
"instead.")
|
|
raise TypeError(msg)
|
|
|
|
# at this point we know the gufunc is a dynamic one
|
|
ewise = self._get_ewise_dtypes(args)
|
|
if not (self.ufunc and ufunc_find_matching_loop(self.ufunc, ewise)):
|
|
# A previous call (@njit -> @guvectorize) may have compiled a
|
|
# version for the element-wise dtypes. In this case, we don't need
|
|
# to compile it again, just build the (g)ufunc
|
|
if not self.find_ewise_function(ewise) != (None, None):
|
|
sig = self._get_function_type(*args)
|
|
self.add(sig)
|
|
self.build_ufunc()
|
|
|
|
return self.ufunc(*args, **kwargs)
|
|
|
|
|
|
def _is_array_wrapper(obj):
|
|
"""Return True if obj wraps around numpy or another numpy-like library
|
|
and is likely going to apply the ufunc to the wrapped array; False
|
|
otherwise.
|
|
|
|
At the moment, this returns True for
|
|
|
|
- dask.array.Array
|
|
- dask.dataframe.DataFrame
|
|
- dask.dataframe.Series
|
|
- xarray.DataArray
|
|
- xarray.Dataset
|
|
- xarray.Variable
|
|
- pint.Quantity
|
|
- other potential wrappers around dask array or dask dataframe
|
|
|
|
We may need to add other libraries that pickle ufuncs from their
|
|
__array_ufunc__ method in the future.
|
|
|
|
Note that the below test is a lot more naive than
|
|
`dask.base.is_dask_collection`
|
|
(https://github.com/dask/dask/blob/5949e54bc04158d215814586a44d51e0eb4a964d/dask/base.py#L209-L249), # noqa: E501
|
|
because it doesn't need to find out if we're actually dealing with
|
|
a dask collection, only that we're dealing with a wrapper.
|
|
Namely, it will return True for a pint.Quantity wrapping around a plain float, a
|
|
numpy.ndarray, or a dask.array.Array, and it's OK because in all cases
|
|
Quantity.__array_ufunc__ is going to forward the ufunc call inwards.
|
|
"""
|
|
return (
|
|
not isinstance(obj, type)
|
|
and hasattr(obj, "__dask_graph__")
|
|
and hasattr(obj, "__array_ufunc__")
|
|
)
|