Files
ytts/venv/lib/python3.11/site-packages/triton/impl/base.py
2025-04-02 21:44:17 -07:00

37 lines
864 B
Python

from __future__ import annotations
from functools import wraps
from typing import TypeVar
T = TypeVar("T")
TRITON_BUILTIN = "__triton_builtin__"
def builtin(fn: T) -> T:
"""Mark a function as a builtin."""
assert callable(fn)
@wraps(fn)
def wrapper(*args, **kwargs):
if "_builder" not in kwargs or kwargs["_builder"] is None:
raise ValueError(
"Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)"
)
return fn(*args, **kwargs)
setattr(wrapper, TRITON_BUILTIN, True)
return wrapper
def is_builtin(fn) -> bool:
"""Is this a registered triton builtin function?"""
return getattr(fn, TRITON_BUILTIN, False)
def extern(fn: T) -> T:
"""A decorator for external functions."""
return builtin(fn)