from numba import typeof from numba.core import types from numba.np.ufunc.ufuncbuilder import GUFuncBuilder from numba.np.ufunc.sigparse import parse_signature from numba.np.ufunc.ufunc_base import UfuncBase, UfuncLowererBase from numba.np.numpy_support import ufunc_find_matching_loop from numba.core import serialize, errors from numba.core.typing import npydecl from numba.core.typing.templates import signature, AbstractTemplate import functools def make_gufunc_kernel(_dufunc): from numba.np import npyimpl class GUFuncKernel(npyimpl._Kernel): """ npyimpl._Kernel subclass responsible for lowering a gufunc kernel (element-wise function) inside a broadcast loop (which is generated by npyimpl.numpy_gufunc_kernel()). """ dufunc = _dufunc def __init__(self, context, builder, outer_sig): super().__init__(context, builder, outer_sig) ewise_types = self.dufunc._get_ewise_dtypes(outer_sig.args) self.inner_sig, self.cres = self.dufunc.find_ewise_function( ewise_types) def cast(self, val, fromty, toty): # Handle the case where "fromty" is an array and "toty" a scalar if isinstance(fromty, types.Array) and not \ isinstance(toty, types.Array): return super().cast(val, fromty.dtype, toty) return super().cast(val, fromty, toty) def generate(self, *args): if self.cres.objectmode: msg = ('Calling a guvectorize function in object mode is not ' 'supported yet.') raise errors.NumbaRuntimeError(msg) self.context.add_linking_libs((self.cres.library,)) return super().generate(*args) GUFuncKernel.__name__ += _dufunc.__name__ return GUFuncKernel class GUFuncLowerer(UfuncLowererBase): '''Callable class responsible for lowering calls to a specific gufunc. ''' def __init__(self, gufunc): from numba.np import npyimpl super().__init__(gufunc, make_gufunc_kernel, npyimpl.numpy_gufunc_kernel) class GUFunc(serialize.ReduceMixin, UfuncBase): """ Dynamic generalized universal function (GUFunc) intended to act like a normal Numpy gufunc, but capable of call-time (just-in-time) compilation of fast loops specialized to inputs. """ def __init__(self, py_func, signature, identity=None, cache=None, is_dynamic=False, targetoptions={}, writable_args=()): self.ufunc = None self._frozen = False self._is_dynamic = is_dynamic self._identity = identity # GUFunc cannot inherit from GUFuncBuilder because "identity" # is a property of GUFunc. Thus, we hold a reference to a GUFuncBuilder # object here self.gufunc_builder = GUFuncBuilder( py_func, signature, identity, cache, targetoptions, writable_args) self.__name__ = self.gufunc_builder.py_func.__name__ self.__doc__ = self.gufunc_builder.py_func.__doc__ self._dispatcher = self.gufunc_builder.nb_func self._initialize(self._dispatcher) functools.update_wrapper(self, py_func) def _initialize(self, dispatcher): self.build_ufunc() self._install_type() self._lower_me = GUFuncLowerer(self) self._install_cg() def _reduce_states(self): gb = self.gufunc_builder dct = dict( py_func=gb.py_func, signature=gb.signature, identity=self._identity, cache=gb.cache, is_dynamic=self._is_dynamic, targetoptions=gb.targetoptions, writable_args=gb.writable_args, typesigs=gb._sigs, frozen=self._frozen, ) return dct @classmethod def _rebuild(cls, py_func, signature, identity, cache, is_dynamic, targetoptions, writable_args, typesigs, frozen): self = cls(py_func=py_func, signature=signature, identity=identity, cache=cache, is_dynamic=is_dynamic, targetoptions=targetoptions, writable_args=writable_args) for sig in typesigs: self.add(sig) self.build_ufunc() self._frozen = frozen return self def __repr__(self): return f"" def _install_type(self, typingctx=None): """Constructs and installs a typing class for a gufunc object in the input typing context. If no typing context is given, then _install_type() installs into the typing context of the dispatcher object (should be same default context used by jit() and njit()). """ if typingctx is None: typingctx = self._dispatcher.targetdescr.typing_context _ty_cls = type('GUFuncTyping_' + self.__name__, (AbstractTemplate,), dict(key=self, generic=self._type_me)) typingctx.insert_user_function(self, _ty_cls) def add(self, fty): self.gufunc_builder.add(fty) def build_ufunc(self): self.ufunc = self.gufunc_builder.build_ufunc() return self def expected_ndims(self): parsed_sig = parse_signature(self.gufunc_builder.signature) return (tuple(map(len, parsed_sig[0])), tuple(map(len, parsed_sig[1]))) def _type_me(self, argtys, kws): """ Implement AbstractTemplate.generic() for the typing class built by gufunc._install_type(). Return the call-site signature after either validating the element-wise signature or compiling for it. """ assert not kws ufunc = self.ufunc sig = self.gufunc_builder.signature inp_ndims, out_ndims = self.expected_ndims() ndims = inp_ndims + out_ndims assert len(argtys), len(ndims) for idx, arg in enumerate(argtys): if isinstance(arg, types.Array) and arg.ndim < ndims[idx]: kind = "Input" if idx < len(inp_ndims) else "Output" i = idx if idx < len(inp_ndims) else idx - len(inp_ndims) msg = ( f"{self.__name__}: {kind} operand {i} does not have " f"enough dimensions (has {arg.ndim}, gufunc core with " f"signature {sig} requires {ndims[idx]})") raise errors.TypingError(msg) _handle_inputs_result = npydecl.Numpy_rules_ufunc._handle_inputs( ufunc, argtys, kws) ewise_types, _, _, _ = _handle_inputs_result sig, _ = self.find_ewise_function(ewise_types) if sig is None: # Matching element-wise signature was not found; must # compile. if self._frozen: msg = f"cannot call {self} with types {argtys}" raise errors.TypingError(msg) self._compile_for_argtys(ewise_types) # double check to ensure there is a match sig, _ = self.find_ewise_function(ewise_types) if sig == (None, None): msg = f"Fail to compile {self.__name__} with types {argtys}" raise errors.TypingError(msg) assert sig is not None return signature(types.none, *argtys) def _compile_for_argtys(self, argtys, return_type=None): # Compile a new guvectorize function! Use the gufunc signature # i.e. (n,m),(m)->(n) # plus ewise_types to build a numba function type fnty = self._get_function_type(*argtys) self.gufunc_builder.add(fnty) def match_signature(self, ewise_types, sig): dtypes = self._get_ewise_dtypes(sig.args) return tuple(dtypes) == tuple(ewise_types) @property def is_dynamic(self): return self._is_dynamic def _get_ewise_dtypes(self, args): argtys = map(lambda arg: arg if isinstance(arg, types.Type) else typeof(arg), args) tys = [] for argty in argtys: if isinstance(argty, types.Array): tys.append(argty.dtype) else: tys.append(argty) return tys def _num_args_match(self, *args): parsed_sig = parse_signature(self.gufunc_builder.signature) return len(args) == len(parsed_sig[0]) + len(parsed_sig[1]) def _get_function_type(self, *args): parsed_sig = parse_signature(self.gufunc_builder.signature) # ewise_types is a list of [int32, int32, int32, ...] ewise_types = self._get_ewise_dtypes(args) # first time calling the gufunc # generate a signature based on input arguments l = [] for idx, sig_dim in enumerate(parsed_sig[0]): ndim = len(sig_dim) if ndim == 0: # append scalar l.append(ewise_types[idx]) else: l.append(types.Array(ewise_types[idx], ndim, 'A')) offset = len(parsed_sig[0]) # add return type to signature for idx, sig_dim in enumerate(parsed_sig[1]): retty = ewise_types[idx + offset] ret_ndim = len(sig_dim) or 1 # small hack to return scalars l.append(types.Array(retty, ret_ndim, 'A')) return types.none(*l) def __call__(self, *args, **kwargs): # If compilation is disabled OR it is NOT a dynamic gufunc # call the underlying gufunc if self._frozen or not self.is_dynamic: # Do not unwrap the ufunc if the argument is a wrapper that will # potentially pickle the ufunc after it receives it in # __array_ufunc__. The same logic in theory should be replicated # for reduce(), outer(), etc., but they're not implemented in dask. if args and _is_array_wrapper(args[0]): return args[0].__array_ufunc__( self, "__call__", *args, **kwargs ) else: return self.ufunc(*args, **kwargs) elif "out" in kwargs: # If "out" argument is supplied args += (kwargs.pop("out"),) if self._num_args_match(*args) is False: # It is not allowed to call a dynamic gufunc without # providing all the arguments # see: https://github.com/numba/numba/pull/5938#discussion_r506429392 # noqa: E501 msg = ( f"Too few arguments for function '{self.__name__}'. " "Note that the pattern `out = gufunc(Arg1, Arg2, ..., ArgN)` " "is not allowed. Use `gufunc(Arg1, Arg2, ..., ArgN, out) " "instead.") raise TypeError(msg) # at this point we know the gufunc is a dynamic one ewise = self._get_ewise_dtypes(args) if not (self.ufunc and ufunc_find_matching_loop(self.ufunc, ewise)): # A previous call (@njit -> @guvectorize) may have compiled a # version for the element-wise dtypes. In this case, we don't need # to compile it again, just build the (g)ufunc if not self.find_ewise_function(ewise) != (None, None): sig = self._get_function_type(*args) self.add(sig) self.build_ufunc() return self.ufunc(*args, **kwargs) def _is_array_wrapper(obj): """Return True if obj wraps around numpy or another numpy-like library and is likely going to apply the ufunc to the wrapped array; False otherwise. At the moment, this returns True for - dask.array.Array - dask.dataframe.DataFrame - dask.dataframe.Series - xarray.DataArray - xarray.Dataset - xarray.Variable - pint.Quantity - other potential wrappers around dask array or dask dataframe We may need to add other libraries that pickle ufuncs from their __array_ufunc__ method in the future. Note that the below test is a lot more naive than `dask.base.is_dask_collection` (https://github.com/dask/dask/blob/5949e54bc04158d215814586a44d51e0eb4a964d/dask/base.py#L209-L249), # noqa: E501 because it doesn't need to find out if we're actually dealing with a dask collection, only that we're dealing with a wrapper. Namely, it will return True for a pint.Quantity wrapping around a plain float, a numpy.ndarray, or a dask.array.Array, and it's OK because in all cases Quantity.__array_ufunc__ is going to forward the ufunc call inwards. """ return ( not isinstance(obj, type) and hasattr(obj, "__dask_graph__") and hasattr(obj, "__array_ufunc__") )