Files
ytts/venv/lib/python3.11/site-packages/numba/np/ufunc/ufunc_base.py
2025-04-02 21:44:17 -07:00

114 lines
3.3 KiB
Python

from numba.np import numpy_support
from numba.core import types
class UfuncLowererBase:
'''Callable class responsible for lowering calls to a specific gufunc.
'''
def __init__(self, ufunc, make_kernel_fn, make_ufunc_kernel_fn):
self.ufunc = ufunc
self.make_ufunc_kernel_fn = make_ufunc_kernel_fn
self.kernel = make_kernel_fn(ufunc)
self.libs = []
def __call__(self, context, builder, sig, args):
return self.make_ufunc_kernel_fn(context, builder, sig, args,
self.ufunc, self.kernel)
class UfuncBase:
@property
def nin(self):
return self.ufunc.nin
@property
def nout(self):
return self.ufunc.nout
@property
def nargs(self):
return self.ufunc.nargs
@property
def ntypes(self):
return self.ufunc.ntypes
@property
def types(self):
return self.ufunc.types
@property
def identity(self):
return self.ufunc.identity
@property
def signature(self):
return self.ufunc.signature
@property
def accumulate(self):
return self.ufunc.accumulate
@property
def at(self):
return self.ufunc.at
@property
def outer(self):
return self.ufunc.outer
@property
def reduce(self):
return self.ufunc.reduce
@property
def reduceat(self):
return self.ufunc.reduceat
def disable_compile(self):
"""
Disable the compilation of new signatures at call time.
"""
# If disabling compilation then there must be at least one signature
assert len(self._dispatcher.overloads) > 0
self._frozen = True
def _install_cg(self, targetctx=None):
"""
Install an implementation function for a GUFunc/DUFunc object in the
given target context. If no target context is given, then
_install_cg() installs into the target context of the
dispatcher object (should be same default context used by
jit() and njit()).
"""
if targetctx is None:
targetctx = self._dispatcher.targetdescr.target_context
_any = types.Any
_arr = types.Array
# Either all outputs are explicit or none of them are
sig0 = (_any,) * self.ufunc.nin + (_arr,) * self.ufunc.nout
sig1 = (_any,) * self.ufunc.nin
targetctx.insert_func_defn(
[(self._lower_me, self, sig) for sig in (sig0, sig1)])
def find_ewise_function(self, ewise_types):
"""
Given a tuple of element-wise argument types, find a matching
signature in the dispatcher.
Return a 2-tuple containing the matching signature, and
compilation result. Will return two None's if no matching
signature was found.
"""
if self._frozen:
# If we cannot compile, coerce to the best matching loop
loop = numpy_support.ufunc_find_matching_loop(self, ewise_types)
if loop is None:
return None, None
ewise_types = tuple(loop.inputs + loop.outputs)[:len(ewise_types)]
for sig, cres in self._dispatcher.overloads.items():
if self.match_signature(ewise_types, sig):
return sig, cres
return None, None