283 lines
6.4 KiB
Python
283 lines
6.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import torch
|
|
# pointwise operators can go through a faster pathway
|
|
|
|
tensor_magic_methods = [
|
|
'add',
|
|
''
|
|
]
|
|
pointwise_magic_methods_with_reverse = (
|
|
'add', 'sub', 'mul', 'floordiv', 'div', 'truediv', 'mod',
|
|
'pow', 'lshift', 'rshift', 'and', 'or', 'xor'
|
|
)
|
|
pointwise_magic_methods = (
|
|
*(x for m in pointwise_magic_methods_with_reverse for x in (m, 'r' + m)),
|
|
'eq', 'gt', 'le', 'lt', 'ge', 'gt', 'ne', 'neg', 'pos',
|
|
'abs', 'invert',
|
|
'iadd', 'isub', 'imul', 'ifloordiv', 'idiv',
|
|
'itruediv', 'imod', 'ipow', 'ilshift', 'irshift', 'iand',
|
|
'ior', 'ixor',
|
|
'int', 'long', 'float', 'complex',
|
|
)
|
|
|
|
pointwise_methods = (
|
|
*(f'__{m}__' for m in pointwise_magic_methods),
|
|
)
|
|
|
|
pointwise = (
|
|
*(getattr(torch.Tensor, m) for m in pointwise_methods),
|
|
torch.nn.functional.dropout,
|
|
torch.where,
|
|
torch.Tensor.abs,
|
|
torch.abs,
|
|
torch.Tensor.acos,
|
|
torch.acos,
|
|
torch.Tensor.acosh,
|
|
torch.acosh,
|
|
torch.Tensor.add,
|
|
torch.add,
|
|
torch.Tensor.addcdiv,
|
|
torch.addcdiv,
|
|
torch.Tensor.addcmul,
|
|
torch.addcmul,
|
|
torch.Tensor.addr,
|
|
torch.addr,
|
|
torch.Tensor.angle,
|
|
torch.angle,
|
|
torch.Tensor.asin,
|
|
torch.asin,
|
|
torch.Tensor.asinh,
|
|
torch.asinh,
|
|
torch.Tensor.atan,
|
|
torch.atan,
|
|
torch.Tensor.atan2,
|
|
torch.atan2,
|
|
torch.Tensor.atanh,
|
|
torch.atanh,
|
|
torch.Tensor.bitwise_and,
|
|
torch.bitwise_and,
|
|
torch.Tensor.bitwise_left_shift,
|
|
torch.bitwise_left_shift,
|
|
torch.Tensor.bitwise_not,
|
|
torch.bitwise_not,
|
|
torch.Tensor.bitwise_or,
|
|
torch.bitwise_or,
|
|
torch.Tensor.bitwise_right_shift,
|
|
torch.bitwise_right_shift,
|
|
torch.Tensor.bitwise_xor,
|
|
torch.bitwise_xor,
|
|
torch.Tensor.ceil,
|
|
torch.ceil,
|
|
torch.celu,
|
|
torch.nn.functional.celu,
|
|
torch.Tensor.clamp,
|
|
torch.clamp,
|
|
torch.Tensor.clamp_max,
|
|
torch.clamp_max,
|
|
torch.Tensor.clamp_min,
|
|
torch.clamp_min,
|
|
torch.Tensor.copysign,
|
|
torch.copysign,
|
|
torch.Tensor.cos,
|
|
torch.cos,
|
|
torch.Tensor.cosh,
|
|
torch.cosh,
|
|
torch.Tensor.deg2rad,
|
|
torch.deg2rad,
|
|
torch.Tensor.digamma,
|
|
torch.digamma,
|
|
torch.Tensor.div,
|
|
torch.div,
|
|
torch.dropout,
|
|
torch.nn.functional.dropout,
|
|
torch.nn.functional.elu,
|
|
torch.Tensor.eq,
|
|
torch.eq,
|
|
torch.Tensor.erf,
|
|
torch.erf,
|
|
torch.Tensor.erfc,
|
|
torch.erfc,
|
|
torch.Tensor.erfinv,
|
|
torch.erfinv,
|
|
torch.Tensor.exp,
|
|
torch.exp,
|
|
torch.Tensor.exp2,
|
|
torch.exp2,
|
|
torch.Tensor.expm1,
|
|
torch.expm1,
|
|
torch.feature_dropout,
|
|
torch.Tensor.float_power,
|
|
torch.float_power,
|
|
torch.Tensor.floor,
|
|
torch.floor,
|
|
torch.Tensor.floor_divide,
|
|
torch.floor_divide,
|
|
torch.Tensor.fmod,
|
|
torch.fmod,
|
|
torch.Tensor.frac,
|
|
torch.frac,
|
|
torch.Tensor.frexp,
|
|
torch.frexp,
|
|
torch.Tensor.gcd,
|
|
torch.gcd,
|
|
torch.Tensor.ge,
|
|
torch.ge,
|
|
torch.nn.functional.gelu,
|
|
torch.nn.functional.glu,
|
|
torch.Tensor.gt,
|
|
torch.gt,
|
|
torch.Tensor.hardshrink,
|
|
torch.hardshrink,
|
|
torch.nn.functional.hardshrink,
|
|
torch.nn.functional.hardsigmoid,
|
|
torch.nn.functional.hardswish,
|
|
torch.nn.functional.hardtanh,
|
|
torch.Tensor.heaviside,
|
|
torch.heaviside,
|
|
torch.Tensor.hypot,
|
|
torch.hypot,
|
|
torch.Tensor.i0,
|
|
torch.i0,
|
|
torch.Tensor.igamma,
|
|
torch.igamma,
|
|
torch.Tensor.igammac,
|
|
torch.igammac,
|
|
torch.Tensor.isclose,
|
|
torch.isclose,
|
|
torch.Tensor.isfinite,
|
|
torch.isfinite,
|
|
torch.Tensor.isinf,
|
|
torch.isinf,
|
|
torch.Tensor.isnan,
|
|
torch.isnan,
|
|
torch.Tensor.isneginf,
|
|
torch.isneginf,
|
|
torch.Tensor.isposinf,
|
|
torch.isposinf,
|
|
torch.Tensor.isreal,
|
|
torch.isreal,
|
|
torch.Tensor.kron,
|
|
torch.kron,
|
|
torch.Tensor.lcm,
|
|
torch.lcm,
|
|
torch.Tensor.ldexp,
|
|
torch.ldexp,
|
|
torch.Tensor.le,
|
|
torch.le,
|
|
torch.nn.functional.leaky_relu,
|
|
torch.Tensor.lerp,
|
|
torch.lerp,
|
|
torch.Tensor.lgamma,
|
|
torch.lgamma,
|
|
torch.Tensor.log,
|
|
torch.log,
|
|
torch.Tensor.log10,
|
|
torch.log10,
|
|
torch.Tensor.log1p,
|
|
torch.log1p,
|
|
torch.Tensor.log2,
|
|
torch.log2,
|
|
torch.nn.functional.logsigmoid,
|
|
torch.Tensor.logical_and,
|
|
torch.logical_and,
|
|
torch.Tensor.logical_not,
|
|
torch.logical_not,
|
|
torch.Tensor.logical_or,
|
|
torch.logical_or,
|
|
torch.Tensor.logical_xor,
|
|
torch.logical_xor,
|
|
torch.Tensor.logit,
|
|
torch.logit,
|
|
torch.Tensor.lt,
|
|
torch.lt,
|
|
torch.Tensor.maximum,
|
|
torch.maximum,
|
|
torch.Tensor.minimum,
|
|
torch.minimum,
|
|
torch.nn.functional.mish,
|
|
torch.Tensor.mvlgamma,
|
|
torch.mvlgamma,
|
|
torch.Tensor.nan_to_num,
|
|
torch.nan_to_num,
|
|
torch.Tensor.ne,
|
|
torch.ne,
|
|
torch.Tensor.neg,
|
|
torch.neg,
|
|
torch.Tensor.nextafter,
|
|
torch.nextafter,
|
|
torch.Tensor.outer,
|
|
torch.outer,
|
|
torch.polar,
|
|
torch.Tensor.polygamma,
|
|
torch.polygamma,
|
|
torch.Tensor.positive,
|
|
torch.positive,
|
|
torch.Tensor.pow,
|
|
torch.pow,
|
|
torch.Tensor.prelu,
|
|
torch.prelu,
|
|
torch.nn.functional.prelu,
|
|
torch.Tensor.rad2deg,
|
|
torch.rad2deg,
|
|
torch.Tensor.reciprocal,
|
|
torch.reciprocal,
|
|
torch.Tensor.relu,
|
|
torch.relu,
|
|
torch.nn.functional.relu,
|
|
torch.nn.functional.relu6,
|
|
torch.Tensor.remainder,
|
|
torch.remainder,
|
|
torch.Tensor.round,
|
|
torch.round,
|
|
torch.rrelu,
|
|
torch.nn.functional.rrelu,
|
|
torch.Tensor.rsqrt,
|
|
torch.rsqrt,
|
|
torch.rsub,
|
|
torch.selu,
|
|
torch.nn.functional.selu,
|
|
torch.Tensor.sgn,
|
|
torch.sgn,
|
|
torch.Tensor.sigmoid,
|
|
torch.sigmoid,
|
|
torch.nn.functional.sigmoid,
|
|
torch.Tensor.sign,
|
|
torch.sign,
|
|
torch.Tensor.signbit,
|
|
torch.signbit,
|
|
torch.nn.functional.silu,
|
|
torch.Tensor.sin,
|
|
torch.sin,
|
|
torch.Tensor.sinc,
|
|
torch.sinc,
|
|
torch.Tensor.sinh,
|
|
torch.sinh,
|
|
torch.nn.functional.softplus,
|
|
torch.nn.functional.softshrink,
|
|
torch.Tensor.sqrt,
|
|
torch.sqrt,
|
|
torch.Tensor.square,
|
|
torch.square,
|
|
torch.Tensor.sub,
|
|
torch.sub,
|
|
torch.Tensor.tan,
|
|
torch.tan,
|
|
torch.Tensor.tanh,
|
|
torch.tanh,
|
|
torch.nn.functional.tanh,
|
|
torch.threshold,
|
|
torch.nn.functional.threshold,
|
|
torch.trapz,
|
|
torch.Tensor.true_divide,
|
|
torch.true_divide,
|
|
torch.Tensor.trunc,
|
|
torch.trunc,
|
|
torch.Tensor.xlogy,
|
|
torch.xlogy,
|
|
torch.rand_like,
|
|
)
|