118 lines
3.8 KiB
Python
118 lines
3.8 KiB
Python
"""This file exports ONNX ops for opset 16.
|
|
|
|
Note [ONNX Operators that are added/updated in opset 16]
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
|
|
New operators:
|
|
GridSample https://github.com/onnx/onnx/pull/3557
|
|
|
|
Updated operators:
|
|
Identity
|
|
If
|
|
LeakyRelu
|
|
Loop
|
|
PRelu
|
|
RoiAlign
|
|
Scan
|
|
ScatterElements
|
|
ScatterND
|
|
Where
|
|
GreaterOrEqual
|
|
LessOrEqual
|
|
"""
|
|
|
|
# EDITING THIS FILE? READ THIS FIRST!
|
|
# see Note [Edit Symbolic Files] in README.md
|
|
|
|
import functools
|
|
|
|
import torch
|
|
from torch.nn.functional import (
|
|
GRID_SAMPLE_INTERPOLATION_MODES,
|
|
GRID_SAMPLE_PADDING_MODES,
|
|
)
|
|
from torch.onnx import _type_utils, symbolic_helper
|
|
from torch.onnx._internal import _beartype, jit_utils, registration
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
|
|
|
|
|
|
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
|
|
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
|
|
@_onnx_symbolic("aten::grid_sampler")
|
|
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
|
@_beartype.beartype
|
|
def grid_sampler(
|
|
g: jit_utils.GraphContext,
|
|
input,
|
|
grid,
|
|
mode_enum,
|
|
padding_mode_enum,
|
|
align_corners,
|
|
):
|
|
# Check the input and grid tensor rank beforehand.
|
|
if symbolic_helper._get_tensor_rank(input) == 5:
|
|
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
|
|
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
|
|
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
|
|
return g.op(
|
|
"GridSample",
|
|
input,
|
|
grid,
|
|
align_corners_i=int(align_corners),
|
|
mode_s=mode_s,
|
|
padding_mode_s=padding_mode_s,
|
|
)
|
|
|
|
|
|
@_onnx_symbolic("aten::scatter_add")
|
|
@symbolic_helper.parse_args("v", "i", "v", "v")
|
|
@_beartype.beartype
|
|
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
|
|
if symbolic_helper.is_caffe2_aten_fallback():
|
|
return g.at("scatter", self, dim, index, src, overload_name="src")
|
|
|
|
src_type = _type_utils.JitScalarType.from_value(
|
|
src, _type_utils.JitScalarType.UNDEFINED
|
|
)
|
|
src_sizes = symbolic_helper._get_tensor_sizes(src)
|
|
index_sizes = symbolic_helper._get_tensor_sizes(index)
|
|
|
|
if len(src_sizes) != len(index_sizes):
|
|
return symbolic_helper._unimplemented(
|
|
"scatter_add",
|
|
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
|
|
)
|
|
|
|
# PyTorch only allows index shape <= src shape, so we can only consider
|
|
# taking index as subset size to src, like PyTorch does. When sizes for src
|
|
# and index are not matched or there are dynamic axes, we take index shape to
|
|
# slice src to accommodate.
|
|
if src_sizes != index_sizes or None in index_sizes:
|
|
adjusted_shape = g.op("Shape", index)
|
|
starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
|
|
src = g.op("Slice", src, starts, adjusted_shape)
|
|
|
|
src = symbolic_helper._maybe_get_scalar(src)
|
|
if symbolic_helper._is_value(src):
|
|
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
|
|
else:
|
|
# Check if scalar "src" has same type as self (PyTorch allows different
|
|
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
|
|
if _type_utils.JitScalarType.from_value(self) != src_type:
|
|
src = g.op(
|
|
"Cast",
|
|
src,
|
|
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
|
|
)
|
|
|
|
return g.op(
|
|
"ScatterElements",
|
|
self,
|
|
index,
|
|
src,
|
|
axis_i=dim,
|
|
reduction_s="add",
|
|
)
|