240 lines
7.7 KiB
Python
240 lines
7.7 KiB
Python
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
def num_warps(n):
|
|
if n <= 128:
|
|
return 1
|
|
if n <= 256:
|
|
return 2
|
|
if n <= 512:
|
|
return 4
|
|
if n <= 4096:
|
|
return 8
|
|
return 16
|
|
|
|
|
|
@triton.jit
|
|
def _blocksparse_softmax_fwd(
|
|
Out, A, stride_xz, LUT,
|
|
R, extent, stride_zr, stride_hr, # relative attention
|
|
scale, is_causal,
|
|
ROW_SIZE: tl.constexpr,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
IS_DENSE: tl.constexpr,
|
|
):
|
|
h = tl.program_id(0)
|
|
m = tl.program_id(1)
|
|
z = tl.program_id(2)
|
|
# create index ranges
|
|
hm = h * tl.num_programs(1) + m
|
|
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
|
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
|
# extract information from LUT
|
|
header = LUT + (hm // BLOCK_SIZE) * 2
|
|
size = tl.load(header + 0)
|
|
offset = tl.load(header + 1)
|
|
# pointer offset
|
|
off_a = z * stride_xz
|
|
off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
|
|
off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
|
|
# do not need to read column indices in the dense case
|
|
if IS_DENSE:
|
|
ns = tl.arange(0, ROW_SIZE)
|
|
else:
|
|
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
|
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
|
|
ns = start_n * BLOCK_SIZE + lane_n
|
|
# load X
|
|
mask = block_n < size
|
|
a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
|
|
a = a.to(tl.float32)
|
|
# compute
|
|
out = a
|
|
out *= scale
|
|
# apply relative attention
|
|
if R is not None:
|
|
R += z * stride_zr
|
|
R += h * stride_hr
|
|
off_lo = (extent - m - 1) + ns
|
|
mask_lo = (off_lo >= 0) & (off_lo < extent)
|
|
rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
|
|
out += rel_logits
|
|
out = out.to(tl.float32)
|
|
# apply causal mask
|
|
out = tl.where((ns > m) & is_causal, -float("inf"), out)
|
|
# computation
|
|
out = tl.softmax(out)
|
|
# write-back
|
|
tl.store(Out + off_a + lane_n, out, mask=mask)
|
|
|
|
|
|
@triton.jit
|
|
def _blocksparse_softmax_bwd(
|
|
DA, stride_zdx,
|
|
DOut, stride_zdout,
|
|
Out, stride_zout,
|
|
scale,
|
|
LUT,
|
|
DR, extent, stride_zr, stride_hr, stride_er,
|
|
is_causal,
|
|
ROW_SIZE: tl.constexpr,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
IS_DENSE: tl.constexpr,
|
|
):
|
|
h = tl.program_id(0)
|
|
m = tl.program_id(1)
|
|
z = tl.program_id(2)
|
|
# create index ranges
|
|
hm = h * tl.num_programs(1) + m
|
|
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
|
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
|
# extract information from LUT
|
|
header = LUT + (hm // BLOCK_SIZE) * 2
|
|
size = tl.load(header + 0)
|
|
offset = tl.load(header + 1)
|
|
# row-col offset
|
|
off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
|
|
off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
|
|
mask = block_n < size
|
|
# pointers
|
|
As = Out + z * stride_zout + off_mn
|
|
DOuts = DOut + z * stride_zdout + off_mn
|
|
# do not need to read column indices in the dense case
|
|
if IS_DENSE:
|
|
ns = tl.arange(0, ROW_SIZE)
|
|
else:
|
|
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
|
start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
|
|
ns = start_n * BLOCK_SIZE + lane_n
|
|
# load data
|
|
a = tl.load(As + lane_n, mask=mask, other=0.0)
|
|
a = a.to(tl.float32)
|
|
dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
|
|
dout = dout.to(tl.float32)
|
|
# compute
|
|
a = tl.where((ns > m) & is_causal & (a == a), 0., a)
|
|
da = a * (dout - tl.sum(a * dout, 0))
|
|
# apply relative attention
|
|
if DR is not None:
|
|
DR += z * stride_zr
|
|
DR += h * stride_hr
|
|
off_lo = (extent - m - 1) + ns
|
|
mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
|
|
tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
|
|
da = da * scale
|
|
# convert da
|
|
# write-back
|
|
DAs = DA + z * stride_zdx + off_mn
|
|
tl.store(DAs + lane_n, da, mask=mask)
|
|
|
|
|
|
class _softmax(torch.autograd.Function):
|
|
@staticmethod
|
|
def make_lut(layout, block, device):
|
|
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
|
sizes = _empty.clone()
|
|
# sizes along rows
|
|
for h in range(layout.shape[0]):
|
|
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
|
|
total_sizes = sizes * block
|
|
# offsets in block format
|
|
offsets = torch.zeros_like(sizes)
|
|
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
|
# block indices
|
|
columns = layout.nonzero(as_tuple=False)[:, 2]
|
|
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
|
lut = torch.cat((header, columns)).type(torch.int32).to(device)
|
|
return lut, int(total_sizes.max())
|
|
|
|
@staticmethod
|
|
def forward(
|
|
ctx, a, scale, rel_logits, is_causal,
|
|
spdims, block, lut, maxlut, is_dense
|
|
):
|
|
if scale is not None and isinstance(scale, torch.Tensor):
|
|
assert scale.device.type == "cpu"
|
|
scale = scale.item()
|
|
M = a.shape[0]
|
|
grid = [spdims[0], spdims[1] * block, M]
|
|
rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
|
|
rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
|
|
# enqueue kernel
|
|
out = torch.empty_like(a)
|
|
_blocksparse_softmax_fwd[grid](
|
|
out, a, a.stride(0), lut,
|
|
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
|
scale,
|
|
is_causal,
|
|
BLOCK_SIZE=block,
|
|
ROW_SIZE=triton.next_power_of_2(maxlut),
|
|
IS_DENSE=is_dense,
|
|
num_warps=num_warps(maxlut)
|
|
)
|
|
# save to context
|
|
# ctx.mark_dirty(x)
|
|
ctx.save_for_backward(out, lut)
|
|
ctx.spdims = spdims
|
|
ctx.block = block
|
|
ctx.maxlut = maxlut
|
|
ctx.scale = scale
|
|
ctx.rel_shape = rel_shape
|
|
ctx.rel_strides = rel_strides
|
|
ctx.rel_dtype = a.dtype
|
|
ctx.is_dense = is_dense
|
|
ctx.is_causal = is_causal
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, dout):
|
|
# retrieve from context
|
|
out, lut = ctx.saved_tensors
|
|
# relative logits gradients
|
|
dr = None
|
|
if ctx.needs_input_grad[3]:
|
|
dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
|
|
# run kernel
|
|
M = out.shape[0]
|
|
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
|
|
da = torch.empty_like(dout)
|
|
_blocksparse_softmax_bwd[grid](
|
|
da, da.stride(0),
|
|
dout, dout.stride(0),
|
|
out, out.stride(0),
|
|
ctx.scale,
|
|
lut,
|
|
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
|
ctx.is_causal,
|
|
BLOCK_SIZE=ctx.block,
|
|
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
|
|
IS_DENSE=ctx.is_dense,
|
|
num_warps=num_warps(ctx.maxlut)
|
|
)
|
|
return (da, None, None, dr, None,
|
|
None, None, None, None, None,
|
|
None,
|
|
None, None, None,
|
|
None,
|
|
None, None, None
|
|
)
|
|
|
|
|
|
class softmax:
|
|
def __init__(self, layout, block, device, is_dense=False):
|
|
self.spdims = layout.shape
|
|
self.layout = layout
|
|
self.block = block
|
|
self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
|
|
self.is_dense = is_dense
|
|
|
|
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
|
|
if rel_logits is not None and rel_logits.dtype != a.dtype:
|
|
raise ValueError(f"relative position embedding must be {a.dtype}")
|
|
a = _softmax.apply(
|
|
a, scale, rel_logits, is_causal,
|
|
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
|
|
)
|
|
return a
|