37 lines
864 B
Python
37 lines
864 B
Python
from __future__ import annotations
|
|
|
|
from functools import wraps
|
|
from typing import TypeVar
|
|
|
|
T = TypeVar("T")
|
|
|
|
TRITON_BUILTIN = "__triton_builtin__"
|
|
|
|
|
|
def builtin(fn: T) -> T:
|
|
"""Mark a function as a builtin."""
|
|
assert callable(fn)
|
|
|
|
@wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
if "_builder" not in kwargs or kwargs["_builder"] is None:
|
|
raise ValueError(
|
|
"Did you forget to add @triton.jit ? "
|
|
"(`_builder` argument must be provided outside of JIT functions.)"
|
|
)
|
|
return fn(*args, **kwargs)
|
|
|
|
setattr(wrapper, TRITON_BUILTIN, True)
|
|
|
|
return wrapper
|
|
|
|
|
|
def is_builtin(fn) -> bool:
|
|
"""Is this a registered triton builtin function?"""
|
|
return getattr(fn, TRITON_BUILTIN, False)
|
|
|
|
|
|
def extern(fn: T) -> T:
|
|
"""A decorator for external functions."""
|
|
return builtin(fn)
|