807 lines
23 KiB
Python
807 lines
23 KiB
Python
import operator
|
|
from numba.core import types
|
|
from numba.core.typing.npydecl import (parse_dtype, parse_shape,
|
|
register_number_classes,
|
|
register_numpy_ufunc,
|
|
trigonometric_functions,
|
|
comparison_functions,
|
|
math_operations,
|
|
bit_twiddling_functions)
|
|
from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate,
|
|
AbstractTemplate, CallableTemplate,
|
|
signature, Registry)
|
|
from numba.cuda.types import dim3
|
|
from numba.core.typeconv import Conversion
|
|
from numba import cuda
|
|
from numba.cuda.compiler import declare_device_function_template
|
|
|
|
registry = Registry()
|
|
register = registry.register
|
|
register_attr = registry.register_attr
|
|
register_global = registry.register_global
|
|
|
|
register_number_classes(register_global)
|
|
|
|
|
|
class Cuda_array_decl(CallableTemplate):
|
|
def generic(self):
|
|
def typer(shape, dtype):
|
|
|
|
# Only integer literals and tuples of integer literals are valid
|
|
# shapes
|
|
if isinstance(shape, types.Integer):
|
|
if not isinstance(shape, types.IntegerLiteral):
|
|
return None
|
|
elif isinstance(shape, (types.Tuple, types.UniTuple)):
|
|
if any([not isinstance(s, types.IntegerLiteral)
|
|
for s in shape]):
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
ndim = parse_shape(shape)
|
|
nb_dtype = parse_dtype(dtype)
|
|
if nb_dtype is not None and ndim is not None:
|
|
return types.Array(dtype=nb_dtype, ndim=ndim, layout='C')
|
|
|
|
return typer
|
|
|
|
|
|
@register
|
|
class Cuda_shared_array(Cuda_array_decl):
|
|
key = cuda.shared.array
|
|
|
|
|
|
@register
|
|
class Cuda_local_array(Cuda_array_decl):
|
|
key = cuda.local.array
|
|
|
|
|
|
@register
|
|
class Cuda_const_array_like(CallableTemplate):
|
|
key = cuda.const.array_like
|
|
|
|
def generic(self):
|
|
def typer(ndarray):
|
|
return ndarray
|
|
return typer
|
|
|
|
|
|
@register
|
|
class Cuda_threadfence_device(ConcreteTemplate):
|
|
key = cuda.threadfence
|
|
cases = [signature(types.none)]
|
|
|
|
|
|
@register
|
|
class Cuda_threadfence_block(ConcreteTemplate):
|
|
key = cuda.threadfence_block
|
|
cases = [signature(types.none)]
|
|
|
|
|
|
@register
|
|
class Cuda_threadfence_system(ConcreteTemplate):
|
|
key = cuda.threadfence_system
|
|
cases = [signature(types.none)]
|
|
|
|
|
|
@register
|
|
class Cuda_syncwarp(ConcreteTemplate):
|
|
key = cuda.syncwarp
|
|
cases = [signature(types.none), signature(types.none, types.i4)]
|
|
|
|
|
|
@register
|
|
class Cuda_shfl_sync_intrinsic(ConcreteTemplate):
|
|
key = cuda.shfl_sync_intrinsic
|
|
cases = [
|
|
signature(types.Tuple((types.i4, types.b1)),
|
|
types.i4, types.i4, types.i4, types.i4, types.i4),
|
|
signature(types.Tuple((types.i8, types.b1)),
|
|
types.i4, types.i4, types.i8, types.i4, types.i4),
|
|
signature(types.Tuple((types.f4, types.b1)),
|
|
types.i4, types.i4, types.f4, types.i4, types.i4),
|
|
signature(types.Tuple((types.f8, types.b1)),
|
|
types.i4, types.i4, types.f8, types.i4, types.i4),
|
|
]
|
|
|
|
|
|
@register
|
|
class Cuda_vote_sync_intrinsic(ConcreteTemplate):
|
|
key = cuda.vote_sync_intrinsic
|
|
cases = [signature(types.Tuple((types.i4, types.b1)),
|
|
types.i4, types.i4, types.b1)]
|
|
|
|
|
|
@register
|
|
class Cuda_match_any_sync(ConcreteTemplate):
|
|
key = cuda.match_any_sync
|
|
cases = [
|
|
signature(types.i4, types.i4, types.i4),
|
|
signature(types.i4, types.i4, types.i8),
|
|
signature(types.i4, types.i4, types.f4),
|
|
signature(types.i4, types.i4, types.f8),
|
|
]
|
|
|
|
|
|
@register
|
|
class Cuda_match_all_sync(ConcreteTemplate):
|
|
key = cuda.match_all_sync
|
|
cases = [
|
|
signature(types.Tuple((types.i4, types.b1)), types.i4, types.i4),
|
|
signature(types.Tuple((types.i4, types.b1)), types.i4, types.i8),
|
|
signature(types.Tuple((types.i4, types.b1)), types.i4, types.f4),
|
|
signature(types.Tuple((types.i4, types.b1)), types.i4, types.f8),
|
|
]
|
|
|
|
|
|
@register
|
|
class Cuda_activemask(ConcreteTemplate):
|
|
key = cuda.activemask
|
|
cases = [signature(types.uint32)]
|
|
|
|
|
|
@register
|
|
class Cuda_lanemask_lt(ConcreteTemplate):
|
|
key = cuda.lanemask_lt
|
|
cases = [signature(types.uint32)]
|
|
|
|
|
|
@register
|
|
class Cuda_popc(ConcreteTemplate):
|
|
"""
|
|
Supported types from `llvm.popc`
|
|
[here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
|
|
"""
|
|
key = cuda.popc
|
|
cases = [
|
|
signature(types.int8, types.int8),
|
|
signature(types.int16, types.int16),
|
|
signature(types.int32, types.int32),
|
|
signature(types.int64, types.int64),
|
|
signature(types.uint8, types.uint8),
|
|
signature(types.uint16, types.uint16),
|
|
signature(types.uint32, types.uint32),
|
|
signature(types.uint64, types.uint64),
|
|
]
|
|
|
|
|
|
@register
|
|
class Cuda_fma(ConcreteTemplate):
|
|
"""
|
|
Supported types from `llvm.fma`
|
|
[here](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#standard-c-library-intrinics)
|
|
"""
|
|
key = cuda.fma
|
|
cases = [
|
|
signature(types.float32, types.float32, types.float32, types.float32),
|
|
signature(types.float64, types.float64, types.float64, types.float64),
|
|
]
|
|
|
|
|
|
@register
|
|
class Cuda_hfma(ConcreteTemplate):
|
|
key = cuda.fp16.hfma
|
|
cases = [
|
|
signature(types.float16, types.float16, types.float16, types.float16)
|
|
]
|
|
|
|
|
|
@register
|
|
class Cuda_cbrt(ConcreteTemplate):
|
|
|
|
key = cuda.cbrt
|
|
cases = [
|
|
signature(types.float32, types.float32),
|
|
signature(types.float64, types.float64),
|
|
]
|
|
|
|
|
|
@register
|
|
class Cuda_brev(ConcreteTemplate):
|
|
key = cuda.brev
|
|
cases = [
|
|
signature(types.uint32, types.uint32),
|
|
signature(types.uint64, types.uint64),
|
|
]
|
|
|
|
|
|
@register
|
|
class Cuda_clz(ConcreteTemplate):
|
|
"""
|
|
Supported types from `llvm.ctlz`
|
|
[here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
|
|
"""
|
|
key = cuda.clz
|
|
cases = [
|
|
signature(types.int8, types.int8),
|
|
signature(types.int16, types.int16),
|
|
signature(types.int32, types.int32),
|
|
signature(types.int64, types.int64),
|
|
signature(types.uint8, types.uint8),
|
|
signature(types.uint16, types.uint16),
|
|
signature(types.uint32, types.uint32),
|
|
signature(types.uint64, types.uint64),
|
|
]
|
|
|
|
|
|
@register
|
|
class Cuda_ffs(ConcreteTemplate):
|
|
"""
|
|
Supported types from `llvm.cttz`
|
|
[here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
|
|
"""
|
|
key = cuda.ffs
|
|
cases = [
|
|
signature(types.uint32, types.int8),
|
|
signature(types.uint32, types.int16),
|
|
signature(types.uint32, types.int32),
|
|
signature(types.uint32, types.int64),
|
|
signature(types.uint32, types.uint8),
|
|
signature(types.uint32, types.uint16),
|
|
signature(types.uint32, types.uint32),
|
|
signature(types.uint32, types.uint64),
|
|
]
|
|
|
|
|
|
@register
|
|
class Cuda_selp(AbstractTemplate):
|
|
key = cuda.selp
|
|
|
|
def generic(self, args, kws):
|
|
assert not kws
|
|
test, a, b = args
|
|
|
|
# per docs
|
|
# http://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp
|
|
supported_types = (types.float64, types.float32,
|
|
types.int16, types.uint16,
|
|
types.int32, types.uint32,
|
|
types.int64, types.uint64)
|
|
|
|
if a != b or a not in supported_types:
|
|
return
|
|
|
|
return signature(a, test, a, a)
|
|
|
|
|
|
def _genfp16_unary(l_key):
|
|
@register
|
|
class Cuda_fp16_unary(ConcreteTemplate):
|
|
key = l_key
|
|
cases = [signature(types.float16, types.float16)]
|
|
|
|
return Cuda_fp16_unary
|
|
|
|
|
|
def _genfp16_unary_operator(l_key):
|
|
@register_global(l_key)
|
|
class Cuda_fp16_unary(AbstractTemplate):
|
|
key = l_key
|
|
|
|
def generic(self, args, kws):
|
|
assert not kws
|
|
if len(args) == 1 and args[0] == types.float16:
|
|
return signature(types.float16, types.float16)
|
|
|
|
return Cuda_fp16_unary
|
|
|
|
|
|
def _genfp16_binary(l_key):
|
|
@register
|
|
class Cuda_fp16_binary(ConcreteTemplate):
|
|
key = l_key
|
|
cases = [signature(types.float16, types.float16, types.float16)]
|
|
|
|
return Cuda_fp16_binary
|
|
|
|
|
|
@register_global(float)
|
|
class Float(AbstractTemplate):
|
|
|
|
def generic(self, args, kws):
|
|
assert not kws
|
|
|
|
[arg] = args
|
|
|
|
if arg == types.float16:
|
|
return signature(arg, arg)
|
|
|
|
|
|
def _genfp16_binary_comparison(l_key):
|
|
@register
|
|
class Cuda_fp16_cmp(ConcreteTemplate):
|
|
key = l_key
|
|
|
|
cases = [
|
|
signature(types.b1, types.float16, types.float16)
|
|
]
|
|
return Cuda_fp16_cmp
|
|
|
|
# If multiple ConcreteTemplates provide typing for a single function, then
|
|
# function resolution will pick the first compatible typing it finds even if it
|
|
# involves inserting a cast that would be considered undesirable (in this
|
|
# specific case, float16s could be cast to float32s for comparisons).
|
|
#
|
|
# To work around this, we instead use an AbstractTemplate that implements
|
|
# exactly the casting logic that we desire. The AbstractTemplate gets
|
|
# considered in preference to ConcreteTemplates during typing.
|
|
#
|
|
# This is tracked as Issue #7863 (https://github.com/numba/numba/issues/7863) -
|
|
# once this is resolved it should be possible to replace this AbstractTemplate
|
|
# with a ConcreteTemplate to simplify the logic.
|
|
|
|
|
|
def _fp16_binary_operator(l_key, retty):
|
|
@register_global(l_key)
|
|
class Cuda_fp16_operator(AbstractTemplate):
|
|
key = l_key
|
|
|
|
def generic(self, args, kws):
|
|
assert not kws
|
|
|
|
if len(args) == 2 and \
|
|
(args[0] == types.float16 or args[1] == types.float16):
|
|
if (args[0] == types.float16):
|
|
convertible = self.context.can_convert(args[1], args[0])
|
|
else:
|
|
convertible = self.context.can_convert(args[0], args[1])
|
|
|
|
# We allow three cases here:
|
|
#
|
|
# 1. fp16 to fp16 - Conversion.exact
|
|
# 2. fp16 to other types fp16 can be promoted to
|
|
# - Conversion.promote
|
|
# 3. fp16 to int8 (safe conversion) -
|
|
# - Conversion.safe
|
|
|
|
if (convertible == Conversion.exact) or \
|
|
(convertible == Conversion.promote) or \
|
|
(convertible == Conversion.safe):
|
|
return signature(retty, types.float16, types.float16)
|
|
|
|
return Cuda_fp16_operator
|
|
|
|
|
|
def _genfp16_comparison_operator(op):
|
|
return _fp16_binary_operator(op, types.b1)
|
|
|
|
|
|
def _genfp16_binary_operator(op):
|
|
return _fp16_binary_operator(op, types.float16)
|
|
|
|
|
|
Cuda_hadd = _genfp16_binary(cuda.fp16.hadd)
|
|
Cuda_add = _genfp16_binary_operator(operator.add)
|
|
Cuda_iadd = _genfp16_binary_operator(operator.iadd)
|
|
Cuda_hsub = _genfp16_binary(cuda.fp16.hsub)
|
|
Cuda_sub = _genfp16_binary_operator(operator.sub)
|
|
Cuda_isub = _genfp16_binary_operator(operator.isub)
|
|
Cuda_hmul = _genfp16_binary(cuda.fp16.hmul)
|
|
Cuda_mul = _genfp16_binary_operator(operator.mul)
|
|
Cuda_imul = _genfp16_binary_operator(operator.imul)
|
|
Cuda_hmax = _genfp16_binary(cuda.fp16.hmax)
|
|
Cuda_hmin = _genfp16_binary(cuda.fp16.hmin)
|
|
Cuda_hneg = _genfp16_unary(cuda.fp16.hneg)
|
|
Cuda_neg = _genfp16_unary_operator(operator.neg)
|
|
Cuda_habs = _genfp16_unary(cuda.fp16.habs)
|
|
Cuda_abs = _genfp16_unary_operator(abs)
|
|
Cuda_heq = _genfp16_binary_comparison(cuda.fp16.heq)
|
|
_genfp16_comparison_operator(operator.eq)
|
|
Cuda_hne = _genfp16_binary_comparison(cuda.fp16.hne)
|
|
_genfp16_comparison_operator(operator.ne)
|
|
Cuda_hge = _genfp16_binary_comparison(cuda.fp16.hge)
|
|
_genfp16_comparison_operator(operator.ge)
|
|
Cuda_hgt = _genfp16_binary_comparison(cuda.fp16.hgt)
|
|
_genfp16_comparison_operator(operator.gt)
|
|
Cuda_hle = _genfp16_binary_comparison(cuda.fp16.hle)
|
|
_genfp16_comparison_operator(operator.le)
|
|
Cuda_hlt = _genfp16_binary_comparison(cuda.fp16.hlt)
|
|
_genfp16_comparison_operator(operator.lt)
|
|
_genfp16_binary_operator(operator.truediv)
|
|
_genfp16_binary_operator(operator.itruediv)
|
|
|
|
|
|
def _resolve_wrapped_unary(fname):
|
|
decl = declare_device_function_template(f'__numba_wrapper_{fname}',
|
|
types.float16,
|
|
(types.float16,))
|
|
return types.Function(decl)
|
|
|
|
|
|
def _resolve_wrapped_binary(fname):
|
|
decl = declare_device_function_template(f'__numba_wrapper_{fname}',
|
|
types.float16,
|
|
(types.float16, types.float16,))
|
|
return types.Function(decl)
|
|
|
|
|
|
hsin_device = _resolve_wrapped_unary('hsin')
|
|
hcos_device = _resolve_wrapped_unary('hcos')
|
|
hlog_device = _resolve_wrapped_unary('hlog')
|
|
hlog10_device = _resolve_wrapped_unary('hlog10')
|
|
hlog2_device = _resolve_wrapped_unary('hlog2')
|
|
hexp_device = _resolve_wrapped_unary('hexp')
|
|
hexp10_device = _resolve_wrapped_unary('hexp10')
|
|
hexp2_device = _resolve_wrapped_unary('hexp2')
|
|
hsqrt_device = _resolve_wrapped_unary('hsqrt')
|
|
hrsqrt_device = _resolve_wrapped_unary('hrsqrt')
|
|
hfloor_device = _resolve_wrapped_unary('hfloor')
|
|
hceil_device = _resolve_wrapped_unary('hceil')
|
|
hrcp_device = _resolve_wrapped_unary('hrcp')
|
|
hrint_device = _resolve_wrapped_unary('hrint')
|
|
htrunc_device = _resolve_wrapped_unary('htrunc')
|
|
hdiv_device = _resolve_wrapped_binary('hdiv')
|
|
|
|
|
|
# generate atomic operations
|
|
def _gen(l_key, supported_types):
|
|
@register
|
|
class Cuda_atomic(AbstractTemplate):
|
|
key = l_key
|
|
|
|
def generic(self, args, kws):
|
|
assert not kws
|
|
ary, idx, val = args
|
|
|
|
if ary.dtype not in supported_types:
|
|
return
|
|
|
|
if ary.ndim == 1:
|
|
return signature(ary.dtype, ary, types.intp, ary.dtype)
|
|
elif ary.ndim > 1:
|
|
return signature(ary.dtype, ary, idx, ary.dtype)
|
|
return Cuda_atomic
|
|
|
|
|
|
all_numba_types = (types.float64, types.float32,
|
|
types.int32, types.uint32,
|
|
types.int64, types.uint64)
|
|
|
|
integer_numba_types = (types.int32, types.uint32,
|
|
types.int64, types.uint64)
|
|
|
|
unsigned_int_numba_types = (types.uint32, types.uint64)
|
|
|
|
Cuda_atomic_add = _gen(cuda.atomic.add, all_numba_types)
|
|
Cuda_atomic_sub = _gen(cuda.atomic.sub, all_numba_types)
|
|
Cuda_atomic_max = _gen(cuda.atomic.max, all_numba_types)
|
|
Cuda_atomic_min = _gen(cuda.atomic.min, all_numba_types)
|
|
Cuda_atomic_nanmax = _gen(cuda.atomic.nanmax, all_numba_types)
|
|
Cuda_atomic_nanmin = _gen(cuda.atomic.nanmin, all_numba_types)
|
|
Cuda_atomic_and = _gen(cuda.atomic.and_, integer_numba_types)
|
|
Cuda_atomic_or = _gen(cuda.atomic.or_, integer_numba_types)
|
|
Cuda_atomic_xor = _gen(cuda.atomic.xor, integer_numba_types)
|
|
Cuda_atomic_inc = _gen(cuda.atomic.inc, unsigned_int_numba_types)
|
|
Cuda_atomic_dec = _gen(cuda.atomic.dec, unsigned_int_numba_types)
|
|
Cuda_atomic_exch = _gen(cuda.atomic.exch, integer_numba_types)
|
|
|
|
|
|
@register
|
|
class Cuda_atomic_compare_and_swap(AbstractTemplate):
|
|
key = cuda.atomic.compare_and_swap
|
|
|
|
def generic(self, args, kws):
|
|
assert not kws
|
|
ary, old, val = args
|
|
dty = ary.dtype
|
|
|
|
if dty in integer_numba_types and ary.ndim == 1:
|
|
return signature(dty, ary, dty, dty)
|
|
|
|
|
|
@register
|
|
class Cuda_atomic_cas(AbstractTemplate):
|
|
key = cuda.atomic.cas
|
|
|
|
def generic(self, args, kws):
|
|
assert not kws
|
|
ary, idx, old, val = args
|
|
dty = ary.dtype
|
|
|
|
if dty not in integer_numba_types:
|
|
return
|
|
|
|
if ary.ndim == 1:
|
|
return signature(dty, ary, types.intp, dty, dty)
|
|
elif ary.ndim > 1:
|
|
return signature(dty, ary, idx, dty, dty)
|
|
|
|
|
|
@register
|
|
class Cuda_nanosleep(ConcreteTemplate):
|
|
key = cuda.nanosleep
|
|
|
|
cases = [signature(types.void, types.uint32)]
|
|
|
|
|
|
@register_attr
|
|
class Dim3_attrs(AttributeTemplate):
|
|
key = dim3
|
|
|
|
def resolve_x(self, mod):
|
|
return types.int32
|
|
|
|
def resolve_y(self, mod):
|
|
return types.int32
|
|
|
|
def resolve_z(self, mod):
|
|
return types.int32
|
|
|
|
|
|
@register_attr
|
|
class CudaSharedModuleTemplate(AttributeTemplate):
|
|
key = types.Module(cuda.shared)
|
|
|
|
def resolve_array(self, mod):
|
|
return types.Function(Cuda_shared_array)
|
|
|
|
|
|
@register_attr
|
|
class CudaConstModuleTemplate(AttributeTemplate):
|
|
key = types.Module(cuda.const)
|
|
|
|
def resolve_array_like(self, mod):
|
|
return types.Function(Cuda_const_array_like)
|
|
|
|
|
|
@register_attr
|
|
class CudaLocalModuleTemplate(AttributeTemplate):
|
|
key = types.Module(cuda.local)
|
|
|
|
def resolve_array(self, mod):
|
|
return types.Function(Cuda_local_array)
|
|
|
|
|
|
@register_attr
|
|
class CudaAtomicTemplate(AttributeTemplate):
|
|
key = types.Module(cuda.atomic)
|
|
|
|
def resolve_add(self, mod):
|
|
return types.Function(Cuda_atomic_add)
|
|
|
|
def resolve_sub(self, mod):
|
|
return types.Function(Cuda_atomic_sub)
|
|
|
|
def resolve_and_(self, mod):
|
|
return types.Function(Cuda_atomic_and)
|
|
|
|
def resolve_or_(self, mod):
|
|
return types.Function(Cuda_atomic_or)
|
|
|
|
def resolve_xor(self, mod):
|
|
return types.Function(Cuda_atomic_xor)
|
|
|
|
def resolve_inc(self, mod):
|
|
return types.Function(Cuda_atomic_inc)
|
|
|
|
def resolve_dec(self, mod):
|
|
return types.Function(Cuda_atomic_dec)
|
|
|
|
def resolve_exch(self, mod):
|
|
return types.Function(Cuda_atomic_exch)
|
|
|
|
def resolve_max(self, mod):
|
|
return types.Function(Cuda_atomic_max)
|
|
|
|
def resolve_min(self, mod):
|
|
return types.Function(Cuda_atomic_min)
|
|
|
|
def resolve_nanmin(self, mod):
|
|
return types.Function(Cuda_atomic_nanmin)
|
|
|
|
def resolve_nanmax(self, mod):
|
|
return types.Function(Cuda_atomic_nanmax)
|
|
|
|
def resolve_compare_and_swap(self, mod):
|
|
return types.Function(Cuda_atomic_compare_and_swap)
|
|
|
|
def resolve_cas(self, mod):
|
|
return types.Function(Cuda_atomic_cas)
|
|
|
|
|
|
@register_attr
|
|
class CudaFp16Template(AttributeTemplate):
|
|
key = types.Module(cuda.fp16)
|
|
|
|
def resolve_hadd(self, mod):
|
|
return types.Function(Cuda_hadd)
|
|
|
|
def resolve_hsub(self, mod):
|
|
return types.Function(Cuda_hsub)
|
|
|
|
def resolve_hmul(self, mod):
|
|
return types.Function(Cuda_hmul)
|
|
|
|
def resolve_hdiv(self, mod):
|
|
return hdiv_device
|
|
|
|
def resolve_hneg(self, mod):
|
|
return types.Function(Cuda_hneg)
|
|
|
|
def resolve_habs(self, mod):
|
|
return types.Function(Cuda_habs)
|
|
|
|
def resolve_hfma(self, mod):
|
|
return types.Function(Cuda_hfma)
|
|
|
|
def resolve_hsin(self, mod):
|
|
return hsin_device
|
|
|
|
def resolve_hcos(self, mod):
|
|
return hcos_device
|
|
|
|
def resolve_hlog(self, mod):
|
|
return hlog_device
|
|
|
|
def resolve_hlog10(self, mod):
|
|
return hlog10_device
|
|
|
|
def resolve_hlog2(self, mod):
|
|
return hlog2_device
|
|
|
|
def resolve_hexp(self, mod):
|
|
return hexp_device
|
|
|
|
def resolve_hexp10(self, mod):
|
|
return hexp10_device
|
|
|
|
def resolve_hexp2(self, mod):
|
|
return hexp2_device
|
|
|
|
def resolve_hfloor(self, mod):
|
|
return hfloor_device
|
|
|
|
def resolve_hceil(self, mod):
|
|
return hceil_device
|
|
|
|
def resolve_hsqrt(self, mod):
|
|
return hsqrt_device
|
|
|
|
def resolve_hrsqrt(self, mod):
|
|
return hrsqrt_device
|
|
|
|
def resolve_hrcp(self, mod):
|
|
return hrcp_device
|
|
|
|
def resolve_hrint(self, mod):
|
|
return hrint_device
|
|
|
|
def resolve_htrunc(self, mod):
|
|
return htrunc_device
|
|
|
|
def resolve_heq(self, mod):
|
|
return types.Function(Cuda_heq)
|
|
|
|
def resolve_hne(self, mod):
|
|
return types.Function(Cuda_hne)
|
|
|
|
def resolve_hge(self, mod):
|
|
return types.Function(Cuda_hge)
|
|
|
|
def resolve_hgt(self, mod):
|
|
return types.Function(Cuda_hgt)
|
|
|
|
def resolve_hle(self, mod):
|
|
return types.Function(Cuda_hle)
|
|
|
|
def resolve_hlt(self, mod):
|
|
return types.Function(Cuda_hlt)
|
|
|
|
def resolve_hmax(self, mod):
|
|
return types.Function(Cuda_hmax)
|
|
|
|
def resolve_hmin(self, mod):
|
|
return types.Function(Cuda_hmin)
|
|
|
|
|
|
@register_attr
|
|
class CudaModuleTemplate(AttributeTemplate):
|
|
key = types.Module(cuda)
|
|
|
|
def resolve_cg(self, mod):
|
|
return types.Module(cuda.cg)
|
|
|
|
def resolve_threadIdx(self, mod):
|
|
return dim3
|
|
|
|
def resolve_blockIdx(self, mod):
|
|
return dim3
|
|
|
|
def resolve_blockDim(self, mod):
|
|
return dim3
|
|
|
|
def resolve_gridDim(self, mod):
|
|
return dim3
|
|
|
|
def resolve_laneid(self, mod):
|
|
return types.int32
|
|
|
|
def resolve_shared(self, mod):
|
|
return types.Module(cuda.shared)
|
|
|
|
def resolve_popc(self, mod):
|
|
return types.Function(Cuda_popc)
|
|
|
|
def resolve_brev(self, mod):
|
|
return types.Function(Cuda_brev)
|
|
|
|
def resolve_clz(self, mod):
|
|
return types.Function(Cuda_clz)
|
|
|
|
def resolve_ffs(self, mod):
|
|
return types.Function(Cuda_ffs)
|
|
|
|
def resolve_fma(self, mod):
|
|
return types.Function(Cuda_fma)
|
|
|
|
def resolve_cbrt(self, mod):
|
|
return types.Function(Cuda_cbrt)
|
|
|
|
def resolve_threadfence(self, mod):
|
|
return types.Function(Cuda_threadfence_device)
|
|
|
|
def resolve_threadfence_block(self, mod):
|
|
return types.Function(Cuda_threadfence_block)
|
|
|
|
def resolve_threadfence_system(self, mod):
|
|
return types.Function(Cuda_threadfence_system)
|
|
|
|
def resolve_syncwarp(self, mod):
|
|
return types.Function(Cuda_syncwarp)
|
|
|
|
def resolve_shfl_sync_intrinsic(self, mod):
|
|
return types.Function(Cuda_shfl_sync_intrinsic)
|
|
|
|
def resolve_vote_sync_intrinsic(self, mod):
|
|
return types.Function(Cuda_vote_sync_intrinsic)
|
|
|
|
def resolve_match_any_sync(self, mod):
|
|
return types.Function(Cuda_match_any_sync)
|
|
|
|
def resolve_match_all_sync(self, mod):
|
|
return types.Function(Cuda_match_all_sync)
|
|
|
|
def resolve_activemask(self, mod):
|
|
return types.Function(Cuda_activemask)
|
|
|
|
def resolve_lanemask_lt(self, mod):
|
|
return types.Function(Cuda_lanemask_lt)
|
|
|
|
def resolve_selp(self, mod):
|
|
return types.Function(Cuda_selp)
|
|
|
|
def resolve_nanosleep(self, mod):
|
|
return types.Function(Cuda_nanosleep)
|
|
|
|
def resolve_atomic(self, mod):
|
|
return types.Module(cuda.atomic)
|
|
|
|
def resolve_fp16(self, mod):
|
|
return types.Module(cuda.fp16)
|
|
|
|
def resolve_const(self, mod):
|
|
return types.Module(cuda.const)
|
|
|
|
def resolve_local(self, mod):
|
|
return types.Module(cuda.local)
|
|
|
|
|
|
register_global(cuda, types.Module(cuda))
|
|
|
|
|
|
# NumPy
|
|
|
|
for func in trigonometric_functions:
|
|
register_numpy_ufunc(func, register_global)
|
|
|
|
for func in comparison_functions:
|
|
register_numpy_ufunc(func, register_global)
|
|
|
|
for func in bit_twiddling_functions:
|
|
register_numpy_ufunc(func, register_global)
|
|
|
|
for func in math_operations:
|
|
if func in ('log', 'log2', 'log10'):
|
|
register_numpy_ufunc(func, register_global)
|