Files
2025-04-02 21:44:17 -07:00

199 lines
5.8 KiB
Python

from llvmlite import ir
from numba import cuda, types
from numba.core import cgutils
from numba.core.errors import RequireLiteralValue, NumbaValueError
from numba.core.typing import signature
from numba.core.extending import overload_attribute
from numba.cuda import nvvmutils
from numba.cuda.extending import intrinsic
#-------------------------------------------------------------------------------
# Grid functions
def _type_grid_function(ndim):
val = ndim.literal_value
if val == 1:
restype = types.int64
elif val in (2, 3):
restype = types.UniTuple(types.int64, val)
else:
raise NumbaValueError('argument can only be 1, 2, 3')
return signature(restype, types.int32)
@intrinsic
def grid(typingctx, ndim):
'''grid(ndim)
Return the absolute position of the current thread in the entire grid of
blocks. *ndim* should correspond to the number of dimensions declared when
instantiating the kernel. If *ndim* is 1, a single integer is returned.
If *ndim* is 2 or 3, a tuple of the given number of integers is returned.
Computation of the first integer is as follows::
cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
and is similar for the other two indices, but using the ``y`` and ``z``
attributes.
'''
if not isinstance(ndim, types.IntegerLiteral):
raise RequireLiteralValue(ndim)
sig = _type_grid_function(ndim)
def codegen(context, builder, sig, args):
restype = sig.return_type
if restype == types.int64:
return nvvmutils.get_global_id(builder, dim=1)
elif isinstance(restype, types.UniTuple):
ids = nvvmutils.get_global_id(builder, dim=restype.count)
return cgutils.pack_array(builder, ids)
return sig, codegen
@intrinsic
def gridsize(typingctx, ndim):
'''gridsize(ndim)
Return the absolute size (or shape) in threads of the entire grid of
blocks. *ndim* should correspond to the number of dimensions declared when
instantiating the kernel. If *ndim* is 1, a single integer is returned.
If *ndim* is 2 or 3, a tuple of the given number of integers is returned.
Computation of the first integer is as follows::
cuda.blockDim.x * cuda.gridDim.x
and is similar for the other two indices, but using the ``y`` and ``z``
attributes.
'''
if not isinstance(ndim, types.IntegerLiteral):
raise RequireLiteralValue(ndim)
sig = _type_grid_function(ndim)
def _nthreads_for_dim(builder, dim):
i64 = ir.IntType(64)
ntid = nvvmutils.call_sreg(builder, f"ntid.{dim}")
nctaid = nvvmutils.call_sreg(builder, f"nctaid.{dim}")
return builder.mul(builder.sext(ntid, i64), builder.sext(nctaid, i64))
def codegen(context, builder, sig, args):
restype = sig.return_type
nx = _nthreads_for_dim(builder, 'x')
if restype == types.int64:
return nx
elif isinstance(restype, types.UniTuple):
ny = _nthreads_for_dim(builder, 'y')
if restype.count == 2:
return cgutils.pack_array(builder, (nx, ny))
elif restype.count == 3:
nz = _nthreads_for_dim(builder, 'z')
return cgutils.pack_array(builder, (nx, ny, nz))
return sig, codegen
@intrinsic
def _warpsize(typingctx):
sig = signature(types.int32)
def codegen(context, builder, sig, args):
return nvvmutils.call_sreg(builder, 'warpsize')
return sig, codegen
@overload_attribute(types.Module(cuda), 'warpsize', target='cuda')
def cuda_warpsize(mod):
'''
The size of a warp. All architectures implemented to date have a warp size
of 32.
'''
def get(mod):
return _warpsize()
return get
#-------------------------------------------------------------------------------
# syncthreads
@intrinsic
def syncthreads(typingctx):
'''
Synchronize all threads in the same thread block. This function implements
the same pattern as barriers in traditional multi-threaded programming: this
function waits until all threads in the block call it, at which point it
returns control to all its callers.
'''
sig = signature(types.none)
def codegen(context, builder, sig, args):
fname = 'llvm.nvvm.barrier0'
lmod = builder.module
fnty = ir.FunctionType(ir.VoidType(), ())
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
builder.call(sync, ())
return context.get_dummy_value()
return sig, codegen
def _syncthreads_predicate(typingctx, predicate, fname):
if not isinstance(predicate, types.Integer):
return None
sig = signature(types.i4, types.i4)
def codegen(context, builder, sig, args):
fnty = ir.FunctionType(ir.IntType(32), (ir.IntType(32),))
sync = cgutils.get_or_insert_function(builder.module, fnty, fname)
return builder.call(sync, args)
return sig, codegen
@intrinsic
def syncthreads_count(typingctx, predicate):
'''
syncthreads_count(predicate)
An extension to numba.cuda.syncthreads where the return value is a count
of the threads where predicate is true.
'''
fname = 'llvm.nvvm.barrier0.popc'
return _syncthreads_predicate(typingctx, predicate, fname)
@intrinsic
def syncthreads_and(typingctx, predicate):
'''
syncthreads_and(predicate)
An extension to numba.cuda.syncthreads where 1 is returned if predicate is
true for all threads or 0 otherwise.
'''
fname = 'llvm.nvvm.barrier0.and'
return _syncthreads_predicate(typingctx, predicate, fname)
@intrinsic
def syncthreads_or(typingctx, predicate):
'''
syncthreads_or(predicate)
An extension to numba.cuda.syncthreads where 1 is returned if predicate is
true for any thread or 0 otherwise.
'''
fname = 'llvm.nvvm.barrier0.or'
return _syncthreads_predicate(typingctx, predicate, fname)